《電子技術(shù)應(yīng)用》
您所在的位置:首頁(yè) > 可編程邏輯 > 解決方案 > OpenAI新研究補(bǔ)齊Transformer短板,將可預(yù)測(cè)序列長(zhǎng)度提高30倍

OpenAI新研究補(bǔ)齊Transformer短板,將可預(yù)測(cè)序列長(zhǎng)度提高30倍

2019-04-24

Transformer是一種強(qiáng)大的序列模型,,但是它所需的時(shí)間和內(nèi)存會(huì)隨著序列長(zhǎng)度出現(xiàn)二階增長(zhǎng)。近日,,OpenAI研究人員開(kāi)發(fā)出了一種深度神經(jīng)網(wǎng)絡(luò)Sparse Transformer,該網(wǎng)絡(luò)在預(yù)測(cè)長(zhǎng)序列方面創(chuàng)造了新紀(jì)錄——無(wú)論預(yù)測(cè)的是文本,、圖像還是聲音,。該神經(jīng)網(wǎng)絡(luò)利用注意力機(jī)制中的一種改進(jìn)算法,可以從長(zhǎng)度可能是之前30倍的序列中提取模式,。

現(xiàn)在,,AI 研究中的一項(xiàng)挑戰(zhàn)是在圖像、視頻或聲音等復(fù)雜數(shù)據(jù)中進(jìn)行長(zhǎng)序列的精細(xì)相關(guān)性建模,。Sparse Transformer 合并了 O(N^2)Transformer 自注意力機(jī)制的 O(N√N(yùn)) 重組以及其他一些改進(jìn),,從而直接用于這些豐富的數(shù)據(jù)類(lèi)型。以前,,這些數(shù)據(jù)上所使用的模型是專(zhuān)為某個(gè)領(lǐng)域制作的,,或者很難將序列擴(kuò)展到包含幾千個(gè)元素。


相比之下,,OpenAI 開(kāi)發(fā)的模型通過(guò)使用數(shù)以百計(jì)的層可以對(duì)包含上萬(wàn)個(gè)元素的序列進(jìn)行建模,,在諸多領(lǐng)域都取得了當(dāng)前最佳的表現(xiàn)。OpenAI 研究人員利用該模型幫助創(chuàng)建能夠更好地理解世界的 AI 系統(tǒng)。


深度注意力


在 Transformer 中,,每一個(gè)輸出元素與輸入元素相連接,,同時(shí)根據(jù)具體情況對(duì)它們之間的權(quán)重進(jìn)行動(dòng)態(tài)計(jì)算,這一過(guò)程被稱(chēng)為「注意力機(jī)制」,。雖然人們相信這使得 Transformer 較那些具有固定連接模式的模型更為靈活,,但實(shí)際操作中需要為每一層和注意力頭創(chuàng)建一個(gè) N×N 注意力矩陣,當(dāng)應(yīng)用于圖像或原始音頻等具有許多元素的數(shù)據(jù)類(lèi)型時(shí)會(huì)消耗大量?jī)?nèi)存,。

微信圖片_20190424221637.jpg

當(dāng)矩陣存儲(chǔ)在內(nèi)存或在逆推計(jì)算過(guò)程中進(jìn)行再計(jì)算時(shí),,深度 Transformer(64 層和 4 個(gè)頭)的注意力內(nèi)存使用情況。作為參考,,用于深度學(xué)習(xí)的標(biāo)準(zhǔn) GPU 內(nèi)存通常是 12-32GB.


減少內(nèi)存消耗的一種方法是在反向傳播過(guò)程中從檢查點(diǎn)處重新計(jì)算注意力矩陣,,這是深度學(xué)習(xí)中的一種成熟的方法,以更多的計(jì)算來(lái)減少內(nèi)存使用,。


當(dāng) Transformer 中的注意力矩陣完成時(shí),,這意味著最大的內(nèi)存消耗將不受層數(shù)的支配,使研究人員訓(xùn)練網(wǎng)絡(luò)的深度大大超過(guò)從前,。在實(shí)際操作中,研究人員發(fā)現(xiàn)在處理 CIFAR-10 等基準(zhǔn)測(cè)試任務(wù)時(shí),,深度達(dá) 128 層的 Transformer 表現(xiàn)出的性能優(yōu)于較淺的網(wǎng)絡(luò),。


為了訓(xùn)練深度更大的模型,研究人員對(duì) transformer 的操作順序進(jìn)行了幾次調(diào)整,,修改了初始化方法,。詳情參見(jiàn)論文。


稀疏注意力


然而,,對(duì)于非常大的輸入來(lái)說(shuō),,甚至計(jì)算單個(gè)注意力矩陣都是不現(xiàn)實(shí)的。因此,,OpenAI 使用了稀疏注意力模式,,在這種模式中,每個(gè)輸出位置僅從輸入位置子集中計(jì)算權(quán)重,。當(dāng)子集相對(duì)于整個(gè)輸入集較小時(shí)(如元素?cái)?shù)量是√N(yùn) 而不是 N),,即使對(duì)于非常長(zhǎng)的序列,注意力計(jì)算也會(huì)變得比較容易,,算法復(fù)雜度為 O(N√N(yùn))而不是 O(N^2),。


為了評(píng)估該方法的可行性,研究人員首先可視化并學(xué)習(xí)了圖像上深度 Transformer 的注意力模式,,發(fā)現(xiàn)其中許多模式表現(xiàn)出了可解釋和結(jié)構(gòu)化的稀疏模式,。以下每幅圖像都顯示了哪個(gè)輸入像素(白色高亮標(biāo)出)由一個(gè)給定的注意力頭處理,以預(yù)測(cè)圖像中的下一個(gè)值。當(dāng)輸入部分集中在小的子集上并顯示出高度規(guī)律性時(shí),,該層就易于稀疏化,。以下是 CIFAR-10 圖像上 128 層模型的樣本:

微信圖片_20190424221702.jpg

左:Layer 19,右:Layer 20,。為一個(gè) 128 層的 CIFAR-10 網(wǎng)絡(luò)的若干層學(xué)習(xí)注意力模式(白色高亮顯示),。這些層學(xué)會(huì)了在兩個(gè)維度上分割注意力。Layer 19 匯總每一行的信息,,Layer 20 按列匯總這些信息,,從而有效分解了全注意力運(yùn)算。

微信圖片_20190424221725.jpg

為獲取位置記憶而訓(xùn)練的層(左:Layer 6,;右:Layer 36),,它們通常關(guān)注類(lèi)似的位置,不管輸入數(shù)據(jù)或時(shí)間步長(zhǎng)如何(Layer 6),。其他層學(xué)習(xí)高度依賴(lài)數(shù)據(jù)的訪問(wèn)模式(Layer 36),。


雖然許多層顯示出稀疏的結(jié)構(gòu),但有些層清晰地顯示出了動(dòng)態(tài)注意力,,這種注意力延伸到整個(gè)圖像,。為了保持網(wǎng)絡(luò)學(xué)習(xí)這種模式的能力,研究人員實(shí)現(xiàn)了注意力矩陣的二維分解,,其中網(wǎng)絡(luò)可以通過(guò)兩步稀疏注意力關(guān)注到所有位置,。

微信圖片_20190424221749.png


第一版 strided attention 大概等同于每個(gè)位置處理自己的行和列,它與以上網(wǎng)絡(luò)學(xué)得的注意力模式類(lèi)似,。(注意,,列注意力可等同于處理轉(zhuǎn)置矩陣的行)。第二版 fixed attention 在最新的列元素之后處理固定列和元素,,研究者認(rèn)為這個(gè)模式對(duì)于數(shù)據(jù)無(wú)法擬合二維結(jié)構(gòu)(如文本)的情況很有用,。


實(shí)驗(yàn)結(jié)果


Sparse Transformer 在 CIFAR-10、Enwik8 和 Imagenet 64 數(shù)據(jù)集上刷新了當(dāng)前最優(yōu)密度估計(jì)分?jǐn)?shù),。

微信圖片_20190424221807.jpg

微信圖片_20190424221829.png

在 CIFAR-10,、Enwik8 和 Imagenet 64 數(shù)據(jù)集上的密度估計(jì)性能(單位為 bits per byte/dim)。M 表示網(wǎng)絡(luò)中使用的參數(shù)(單位為百萬(wàn)),,W 表示網(wǎng)絡(luò)寬度,,L 表示層數(shù),H 表示頭數(shù),。


研究者還發(fā)現(xiàn)稀疏注意力比完整注意力的損失更低,,且速度更快。這可能指向稀疏模式產(chǎn)生的有用歸納偏置,,或者密集注意力的底層優(yōu)化問(wèn)題,。


生成圖像


使用了稀疏注意力的 Transformer 似乎有一種全局結(jié)構(gòu)的概念,,這可以通過(guò)觀察圖像補(bǔ)全(image completion)進(jìn)行定性評(píng)估。下圖可視化了一個(gè)在 64×64 ImageNet 上訓(xùn)練的模型:

微信圖片_20190424221849.jpg


損壞原圖

微信圖片_20190424221907.jpg

修復(fù)圖像

微信圖片_20190424221927.jpg

真實(shí)圖像


研究人員還生成了完全無(wú)條件的樣本,,其中未調(diào)整的 softmax 溫度為 1.0,。這些模型使用最大似然目標(biāo)進(jìn)行訓(xùn)練,其覆蓋了所有的數(shù)據(jù)模式(其中包括可能不存在的數(shù)據(jù)),,而不是增強(qiáng)較小部分?jǐn)?shù)據(jù)的保真度,。從具有未調(diào)整溫度的模型中取樣,研究人員看到了該模型認(rèn)為世界上存在的圖像的完整分布,。因此,,一些樣本看起來(lái)奇奇怪怪的。

微信圖片_20190424221943.jpg

模型示例


生成原始音頻波形


通過(guò)簡(jiǎn)單改變位置嵌入,,稀疏 Transformer 還能用來(lái)生成原始音頻,,而非圖像。隨著深度學(xué)習(xí)擴(kuò)展到新的數(shù)據(jù)類(lèi)型,,用這類(lèi)網(wǎng)絡(luò)來(lái)指定歸納偏置也很容易,。


該模型是在原始的古典音樂(lè)片段上訓(xùn)練的,并使用了稀疏注意力來(lái)生成長(zhǎng)度為 65000 的序列,。這相當(dāng)于大約 5 秒長(zhǎng)的原始音頻,,研究人員在下面的每個(gè)片段中將幾個(gè)樣本連接在一起。

微信圖片_20190424222027.jpg


代碼公布


通常,,實(shí)現(xiàn)稀疏注意力需要將查詢(xún)和關(guān)鍵矩陣分割成塊,,因此為了簡(jiǎn)化實(shí)驗(yàn),OpenAI 實(shí)現(xiàn)了一組塊稀疏核,,這些核在 GPU 上高效地執(zhí)行這些操作。OpenAI 開(kāi)源了這些核并提供了稀疏注意力函數(shù)的示例:

https://github.com/openai/sparse_attention


未來(lái)發(fā)展和限制


本文介紹的稀疏注意力模式只是對(duì)長(zhǎng)序列進(jìn)行高效建模的初步嘗試,。研究人員認(rèn)為,,探索稀疏注意力的不同模式和各種組合非常有用,而且學(xué)習(xí)稀疏模式對(duì)下一代神經(jīng)網(wǎng)絡(luò)架構(gòu)來(lái)說(shuō)也是一個(gè)很重要的研究途徑,。


即使有了上述改進(jìn),,自回歸序列生成對(duì)非常高分辨率圖像和音頻來(lái)說(shuō)仍是不切實(shí)際的。但是,,研究人員介紹的優(yōu)化注意力操作可能有用,,將它與其它方法(如多尺度方法)結(jié)合,可以建模高維數(shù)據(jù),。


論文:Generating Long Sequences with Sparse Transformers


微信圖片_20190424222049.png


論文鏈接:https://d4mucfpksywv.cloudfront.net/Sparse_Transformer/sparse_transformers.pdf


摘要:Transformer 是一種強(qiáng)大的序列模型,,但是它所需的時(shí)間和內(nèi)存會(huì)隨著序列長(zhǎng)度出現(xiàn)二階增長(zhǎng)。這篇論文介紹了注意力矩陣的稀疏因式分解,,可以將其降低到 O(N√N(yùn)),。該研究提出了 a)訓(xùn)練更深網(wǎng)絡(luò)的架構(gòu)和初始化變體,;b)重新計(jì)算注意力矩陣以節(jié)省內(nèi)存;c)用于訓(xùn)練的快速注意力內(nèi)核,。研究者將具備這些變化的網(wǎng)絡(luò)稱(chēng)為 Sparse Transformer,,并證明該網(wǎng)絡(luò)可以使用數(shù)百個(gè)層來(lái)建模成千上萬(wàn)個(gè)時(shí)間步長(zhǎng)的序列。


該網(wǎng)絡(luò)在從原始字節(jié)中建模圖像,、音頻和文本時(shí)使用的是同樣的架構(gòu),,在 Enwik8、CIFAR10 和 ImageNet-64 數(shù)據(jù)集上取得了當(dāng)前最優(yōu)的密度估計(jì)性能,。研究者生成的無(wú)條件樣本展示了全局一致性和極大的多樣性,,并證明原則上可以使用自注意力建模長(zhǎng)度超百萬(wàn)的序列。


本站內(nèi)容除特別聲明的原創(chuàng)文章之外,,轉(zhuǎn)載內(nèi)容只為傳遞更多信息,,并不代表本網(wǎng)站贊同其觀點(diǎn)。轉(zhuǎn)載的所有的文章,、圖片,、音/視頻文件等資料的版權(quán)歸版權(quán)所有權(quán)人所有。本站采用的非本站原創(chuàng)文章及圖片等內(nèi)容無(wú)法一一聯(lián)系確認(rèn)版權(quán)者,。如涉及作品內(nèi)容,、版權(quán)和其它問(wèn)題,請(qǐng)及時(shí)通過(guò)電子郵件或電話(huà)通知我們,,以便迅速采取適當(dāng)措施,,避免給雙方造成不必要的經(jīng)濟(jì)損失。聯(lián)系電話(huà):010-82306118,;郵箱:[email protected],。