Transformer直接預測完整數學表示式,推理速度提高多個數量級

機器之心報道

機器之心編輯部

來自 Mata AI、法國索邦大學、巴黎高師的研究者成功讓 Transformer 直接預測出完整的數學表示式。

符號迴歸,即根據觀察函式值來預測函式數學表示式的任務,通常涉及兩步過程:預測表示式的「主幹」並選擇數值常數,然後透過最佳化非凸損失函式來擬合常數。其中用到的方法主要是遺傳程式設計,透過多次迭代子程式實現演算法進化。神經網路最近曾在一次嘗試中預測出正確的表示式主幹,但仍然沒有那麼強大。

在近期的一項研究中,來自 Meta AI(Facebook)、法國索邦大學、巴黎高師的研究者提出了一種 E2E 模型,嘗試一步完成預測,讓 Transformer 直接預測完整的數學表示式,包括其中的常數。隨後透過將預測常數作為已知初始化提供給非凸最佳化器來更新預測常數。

Transformer直接預測完整數學表示式,推理速度提高多個數量級

論文地址:https://arxiv。org/abs/2204。10532

該研究進行消融實驗以表明這種端到端方法產生了更好的結果,有時甚至不需要更新步驟。研究者針對 SRBench 基準測試中的問題評估了該模型,並表明該模型接近 SOTA 遺傳程式設計的效能,推理速度提高了幾個數量級。

方法

Embedder

該模型提供了 N 個輸入點 (x, y) ∈ R^(D+1),每個輸入點被表徵為 d_emb 維度的 3(D + 1) 個 token。隨著 D 和 N 變大,這會導致輸入序列很長(例如,D = 10 和 N = 200 時有 6600 個 token),這對 Transformer 的二次複雜度提出了挑戰。

為了緩解這種情況,該研究提出了一個嵌入器( embedder )來將每個輸入點對映成單一嵌入。嵌入器將空輸入維度填充(pad)到 D_max,然後將 3(D_max+1)d_emb 維向量饋入具有 ReLU 啟用的 2 層全連線前饋網路 (FFN) 中,該網路向下投影到 d_emb 維度,得到的 d_emb 維的 N 個嵌入被饋送到 Transformer。

該研究使用一個序列到序列的 Transformer 架構,它有 16 個 attention head,嵌入維度為 512,總共包含 86M 個引數。像《 ‘Linear algebra with transformers 》研究中一樣,研究者觀察到解決這個問題的最佳架構是不對稱的,解碼器更深:在編碼器中使用 4 層,在解碼器中使用 16 層。該任務的一個顯著特性是 N 個輸入點的排列不變性。為了解釋這種不變性,研究者從編碼器中刪除了位置嵌入。

如下圖 3 所示,編碼器捕獲所考慮函式的最顯著特徵,例如臨界點和週期性,並將專注於區域性細節的短程 head 與捕獲函式全域性的長程 head 混合在一起。

Transformer直接預測完整數學表示式,推理速度提高多個數量級

訓練

該研究使用 Adam 最佳化器最佳化交叉熵損失,在前 10000 步中將學習率從 10^(-7) 提升到 2。10^(-4),然後按照論文《 Attention is all you need 》中的方法將其衰減為步數的平方根倒數(inverse square root)。該研究提供了包含來自同一生成器的 10^4 個樣本的驗證集,並訓練模型,直到驗證集的準確率達到飽和(大約 50 個 epoch 的 3M 個樣本)。

輸入序列長度隨點數 N 顯著變化;為了避免浪費填充,該研究將相似長度的樣本一起批處理,確保一個完整的批處理包含至少 10000 個 token。

實驗結果

該研究不僅評估了域內準確性,也展示了在域外資料集上的結果。

域內效能

表 2 給出了該模型的平均域內結果。如果不進行修正,E2E 模型在低精度預測(R^2 和 Acc_0。1 指標)方面優於在相同協議下訓練的 skeleton 模型,但常數預測中存在的錯誤會導致在高精度(Acc_0。001)下的效能較低。

Transformer直接預測完整數學表示式,推理速度提高多個數量級

修正之後的程式顯著緩解了這個問題,讓 Acc_0。001 提升了三倍,同時其他指標也有所改進。

Transformer直接預測完整數學表示式,推理速度提高多個數量級

圖 4A、B、C 給出了 3 個公式難度指標的消融實驗結果(從左到右):一元運算元的數量、二元運算元的數量和輸入維數。正如人們所預料的那樣,在所有情況下,增加難度係數會降低效能。這可能會讓人認為該模型在輸入維度上不能很好地擴充套件,但實驗表明,與併發方法相比,該模型在域外資料集上的擴充套件效能非常好,如下圖所示。

Transformer直接預測完整數學表示式,推理速度提高多個數量級

圖 4D 顯示了效能與輸入模型的點數 N 之間的關係。在所有情況下,效能都會提高,但 E2E 模型比 skeleton 模型更顯著,這證明大量資料對於準確預測表示式中的常數是非常重要的。

外推和穩健性。如圖 4E 所示,該研究透過改變測試點的規模來檢查模型內插 / 外推的能力:該研究沒有將測試點歸一化為單位方差,而是將它們歸一化為 σ。隨著 σ 的增加,效能會下降,但是即使遠離輸入(σ = 32),外推效能仍然不錯。

最後,如圖 4F 所示,研究者檢查了使用方差 σ 的乘性噪聲(multiplicative noise)對目標 y 的影響:y y(1 + ξ), ξ N (0, ε)。這個結果揭示了一些有趣的事情:如果不進行修正,E2E 模型對噪聲的穩健性不強,實際上在高噪聲下效能比 skeleton 模型差。這顯示了 Transformer 在預測常數時對輸入的敏感程度。修正之後 E2E 模型的穩健性顯著提高,但將常數初始化為估計值的影響較小,因為常數的預測被噪聲破壞了。

感興趣的讀者可以閱讀論文原文,瞭解更多研究細節。

TAG: 模型輸入常數Transformer預測