余 鷹 危 偉 湯 洪 錢 進
(華東交通大學(xué)軟件學(xué)院 南昌 330013)
細粒度圖像識別(fine-grained image recognition,F(xiàn)GIR)是圖像分類的一個重要分支[1],旨在識別相同大類下的不同子類,如不同型號的飛機、不同款式的汽車等.相較于傳統(tǒng)的粗粒度圖像識別,細粒度圖像識別能夠捕捉對象之間細微的差異,在現(xiàn)實生活中有廣泛的需求和應(yīng)用場景,與之相關(guān)的研究任務(wù)包括生物多樣性檢測[2]、智慧交通[3]、智能零售[4]等.如果能夠借助計算機視覺技術(shù),實現(xiàn)低成本的細粒度圖像識別,那么無論對于學(xué)術(shù)界還是工業(yè)界而言,都有著非常重要的意義.然而由于細粒度圖像識別固有的類內(nèi)差異大、類間差異小的特性,以及強監(jiān)督標(biāo)簽獲取困難等問題,使得其成為一項極具挑戰(zhàn)性的任務(wù).
傳統(tǒng)的細粒度圖像識別方法主要使用人工標(biāo)注的圖像標(biāo)簽、對象標(biāo)注框、局部區(qū)域位置等強監(jiān)督信息來輔助模型進行特征學(xué)習(xí).例如Zhang等人[5]提出了一個基于候選區(qū)域(region proposals)的細粒度圖像識別模型,該模型在訓(xùn)練時借助除圖像標(biāo)簽外的對象邊界框、區(qū)域標(biāo)注等額外標(biāo)注信息聚焦局部關(guān)鍵區(qū)域,從而降低背景干擾、提升識別精度.然而,大量精細且類別齊全的人工標(biāo)注信息在現(xiàn)實場景中往往難以獲取.為了避免大量依賴人工標(biāo)注信息,研究人員開始關(guān)注弱監(jiān)督細粒度圖像識別模型,僅利用圖像類別標(biāo)簽來完成模型訓(xùn)練.為了提升捕獲細微差異的能力,弱監(jiān)督細粒度圖像識別模型通常會借助注意力機制、聚類等手段定位最具判別力的關(guān)鍵區(qū)域,從而學(xué)習(xí)到更具區(qū)分性的視覺特征,目前已經(jīng)成為細粒度圖像識別領(lǐng)域的主流方法.
良好的特征表達是細粒度圖像識別的基礎(chǔ),它對模型的泛化性能有著至關(guān)重要的影響.近年來,隨著深度學(xué)習(xí)的興起,深度神經(jīng)網(wǎng)絡(luò)強大的特征自學(xué)習(xí)能力受到研究人員的廣泛關(guān)注.早期,研究人員主要基于卷積神經(jīng)網(wǎng)絡(luò)(convolutional neural network,CNN)提取細粒度圖像特征.例如Ding等人[6]提出了一種注意力金字塔卷積神經(jīng)網(wǎng)絡(luò)(attention pyramid convolutional neural network,AP-CNN)用于弱監(jiān)督細粒度圖像分類.該模型對特征金字塔網(wǎng)絡(luò)(feature pyramid network,F(xiàn)PN)的每層特征圖(feature map)都添加空間注意力和通道注意力,形成了一個自下而上的注意力層級結(jié)構(gòu),從而可以學(xué)習(xí)到增強的高層語義和低層細節(jié)表示.但是,CNN在捕獲全局特征方面具有一定的局限性,因此研究人員將具有良好全局關(guān)系建模能力的Transformer[7]引入計算機視覺領(lǐng)域,以便更好地提取圖像全局特征.例如Vison Transformer(ViT)[8]將Transformer結(jié)構(gòu)應(yīng)用在圖像分類任務(wù)中,利用級聯(lián)自注意力模塊提取長距離的特征依賴,獲得了良好的分類效果.但是細粒度圖像識別需要捕獲局部細微差異,而ViT模型容易忽略局部細節(jié)特征,因此無法直接在細粒度圖像識別任務(wù)上發(fā)揮優(yōu)勢.為了彌補ViT模型無法提取多粒度特征的不足,研究人員常采用基于定位的方法進行細粒度圖像識別.首先以強監(jiān)督或弱監(jiān)督的方式定位有區(qū)分度的局部區(qū)域,然后從這些區(qū)域中提取局部細粒度特征,并將全局和局部特征融合后用于分類.例如,He等人[9]基于ViT模型設(shè)計了弱監(jiān)督細粒度圖像識別模型TransFG,在沒有bounding box標(biāo)注的情況下通過多頭自注意力有效定位判別區(qū)域,提升網(wǎng)絡(luò)捕捉微小差異的能力,并且利用對比學(xué)習(xí)最大化相同類別表示的相似度,從而提升模型的精度.然而TransFG忽略了判別區(qū)域之間的相互依賴關(guān)系以及組成的整體對象結(jié)構(gòu),而這些信息對于模型鑒別性信息的定位和理解是至關(guān)重要的.為此Sun等人[10]提出了SIM-Trans模型,引入結(jié)構(gòu)信息學(xué)習(xí)模塊來挖掘?qū)ο蠓秶鷥?nèi)重要Patch的空間上下文關(guān)系,從而加強外觀信息和結(jié)構(gòu)信息的鑒別性表示學(xué)習(xí).此外,為了解決ViT潛在的問題,Zhang等人[11]提出了自適應(yīng)注意力多尺度融合Transformer(AFTrans),它也可以在沒有bounding box標(biāo)注的情況下通過選擇性注意力收集模塊SACM來定位局部差異性區(qū)域,從而提取更加魯棒的特征.總之,提取既包含全局信息,又包含更具判別性的局部信息的不同粒度特征是提升細粒度圖像識別精度的關(guān)鍵.
此外,現(xiàn)有的細粒度圖像識別方法往往沒有考慮網(wǎng)絡(luò)深層語義知識對淺層網(wǎng)絡(luò)知識學(xué)習(xí)過程的指導(dǎo)作用,如Yu等人[12]提出了可信多粒度信息融合模型TMGIF用于細粒度圖像識別,該模型獲取圖像的多粒度信息表示,并通過對多粒度信息的質(zhì)量評價進行逐步融合.其中多粒度信息可以被看作是從不同網(wǎng)絡(luò)層次抽取的廣義多層次信息,但多粒度信息之間相互獨立,高層次粒度信息無法有效指導(dǎo)低層次粒度信息的學(xué)習(xí),無法有效緩解細粒度圖像識別中可能出現(xiàn)的背景干擾問題.換言之,如果通過深層網(wǎng)絡(luò)所學(xué)到的特征引導(dǎo)淺層網(wǎng)絡(luò)學(xué)習(xí)更加高級的語義知識,不但可以緩解背景干擾的問題,也可以促進深層網(wǎng)絡(luò)提取到更加魯棒的特征,從而更好地完成特征的學(xué)習(xí).知識自蒸餾是一個理想的解決方案,它通過將深層語義信息壓縮到淺層,提升了模型的識別準確率.傳統(tǒng)的知識自蒸餾方法對于傳統(tǒng)圖像的識別是有效的,這主要是因為傳統(tǒng)圖像識別類內(nèi)差異較小、類間差異較大,這讓模型在反向傳播過程中樣本空間上更容易收斂.而細粒度圖像識別具有類內(nèi)差異大、類間差異小的特點,外觀相似圖像可能會有不同的類別標(biāo)簽,導(dǎo)致了模型優(yōu)化存在困難.如果直接將傳統(tǒng)的知識自蒸餾應(yīng)用于細粒度識別,模型性能提升有限.因此,需要針對細粒度圖像識別的特點,設(shè)計一種適用于該任務(wù)的知識自蒸餾方法.此外,也需要考慮到不應(yīng)過多增加額外的模塊,否則會增加網(wǎng)絡(luò)復(fù)雜度,從而不利于模型在現(xiàn)實應(yīng)用中的推廣.
針對上述問題,本文提出了基于弱監(jiān)督信息的多層次知識自蒸餾聯(lián)合多步驟訓(xùn)練的細粒度圖像識別(multi-level knowledge self-distillation with multistep training for fine-grained image recognition, MKSMT)模型.該模型使用多階段層次結(jié)構(gòu)設(shè)計的Swin Transformer[13]作為主干網(wǎng)絡(luò).相比于ViT,Swin Transformer使用了類似特征金字塔的層級結(jié)構(gòu),能夠更好地捕捉不同粒度特征,具有更強的局部特征和全局特征建模能力.MKSMT只在網(wǎng)絡(luò)內(nèi)部提取多層次知識,并采用多步驟訓(xùn)練和知識自蒸餾來優(yōu)化特征學(xué)習(xí)過程.具體而言,它會對Swin Transformer模型包含的4段(Stage)分步驟進行訓(xùn)練。第1步首先獨立地對淺層網(wǎng)絡(luò)(第3段及之前)進行特征學(xué)習(xí);第2步對所有段構(gòu)成的網(wǎng)絡(luò)進行再訓(xùn)練,并將高層知識遷移到淺層網(wǎng)絡(luò)中以進行知識自蒸餾,而優(yōu)化后的淺層網(wǎng)絡(luò)又能進一步幫助深層網(wǎng)絡(luò)更好地提取魯棒的特征.這種分步驟訓(xùn)練的做法可以帶來2個好處:一是分步驟后,每一步中對模型的約束比單步驟訓(xùn)練更少,能夠保證知識自蒸餾的過程順利完成;二是在特征學(xué)習(xí)時,網(wǎng)絡(luò)只需關(guān)心當(dāng)前層次,不需要考慮深層網(wǎng)絡(luò)和淺層網(wǎng)絡(luò)的交互,這使得模型在反向傳播過程中在樣本空間上更容易收斂.
總之,本文的貢獻有3個方面:
1)提出了一種多層次知識自蒸餾聯(lián)合多步驟訓(xùn)練方法用于細粒度圖像識別任務(wù).它通過將特征學(xué)習(xí)和知識自蒸餾分步驟進行,解決了單步驟訓(xùn)練時網(wǎng)絡(luò)優(yōu)化困難的問題.
2)多層次知識自蒸餾聯(lián)合多步驟訓(xùn)練方法只在訓(xùn)練階段使用,且MKSMT沒有在主干網(wǎng)絡(luò)基礎(chǔ)上過多增加模塊.這表明在沒有過多增加模型復(fù)雜度和推理時間的前提下,細粒度圖像識別的準確率得到了有效提升,因此在現(xiàn)實場景中更容易得到推廣.另外,該方法具有通用性,理論上可以被動態(tài)集成在各種計算機視覺任務(wù)中,例如人群計數(shù).
3)MKSMT在3個極具挑戰(zhàn)性的數(shù)據(jù)集:CUB-200-2011、NA-Birds和Stanford Dogs與其它先進的算法對比,取得了最優(yōu)性能或有競爭力的性能,消融實驗進一步證明了MKSMT的有效性.
谷歌于2017年提出了Transformer[7]模型,并在自然語言處理(natural language processing,NLP)任務(wù)中推廣運用.相較于傳統(tǒng)循環(huán)神經(jīng)網(wǎng)絡(luò)(recurrent neural network,RNN),Transformer通過自注意力機制實現(xiàn)雙向語義編碼,具有更強大的全局信息建模能力.ViT[8]是Transformer模型在計算機視覺領(lǐng)域的應(yīng)用,它通過自注意力實現(xiàn)像素間的全局交互,從而獲得良好的全局特征表示.然而,ViT架構(gòu)存在2個設(shè)計缺陷:1)ViT在輸入階段便對圖像進行了高倍率的下采樣,并對后續(xù)的特征圖維持同樣的高倍下采樣率,導(dǎo)致圖像部分細節(jié)信息和空間信息丟失,從而無法提取到足夠精細的特征;2)ViT對于所有的編碼層都輸入固定大小的特征塊,使得模型處理多尺度視覺對象的能力不足.
針對這2個問題,Liu等人[13]提出了Swin Transformer模型,它采用類似CNN中特征金字塔的層次化架構(gòu),在不同階段分別對特征圖進行逐級下采樣,從而得到了多尺度的特征表示.為了降低計算復(fù)雜度和提高長程依賴關(guān)系的捕獲能力,Swin Transformer又提出了基于移動窗口的自注意力機制.具體來說,設(shè)計了一個不重疊的滑動窗口,在每個階段,自注意力的計算只在滑動窗口中進行,然后設(shè)計平移窗口機制,使得不重疊的窗口之間可以進行交互,加強了跨窗口之間的關(guān)聯(lián)性,從而實現(xiàn)對局部特征和全局特征的有效建模.Swin Transformer目前共有4個版本,分別為Swin-T、Swin-S、Swin-B和Swin-L,本文采用Swin-L作為基準模型.
知識蒸餾(knowledge distillation, KD)是2015年由Hinton等人[14]提出的模型壓縮方法,它是一種基于教師-學(xué)生網(wǎng)絡(luò)的訓(xùn)練方式,易于實現(xiàn)且簡單有效,因此迅速在工業(yè)界得到廣泛應(yīng)用[15].以圖像識別為例,知識蒸餾的一些關(guān)鍵要素包括教師模型、學(xué)生模型、經(jīng)過Softmax函數(shù)的映射值、蒸餾溫度.其中教師模型是一個參數(shù)量較大的模型,學(xué)生模型是一個參數(shù)量較小的模型,要求這2個模型在給定任意輸入后得到的輸出經(jīng)過Softmax函數(shù)映射能夠得到各類別的概率值.蒸餾溫度可以調(diào)節(jié)概率分布的熵,讓模型更加關(guān)注負標(biāo)簽值的變化.添加蒸餾溫度后的Softmax函數(shù)的數(shù)學(xué)表達如式(1)所示.
其中,zi表示模型關(guān)于第i個類別的輸出值,T表示蒸餾溫度,pi表示經(jīng)過映射后的第i個類別概率值.根據(jù)式(1)計算得到教師和學(xué)生的概率值pi后,就可以計算2個概率分布的KL散度來衡量二者之間的差異,為了更加準確詳細說明其計算過程,以一個具有N個樣本的C類分類問題為例,其數(shù)學(xué)表達為
其中,pt表示教師網(wǎng)絡(luò)概率分布,ps表示學(xué)生網(wǎng)絡(luò)概率分布.在得到KL散度后計算學(xué)生網(wǎng)絡(luò)的交叉熵損失值,兩者聯(lián)合共同訓(xùn)練學(xué)生網(wǎng)絡(luò).
然而,在傳統(tǒng)的知識蒸餾中,知識是單項轉(zhuǎn)移的,這在很大程度上需要保證教師網(wǎng)絡(luò)足夠好.深度互學(xué)習(xí)(deep mutual learning, DML)[16]嘗試了一種新的思路,讓2個學(xué)生網(wǎng)絡(luò)在訓(xùn)練階段協(xié)同學(xué)習(xí),相互轉(zhuǎn)移知識.除了在最后輸出層進行互學(xué)習(xí)外,還可以在網(wǎng)絡(luò)中間特征層進行.這種策略具有普適性,而且對于模型大小沒有限制.但DML和傳統(tǒng)的知識蒸餾一樣,仍舊需要2個網(wǎng)絡(luò)來工作.
知識自蒸餾[17]打破了傳統(tǒng)知識蒸餾和DML需要2個網(wǎng)絡(luò)的規(guī)則,它只在網(wǎng)絡(luò)內(nèi)部的不同層次間進行知識蒸餾,以深層網(wǎng)絡(luò)為教師,將知識轉(zhuǎn)移到淺層網(wǎng)絡(luò)上,在相同甚至更低計算量情況下,也可以獲得和傳統(tǒng)知識蒸餾一樣的效果甚至超越傳統(tǒng)知識蒸餾方法.
從上述分析可以看出,知識自蒸餾在實際使用中限制最小,它能夠充分挖掘網(wǎng)絡(luò)內(nèi)部不同層次的知識.因此,本文將自蒸餾技術(shù)應(yīng)用于細粒度圖像識別中,但本文采用知識自蒸餾的目的并不是為了壓縮網(wǎng)絡(luò),而為了使網(wǎng)絡(luò)能夠提取到更加魯棒的特征.
MKSMT的網(wǎng)絡(luò)結(jié)構(gòu)如圖1所示,它將層次化設(shè)計的Swin Transformer作為主干網(wǎng)絡(luò),共由4段(Stage)組成,即S1、S2、S3和S4。每個段逐步縮小特征圖的分辨率,像CNN一樣逐層擴大感受野,從而提取出不同尺度特征.
Fig.1 The model architecture of MKSMT圖1 MKSMT模型架構(gòu)
從圖1可見,MKSMT的主干網(wǎng)絡(luò)保留了Swin-L的原始架構(gòu).Patch Partition模塊將輸入圖像中每N×N個相鄰像素劃分成一個塊,然后將每個Patch在通道方向上展平.對于一張維度為H×W×3的圖像,其中,H為圖像的高度,W為圖像的寬度,可以劃分成個Patch,每個Patch在通道方向上展平后的維度為N×N×3.
在S1中,首先由Patch Embedding模塊對劃分后Patch的特征維度進行線性變換,即由原來的維映射為維,C是特征圖的通道維度.然后,輸入堆疊的Swin Transformer Block.
S2,S3,S4的操作相同,均是先輸入Patch Merging模塊進行下采樣,然后輸入重復(fù)堆疊的Swin Transformer Block.Patch Merging模塊將特征圖的高度和寬度減半,同時將深度翻倍,因此S2、S3和S4輸出的特征圖維度分別是8C.
同時,MKSMT集成了多層次知識自蒸餾模塊以便提取更加魯棒的特征,并采用多步驟訓(xùn)練的方法進行模型優(yōu)化.需要特別說明的是,知識自蒸餾模塊僅在訓(xùn)練階段輔助主干網(wǎng)絡(luò)學(xué)習(xí)模型參數(shù),并不會參與預(yù)測階段的工作.MKSMT模型的訓(xùn)練過程分2步完成:第1步先訓(xùn)練S3之前的部分網(wǎng)絡(luò),第2步再訓(xùn)練整個網(wǎng)絡(luò).在第1步訓(xùn)練中,S3的輸出被送入輸出處理模塊(output processing module, OPM),該模塊結(jié)構(gòu)如圖2所示.OPM可以將主干網(wǎng)絡(luò)輸出的不同維度特征歸一化映射為相同維度的特征向量,以便輸出層(output layer)處理.例如,對于維度為4C的數(shù)據(jù),經(jīng)過OPM模塊處理后得到的數(shù)據(jù)維度為K.然后,可以根據(jù)輸出層輸出的分類結(jié)果計算交叉熵損失,并進行反向傳播優(yōu)化模型.在第2步訓(xùn)練中,S3和S4的輸出會同步輸入不同的OPM模塊、輸出層Softmax層,然后S4對應(yīng)的輸出層的輸出將用于計算交叉熵損失,而2個Softmax層的輸出將用于計算知識自蒸餾損失,并用于反向傳播更新網(wǎng)絡(luò)參數(shù).
Fig.2 Structure of OPM圖2 OPM結(jié)構(gòu)
MKSMT采用知識自蒸餾的目的是為了將深層次的知識壓縮到淺層網(wǎng)絡(luò),幫助淺層網(wǎng)絡(luò)更好地捕捉圖像的細節(jié)信息,從而更新淺層網(wǎng)絡(luò)已學(xué)習(xí)到的特征,進而在下一批次數(shù)據(jù)訓(xùn)練時,淺層網(wǎng)絡(luò)將遞進反饋于深層網(wǎng)絡(luò),提高網(wǎng)絡(luò)整體對于目標(biāo)識別的準確率.Zhang等人[17]經(jīng)過實驗證明采用知識自蒸餾訓(xùn)練的模型對參數(shù)擾動更具魯棒性,而在沒有自蒸餾情況下訓(xùn)練的模型對高斯噪聲更加敏感.總之,模型經(jīng)過自蒸餾,可以幫助網(wǎng)絡(luò)更好地理解圖像潛在的細節(jié)信息,從而提高特征的魯棒性和表達能力,最終使得模型的準確率得到提升.
從圖1可見,知識自蒸餾只在訓(xùn)練的第2步使用.首先,將輸入圖像輸入主干網(wǎng)絡(luò)后,將分別從S3和S4得到原始特征和,其中上標(biāo)表示特征來源,下標(biāo)表示多步驟訓(xùn)練的步驟編號.然后,將原始特征和分別輸入不同的OPM模塊和輸出層,得到和,隨后與輸出層相連的Softmax層輸出和.S4對應(yīng)的輸出層結(jié)果將用于計算交叉熵損失值.值得一提的是,這里不再根據(jù)S3輸出層結(jié)果計算相應(yīng)的交叉熵損失值.2個Softmax層的輸出結(jié)果和將用于計算知識自蒸餾損失值,最后進行反向傳播優(yōu)化模型.完整的多層次知識自蒸餾算法流程如算法1所示.
算法1.多層次知識自蒸餾算法
輸入:細粒度圖像訓(xùn)練集;
輸出:根據(jù)交叉熵損失和知識自蒸餾損失優(yōu)化的模型.
① 輸入細粒度圖像,分別從S3和S4獲得原始特征和;
⑥ 根據(jù)2種損失反向傳播更新網(wǎng)絡(luò)參數(shù).
網(wǎng)絡(luò)架構(gòu)由若干個階段組成是進行多步驟訓(xùn)練的前提條件.在CNN網(wǎng)絡(luò)中,每個階段可由一組級聯(lián)的卷積層組成,然后在每個階段結(jié)束時輸出不同尺度的特征圖.而在Transformer架構(gòu)的網(wǎng)絡(luò)中,每個階段則一般由若干個Transformer Block堆疊而成.MKSMT采用Swin Transformer中的Swin-L作為主干網(wǎng)絡(luò),共劃分為4個階段(S1~S4),其中S1、S2和S4只包含2個堆疊的Swin Transformer Block,而S3則由18個Swin Transformer Block堆疊而成.因此,從直觀上來講,S3和S4具有更好的特征表達能力.為此,在進行多步驟訓(xùn)練時,步驟1只訓(xùn)練了S3之前的部分網(wǎng)絡(luò),將S3的輸出特征輸入到OPM模塊后經(jīng)過輸出層處理得到輸出,然后計算交叉熵損失;步驟2是對整個網(wǎng)絡(luò)進行訓(xùn)練,并添加了多層次知識自蒸餾,最終依據(jù)交叉熵損失和知識自蒸餾損失優(yōu)化模型.多步驟訓(xùn)練的算法流程如算法2所示.
算法2.多步驟訓(xùn)練算法
輸入:細粒度圖像;
輸出:根據(jù)多步驟訓(xùn)練更新后的網(wǎng)絡(luò).
① 輸入細粒度圖像,從S3獲得原始特征;
④ 步驟1結(jié)束;
⑤ 根據(jù)交叉熵損失反向傳播更新網(wǎng)絡(luò)參數(shù);
⑥ 輸入細粒度圖像,分別從S3和S4獲得原始特征和;
⑩ 根據(jù)式(2)計算自蒸餾的損失;
? 根據(jù)行⑧和行⑨的損失反向傳播更新網(wǎng)絡(luò)參數(shù);
? 步驟2結(jié)束.
由2.3節(jié)分析可知,MKSMT模型的訓(xùn)練過程由2個步驟構(gòu)成,每一步會計算一個損失,然后進行反向傳播更新網(wǎng)絡(luò)參數(shù).在步驟1中,會根據(jù)S3的輸出計算交叉熵損失.假設(shè)現(xiàn)有N個樣本用于分類,那么交叉熵損失函數(shù)的數(shù)學(xué)表達如式(3)所示.
在步驟2中,損失函數(shù)由2部分構(gòu)成,即S4輸出的交叉熵損失,以及S3和S4這2部分輸出Softmax結(jié)果的KL散度,分別由式(5)和式(6)求得,最終損失函數(shù)結(jié)果如式(7)所示.
其中,Ltotal表示步驟2的總損失,L4表示步驟2中根據(jù)S4輸出得到的交叉熵損失,p3和p4分別表示S3和S4輸出經(jīng)過添加蒸餾溫度后的Softmax的預(yù)測結(jié)果,LKL表示S3和S4輸出的KL散度,λ1、λ2為2部分損失函數(shù)的權(quán)重超參數(shù),在CUB數(shù)據(jù)集上都設(shè)為1.
為了驗證MKSMT模型的有效性,分別在CUB-200-2011[18](CUB)、NA-Birds[19](NAB)和Stanford Dogs[20](DOG)這3個基準數(shù)據(jù)集上進行了實驗.數(shù)據(jù)集詳細信息如表1所示,CUB數(shù)據(jù)集包含200種鳥類,是細粒度圖像識別任務(wù)中使用最廣泛的數(shù)據(jù)集之一.除提供圖像級標(biāo)簽外,該數(shù)據(jù)集還提供對象邊界框、關(guān)鍵區(qū)域標(biāo)注和文本描述等人工標(biāo)注信息.NAB數(shù)據(jù)集是一個具有高質(zhì)量標(biāo)注的大規(guī)模細粒度圖像分類數(shù)據(jù)集,共包含48 562張鳥類圖片,分別屬于550個子類.DOG數(shù)據(jù)集是一個由120種狗組成的大型數(shù)據(jù)集,其中的圖像主要來源于日常生活.
Table 1 Experimental Datasets表1 實驗數(shù)據(jù)集
所有實驗均使用PyTorch 1.6,在4×GTX 2080Ti GPU服務(wù)器上進行,對于輸入圖像的預(yù)處理如表2所示.
Table 2 Image Pre-processing Settings表2 圖像預(yù)處理設(shè)置
除了需要對圖像進行預(yù)處理外,主干網(wǎng)絡(luò)Swin-L也使用了在ImageNet-22K上的預(yù)訓(xùn)練參數(shù),MKSMT的其余超參數(shù)設(shè)置如表3所示.
Table 3 Hyperparameters of MKSMT表3 MKSMT的超參數(shù)
本節(jié)分別展示了MKSMT在3個數(shù)據(jù)集上與其它經(jīng)典模型的實驗對比結(jié)果,模型性能評價指標(biāo)采用準確率.參與比較的經(jīng)典模型有的采用基于CNN架構(gòu)的ResNet或EfficientNet作為主干網(wǎng)絡(luò),而有的則以基于Transformer架構(gòu)的ViT或Swin Transformer作為主干網(wǎng)絡(luò).
在CUB數(shù)據(jù)集上的實驗結(jié)果如表4所示,參與比較的模型來自2種不同網(wǎng)絡(luò)架構(gòu).由表4可見,以CNN為主干網(wǎng)絡(luò)的模型中準確率最高的是TMGIF,但只達到90.7%,分類性能遜色于以ViT為主干網(wǎng)絡(luò)的模型,這說明了CNN提取魯棒性特征能力不足.而MKSMT的性能優(yōu)于參與對比的算法,除了Swin-L本身具有強大的特征提取能力外,還得益于MKSMT能夠?qū)⑸顚哟尉W(wǎng)絡(luò)信息壓縮到淺層次網(wǎng)絡(luò)中,從而使得深層次網(wǎng)絡(luò)能夠提取到更加魯棒的特征,最終提升了算法的準確率.
Table 4 Comparison of Experimental Results on CUB Dataset表4 CUB數(shù)據(jù)集的實驗結(jié)果對比
在NAB數(shù)據(jù)集上的實驗結(jié)果如表5所示,參與對比的模型也采用了2種不同架構(gòu)的主干網(wǎng)絡(luò).從表5可見,以CNN為主干網(wǎng)絡(luò)的模型中準確率最高的是GDSMP-Net,但只達到89.0%,性能不如以ViT為主干網(wǎng)絡(luò)的模型.基于CUB和NAB數(shù)據(jù)集背景復(fù)雜、姿態(tài)多樣等特點,說明CNN提取魯棒性特征的能力不足.而MKSMT優(yōu)于對比算法,使用DenseNet-161為主干網(wǎng)絡(luò)的API-Net模型在精度上不僅更低,其網(wǎng)絡(luò)復(fù)雜度也高于本文提出的MKSMT.MKSMT之所以能獲得優(yōu)異的性能,除了上面提及的優(yōu)點外,還集成了多步驟訓(xùn)練方法,能夠逐步進行特征學(xué)習(xí),提取到更有效的特征,且不至于過高提升網(wǎng)絡(luò)復(fù)雜度.
Table 5 Comparison of Experimental Results on NAB Dataset表5 NAB數(shù)據(jù)集的實驗結(jié)果對比
在DOG數(shù)據(jù)集上的實驗結(jié)果如表6所示,參與對比的模型主干網(wǎng)絡(luò)也采用CNN或Transformer架構(gòu).從表6可見,以CNN架構(gòu)為主干網(wǎng)絡(luò)的算法中準確率最高的是PMG-V2,但只達到89.1%,不如以ViT為主干網(wǎng)絡(luò)的模型,說明了CNN提取魯棒特征能力的不足.基于Swin Transformer的MKSMT雖然其性能略遜色于基于ViT的模型,但優(yōu)于基于CNN的模型.總體來看,MKSMT在一眾網(wǎng)絡(luò)中仍有較強競爭力.另一方面,MKSMT相比于其它網(wǎng)絡(luò),可以在沒有過多增加模型復(fù)雜度的情況下,利用多層次知識自蒸餾和多步驟訓(xùn)練實現(xiàn)性能的提升,并且可以作為通用框架集成到其他方法上,這是其它方法無法做到的.
Table 6 Comparison of Experimental Results on DOG Dataset表6 DOG數(shù)據(jù)集的實驗結(jié)果對比
MKSMT主要改變的是訓(xùn)練策略和方法,目的是為了在提升模型準確率的同時不會增加模型的參數(shù)量.為了展示MKSMT的效率,表7對基準模型和MKSMT的計算復(fù)雜度和參數(shù)量做了總結(jié).可以觀察到,與Swin-L相比,MKSMT的計算量基本不變,參數(shù)量僅增加1.1%,總體準確率提升了1.0%,這也反映了模型的優(yōu)異性.
Table 7 Comparison of Model Complexity表7 模型復(fù)雜度對比
為了進一步說明模型提取的特征具有較強的魯棒性,本文將CUB數(shù)據(jù)集中的圖片分別進行水平和垂直翻轉(zhuǎn),再重新對基準模型和MKSMT進行測試.結(jié)果如表8所示,當(dāng)將原始圖片做水平翻轉(zhuǎn)時,2個模型輸出準確率均變化不大,但將圖片做垂直翻轉(zhuǎn)并輸入模型后,Swin-L的輸出準確率下降近11%,而MKSMT僅下降7%.由此可見,MKSMT能夠提取出更加魯棒的特征,以至于當(dāng)測試數(shù)據(jù)發(fā)生變化時,模型依然能得到較好的分類結(jié)果.
Table 8 Comparison of Model Robustness表8 模型魯棒性對比%
為了說明多層次知識自蒸餾和多步驟訓(xùn)練的有效性,本節(jié)在CUB數(shù)據(jù)集上進行了消融實驗.通過僅使用Swin-L網(wǎng)絡(luò),而不集成知識自蒸餾和多步驟訓(xùn)練方法,只保留1個輸出層,最終得到的識別準確率為91.9%.
在多步驟訓(xùn)練的有效性驗證中,首先利用單步驟訓(xùn)練網(wǎng)絡(luò),在反向傳播時網(wǎng)絡(luò)總損失一部分來自S3和S4對應(yīng)輸出層的交叉熵損失,一部分來自S3和S4后續(xù)Softmax層的蒸餾損失;然后3個步驟訓(xùn)練時,首先根據(jù)S2對應(yīng)的輸出層所計算的損失值更新一次網(wǎng)絡(luò)參數(shù);接著根據(jù)S3對應(yīng)的輸出層所計算的損失再更新一次網(wǎng)絡(luò)參數(shù);最后是S4對應(yīng)的輸出層所計算的損失,加上S2~S4對應(yīng)的Softmax層計算得到的知識自蒸餾損失,再更新一次網(wǎng)絡(luò)參數(shù),實驗結(jié)果如表9所示.
Table 9 Accuracy Comparison of the Multi-step Training表9 多步驟訓(xùn)練準確率對比%
從表9可見,單步驟中S3的準確率與兩步驟中S3的準確率一致,但單步驟S4的準確率比兩步驟S4的準確率低0.4個百分點.這主要是由于單步驟的優(yōu)化目標(biāo)過多,在樣本空間上不容易收斂,也就造成提取到的特征魯棒性不如兩步驟訓(xùn)練方法.采用三步驟訓(xùn)練時,各部分準確率都比同樣來源的更低,這是由于S2所在的網(wǎng)絡(luò)深度不夠,淺層網(wǎng)絡(luò)的表達能力不足,無法幫助深層網(wǎng)絡(luò)提取更加魯棒的特征,而在S2就注入梯度也破壞了后面S3和S4的特征提取過程,導(dǎo)致其準確率過低.因此,考慮到特征提取的有效性,兩步驟的訓(xùn)練策略是最好的.在兩步驟訓(xùn)練的步驟2中,包含了特征學(xué)習(xí)和知識自蒸餾,為了驗證這2個部分是同步執(zhí)行還是異步執(zhí)行更有效,本文做了實驗,實驗結(jié)果如表10所示.
Table 10 Comparison of Output Accuracy of the Different Execution Modes表10 不同執(zhí)行方式的輸出準確率對比%
從表10可見,當(dāng)特征學(xué)習(xí)和知識自蒸餾異步執(zhí)行時,淺層(S3)的輸出識別精度要比同步執(zhí)行時高0.1個百分點,但深層(S4)的輸出識別精度要比同步執(zhí)行時低0.3個百分點,這說明了異步執(zhí)行方式更有利于提升淺層網(wǎng)絡(luò)的特征學(xué)習(xí)能力,而同步執(zhí)行方式則有助于深層網(wǎng)絡(luò)提取到更魯棒的特征.
在知識自蒸餾的有效性驗證中,首先移除知識自蒸餾模塊,僅保留多步驟訓(xùn)練.此時,S3的最終輸出精度為90.9%,S4的最終輸出精度為92.1%,均低于添加了知識自蒸餾模塊時的性能.除了在最終的Softmax層進行知識蒸餾(決策蒸餾),對于OPM的輸出特征也進行了知識蒸餾(特征蒸餾),實驗結(jié)果如表11所示.
Table 11 Comparison of Final Output Accuracy of Output Layer and Feature Distillation表11 輸出層和特征蒸餾最終輸出準確率對比%
從表11可以看出,決策蒸餾在淺層和深層輸出的識別精度都高于特征蒸餾.雖然淺層的輸出精度兩者相差不大,但對于深層而言,特征蒸餾的輸出精度比決策蒸餾低了1.2個百分點,甚至不如基準模型的91.9%,這說明了特征蒸餾對于深層網(wǎng)絡(luò)提取強魯棒性特征是無效的,甚至?xí)鸱醋饔?其原因在于,深層和淺層提取了不同層次的特征,強行增加不同尺度特征之間的相似性并不利于深層網(wǎng)絡(luò)提取魯棒性特征.結(jié)合之前的實驗,也說明了知識自蒸餾是多步驟訓(xùn)練的有機結(jié)合.
蒸餾溫度的設(shè)定也會對結(jié)果產(chǎn)生一定的影響,本節(jié)的蒸餾溫度取值分別為1,2,3,4,7,10,根據(jù)不同蒸餾溫度進行實驗得到的結(jié)果如圖3所示,當(dāng)蒸餾溫度為1時,準確率達到了92.8%,效果是最好的.隨著蒸餾溫度的提升,準確率呈下降趨勢.說明對于模型而言,那些明顯低于平均值的負標(biāo)簽無需過多關(guān)注,負標(biāo)簽中的信息量偏少,蒸餾溫度越低,受負標(biāo)簽的影響會更小.
Fig.3 Effect of distillation temperature on accuracy圖3 蒸餾溫度對準確率的影響
為了更直觀地對比模型改進前后的效果,采用類激活可視化方法Grad-CAM[39]對基準模型和MKSMT模型的熱力圖進行可視化.首先在3個數(shù)據(jù)集的測試集中隨機各選取了一張圖片,然后提取每張圖片在基準模型和MKSMT的S2、S3和S4輸出的特征圖,并以可視化熱力圖的方式展現(xiàn)可判別區(qū)域,可視化結(jié)果如圖4所示,在基準模型的熱力圖中,判別區(qū)域更加分散和零碎.對于復(fù)雜圖片,基準模型更關(guān)注紛雜的背景,在S2輸出的熱力圖尤為明顯;而MKSMT的熱力圖效果有所改善,判別特征更加集中.對比可以看出,經(jīng)過模型的知識自蒸餾,S2的熱力圖能更早地將特征集中在對應(yīng)目標(biāo)上,S3的熱力圖特征區(qū)別于基準模型對應(yīng)層而更加接近S4的熱力圖.這表明,MKSMT能通過深層網(wǎng)絡(luò)所學(xué)到的特征有效引導(dǎo)淺層網(wǎng)絡(luò)學(xué)習(xí)更加高級的語義知識,從而促進深層網(wǎng)絡(luò)提取到更魯棒的特征.
Fig.4 Comparison of heat maps generated by the baseline models and MKSMT圖4 基準模型與MKSMT生成的熱度圖對比
細粒度圖像識別由于其固有的類內(nèi)差異大、類間差異小的特點,已經(jīng)成為計算機視覺領(lǐng)域一個非常具有挑戰(zhàn)性的任務(wù).為了讓網(wǎng)絡(luò)提取更加魯棒的特征,同時又不過度增加模型復(fù)雜度,本文提出了多層次知識自蒸餾聯(lián)合多步驟訓(xùn)練的細粒度圖像識別模型MKSMT.MKSMT能在一個網(wǎng)絡(luò)內(nèi)部提取多層次知識進行知識自蒸餾,達到了層次知識之間交互的目的,并通過多步驟訓(xùn)練的方式來更好地完成特征學(xué)習(xí),最終使網(wǎng)絡(luò)提取到更加魯棒的特征,并且在沒有過多增加模型復(fù)雜度的前提下實現(xiàn)模型性能的提升.
實驗結(jié)果表明,MKSMT在CUB、NAB和DOG數(shù)據(jù)集上的準確率優(yōu)于大部分主流模型,且從理論上來說MKSMT所提出來的多層次知識自蒸餾聯(lián)合多步驟訓(xùn)練的思想具有通用性,能集成到其他任務(wù)中.
未來研究方向主要包括:1)將本文思想擴展到其他具有較大應(yīng)用價值的任務(wù)中,如人群計數(shù)[40].2)在實際場景中,常面臨采集的各類別樣本分布不均的問題,從而在訓(xùn)練模型時,由于數(shù)據(jù)集樣本存在長尾分布,導(dǎo)致模型準確率不佳.最近,Adversarial AutoAugment[41]、RandAugment[42]等工作證明了增強樣本數(shù)據(jù)能夠有效提高模型的性能,因此,下一步研究將結(jié)合細粒度圖像識別的特點,從數(shù)據(jù)增強的角度開展研究工作.
作者貢獻聲明:余鷹負責(zé)論文整體設(shè)計和實驗方案的制定,并參與了論文的撰寫和修訂;危偉負責(zé)文獻整理,撰寫部分論文和完成實驗;湯洪負責(zé)算法的開發(fā)和實驗評估,改進和優(yōu)化網(wǎng)絡(luò)方法,并撰寫論文;錢進指導(dǎo)論文撰寫和參與論文的修訂.