從零實現BERT網路模型

由於公眾號改版不再按照作者的釋出時間進行推送,為防止各位朋友錯過月來客棧推送的最新文章,大家可以手動將公眾號設定為“星標⭐”以第一時間獲得推送內容,感謝各位~

1 引言

各位朋友大家好,歡迎來到月來客棧,我是掌櫃空字元。

上一篇文章[1]

中掌櫃說到,對於BERT技術實現這部分內容將會分為三個大的部分來進行介紹。第一部分主要介紹BERT的網路結構原理以及MLM和NSP這兩種任務的具體原理;第二部分將主要介紹如何實現BERT以及BERT預訓練模型在下游任務中的使用;第三部分則是介紹如何利用MLM和NSP這兩個任務來訓練BERT模型(可以是從頭開始,也可以是基於開源的BERT預訓練模型開始)。第一部分內容在上一篇文章中已經介紹完了,在本篇文章中掌櫃將開始詳細來介紹第二部分的內容。

以下所有完整實現程式碼均可從倉庫 https://github。com/moon-hotel/BertWithPretrained 中獲取!

2 BERT實現

2。1 BERT網路結構回顧

經過

上一篇文章[1]

的介紹相信大家對於BERT模型的整體結構已經有了一定的瞭解。如圖1所示,本質上來說BERT就是由多個不同的Transformer結構堆疊而來,同時在Embedding部分多加入了一個Segment Embedding。

從零實現BERT網路模型

圖 1。 BERT網路結構圖

進一步,如果將圖1所示的網路結構展開,將會得到如圖2所示的樣子。在接下來的程式碼實現過程中,掌櫃將會以圖2中黑色加粗字型所示的部分為一個類進行實現。

從零實現BERT網路模型

圖 2。 BERT網路模型細節圖

2。2 Input Embedding實現

首先,我們先來看看Input Embedding的實現過程。為了複用之前在介紹Transformer實現時所用到的這部分程式碼,我們直接在這基礎上再加一個Segment Embedding即可。

2。2。1 Token Embedding

Token Embedding算是NLP中將文字表示為向量的一個基本操作,其原理就不再贅述,具體實現如下:

在上述程式碼中,第4行中的是用來指定序列中用於padding處理的索引編號,一般來說預設都是0。在指定後,如果輸入序列中有0,那麼對應位置的向量就會全是0。當然,這一步掌櫃認為不做也可以,因為在計算自主力權重的時候會透過向量來去掉這部分內容,具體可

參見[2]

。第5行程式碼便是用給定的方式來初始化引數,當然這幾乎不會用到。因為不管是在下游任務中微調,還是繼續透過NSL和MLM這兩個任務來訓練模型引數,我們大多數情況下都會在開源的BERT模型引數上進行,而不是從頭再來。

2。2。2 Positional Embedding

對於Positional Embedding來說,其作用便是用來解決自注意力機制不能捕捉到文字序列內部各個位置之間順序的問題。關於這部分內容原理的介紹,可以參加

文章[3]

。不同於Transformer中Positional Embedding的實現方式,在BERT中Positional Embedding並沒有採用固定的變換公式來計算每個位置上的值,而是採用了類似普通Embedding的方式來為每個位置生成一個向量,然後隨著模型一起訓練。

因此,這一操作就限制了在使用預訓練的中文BERT模型時,最大的序列長度只能是512,因為在訓練時只初始化了512個位置向量。

具體地,其實現程式碼如下:

從上述程式碼可以看出,其本質上就是一個普通的Embedding層,只是在這一場景下作者賦予了它另外的含義,即序列中的每一個位置有自己獨屬的向量表示。同時, 在預設配置中,第16行中的值為512。

2。2。3 Segment Embedding

Segment Embedding的原理及目的掌櫃在

上一篇文章中[1]

已經詳細介紹過,總結起來就是為了滿足下游任務中存在需要兩句話同時輸入到模型中的場景,即可以看成是對輸入的兩個序列分別賦予一個位置向量用以區分各自所在的位置。這一點可以和上面的Positional Embedding進行類比。具體地,其實現程式碼如下:

在上述程式碼中,的預設值為2,即只用於區分兩個序列的不同位置。

2。2。4 Bert Embeddings

在完成Token、Positional、Segment Embedding這3個部分的程式碼之後,只需要將每個部分的結果相加即可得到最終的Input Embedding作為模型的輸入,如圖3所示。

圖 3。 BERT輸入圖

具體地,其程式碼實現為:

在上述程式碼中,是傳入的一個配置類,裡面各個類成員就是BERT中對應的模型引數。第12、19、25行程式碼便是用來分別定義圖3中的3部分Embedding。第33行程式碼是用來生成一個預設的位置id,即,在後續可以透過來進行呼叫。

進一步,其前向傳播過程程式碼為:

在上述程式碼中,表示輸入序列的原始token id,即根據詞表對映後的索引,其形狀為;是位置序列,本質就是,其形狀為;用於不同序列之間的分割,例如用於區分前後不同的兩個句子,形狀為。

同時,第9-10程式碼表示當模型輸入的為空時,需要根據輸入序列的長度來生成一個位置序列(其實這部分輸入僅作為內部實現即可,因為它只是的一串數字。同理,第14行程式碼表示當模型輸入僅包含一個序列(如文字分類)且為空時,那麼可以透過15-16行程式碼來生成一個全0向量。第20-23行程式碼則是用來將三部分Embeeding的結果相加。

2。3 BertAttention實現

在實現完Input Embedding部分的程式碼後,下面就可以著手來實現BertEncoder了。如圖4所示,整個BertEncoder由多個BertLayer堆疊形成;而BertLayer又是由BertOutput、BertIntermediate和BertAttention這3個部分組成;同時BertAttention是由BertSelfAttention和BertSelfOutput所構成圖 4。 BertEncoder實現結構圖

接下來,我們就以圖4中從下到上的順序來依次對每個部分進行實現。

2。4 BertAttention實現

對於BertAttention來說,需要明白的是其核心就是在Transformer中所提出來的self-attention機制,也就是圖4中的BertSelfAttention模組;其次再是一個殘差連線和標準化操作。對於BertSelfAttention的實現,其程式碼如下

如上所示所示便是BertSelfAttention的實現程式碼,其對應的就是GoogleResearch[4]程式碼中的方法。正如前面所說,本質上就是Transformer模型中的self-attention模組,具體原理可參見

文章[3]

,這裡就不再贅述。

對於的實現,其主要就是層Dropout、標準化和殘差連線三個操作,程式碼如下:

接下來就是對BertAttention部分的實現,其由和這兩個類構成,程式碼如下:

在上述程式碼中,第8行的就是Input Embedding處理後的結果;第9行的就是同一個batch中不同長度序列的padding資訊,具體可以參加

文章[2]

;第15行就是自注意力機制的輸出結果;第21行便是執行中的3個操作。

2。5 BertLayer實現

根據圖4可知,BertLayer裡面還有和這兩個模組,因此下面先來實現這兩個部分。

對於來說也就是一個普通的全連線層,因此實現起來也非常簡單,程式碼如下:

在上述程式碼中,第6行用來根據指定引數獲取啟用函式。

進一步,對於來說,其包含有其包含有一個全連線層和殘差連線,實現程式碼如下:

在上述程式碼中,第8行裡指的就是模組的輸出,而則是部分的輸出。

在實現完這兩個部分的程式碼後,便可以透過、和這3個部分來實現組合的部分,程式碼如下:

從上述程式碼中可以發現,對於的實現來說其整體邏輯也並不太複雜,就是根據、和這三部分構造而來;同時每個部分輸出後的維度掌櫃也都進行了標註以便大家進行理解。

到此,對於部分的實現就介紹完了,下面繼續來看如何實現BERT。

2。6 BERT模型實現

根據圖2所示可知,BERT主要由和這兩部分構成;而是有多個堆疊所形成,因此需要先實現,程式碼如下:

在上述程式碼中,第5行便是用來定義多個;第18-22行用來迴圈計算多層堆疊後的輸出結果。最後,只需要按需將部分的輸出結果輸入到下游任務即可。

進一步,在將部分的輸出結果輸入到下游任務前,需要將其進行略微的處理,程式碼如下:

在上述程式碼中,第13-14行程式碼用來取輸出的第一個位置(位置),例如在進行文字分類時可以取該位置上的結果進行下一步的分類處理;第15-16行是掌櫃自己加入的一個選項,表示取所有位置的平均值,當然我們也可以根據自己的需要在新增下面新增其它的方式;最後,17-19行就是一個普通的全連線層。

緊接著,基於上述所有實現便可以搭建完成整個BERT的主體結構,程式碼如下:

如上程式碼所示便是整個BERT部分的實現,可以發現在釐清了整個思路後這部分程式碼理解起來就相對容易了。第22-24行便是Embedding後的輸出結果;第25-26行是整個BERT編碼部分的輸出;第27-28行便是處理得到整個BERT網路的輸出。到此,對於整個BERT主體部分的程式碼實現就介紹完了。

以上程式碼的實現均參考自[4] [5] [6],大家有興趣也可以自行閱讀研究。

4 總結

在本篇文章中,掌櫃首先和大家一起回顧了BERT的整個網路結構;然後一步一步從Input Embedding、BertAttention、BertLayer再到BertEncoder來詳細介紹了整個BERT模型的實現。需要提醒各位讀者朋友的是,在閱讀本文的過程中最好是結合著每個部分的輸出結果(包括形狀和意義)來進行理解。在下一篇文章中,掌櫃將會介紹如何在現有程式碼的基礎上,實現一個

基於BERT的文字分類模型

,並同時用開源的預訓練引數來對模型進行初始化。

引用

[4]Google Research https://github。com/google-research/bert

[5]BERT https://huggingface。co/transformers/model_doc/bert。html#bertmodel

[6]https://github。com/codertimo/BERT-pytorch

TAG: BERT程式碼Embedding實現部分