有時(shí)候,,好的訓(xùn)練「技巧」比蠻力堆參更有效。
現(xiàn)階段,,視覺 transformer(ViT)模型已經(jīng)在圖像分類,、目標(biāo)檢測(cè)與分割等各樣各樣的計(jì)算機(jī)視覺任務(wù)中得到了廣泛應(yīng)用,并可以在視覺表征與識(shí)別中實(shí)現(xiàn) SOTA 結(jié)果,。由于計(jì)算機(jī)視覺模型的性能往往與參數(shù)量和訓(xùn)練時(shí)長(zhǎng)呈正相關(guān),,AI 社區(qū)已經(jīng)實(shí)驗(yàn)了越來越大規(guī)模的 ViT 模型。
但應(yīng)看到,,隨著模型開始超出萬億次浮點(diǎn)運(yùn)算的規(guī)模,該領(lǐng)域已經(jīng)遇到了一些主要的瓶頸,。訓(xùn)練單個(gè)模型可能耗費(fèi)數(shù)月,,需要數(shù)以千塊的 GPU,進(jìn)而增加了加速器需求并導(dǎo)致大規(guī)模 ViT 模型將很多從業(yè)者「排除在外」,。
為了擴(kuò)展 ViT 模型的使用范圍,,Meta AI 的研究者已經(jīng)開發(fā)出了更高效的訓(xùn)練方法。非常重要的一點(diǎn)是對(duì)訓(xùn)練進(jìn)行優(yōu)化以實(shí)現(xiàn)最佳的加速器利用,。但是,,這一過程耗時(shí)費(fèi)力且需要大量的專業(yè)知識(shí)。為了設(shè)置有序的實(shí)驗(yàn),,研究者必須從無數(shù)可能的優(yōu)化方案中進(jìn)行選擇:一次訓(xùn)練過程中執(zhí)行的百萬次運(yùn)算中的任何一個(gè)都有可能受到低效率的影響和阻礙,。
Meta AI 發(fā)現(xiàn),通過將一系列優(yōu)化應(yīng)用到其圖像分類代碼庫 PyCls 中的 ViT 實(shí)現(xiàn),,可以提升計(jì)算和存儲(chǔ)效率,。對(duì)于使用 PyCIs 訓(xùn)練的 ViT 模型,Meta AI 的方法可以提升訓(xùn)練速度和每加速器吞吐量(TFLOPS),。
下圖展示了使用優(yōu)化代碼庫 PyCIs 后每芯片(per chip)加速器吞吐量相較于 V100 基準(zhǔn)的相對(duì)增加,,而 A100 優(yōu)化的加速器吞吐量是 V100 基準(zhǔn)的 4.05 倍。
運(yùn)行原理
Meta AI 首先對(duì) PyCIs 代碼庫進(jìn)行分析以確認(rèn)低訓(xùn)練效率的潛在來源,,最終將注意力放在了對(duì)數(shù)字格式的選擇上,。在默認(rèn)情況下,大多數(shù)應(yīng)用使用 32-bit 單精度浮點(diǎn)格式來表征神經(jīng)網(wǎng)絡(luò)值,。轉(zhuǎn)換至 16-bit 半精度格式(FP16)可以減少模型的內(nèi)存占用和執(zhí)行時(shí)間,,但往往也會(huì)降低準(zhǔn)確率,。
研究者采取了折中方案,即混合精度,。利用它,,系統(tǒng)通過單精度格式執(zhí)行計(jì)算以加速訓(xùn)練并減少內(nèi)存使用,同時(shí)通過單精度存儲(chǔ)結(jié)果以保持準(zhǔn)確率,。他們沒有手動(dòng)地將部分網(wǎng)絡(luò)轉(zhuǎn)換至半精度,,而是實(shí)驗(yàn)了不同模式的自動(dòng)混合精度訓(xùn)練,這樣可以在數(shù)字格式之間自動(dòng)切換,。更高級(jí)模式的自動(dòng)混合精度主要依賴半精度運(yùn)算和模型權(quán)重,。研究者采用的平衡設(shè)置既能大幅度加速訓(xùn)練,同時(shí)也不犧牲準(zhǔn)確率,。
為了使流程更加高效,,研究者充分利用了 FairScale 庫中的完全分片數(shù)據(jù)并行(Fully Sharder Data Parallel, FSDP)訓(xùn)練算法,它在 GPU 上對(duì)參數(shù),、梯度和優(yōu)化器狀態(tài)進(jìn)行分片,。通過 FSDP 算法,研究者可以使用更少的 GPU 構(gòu)建更大量級(jí)的模型,。此外,,研究者還使用了 MTA 優(yōu)化器、一個(gè)池化的 ViT 分類器和一個(gè) batch-second 輸入張量布局來跳過冗余轉(zhuǎn)置運(yùn)算,。
下圖 X 軸為可能的優(yōu)化,,Y 軸為采用 ViT-H/16 訓(xùn)練時(shí)加速器吞吐量相較于分布式數(shù)據(jù)并行(DDP)基準(zhǔn)的相對(duì)增加。
研究者在總 patch 大小為 560 時(shí)實(shí)現(xiàn)了 1.51 倍的加速器吞吐量提升,,以每個(gè)加速器芯片上每秒執(zhí)行的浮點(diǎn)運(yùn)算數(shù)量衡量,。通過將圖像大小從 224 像素增加至 256 像素,他們可以將吞吐量提升至 1.86 倍,。但是,,改變圖像大小意味著超參數(shù)的變化,這會(huì)對(duì)模型的準(zhǔn)確率造成影響,。在完全 FP16 模式下訓(xùn)練時(shí),,相對(duì)吞吐量增加至 2.18 倍。盡管有時(shí)會(huì)降低準(zhǔn)確率,,但在實(shí)驗(yàn)中準(zhǔn)確率降低少于 10%,。
下圖 Y 軸為 epoch 時(shí)間,在整個(gè) ImageNet-1K 數(shù)據(jù)集上一次訓(xùn)練的持續(xù)時(shí)間,。這里專注于現(xiàn)有配置的實(shí)際訓(xùn)練時(shí)間,,這些配置通常使用 224 像素的圖像大小。
Meta AI 的研究者使用優(yōu)化方案,,將 epoch 時(shí)間(在整個(gè) ImageNet-1K 數(shù)據(jù)集上一次訓(xùn)練的持續(xù)時(shí)間)從 0.65 小時(shí)減少到 0.43 小時(shí),。
下圖 X 軸表示特定配置中 A100 GPU 加速器芯片的數(shù)量,,Y 軸表示每芯片 TFLOPS 的絕對(duì)吞吐量。
該研究還討論了不同 GPU 配置的影響,。在每種情況下,,系統(tǒng)都實(shí)現(xiàn)了比分布式數(shù)據(jù)并行(DDP)基線水平更高的吞吐量。隨著芯片數(shù)量的增加,,由于設(shè)備間通信的開銷,,我們可以觀察到吞吐量略有下降。然而,,即使用 64 塊 GPU,,Meta 的系統(tǒng)也比 DDP 基準(zhǔn)快 1.83 倍。
新研究的意義
將 ViT 訓(xùn)練中可實(shí)現(xiàn)的吞吐量翻倍可以有效讓訓(xùn)練集群規(guī)模翻倍,,提高加速器利用率直接減少了 AI 模型的碳排放,。由于最近大模型的發(fā)展帶來了更大模型和更長(zhǎng)訓(xùn)練時(shí)間的趨勢(shì),這種優(yōu)化有望幫助研究領(lǐng)域進(jìn)一步推動(dòng)最先進(jìn)的技術(shù),,縮短周轉(zhuǎn)時(shí)間并提高生產(chǎn)力,。