【新智元導(dǎo)讀】DeepMind 最近被 ICML 2024 接收的一篇論文,完完全全暴露了他們背靠谷歌的「豪橫」。一篇文章預(yù)估了這項(xiàng)研究所需的算力和成本,大概是 Llama 3 預(yù)訓(xùn)練的 15%,耗費(fèi)資金可達(dá) 12.9M 美元。
發(fā)一篇頂會(huì)論文,需要多少實(shí)驗(yàn)預(yù)算?
最近,DeepMind 發(fā)表了一項(xiàng)研究,對(duì) LLM 擴(kuò)大規(guī)模時(shí)各種算法和架構(gòu)細(xì)節(jié),比如參數(shù)和優(yōu)化器的選擇,進(jìn)行了廣泛的實(shí)證調(diào)查。
這篇論文已被 ICML 2024 接收。
論文地址:https://arxiv.org/abs/2407.05872
63 頁的論文涵蓋了數(shù)以萬計(jì)的模型,備選方案包括 3 種優(yōu)化器、4 種參數(shù)化方案、幾種對(duì)齊假設(shè)、十多個(gè)學(xué)習(xí)率,以及最高達(dá) 26.8B 的 14 種參數(shù)規(guī)模。
需要進(jìn)行實(shí)驗(yàn)的 4 種參數(shù)化方案
僅僅聽到這些數(shù)字,就不難知道,這項(xiàng)研究必定涉及海量的模型運(yùn)行實(shí)驗(yàn)。
而有一位忠實(shí)讀者,為了測(cè)試自己對(duì)論文內(nèi)容的理解,統(tǒng)計(jì)了其中進(jìn)行的所有實(shí)驗(yàn),并估算出了復(fù)現(xiàn)論文的成本。
將所需算力全部加在一起,林林總總,居然達(dá)到了驚人的 1290 萬美元。
考驗(yàn)基本功的時(shí)刻到了,假如你是研究團(tuán)隊(duì)的 leader,根據(jù)實(shí)驗(yàn)計(jì)劃對(duì)所需算力和成本進(jìn)行預(yù)估是一項(xiàng)必不可少的技能。
那就讓我們跟著這篇博客文章盤一遍,這一千多萬美元,究竟燒在哪里。
Transformer 架構(gòu)信息
論文附錄 C 提供了關(guān)于模型算法和架構(gòu)的各種細(xì)節(jié)設(shè)置,比如使用 decoder-only 架構(gòu)、層歸一化、GeLU 激活函數(shù)、無 dropout、T5 分詞器、批大小為 256、用 FSDP 并行等等。
實(shí)驗(yàn)?zāi)P偷膮?shù)規(guī)模統(tǒng)計(jì)
通過架構(gòu)方面的信息,我們可以大致估算出訓(xùn)練中每個(gè) token 所需的 FLOPS,記為 M。
由于論文沒有描述到任何 GQA / MQA 機(jī)制,所以就假設(shè) Rkv=1,此外還有 lseq=512,Dhead=128,L=8(深度),V=32101(分詞器詞匯量)。
模型總參數(shù)量可以表示為:
因此,就可以得到 M 的計(jì)算公式:
默認(rèn)情況下,每次實(shí)驗(yàn)處理的 token 數(shù)(tokens per experiment, TPE)為 5k(訓(xùn)練步數(shù))×256(批大?。?12(lseq),約為 6.5536e9。
def M(d: int, L=8, l_seq=512, V=32101) -> int: return 6*d * (L*(12*d + l_seq) + V) TPE = 50000 * 256 * 512
對(duì)齊實(shí)驗(yàn)
假設(shè)對(duì)齊實(shí)驗(yàn)中,直接使用了后面的學(xué)習(xí)率掃描得出的最優(yōu)結(jié)果,并沒有單獨(dú)進(jìn)行學(xué)習(xí)率掃描,因此這一步的成本計(jì)算比較簡(jiǎn)單:
def alignment() - int: return 4 * TPE * sum(M(d) for d in [1024,2048,4096]) # >> f'{alignment()3E}' # '3.733E+20' # >> cost_of_run(alignment())[0] # 888.81395400704
如果 H100 每運(yùn)行 1 小時(shí)的花費(fèi)以 3 美元計(jì)算,對(duì)齊實(shí)驗(yàn)的成本大致為 888 美元。
學(xué)習(xí)率
子問題:最佳評(píng)估損失(eval loss)實(shí)驗(yàn)
論文的表 E1 記錄了 6 種模型規(guī)模下,所有可能的優(yōu)化器 × 參數(shù)化方案 × 模型大小 × 實(shí)驗(yàn)設(shè)置的組合,分別進(jìn)行基礎(chǔ)學(xué)習(xí)率掃描,以獲得最佳評(píng)估損失。
總共包括如下幾個(gè)實(shí)驗(yàn)變量:
模型維度 D∈3072,4096,6144,8192,12288,16384
4 種參數(shù)化方案
3 種優(yōu)化器,其中 SGD 僅有 5 個(gè)實(shí)驗(yàn)設(shè)置,Adam 和 Adam+Param Scaling 有 7 個(gè)實(shí)驗(yàn)設(shè)置
假設(shè)這里的實(shí)驗(yàn)都是單獨(dú)進(jìn)行,沒有從其他地方復(fù)制結(jié)果,因此如果全部運(yùn)行一遍,有成本上限預(yù)估:
H = [1,2,4,6,8,12,16,20,24,32,48,64,96,128] D = [h * 128 for h in H] def table_e1() - int: sets_x_optims = 5 + 7 + 7 return 4 * sets_x_optims * TPE * sum(M(d) for d in D[-6]) # >> f'{table_e1()3E}'cost_of_run(table_e1()) # '1.634E+23' # (388955.9991064986 16206.499962770775)
這部分的成本就接近 40 萬美元,雖然仍屬于可接受范圍內(nèi),但對(duì)于大多數(shù)學(xué)術(shù)預(yù)算來說,已經(jīng)算是非常昂貴了。
表 E1 給出了最佳評(píng)估損失,但沒有描述 LR 的掃描策略,每張圖上的點(diǎn)數(shù)也不盡相同。
由于沒有得到論文作者的答復(fù),我們也無法確定具體機(jī)制,因此假設(shè)每個(gè)最佳評(píng)估損失都經(jīng)過了 15 次實(shí)驗(yàn)(目測(cè)發(fā)現(xiàn),每條線的點(diǎn)數(shù)約為 10~15)。
β 參數(shù)
根據(jù)論文 4.2 節(jié)內(nèi)容,學(xué)習(xí)率還涉及到兩個(gè)超參數(shù)的選擇:β 和 γ。
如果僅有 β 參數(shù),則被稱為「LR+default」設(shè)置:
這部分包括 3× 優(yōu)化器,4× 參數(shù)化,加上全局和單層(GlobalLR、Perlayer-fullalign)分別進(jìn)行實(shí)驗(yàn),以及未知的 LR 掃描數(shù)量:
def _only() -> int: return 3*4*2*PpL * TPE * sum(M(d) for d in D) # 7.988E+23 (1902022.3291813303, 79250.93038255542)
從公式就可以看出,成本和下文的 epsilon 實(shí)驗(yàn)類似,都是 200 萬美元。
γ 參數(shù)
相比 β 參數(shù)的實(shí)驗(yàn),這部分有兩個(gè)細(xì)節(jié)差異。
首先,除了 GlobalLR、Perlayer-fullalign 兩種設(shè)置外,還需要加上 Perlayer-noalign 設(shè)置。
其次,僅針對(duì) d=1024=b,進(jìn)行 3D 超參數(shù)搜索 (γ_1,γ_h,γ_L+1),因此有額外的 800 次運(yùn)行。
兩者結(jié)合后的計(jì)算公式為:
這部分的預(yù)估成本與 Adam 的 epsilon 熱力圖實(shí)驗(yàn)接近,約為 320 萬美元。
def gamma_expts() -> int: return 36*TPE * (800*M(1024) + PpL*sum(M(d) for d in D)) # gamma_expts 1.354E+24 (3224397.534237257, 134349.8972598857)
Adam 優(yōu)化器的 Epsilon 參數(shù)
論文 4.3 節(jié)所述的 Epsilon 參數(shù)實(shí)驗(yàn)是計(jì)算量的大頭。
根據(jù)上面的推斷,每次找到最佳評(píng)估損失時(shí)都嘗試過 15 個(gè)不同的學(xué)習(xí)率(points per line),那么圖 6 所示的 epsilon 參數(shù)變化圖耗費(fèi)的計(jì)算量為:
計(jì)算結(jié)果透露出一種簡(jiǎn)潔的昂貴,也就是 200 萬美元的賬單而已。
PpL = 15 # unprincipled estimate def eps_variants() -> int: return 4 * 6 * PpL * TPE * sum(M(d) for d in D) ''' >>> f'{eps_variants():.3E}';cost_of_run(eps_variants()) '7.988E+23' (1902022.3291813303, 79250.93038255542) '''
除了圖 6 左側(cè)的折線圖,還有附錄 F 熱力圖的結(jié)果。
假設(shè)每個(gè)方塊值都是經(jīng)過 13 次學(xué)習(xí)率掃描后得到的結(jié)果,這部分計(jì)算量則為:
結(jié)果發(fā)現(xiàn),僅僅要得到這 8 張熱力圖,成本就是 320 萬美元。而且,由于我們將 LR 掃描數(shù)量建模為常數(shù) 13,這個(gè)數(shù)字可能低于實(shí)際成本。
def eps_heatmaps() - int: # eps-type * eps-val * parameterizations * LR range * ... return 2 * 6 * 4 * 13 * TPE * sum(M(d) for d in D[-6]) ''' >> f'{eps_heatmaps()3E}'cost_of_run(eps_heatmaps()) '1.341E+24' (3193533.466348094 133063.89443117057) '''
權(quán)重衰減
權(quán)重衰減實(shí)驗(yàn)(附錄 G)比較好理解,對(duì) 4× 參數(shù)化方案以及所有參數(shù)進(jìn)行一次基本的 LR 掃描:
比 epsilon 實(shí)驗(yàn)便宜不少,也就是灣區(qū)工程師一年的工資 ——31.7 萬美元。
def weight_decay() -> int: return 4 * PpL * TPE * sum(M(d) for d in D) ''' >>> f'{weight_decay():.3E}'; cost_of_run(weight_decay()) '1.331E+23' (317003.7215302217, 13208.488397092571) '''
Adafactor 優(yōu)化器
這部分實(shí)驗(yàn)在附錄 C3 中有詳細(xì)描述,是為了檢驗(yàn) Adafactor 和 Adam+parameter scaling 是否有相似的寬度縮放機(jī)制。
共有 2×4 張圖,其中每個(gè)優(yōu)化器收集 11 個(gè)數(shù)據(jù)點(diǎn),因此計(jì)算公式為:
賬單上再加 18.8 萬美元。
def adafactor() -> int: return 2*2*4*PpL*TPE*sum(M(d) for d in D[:11]) ''' >>> f'{adafactor():.3E}'; cost_of_run(adafactor()) '7.918E+22' (188532.80765144504, 7855.533652143543) '''
計(jì)算最優(yōu)化
論文嘗試改變注意力頭 H 的數(shù)量,希望找到計(jì)算最優(yōu)化的設(shè)置,但其中涉及步長(zhǎng)和數(shù)據(jù)集的改變,因此這部分不使用公式描述,計(jì)算代碼如下:
def P(d: int, L=8, V=32101) -> int: return 2 * d * (6*L*d + V) def compute_optimal(): indices_50k = (14, 14, 12) return 4*PpL*sum([ TPE * sum(sum( M(d) for d in D[:i] ) for i in indices_50k), 20 * sum(P(d)*M(d) for d in D[:11]) *3, ]) # compute_optim 7.518E+23 (1790104.1799513847, 74587.67416464102)
總結(jié)
將以上各部分實(shí)驗(yàn)的算力和成本匯總在一起:
alignment 3.733E+20 (888.81395400704, 37.033914750293334) table_e1 1.634E+23 (388955.9991064986, 16206.499962770775) eps_variants 7.988E+23 (1902022.3291813303, 79250.93038255542) eps_heatmaps 1.341E+24 (3193533.466348094, 133063.89443117057) _only 7.988E+23 (1902022.3291813303, 79250.93038255542) gamma_expts 1.354E+24 (3224397.534237257, 134349.8972598857) weight_decay 1.331E+23 (317003.7215302217, 13208.488397092571) adafactor 7.918E+22 (188532.80765144504, 7855.533652143543) compute_optim 7.518E+23 (1790104.1799513847, 74587.67416464102)
結(jié)果發(fā)現(xiàn),整篇論文的運(yùn)算量為 5.42e24 FLOPS。
這個(gè)數(shù)字僅僅是 Llama 3 訓(xùn)練計(jì)算量的 15%,如果在 10 萬卡 H100 集群上運(yùn)行,只需要 2 天時(shí)間即可完成所有實(shí)驗(yàn)。
total_flops=5.421E+24 rental price: US$12.9M h100 node months required: 746.9595590938408 (sanity check) D=[128, 256, 512, 768, 1024, 1536, 2048, 2560, 3072, 4096, 6144, 8192, 12288, 16384] (sanity check) model sizes: ['0.00979B', '0.0227B', '0.058B', '0.106B', '0.166B', '0.325B', '0.534B', '0.794B', '1.1B', '1.87B', '4.02B', '6.97B', '15.3B', '26.8B'] (sanity check) M/6P: ['63.4%', '68.5%', '75.3%', '79.7%', '82.8%', '86.8%', '89.3%', '91.0%', '92.2%', '93.9%', '95.7%', '96.7%', '97.7%', '98.3%']
然而,如果不從 LLM 預(yù)訓(xùn)練的標(biāo)準(zhǔn)來衡量,僅把 DeepMind 的這篇論文看做一篇學(xué)術(shù)研究,這個(gè)計(jì)算量就顯得相當(dāng)奢侈了。
如果實(shí)驗(yàn)室僅有 10 張 H100,就根本不可能進(jìn)行這個(gè)量級(jí)的研究。
有 100 張 H100 的大型實(shí)驗(yàn)室,或許能用幾年時(shí)間跑完以上所有實(shí)驗(yàn)。
參考資料:
https://152334h.github.io/blog/scaling-exponents/
https://news.ycombinator.com/item?id=41107721
https://arxiv.org/abs/2407.05872
本文來自微信公眾號(hào):微信公眾號(hào)(ID:null),作者:新智元
廣告聲明:文內(nèi)含有的對(duì)外跳轉(zhuǎn)鏈接(包括不限于超鏈接、二維碼、口令等形式),用于傳遞更多信息,節(jié)省甄選時(shí)間,結(jié)果僅供參考,IT之家所有文章均包含本聲明。