張?zhí)? 郭輝 郭靜純
摘要:針對增量學習存在的災難性遺忘和新任務數(shù)據(jù)逐步積累問題,提出了基于新舊任務之間相似度的樣本重放優(yōu)化學習方法,相似度越高,重放樣本越少。并選擇MINIST數(shù)據(jù)集在卷積神經(jīng)網(wǎng)絡上進行了實驗研究,驗證了該方法的可行性和有效性。
關鍵詞: 增量學習;災難性遺忘;樣本重放;任務相似度
中圖分類號: TP183? ? ? ? 文獻標識碼:A
文章編號:1009-3044(2021)08-0013-03
Abstract: To solve the problem of catastrophic forgetting and gradual accumulation of new task data in incremental learning, an optimal learning approach based on the similarity difference between old and new tasks is proposed. The more similar the tasks are, the less the old samples will be replayed. Moreover, MINIST data set is selected to conduct experimental research on the convolutional neural network, which verifies the feasibility and effectiveness of the method.
Key words:incremental learning; catastrophic forgetting; sample replay; task similarity
隨著深度學習的快速發(fā)展和在圖像、語音等領域的應用,其在單個任務處理方面取得了優(yōu)異的性能。但當它面對多任務增量學習時,常常產(chǎn)生“災難性遺忘”現(xiàn)象[1],即學習新任務時會改變原有的網(wǎng)絡參數(shù),相應的舊任務記憶就會急劇下降甚至完全消失。
樣本重放是緩解災難性遺忘的主要方法之一,包括兩種典型方式:一種通過舊任務的偽樣本生成器保留其信息,如深層生成重放[2]和記憶重放GANs[3],不使用舊任務原始數(shù)據(jù),但GAN模型訓練較復雜;另一種直接選用舊任務的原始數(shù)據(jù)子集,如內存固定的iCaRl[4]及其改進訓練樣本不均衡的增量學習文獻[5],文獻[6]提出一種自動記憶框架,基于樣本參數(shù)化選取具有代表性的舊樣本子集,采用雙層優(yōu)化訓練框架。這些方法均未考慮新舊任務之間的相似度差異:相似度越高,網(wǎng)絡提取的共有信息越多,則對舊任務的回顧應越少。此外,真實環(huán)境下新任務數(shù)據(jù)通常按照時間順序流式到達,新數(shù)據(jù)較少,無法滿足上述方法的需要。針對這些問題,本文提出了一種基于任務相似度的增量學習優(yōu)化方法,根據(jù)兩者之間相似度差異設置不同比例的訓練數(shù)據(jù),避免重復訓練,減少資源占用,加快訓練速度。
1 樣本重放增量學習優(yōu)化方法
增量學習優(yōu)化方法的實現(xiàn)過程主要分為以下三個階段:首先,當新任務到達時,用特征提取器提取新舊類特征,進行相似度差異分析;其次,根據(jù)相似度差異結果,計算新舊任務不同比例的訓練數(shù)據(jù)增量,構建每批次增量訓練數(shù)據(jù)集;最后,進行增量優(yōu)化訓練,實現(xiàn)符合真實場景下的新任務數(shù)據(jù)增量訓練和任務增量學習。
1.1 符號表示
假設增量學習分為1個初始階段和N個新任務的增量階段。在初始階段使用數(shù)據(jù)[D0]進行訓練得到網(wǎng)絡模型[Θ0];在第[i]個增量階段,若有[s]個舊類[X1,X2,...,Xs],新類[Xi,i∈N],模型狀態(tài)為[Θi-1],令[Di?j]、[Dij]、[Dj]分別表示第[i]類第[j]個批次的新增樣本數(shù)據(jù)、前[j]個批次新數(shù)據(jù)和第[j]個批次的新舊訓練數(shù)據(jù)。
1.2 任務相似度分析
根據(jù)假設,新任務數(shù)據(jù)流式到達。當新任務到達時,首先,選取同等數(shù)量的舊任務樣本和首次到達的新任務樣本作為代表性樣本一起訓練特征提取網(wǎng)絡作為特征提取器[φ],通過使用新舊任務的平衡數(shù)據(jù)集,特征提取器可以更均衡地提取新舊任務的樣本特征,使網(wǎng)絡能充分學習新舊任務樣本之間的差異,得到更具有代表性的樣本特征。
對新任務樣本數(shù)據(jù)提取特征后,采用余弦相似度衡量新舊任務之間的相似程度,其值越大,特征越相似,計算公式如下:
1.3 構建增量訓練數(shù)據(jù)集
由于相似度較高的兩個任務,在進行網(wǎng)絡訓練時,相同部分特征已經(jīng)被提取到了,所以對于相似度較高的任務,新舊任務越相似,則越應減少舊任務的重放訓練樣本數(shù)量,減少重復訓練造成的資源浪費;反之,則應增加舊類的數(shù)量,強化舊類知識,減少網(wǎng)絡對于新類的偏向。根據(jù)新舊任務之間的相似度,令每批次重放舊任務的樣本增量為[Doldk?j],其計算公式如下:
1.4 蒸餾損失和分類損失計算
蒸餾損失最早在文獻[7]中提出,在增量學習中適用于文獻[4,6,8],主要用來促使新的模型和舊的模型在舊類上保持相同的預測能力。增量學習損失包括蒸餾損失[LdΘi;Θi-1;x]和衡量分類準確度的交叉熵損失[LcΘi;x]之和,兩者的計算公式分別如下:
1.5 增量優(yōu)化訓練
通過分析不同任務之間的相似性差異,在新任務數(shù)據(jù)流式到達時設置不同比例的新舊數(shù)據(jù)進行增量優(yōu)化訓練,整個的訓練流程總結如下:
算法1 增量優(yōu)化訓練
輸入 1個初始任務(2個類別的分類任務)的數(shù)據(jù)集[D0],N個新增任務(一個類別表示一個任務)的流式數(shù)據(jù)集[Di,i∈N]
輸出 N+1個任務(N+2個類別)的分類性能
(1) 用數(shù)據(jù)[D0]訓練得到網(wǎng)絡模型[Θ0]
(2) 新任務到達,[Di1=500],[Di?j=500],有s個舊類(s的初始值為2)
(3) 新舊類之間進行相似度差異分析,用公式(1)計算新類與每個舊類的余弦相似度[sφXold,φXnew]
(4) 根據(jù)相似度差異結果,用公式(2)計算舊類每批次投放的樣本增量[Doldk?j,k∈s]
(5) 用公式(3)構建第j個批次的訓練數(shù)據(jù)
(6) 進行增量訓練
(7) if各個類別的分類性能達到預期 //測試網(wǎng)絡分類性能
(8) then if 還有未完成的任務 then 返回步驟(2) //繼續(xù)訓練下一個增量任務
(9) else 輸出N+1個任務(N+2個類別)的分類性能 //已經(jīng)完成N+1個任務的增量學習
(10) end if
(11) else then 返回步驟(5) //任務分類準確率沒有達到要求,繼續(xù)訓練
(12) end if
2 實驗研究
選取MNIST數(shù)據(jù)集中的數(shù)字0、1、2在三層卷積神經(jīng)網(wǎng)絡上進行增量學習,以數(shù)字0和1作為初始階段,數(shù)字2為新增類別階段。實驗結果如表1所示。
由表1可知,采用本文方法進行增量學習,在第6批次時的平均準確率為0.9818,比重放全部舊數(shù)據(jù)的準確率0.99稍小,但訓練數(shù)據(jù)量急劇下降,由5923+6741個舊樣本變?yōu)?0+66,顯著提升了訓練效率。以此類推依次完成數(shù)字3-9的增量學習,對比結果如圖1所示。
圖1中折線圖的橫坐標為增量學習的各個階段,縱坐標為平均分類精度,圖中結果表明相較于使用全部的新舊類訓練數(shù)據(jù),使用新的基于任務相似度的增量學習優(yōu)化方法雖然在分類精度上有所下降,但是結果相差不大,能有效緩解災難性遺忘的影響,且所使用的訓練數(shù)據(jù)集要遠小于使用全部的訓練集,減少了訓練量,加快了訓練速度。
3 結論
針對增量學習中的災難性遺忘問題,提出了一種基于新舊任務相似度的樣本重放學習方法,在盡量保持對舊任務記憶的同時著力提升學習效率,據(jù)此選用MINIST數(shù)據(jù)集進行實驗研究,驗證了該方法的可行性與有效性,為緩解災難性遺忘提供了新的解決思路。
參考文獻:
[1] McCloskey M,Cohen N J.Catastrophic interference in connectionist networks:the sequential learning problem[J].Psychology of Learning and Motivation,1989,24:109-165.
[2] Shin H, Lee J K, Kim J, et al. Continual learning with deep generative replay[C]. Advances in Neural Information Processing Systems. Curran Associates: New York, 2017:2991-3000.
[3] Wu C S, Herranz L, Liu X L, et al. Memory Replay GANs: learning to generate images from new categories without forgetting[C].Advances in Neural Information Processing Systems. Curran Associates: New York, 2018: 5962-5972.
[4] Rebuffi S A, Kolesnikov A, Sperl G, et al. iCaRL: Incremental Classifier and Representation Learning[C]. Proc of the IEEE Conf on Computer Vision and Pattern Recognition. Piscataway: IEEE Computer Society, 2017: 5533-5542.
[5] Castro F M, Marin-Jimenez M J, Guil N, et al. End-to-End Incremental Learning[C]. European Conference on Computer Vision. Berlin: Springer, 2018:233-248.
[6] Liu Y Y, Su Y , Liu A A , et al. Mnemonics Training: Multi-Class Incremental Learning Without Forgetting[C]. CVPR, 2020:12242-12251.
[7] Hinton G, Vinyals O, Dean J. Distilling the Knowledge in a Neural Network[J]. Computer Science, 2015, 14(7)38-39.
[8] Zenke F, Poole B, Ganguli S. Continual Learning Through Synaptic Intelligence[C].International Conference on Machine Lea rning. Lille: International Machine Learning Society, 2017:3987-3995.
【通聯(lián)編輯:唐一東】