Pytorch中模型的儲存與遷移

由於公眾號改版不再按照作者的釋出時間進行推送,為防止各位朋友錯過月來客棧推送的最新文章,大家可以手動將公眾號設定為“

星標

⭐”以第一時間獲得推送內容,感謝各位~

1 引言

各位朋友大家好,歡迎來到月來客棧。今天要和大家介紹的內容是如何在Pytorch框架中對模型進行儲存和載入、以及模型的遷移和再訓練。一般來說,最常見的場景就是模型完成訓練後的推斷過程。一個網路模型在完成訓練後通常都需要對新樣本進行預測,此時就只需要構建模型的前向傳播過程,然後載入已訓練好的引數初始化網路即可。

第2個場景就是模型的再訓練過程。一個模型在一批資料上訓練完成之後需要將其儲存到本地,並且可能過了一段時間後又收集到了一批新的資料,因此這個時候就需要將之前的模型載入進行在新資料上進行增量訓練(或者是在整個資料上進行全量訓練)。

第3個應用場景就是模型的遷移學習。這個時候就是將別人已經訓練好的預模型拿過來,作為你自己網路模型引數的一部分進行初始化。例如:你自己在Bert模型的基礎上加了幾個全連線層來做分類任務,那麼你就需要將原始BERT模型中的引數載入並以此來初始化你的網路中的BERT部分的權重引數。

在接下來的這篇文章中,筆者就以上述3個場景為例來介紹如何利用Pytorch框架來完成上述過程。

2 模型的儲存與複用

在Pytorch中,我們可以透過和來完成上述場景中的主要步驟。下面,筆者將以之前介紹的LeNet5網路模型為例來分別進行介紹。不過在這之前,我們先來看看Pytorch中模型引數的儲存形式。

2。1 檢視網路模型引數

(1)檢視引數

首先定義好LeNet5的網路模型結構,如下程式碼所示:

在定義好LeNet5這個網路結構的類之後,只要我們完成了這個類的例項化操作,那麼網路中對應的權重引數也都完成了初始化的工作,即有了一個初始值。同時,我們可以透過如下方式來訪問:

其輸出的結果為:

可以發現,網路模型中的引數其實是以字典的形式(實質上是模組中的)儲存下來的:

(2)自定義引數字首

同時,這裡值得注意的地方有兩點:①引數名中的和字首是根據你在上面定義時的名字所確定的;②引數名中的數字表示每個中網路層所在的位置。例如將網路結構定義成如下形式:

那麼其引數名則為:

理解了這一點對於後續我們去解析和載入一些預訓練模型很有幫助。

除此之外,對於中的最佳化器等,其同樣有對應的方法來獲取對於的引數,例如:

在介紹完模型引數的檢視方法後,就可以進入到模型複用階段的內容介紹了。

2。2 載入模型進行推斷

(1) 模型儲存

在Pytorch中,對於模型的儲存來說是非常簡單的,通常來說透過如下兩行程式碼便可以實現:

在指定儲存的模型名稱時Pytorch官方建議的字尾為或者(當然也不是強制的)。最後,只需要在合適的地方加入第2行程式碼即可完成模型的儲存。

同時,如果想要在訓練過程中儲存某個條件下的最優模型,那麼應該透過如下方式:

而不是:

因為後者得到只是的引用,它依舊會隨著訓練過程而發生改變。

(2)複用模型進行推斷

在推斷過程中,首先需要完成網路的初始化,然後再載入已有的模型引數來覆蓋網路中的權重引數即可,示例程式碼如下:

在上述程式碼中,4-7行便是用來載入本地模型引數,並用其覆蓋網路模型中原有的引數。這樣,便可以進行後續的推斷工作:

2。3 載入模型進行訓練

在介紹完模型的儲存與複用之後,對於網路的追加訓練就很簡單了。最簡便的一種方式就是在訓練過程中只儲存網路權重,然後在後續進行追加訓練時只載入網路權重引數初始化網路進行訓練即可,示例如下(完整程式碼參見[2]):

這樣,便完成了模型的追加訓練:

除此之外,你也可以在儲存引數的時候,將最佳化器引數、損失值等一同儲存下來,然後在恢復模型的時候連同其它引數一起恢復,示例如下:

載入方式如下:

2。4 載入模型進行遷移

(1)定義新模型

到目前為止,對於前面兩種應用場景的介紹就算完成了,可以發現總體上並不複雜。但是對於第3中場景的應用來說就會略微複雜一點。

假設現在有一個LeNet6網路模型,它是在LeNet5的基礎最後多加了一個全連線層,其定義如下:

接下來,我們需要將在LeNet5上訓練得到的權重引數遷移到LeNet6網路中去。從上面LeNet6的定義可以發現,此時儘管只是多加了一個全連線層,但是倒數第2層引數的維度也發生了變換。因此,對於LeNet6來說只能複用LeNet5網路前面4層的權重引數。

(2)檢視模型引數

在拿到一個模型引數後,首先我們可以將其載入,然檢視相關引數的資訊:

同時,對於LeNet6網路的引數資訊為:

在理清楚了新舊模型的引數後,下面就可以將LeNet5中我們需要的引數給取出來,然後再換到LeNet6的網路中。

(3)模型遷移

雖然本地載入的模型引數(上面的)和模型初始化後的引數(上面的)都是一個字典的形式,但是我們並不能夠直接改變中的權重引數。這裡需要先構造一個然後透過方法來重新初始化網路中的引數。

同時,在這個過程中我們需要篩選掉本地模型中不可複用的部分,具體程式碼如下:

在上述程式碼中,第2行的作用是先複製網路中(LeNet6)原有的引數;第6-9行則是用本地的模型引數(LeNet5)中可以複用的替換掉LeNet6中的對應部分,其中第7行就是判斷可用的條件。同時需要注意的是在不同的情況下篩選的方式可能不一樣,因此具體情況需要具體分析,但是整體邏輯是一樣的。

最後,我們只需要在模型訓練之前呼叫該函式,然後重新初始化LeNet6中的部分權重引數即可[2]:

訓練結果如下:

可以發現,在大約100個batch之後,模型的準確率就提升上來了。

3 總結

在本篇文章中,筆者首先介紹了模型複用的幾種典型場景;然後介紹瞭如何檢視Pytorch模型中的相關引數資訊;接著介紹瞭如何載入模型、如何進行追加訓練以及進行模型的遷移學習等。有了這部分內容的鋪墊,在後續介紹類似BERT這樣的預訓練模型載入時就會容易很多了。

引用

[1] SAVING AND LOADING MODELS https://pytorch。org/tutorials/beginner/saving_loading_models。html

[2] 示例程式碼:https://github。com/moon-hotel/DeepLearningWithMe

TAG: 模型引數載入訓練網路