訓練大模型時,有時讓它“記性差一點”,反而更聰明!
大語言模型如果不加約束,很容易把訓練數據原封不動地復刻出來。為解決這個問題,來自馬里蘭大學、圖賓根大學和馬普所的研究團隊提出了一個新方法 —— 金魚損失(Goldfish Loss)。

顧名思義,金魚損失就是讓模型像金魚一樣,不去死記每一個細節,而是在損失函數計算時隨機剔除一小部分 token。
由此,模型不再逐字記住訓練集內容,但仍能學會語言規律。
實驗顯示,LLaMA-2 在使用金魚損失后:
記憶化內容顯著減少:模型不再復現訓練數據
下游任務性能幾乎不受影響:仍然能流暢生成文本
用網友的精辟評論概括就是:dropout,但損失函數!

金魚損失的核心理念非常簡單,就是在模型訓練過程中隨機剔除一部分訓練文本中的 tokens,使其不參與損失計算。
這樣一來,當模型在推理階段遇到這些位置時,就只能“猜測”,而不是逐字逐句復現訓練數據的完整序列。
此外,為了保證被剔除 token 的一致性,研究人員設計了一種基于哈希(hashing)的掩碼策略。

那么,這和同樣是防止模型背會的正則化方法有什么不同呢?
以 Dropout 這樣的正則化方法為例,它通過在訓練時“加噪聲”來防止模型過度依賴某些參數,從而提高模型舉一反三的能力。
但這樣做的問題在于:如果只是隨機丟 token,那么,每次看到同一段落時,丟掉的地方不一樣,模型累計幾次就能拼湊出完整段落。
所以,說到底,模型還是靠死記硬背,記住了答案。
相比之下,金魚損失則用哈希掩碼確保每次遇到同一段落,掩蓋位置都一樣,這就從根本上阻止了模型復現完整訓練文本。
接下來,我們來看金魚損失具體是怎么做的。
在傳統的 next-token prediction 中,模型以序列中的下一個真實 token 作為目標,輸出預測分布,并基于該分布計算交叉熵損失。

在金魚損失下,模型雖然也在前向傳播中預測序列里下一個 token。但在計算損失時,會以一定的概率將某些位置的 token 從損失計算里“抹掉”。
也就是說,有些真實的下一個 token 不會作為目標來訓練。

在這里,研究人員采用了簡單的靜態掩碼(static mask),剔除每序列中的第 4 個 token。
更進一步,為了確保模型不會從其他地方學到被掩碼的數據(例如不同的文檔會在不同的網頁中反復出現),研究團隊還提出了一種局部化哈希掩碼(localized hashed mask),使得當相同的前 h 個 token 出現時,掩蓋模式是相同的(可重復)。
實驗測試與結果為了驗證金魚損失確實能防止記憶化,研究團隊設計了兩種實驗場景:
一種是極端場景,通過對少量樣本進行多個訓練周期(即重復)來強烈促使記憶化;
另一種是標準場景,模擬現實模型訓練中使用的批次處理方式。
同時,為了評估模型的記憶化程度,研究采用了以下指標:
RougeL 得分:該指標衡量最長公共(非連續)子序列的長度。得分為 1.0 表示完美記憶。
精確匹配率(Exact Match):該指標衡量正確預測的序列占真實序列的百分比.
實驗表明,在極端場景下,標準訓練導致模型逐字記憶了 100 篇文章中的 84 篇,而金魚損失沒有記憶任何文章。

此外,在標準訓練場景下,金魚損失也明顯減少了模型逐字復現訓練語料庫中目標序列的情況。

但這里可能有個直覺式的反應 —— 如果讓模型“隨機漏學”一些 token,它的能力會不會也隨之降低呢?
對此,研究人員進行了測試:研究表明,金魚損失模型、標準損失模型和對照模型之間的總體性能沒有系統性差異。

需要注意的是,金魚損失的核心在于忽略部分 token 的梯度計算。因此,為了學到足夠的語言模式,模型必須通過更多數據來補償這些空缺,這可能導致計算效率的下降。
參考鏈接
[1]https://arxiv.org/pdf/2406.10209
本文來自微信公眾號:量子位(ID:QbitAI),作者:henry,原標題《大模型“記性差一點”反而更聰明!金魚損失隨機剔除 token,讓 AI 不再死記硬背》
本文鏈接:http://www.www897cc.com/showinfo-45-27337-0.html大模型“記性差一點”反而更聰明:金魚損失隨機剔除 token,讓 AI 不再死記硬背
聲明:本網頁內容旨在傳播知識,若有侵權等問題請及時與本網聯系,我們將在第一時間刪除處理。郵件:2376512515@qq.com