陳 曦,姜 黎
(湘潭大學(xué)物理與光電工程學(xué)院,湖南湘潭 411100)
眾所周知,深度神經(jīng)網(wǎng)絡(luò)具有較強(qiáng)的函數(shù)逼近能力,能夠表征復(fù)雜函數(shù)[1]。最近研究表明,神經(jīng)網(wǎng)絡(luò)的表征能力隨著網(wǎng)絡(luò)深度指數(shù)增長(zhǎng)而增強(qiáng)[2]。在機(jī)器學(xué)習(xí)領(lǐng)域,泛化能力指學(xué)習(xí)到的模型對(duì)未知數(shù)據(jù)的預(yù)測(cè)能力[3]。根據(jù)可能近似正確(probably approximate correct,PAC)理論[4]理解為以e 指數(shù)形式正比于假設(shè)空間的復(fù)雜度,反比于數(shù)據(jù)量。目前提高泛化能力方式有增加數(shù)據(jù)量[5]、正則化[6]、凸優(yōu)化[7],這些方法因?yàn)閷?shí)際條件差異在使用時(shí)有一定的局限性。如今神經(jīng)網(wǎng)絡(luò)在許多領(lǐng)域大放異彩,然而在某些場(chǎng)景中卻不盡如人意[8]。由于神經(jīng)網(wǎng)絡(luò)的泛化問題會(huì)影響其推廣,所以提高神經(jīng)網(wǎng)絡(luò)泛化對(duì)生產(chǎn)生活都極具意義。
提高泛化能力研究目前主要有基于神經(jīng)網(wǎng)絡(luò)剪枝[9]和基于多個(gè)獨(dú)立單元結(jié)合的方法?;谏窠?jīng)網(wǎng)絡(luò)剪枝的方法提高泛化能力效果甚微,其主要作用是減少神經(jīng)網(wǎng)絡(luò)參數(shù)量?;诙鄠€(gè)獨(dú)立單元結(jié)合的研究將多個(gè)相同的子模塊獨(dú)立運(yùn)行,然后再對(duì)子模塊信息進(jìn)行整合,從而提高模型性能。這種方法提高泛化效果較好,但參數(shù)量明顯多于剪枝方法。Li 等[10]提出一種新的循環(huán)神經(jīng)網(wǎng)絡(luò)——獨(dú)立循環(huán)神經(jīng)網(wǎng)絡(luò)方法,即同層的神經(jīng)元相互獨(dú)立,跨層連接;Henaff 等[11]在Entnet 結(jié)構(gòu)中應(yīng)用獨(dú)立的門從每個(gè)記憶單元中讀寫,能夠在bAbI 任務(wù)中有優(yōu)于基準(zhǔn)模型的表現(xiàn);Clemens 等[12]采用激活層控制多個(gè)模塊信息交流,但只有在特定的時(shí)間步才能進(jìn)行信息交流[13]。這些研究未對(duì)交流的信息進(jìn)行篩選,在一定程度上保留了冗余信息,因此影響網(wǎng)絡(luò)的泛化能力;Vaswani[14]提出的Transformer 模型在兩項(xiàng)機(jī)器翻譯任務(wù)中表現(xiàn)遠(yuǎn)優(yōu)于當(dāng)前的最優(yōu)模型,其中提出的注意力機(jī)制能夠極大提高模型的泛化能力。
受上述方法啟發(fā),本文沿用多個(gè)獨(dú)立單元結(jié)合思想,采用多頭注意力以提高并行長(zhǎng)短期記憶網(wǎng)絡(luò)(Long Shortterm Memory,LSTM)模型的泛化能力。多頭注意力根據(jù)當(dāng)前時(shí)間步的輸入和LSTM 狀態(tài)的相關(guān)度進(jìn)行選擇性激活,激活的LSTM 包含當(dāng)前輸入重要的信息。在信息交流時(shí),激活的LSTM 會(huì)讀取其它LSTM 信息(包括未激活LSTM中的信息),未激活的LSTM 則按照原有的狀態(tài)獨(dú)立更新。于是,當(dāng)某個(gè)LSTM 信息被改變后,其它激活的LSTM 中還存有其信息。如此操作即能提取到樣本普遍性特征,增強(qiáng)了魯棒性,與很多提高泛化的研究思想不謀而合。為了驗(yàn)證本文方法可行性,與傳統(tǒng)并行LSTM 進(jìn)行對(duì)比實(shí)驗(yàn),證明本文方法比傳統(tǒng)并行LSTM 更穩(wěn)定、更泛化。將本文方法與3 種相關(guān)研究方法進(jìn)行對(duì)比,結(jié)果表明本文方法比相關(guān)方法能更顯著地提高泛化能力。
當(dāng)輸入為長(zhǎng)序列時(shí),傳統(tǒng)的循環(huán)神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Networks,RNN)會(huì)出現(xiàn)梯度消失和梯度爆炸問題,LSTM 就是為解決該問題而專門設(shè)計(jì)的。LSTM 能夠在長(zhǎng)或者短的序列輸入中保留關(guān)鍵信息[15]。實(shí)踐證明,LSTM性能優(yōu)于傳統(tǒng)RNN。LSTM 狀態(tài)參數(shù)在每個(gè)隱藏節(jié)點(diǎn)是共享的,就是每個(gè)細(xì)胞參數(shù)可以對(duì)整個(gè)反應(yīng)鏈狀態(tài)作出修改,Colah 將這種細(xì)胞狀態(tài)的更新機(jī)制類比為傳送帶。LSTM 內(nèi)部結(jié)構(gòu)如圖1 所示(彩圖掃OSID 碼可見,下同)。
Fig.1 LSTM internal structure圖1 LSTM 內(nèi)部結(jié)構(gòu)
如圖1 所示,LSTM 關(guān)鍵在于細(xì)胞狀態(tài)和整個(gè)穿過細(xì)胞上方的那條水平線,細(xì)胞狀態(tài)在這條水平線上傳遞,只有少量的線性交互[16]。若只有上面那條水平線是無法實(shí)現(xiàn)添加或者刪除信息的,只有通過一種叫做“門”的結(jié)構(gòu)來實(shí)現(xiàn)。門可以控制信息流通,通常是利用非線性激活sig?moid 函數(shù)和點(diǎn)積運(yùn)算實(shí)現(xiàn)。sigmoid 層輸出的每個(gè)元素都是0 和1 之間的實(shí)數(shù),表示讓對(duì)應(yīng)信息通過的比例。比如0 表示“不讓任何信息通過”,1 表示“讓所有信息通過”。LSTM 通過3 個(gè)這樣的門結(jié)構(gòu)實(shí)現(xiàn)信息的保護(hù)和控制,分別為遺忘門、輸入門和輸出門。
遺忘門可以過濾之前計(jì)算出的狀態(tài)向量,然后加入到后續(xù)運(yùn)算中,其數(shù)學(xué)表達(dá)式如下:
遺忘門輸入來自當(dāng)前時(shí)間步的輸入向量xt和上一個(gè)時(shí)間步輸出門的輸出向量ht-1,其中Wf和bf為遺忘門的權(quán)重及偏置向量。經(jīng)過sigmoid 運(yùn)算將結(jié)果映射到[0,1],得到遺忘門的輸出ft。ft控制舊狀態(tài)信息舍棄,可以和上一時(shí)間步的細(xì)胞狀態(tài)進(jìn)行點(diǎn)積運(yùn)算,從而更新舊狀態(tài)。
輸入門則是通過激活函數(shù)控制上一時(shí)間步的狀態(tài)和當(dāng)前輸入信息,然后參與當(dāng)前細(xì)胞狀態(tài)更新,其數(shù)學(xué)表達(dá)式如下:
式(2)表示對(duì)細(xì)胞狀態(tài)進(jìn)行更新,式(3)計(jì)算出一組候選的細(xì)胞狀態(tài)來取代更新細(xì)胞狀態(tài)中的舊值,式(4)將這兩個(gè)向量逐元素相乘,接著與經(jīng)過遺忘門的細(xì)胞狀態(tài)相加,如此完成輸入門更新。
輸出門建立在之前兩個(gè)門基礎(chǔ)上,數(shù)學(xué)表達(dá)式如下:
輸出門的輸出是基于當(dāng)前輸入門更新過的細(xì)胞狀態(tài)。式(5)決定輸出的狀態(tài)信息,式(6)中tanh 層將當(dāng)前細(xì)胞狀態(tài)壓縮到(-1,1)區(qū)間內(nèi),該輸出變量同時(shí)作為下個(gè)單元的ht-1加入到循環(huán)。
對(duì)于單個(gè)注意力模型可以理解為給定查詢向量到一系列鍵值對(duì)的映射,本文查詢向量來自LSTM 的狀態(tài)信息,鍵向量和值向量來自于當(dāng)前輸入。在給定目標(biāo)中查詢某個(gè)元素向量后,通過計(jì)算其和各個(gè)鍵向量的相似度得到每個(gè)查詢向量對(duì)應(yīng)值向量的權(quán)重系數(shù),再經(jīng)過softmax 歸一化,將權(quán)重系數(shù)和相應(yīng)的值向量加權(quán)求和,最終計(jì)算出注意力數(shù)值。所以,本質(zhì)上注意力機(jī)制是對(duì)給定目標(biāo)中元素的值向量進(jìn)行加權(quán)求和,而查詢向量和鍵向量用來計(jì)算對(duì)應(yīng)值向量的權(quán)重系數(shù)[19]。最常用的兩種注意力機(jī)制是加性注意力和點(diǎn)積注意力,本文采用點(diǎn)積注意力,其數(shù)學(xué)表達(dá)式如下:
Q,K,V分別是查詢向量、鍵向量、值向量,d是鍵向量的維數(shù),除以d可以防止softmax 之后的值變得很小。
對(duì)于多頭注意力模型,可以認(rèn)為是結(jié)合多個(gè)單獨(dú)的注意力而成,其數(shù)學(xué)表達(dá)式如下:
其中,Q、K、V經(jīng)過線性變換后輸入到單個(gè)注意力運(yùn)算[11],這里要做h次,也就是所謂的多頭。每次計(jì)算一個(gè)頭,頭之間的參數(shù)不共享,每次Q、K、V進(jìn)行線性變換的權(quán)重參數(shù)W不一樣。接著將h次的注意力運(yùn)算結(jié)果進(jìn)行拼接,最后執(zhí)行線性變換,就可計(jì)算出多頭注意力。
本文首先利用多頭注意力根據(jù)并行LSTM 狀態(tài)信息求出每個(gè)LSTM 的注意力權(quán)重,然后從中挑選出權(quán)重較大的LSTM 進(jìn)行激活,再將激活的LSTM 中的狀態(tài)信息通過多頭注意力按照一定比例進(jìn)行信息交流。雖然采用多個(gè)網(wǎng)絡(luò)結(jié)構(gòu)并行的方法較多,但是結(jié)合多頭注意力激活子網(wǎng)絡(luò)并進(jìn)行信息交流的方法卻沒有,且多次對(duì)比實(shí)驗(yàn)表明本文方法有較強(qiáng)的泛化性和穩(wěn)定性。
神經(jīng)網(wǎng)絡(luò)研究發(fā)現(xiàn),通過增加網(wǎng)絡(luò)層數(shù)可以學(xué)習(xí)到任務(wù)的更高層特征以解決更復(fù)雜的任務(wù)。雖然增加層數(shù)可以提高網(wǎng)絡(luò)性能,但是模型的運(yùn)算成本也大幅增加。為了減少深度神經(jīng)網(wǎng)絡(luò)的訓(xùn)練時(shí)間,基于各種計(jì)算平臺(tái)設(shè)計(jì)的并行神經(jīng)網(wǎng)絡(luò)逐漸成為研究熱點(diǎn)[17]。
對(duì)于神經(jīng)網(wǎng)絡(luò)的并行化主要有數(shù)據(jù)并行和模型并行兩種方法[18]。數(shù)據(jù)并行是當(dāng)數(shù)據(jù)量十分龐大時(shí),將數(shù)據(jù)分成多個(gè)小的子數(shù)據(jù)集,再將各個(gè)子數(shù)據(jù)集在多個(gè)相同模型上并行訓(xùn)練,最后由參數(shù)服務(wù)器完成參數(shù)交換[19];模型并行指將網(wǎng)絡(luò)結(jié)構(gòu)分解到各個(gè)計(jì)算設(shè)備上,依靠設(shè)備間的共同協(xié)作完成訓(xùn)練。本文實(shí)驗(yàn)在Cuda 平臺(tái)上進(jìn)行模型并行訓(xùn)練測(cè)試,并行網(wǎng)絡(luò)中每個(gè)LSTM 就是獨(dú)立的結(jié)構(gòu)單元,如圖2 所示。
Fig.2 Structure of this paper圖2 本文結(jié)構(gòu)
多頭注意力結(jié)合當(dāng)前LSTM 狀態(tài)與輸入的相關(guān)度選擇性激活LSTM,其中綠色框表示已激活的LSTM,藍(lán)色為未激活。在每一時(shí)間步,激活的可從其它LSTM 中讀取信息,未激活的則保持隱藏狀態(tài)不變。最后經(jīng)過神經(jīng)元個(gè)數(shù)為10 的全連接層得出預(yù)測(cè)結(jié)果。本文中LSTM 總個(gè)數(shù)為6,每個(gè)時(shí)間步激活4 個(gè)LSTM,每個(gè)LSTM 的神經(jīng)元個(gè)數(shù)為32。
起初每個(gè)LSTM 是相互獨(dú)立的,初始狀態(tài)也是隨機(jī)的,然后進(jìn)行自身動(dòng)態(tài)更新。經(jīng)過多頭注意力選定與輸入相關(guān)的LSTM 設(shè)置激活,激活的LSTM 讀取其它激活或未激活LSTM 一定比例的信息[20]。本文中每個(gè)激活的LSTM 都可以讀取其它LSTM 中1/10 的信息。因此,不僅能保留當(dāng)前任務(wù)的重要信息,還能通過信息交流提高魯棒性[21]。
設(shè)每個(gè)LSTM 都是相互獨(dú)立的,它們之間沒有信息交流。對(duì)于未激活的LSTM,其隱藏狀態(tài)保持不變,如式(10)所示。
此為第k個(gè)LSTM 在t時(shí)間步的狀態(tài)。模型會(huì)動(dòng)態(tài)地在每個(gè)時(shí)間步挑選出和當(dāng)前輸入相關(guān)的LSTM 激活,激活的LSTM 得到真實(shí)的輸入,未激活則得到由全0 組成的空白輸入。令xt為時(shí)間步t時(shí)的輸入,如果未激活則:
式(11)是將xt在行方向上進(jìn)行連接。
接下來用線性操作建立:
R的每行對(duì)應(yīng)一個(gè)獨(dú)立的LSTM 隱藏狀態(tài)。Wv是將輸入映射到對(duì)應(yīng)的V向量矩陣,Wk是將類似的矩陣輸入映射到K。是將LSTM 從其隱藏狀態(tài)映射到Q。
注意力運(yùn)算結(jié)果如下:
基于上式softmax 計(jì)算的值,在每個(gè)時(shí)間步將較大的softmax 值設(shè)置為1,其余則為0。將這幾個(gè)值與其對(duì)應(yīng)的LSTM 執(zhí)行點(diǎn)積運(yùn)算就完成了激活步驟。未激活LSTM 的梯度保持以往的更新,其狀態(tài)可以被激活的LSTM 讀取。對(duì)于激活的LSTM 將進(jìn)行如下更新:
LSTM 在t時(shí)間步經(jīng)過多頭注意力作用得到下一時(shí)間步的狀態(tài)ht+1。本文方法即按照上述步驟進(jìn)行循環(huán)更新。
本文采用MNIST[22]、Fashion-MNIST[23]、CIFAR10[24]、Animals-10 開源數(shù)據(jù)集進(jìn)行實(shí)驗(yàn)驗(yàn)證。MNIST 是手寫數(shù)字(0-9)數(shù)據(jù)集,F(xiàn)ashion-MNIST 是時(shí)尚穿搭衣物數(shù)據(jù)集,CIFAR10 是常見物體彩色圖片數(shù)據(jù)集,Animals-10 是10類常見動(dòng)物圖片數(shù)據(jù)集,各數(shù)據(jù)集詳情如表1 所示。
本文實(shí)驗(yàn)在Linux 系統(tǒng)下搭建的Pytorch 環(huán)境進(jìn)行,批量大小設(shè)置為100,損失函數(shù)采用交叉熵?fù)p失函數(shù),優(yōu)化函數(shù)采用SGD,學(xué)習(xí)率為0.1,迭代訓(xùn)練1 000 次。
Table 1 Distribution of experimental data sets表1 本文實(shí)驗(yàn)數(shù)據(jù)集分布
實(shí)驗(yàn)中LSTM 總個(gè)數(shù)為6,設(shè)置每一時(shí)間步激活的LSTM 個(gè)數(shù)為4,單個(gè)隱藏層神經(jīng)元為32。對(duì)比實(shí)驗(yàn)中采用4 個(gè)并行的LSTM,單個(gè)隱藏層神經(jīng)元也為32,其它參數(shù)設(shè)置與本文方法相同,這樣的設(shè)置排除神經(jīng)個(gè)數(shù)對(duì)實(shí)驗(yàn)的干擾。4 種數(shù)據(jù)集的對(duì)比實(shí)驗(yàn)如圖3 所示。
Fig.3 Comparison between the proposed method and parallel LSTM training圖3 本文方法與并行LSTM 訓(xùn)練對(duì)比
如圖3 所示,黑色曲線和綠色曲線分別對(duì)應(yīng)本文方法在訓(xùn)練中的準(zhǔn)確率、損失函數(shù)值,紅色曲線和藍(lán)色曲線對(duì)應(yīng)并行的LSTM 準(zhǔn)確率、損失函數(shù)值。在4 種數(shù)據(jù)集上,本文方法均比并行LSTM 的訓(xùn)練準(zhǔn)確率高。兩種方法在MNIST 數(shù)據(jù)集上的訓(xùn)練準(zhǔn)確率差距極小,但是并行LSTM的訓(xùn)練損失值波動(dòng)較大。本文方法在Fashion-MNIST 和CIFAR10 的訓(xùn)練準(zhǔn)確率明顯高于并行LSTM,訓(xùn)練損失值同樣比并行LSTM 穩(wěn)定。在Animals-10 數(shù)據(jù)集上,本文方法的訓(xùn)練準(zhǔn)確率比并行LSTM 有較大提升,訓(xùn)練損失值也更低、更穩(wěn)定。從訓(xùn)練表現(xiàn)來看,采用本文方法的性能優(yōu)于并行LSTM 模型。
通常采用測(cè)試誤差來衡量神經(jīng)網(wǎng)絡(luò)的泛化能力,其中測(cè)試誤差為1 減去測(cè)試準(zhǔn)確率。將本文方法與并行LSTM在4 種數(shù)據(jù)集的測(cè)試誤差進(jìn)行對(duì)比實(shí)驗(yàn)。在測(cè)試集進(jìn)行10 次測(cè)試,計(jì)算出平均測(cè)試誤差,如表2 所示。
由表2 可知,本文方法在4 種數(shù)據(jù)集的測(cè)試誤差均低于并行LSTM。其中,由于MNIST 數(shù)據(jù)集的任務(wù)較為簡(jiǎn)單,兩種方法的測(cè)試誤差僅相差0.35%。Fashion-MNIST 數(shù)據(jù)集和CIFAR10 數(shù)據(jù)集的分類任務(wù)較難,測(cè)試誤差相差約1%,能明顯看出本文方法的泛化能力強(qiáng)于并行LSTM 模型。Animals-10 數(shù)據(jù)集由于任務(wù)較難且訓(xùn)練數(shù)據(jù)較少,導(dǎo)致測(cè)試差距較大,達(dá)到3.03%。實(shí)驗(yàn)表明,本文提出的方法能夠有效提高泛化能力。
Table 2 Comparison of test errors between the proposed method and parallel LSTM表2 本文方法與并行LSTM 測(cè)試誤差對(duì)比(%)
為進(jìn)一步探究本文方法的泛化能力,繼續(xù)在4 種數(shù)據(jù)集上對(duì)本文方法與相關(guān)研究進(jìn)行實(shí)驗(yàn)。對(duì)比的方法有基于門控交流的Entnet[11]方法、基于注意力機(jī)制讀寫信息的RMC[25]方法、基于多個(gè)循環(huán)結(jié)構(gòu)結(jié)合的方法[10]。訓(xùn)練參數(shù)設(shè)置前保持一致,依舊采用測(cè)試誤差作為衡量泛化的指標(biāo)。對(duì)比實(shí)驗(yàn)測(cè)試誤差如表3 所示。
Table 3 Test error comparison between the proposed method and related research表3 本文方法與相關(guān)研究的測(cè)試誤差對(duì)比(%)
由表3 可知,本文方法在4 種分類任務(wù)中都取得了最好成績(jī)。其中在MNIST 數(shù)據(jù)集上,本文方法比RMC 方法測(cè)試誤差低0.02%。由于這個(gè)數(shù)據(jù)集上的分類任務(wù)比較簡(jiǎn)單,所以各種方法差距都很小,并不能明顯看出泛化性能的強(qiáng)弱。在其它分類難度大的數(shù)據(jù)集上,本文方法的測(cè)試誤差分別比次優(yōu)方法低0.21%、0.25%、0.37%。因此本文方法比其它3 種相關(guān)研究更具泛化能力,表明本文方法能提高神經(jīng)網(wǎng)絡(luò)泛化能力。
泛化能力是在真實(shí)場(chǎng)景中依然能夠發(fā)揮出色,對(duì)數(shù)據(jù)的變化具有魯棒性。本實(shí)驗(yàn)使用Python 中的skimage 庫將測(cè)試集的圖片添加高斯噪聲,其中高斯噪聲均值為0,方差為0.01,訓(xùn)練集則保持原有狀態(tài)。然后基于前述的訓(xùn)練模型,在4 種數(shù)據(jù)集上對(duì)比相關(guān)算法的測(cè)試誤差,詳情如表4所示。
Table 4 Test error comparison between the proposed method and related research表4 本文方法與相關(guān)研究測(cè)試誤差對(duì)比(%)
由表4 可知,在加噪情況下本文方法的測(cè)試誤差都是最小。和未加噪情況相比,本文方法測(cè)試誤差的變化值均小于相關(guān)方法,分別比次優(yōu)方法低0.35%、0.62%、0.77%、1.19%。這意味著本文方法對(duì)于數(shù)據(jù)的變化有更強(qiáng)的魯棒性,泛化能力也優(yōu)于相關(guān)方法。綜上所述,本文方法能夠顯著提高神經(jīng)網(wǎng)絡(luò)的泛化能力和穩(wěn)定性。
本文采用多頭注意力以提高神經(jīng)網(wǎng)絡(luò)泛化能力,通過多頭注意力選擇性激活LSTM 進(jìn)行信息交流,保留任務(wù)中普適性信息,從而提高神經(jīng)網(wǎng)絡(luò)的泛化能力。與并行LSTM 網(wǎng)絡(luò)相比,本文方法表現(xiàn)出更強(qiáng)的泛化能力和穩(wěn)定性。與其它相關(guān)方法相比,本文方法的泛化能力也更強(qiáng)。但本文方法參數(shù)量較大,會(huì)耗費(fèi)大量的計(jì)算和內(nèi)存成本。后續(xù)研究方向?yàn)閷⒈疚姆椒ㄍ茝V到簡(jiǎn)單的并行結(jié)構(gòu)中,使其能夠移植到硬件中。