摘 要:為了提高模型在長尾視覺識(shí)別領(lǐng)域的性能,文章提出了一種多分類器分級(jí)蒸餾框架,該框架包括旋轉(zhuǎn)自監(jiān)督預(yù)訓(xùn)練和多分類器蒸餾。旋轉(zhuǎn)自監(jiān)督預(yù)訓(xùn)練通過預(yù)測圖像旋轉(zhuǎn),平等地考慮每一張圖像,減少模型受到長尾標(biāo)簽的影響。多分類器蒸餾通過三個(gè)專門優(yōu)化的分類器將教師模型的知識(shí)一一對應(yīng)蒸餾到學(xué)生模型。在開源的長尾圖像識(shí)別數(shù)據(jù)集上進(jìn)行了充分實(shí)驗(yàn),并與現(xiàn)有方法進(jìn)行了比較。實(shí)驗(yàn)結(jié)果表明,所提出的方法在長尾圖像視覺識(shí)別方面取得了一定的提升。
關(guān)鍵詞:知識(shí)蒸餾;長尾分布;圖像識(shí)別;深度學(xué)習(xí)模型
中圖分類號(hào):TP183;TP391.4 文獻(xiàn)標(biāo)識(shí)碼:A 文章編號(hào):2096-4706(2024)16-0049-05
Long-tailed Visual Recognition Method Based on Multi-classifier Graded Distillation
Abstract: In order to enhance model performance in the long-tailed visual recognition domain, this paper proposes a multi-classifier graded distillation framework. The framework comprises rotation self-supervised pre-training and multi-classifier distillation. Rotation self-supervised pre-training treats each image equally by predicting image rotations, and mitigates the impact of long-tailed labels on the model. Multi-classifier systematically distills the knowledge from the teacher model to the student model through three specifically optimized classifiers. Extensive experiment results are conducted on open-source long-tailed image recognition datasets, and comparisons are made with existing methods. The experimental results demonstrate that the proposed method achieves notable improvements in long-tailed image visual recognition.
Keywords: knowledge distillation; long-tailed distribution; image recognition; Deep Learning model
0 引 言
不平衡數(shù)據(jù)在現(xiàn)實(shí)世界中是普遍存在的,大規(guī)模的數(shù)據(jù)集往往以一種長尾分布的形式呈現(xiàn)[1]。尤其在安全或健康相關(guān)方面的應(yīng)用,如自動(dòng)駕駛和醫(yī)療診斷[2],數(shù)據(jù)本質(zhì)上是嚴(yán)重失衡的。盡管現(xiàn)代深度學(xué)習(xí)和機(jī)器學(xué)習(xí)技術(shù)已經(jīng)在不同的任務(wù)集上取得了令人印象深刻的成果,但大多數(shù)模型在面對非常罕見或長尾樣本的不均勻數(shù)據(jù)分布時(shí)仍會(huì)遇到困難。如何從這種不平衡數(shù)據(jù)集中獲取到有用信息已然成為當(dāng)下研究的熱點(diǎn)。
處理不平衡數(shù)據(jù)的一個(gè)經(jīng)典方法是數(shù)據(jù)重采樣方法[3-4],這樣做的目的是為了讓類別分布更加平衡,該方法包括對多數(shù)類別的降采樣和對少數(shù)類別的過采樣,但是重采樣技術(shù)在深度學(xué)習(xí)中會(huì)產(chǎn)生一系列的問題,例如過采樣會(huì)導(dǎo)致模型的過擬合,而降采樣會(huì)限制神經(jīng)網(wǎng)絡(luò)的泛化能力。另一種常用的方法是重加權(quán)方法[5-6],該種方法是作用于損失函數(shù)上,對不同類別或不同樣本對應(yīng)的分類損失項(xiàng)賦予不同的權(quán)重。然而,這些方法都犧牲了多數(shù)類的準(zhǔn)確性來補(bǔ)償少數(shù)類。
最近的研究表明,將長尾分類解耦為兩個(gè)階段:表征學(xué)習(xí)和分類器學(xué)習(xí),是一種良好的處理數(shù)據(jù)不平衡的方法[7-8]。Kang等人[9]通過自然(實(shí)例平衡)采樣學(xué)習(xí)高質(zhì)量的表示,并通過類平衡采樣調(diào)整分類器實(shí)現(xiàn)強(qiáng)大的分類性能。Zhou等人[10]提出了一個(gè)雙邊分支網(wǎng)絡(luò),其中一個(gè)分支使用實(shí)例平衡采樣訓(xùn)練,另一個(gè)分支使用類平衡采樣,得出了類似的結(jié)論。解耦學(xué)習(xí)思想被廣泛采用,Cao等人[11]提出了不同的分類器調(diào)整方法,通過調(diào)整Logit鼓勵(lì)增大少數(shù)類與多數(shù)類之間的相對差距。通過向輸出層添加額外的可學(xué)習(xí)層來修改原始Logit。然而上述方法沒有考慮到少數(shù)類別的未被充分代表的特征。
總之,現(xiàn)有的方法要么缺乏一種能學(xué)習(xí)到良好表征的機(jī)制,要么過于復(fù)雜,缺乏很好的泛化性。針對上述方法存在的問題,本文提出一種基于多分類器的知識(shí)蒸餾方法,首先,考慮到直接在不平衡數(shù)據(jù)集中以監(jiān)督學(xué)習(xí)的方式訓(xùn)練一個(gè)網(wǎng)絡(luò)會(huì)產(chǎn)生較差的性能,其原因是不平衡數(shù)據(jù)集的標(biāo)簽信息會(huì)帶來“偏見”,這種偏見使模型不能學(xué)習(xí)到很好的表征,于是我們在進(jìn)行知識(shí)蒸餾之前先對學(xué)生網(wǎng)絡(luò)進(jìn)行自監(jiān)督的預(yù)訓(xùn)練,目的是使學(xué)生網(wǎng)絡(luò)在不平衡數(shù)據(jù)集中學(xué)習(xí)到更好的初始化,一旦網(wǎng)絡(luò)經(jīng)過自我監(jiān)督預(yù)訓(xùn)練產(chǎn)生了良好的初始化,網(wǎng)絡(luò)就可以從訓(xùn)練前的任務(wù)中受益,并最終學(xué)習(xí)到更好的表示。其次,通過知識(shí)蒸餾技術(shù),使用分級(jí)蒸餾損失將教師網(wǎng)絡(luò)中所包含頭部、中部、尾部類的知識(shí)盡可能多地傳遞給學(xué)生網(wǎng)絡(luò)。我們在幾個(gè)長尾基準(zhǔn)數(shù)據(jù)集上進(jìn)行了大量實(shí)驗(yàn),證明了所提出的方法是長尾學(xué)習(xí)場景中有效的學(xué)習(xí)方法。
1 相關(guān)概念
1.1 知識(shí)蒸餾
知識(shí)蒸餾(Knowledge Distillation, KD)是一種將知識(shí)從大的教師模型轉(zhuǎn)移到小的學(xué)生模型的模型壓縮技術(shù),自誕生以來就受到了廣泛關(guān)注。Hinton等人[12]提出將知識(shí)從教師模型的預(yù)測概率分布中提取到學(xué)生模型中,稱為基于Logit的知識(shí)蒸餾。知識(shí)蒸餾引入軟標(biāo)簽,即帶有參數(shù)τ的Softmax函數(shù),以此來軟化概率分布,使概率分布攜帶更多的有用信息,如式(1)所示:
其中,pi為模型第i類的概率分布,zi為模型第i類的輸出結(jié)果,C為類別數(shù),τ為溫度參數(shù),用于調(diào)節(jié)概率分布的平緩程度,τ越大,概率分布就越平均。于是,基于Logit的知識(shí)蒸餾通過對齊學(xué)生模型與教師模型的概率分布以此來將教師模型的知識(shí)傳輸給學(xué)生模型,形式如式(2)所示:
其中,ps和pt分別為學(xué)生模型和教師模型帶溫度參數(shù)τ的經(jīng)過Softmax函數(shù)的概率分布,KL為Kullback-Leible散度損失。
1.2 自監(jiān)督學(xué)習(xí)
自監(jiān)督學(xué)習(xí)[13]近年來取得了顯著進(jìn)展,尤其是在圖像視覺方面。自監(jiān)督方法設(shè)計(jì)各種代理任務(wù)(proxy tasks)來輔助神經(jīng)網(wǎng)絡(luò)學(xué)習(xí),這些任務(wù)可以是預(yù)測圖像上下文或旋轉(zhuǎn)、圖像著色、解決圖像拼圖游戲、最大化全局和局部特征的互信息以及實(shí)例識(shí)別。最近的研究工作表明[14],經(jīng)過自監(jiān)督預(yù)訓(xùn)練初試化的模型可以產(chǎn)生更好的表示,這一研究啟發(fā)了我們,我們將預(yù)測圖片旋轉(zhuǎn)任務(wù)作用于學(xué)生網(wǎng)絡(luò),使其學(xué)習(xí)到一種良好的初始化方法,以至于在知識(shí)蒸餾階段將教師網(wǎng)絡(luò)的知識(shí)轉(zhuǎn)移給學(xué)生網(wǎng)絡(luò)時(shí)學(xué)生能更好地吸收和歸納。
2 相關(guān)方法
2.1 預(yù)定義
我們有n個(gè)圖像X={x1,…,xn}。每個(gè)圖像根據(jù)Y進(jìn)行標(biāo)記Y={y1,…,yn},其中yi∈C為第C類的標(biāo)簽。在本文中,訓(xùn)練集遵循長尾分布。盡管訓(xùn)練集不平衡,但目標(biāo)是準(zhǔn)確識(shí)別所有類,因此我們使用平衡的測試集來評估分類結(jié)果。
2.2 訓(xùn)練教師模型
我們觀察到現(xiàn)有的通過知識(shí)蒸餾解決長尾分布問題的方法,大多都專注于蒸餾方法的改進(jìn),而忽略了對教師模型進(jìn)行詳細(xì)的分析,現(xiàn)有的教師模型僅僅使用普通交叉熵?fù)p失訓(xùn)練網(wǎng)絡(luò),這使得模型的決策邊界嚴(yán)重偏向頭部類,影響知識(shí)蒸餾的效果,基于這一問題我們提出一種多分類器的教師網(wǎng)絡(luò)結(jié)構(gòu),通過額外的分類器來增強(qiáng)尾部類的分類結(jié)果,具體而言,其中一個(gè)主分類器Ch+m+t學(xué)習(xí)識(shí)別頭部類+中部類+尾部類的圖片,另外兩個(gè)分類器Cm+t和Ct分別識(shí)別中部類+尾部類和尾部類的圖片,最終的分類結(jié)果為這三個(gè)分類器的結(jié)果之和,損失函數(shù)如下:
其中(X,Y)為一個(gè)批次中的圖像和標(biāo)簽。(Xh+m+t,Yh+m+t)與由所有類圖像組成的(X,Y)相同。(Xm+t,Ym+t)是(X,Y)的子集,僅包含中部和尾部類的圖像。(Xt,Yt)是(X,Y)的子集,僅包含屬于尾部類的圖像。CE為交叉熵?fù)p失。通過Lbranch使三個(gè)分類器分工明確,分別針對頭+中+尾部,中+尾部,尾部進(jìn)行專門優(yōu)化學(xué)習(xí)。
2.3 知識(shí)蒸餾過程
前段已經(jīng)了解了教師網(wǎng)絡(luò)的訓(xùn)練策略,本段我們將介紹所提知識(shí)蒸餾方法的蒸餾過程,整體框架圖如圖1所示,我們將在本節(jié)具體介紹其中的內(nèi)容。
2.3.1 旋轉(zhuǎn)自監(jiān)督預(yù)訓(xùn)練
在此階段,我們在原始長尾數(shù)據(jù)分布下預(yù)訓(xùn)練學(xué)生網(wǎng)絡(luò)。分類任務(wù)為判斷圖像旋轉(zhuǎn)角度,對比傳統(tǒng)N分類任務(wù),雖然其提供了豐富的語義信息,但它也受到長尾標(biāo)簽的影響。尾部類的樣本可能會(huì)被數(shù)據(jù)豐富的頭部類所淹沒,從而導(dǎo)致表征不足的問題。因此,我們構(gòu)建了平衡的自監(jiān)督分類任務(wù),要求模型預(yù)測圖像旋轉(zhuǎn),旋轉(zhuǎn)角度為{0°,90°,180°,270°},將傳統(tǒng)N類分類任務(wù)轉(zhuǎn)換為四分類任務(wù),它們在不受標(biāo)簽影響的情況下可以平等地考慮每個(gè)圖像。
2.3.2 多分類器蒸餾
知識(shí)蒸餾首先被引入用于通過軟標(biāo)簽將知識(shí)從高性能網(wǎng)絡(luò)(教師模型)轉(zhuǎn)移到小型網(wǎng)絡(luò)(學(xué)生模型)。我們的方法受到知識(shí)蒸餾的啟發(fā),但與之有本質(zhì)區(qū)別。在我們的方法中,學(xué)生模型與教師模型大小是相同的。此外,針對長尾識(shí)別,軟標(biāo)簽中的暗知識(shí)可以通過將知識(shí)從頭部類轉(zhuǎn)移到尾部類從而幫助尾部類更好地進(jìn)行識(shí)別。由于類樣本分布不均勻,我們設(shè)計(jì)了一種基于多分類器的分級(jí)蒸餾方法,將教師網(wǎng)絡(luò)的三個(gè)分類器中包含頭部+中部+尾部,中部+尾部,尾部的知識(shí)一一對應(yīng)蒸餾到學(xué)生網(wǎng)絡(luò)中,分級(jí)蒸餾損失函數(shù)Lclassifier如下所示:
最終學(xué)生模型的損失函數(shù)為:
其中α為超參數(shù)用于平衡兩個(gè)損失項(xiàng)。
3 實(shí)驗(yàn)分析
我們在兩個(gè)開源數(shù)據(jù)集進(jìn)行了一系列實(shí)驗(yàn)來證明所提方法的有效性。我們首先介紹了數(shù)據(jù)集和實(shí)驗(yàn)設(shè)置,然后討論和驗(yàn)證所提方法和現(xiàn)有方法的實(shí)驗(yàn)結(jié)果,最后對所提方法進(jìn)行消融實(shí)驗(yàn)。
3.1 數(shù)據(jù)集和實(shí)驗(yàn)設(shè)置
實(shí)驗(yàn)所用硬件環(huán)境為11th Gen Intel Core i5 2.40 GHz,16 GB內(nèi)存,使用Python編程語言實(shí)現(xiàn),操作系統(tǒng)平臺(tái)為Windows 10。在實(shí)驗(yàn)中,將使用兩個(gè)基準(zhǔn)數(shù)據(jù)集,即CIFAR10-LT和 CIFAR100-LT,來驗(yàn)證本文所提方法的有效性,數(shù)據(jù)集的詳細(xì)信息如表1所示。
原始CIFAR10和CIFAR100都包含6萬張大小為32×32彩色圖片,其中5萬張用于訓(xùn)練,其余用于驗(yàn)證。前者有10個(gè)類,每個(gè)類別有5 000張訓(xùn)練樣本和1 000張測試樣本,后者有100個(gè)類,每個(gè)類別有500張訓(xùn)練樣本和100張測試樣本。CIFAR10-LT和CIFAR100-LT分別為其對應(yīng)長尾版本。本文和文獻(xiàn)[8]的構(gòu)造方法一致,訓(xùn)練集中每個(gè)類別的數(shù)量按照Nc=Nmax×(IR)-c/C進(jìn)行配置,其中,C為數(shù)據(jù)集中類別總數(shù),Nc為第c個(gè)類別所包含的樣本數(shù),Nmax為原始數(shù)據(jù)集中樣本數(shù)量最多的類別所包含的樣本數(shù),在CIFAR10數(shù)據(jù)集中Nmax為5 000,在CIFAR100數(shù)據(jù)集中Nmax為500,IR為不平衡比率。IR可用于描述數(shù)據(jù)集的不平衡程度,定義為訓(xùn)練集中樣本數(shù)最多的類所包含的樣本數(shù)量與樣本數(shù)最少的類所包含樣本數(shù)量之間的比值。在本文中對不同方法基于三種不平衡比率(IR)進(jìn)行驗(yàn)證,IR的取值分別為100、50和10,測試集數(shù)量不變。不同IR下的數(shù)據(jù)訓(xùn)練集樣本分布如圖2和圖3所示。
對于CIFAR10-LT和CIFAR100-LT數(shù)據(jù)集,我們對圖像進(jìn)行預(yù)處理操作,具體操作是從原始圖像或在水平翻轉(zhuǎn)中隨機(jī)裁剪一個(gè)32×32面片,每側(cè)填充4個(gè)像素,并將像素歸一化值為[0,1]。我們采用ResNet-32作為所有實(shí)驗(yàn)的骨干網(wǎng)絡(luò)。采用動(dòng)量為0.9的SGD優(yōu)化器。迭代次數(shù)為200Epoch。初始學(xué)習(xí)率設(shè)為0.1,前五個(gè)Epoch通過線性預(yù)熱進(jìn)行訓(xùn)練。學(xué)習(xí)率在160和180個(gè)Epoch分別衰減0.1。批次大小為128用于所有實(shí)驗(yàn),動(dòng)量衰減率為0.000 5。采用廣泛使用的Top-1分類準(zhǔn)確率作為評估指標(biāo),所報(bào)告的準(zhǔn)確率為模型在相同設(shè)置情形下運(yùn)行三次取平均的結(jié)果。
3.2 實(shí)驗(yàn)結(jié)果
為了驗(yàn)證本文所提方法的有效性,本文與長尾視覺識(shí)別相關(guān)的7種主流方法進(jìn)行對比:CE、CB、LDAM、BBN、BKD、SSD、ResLT。結(jié)果如表2所示。
表2中,CE是使用普通交叉熵?fù)p失訓(xùn)練長尾分布數(shù)據(jù)集,我們將其作為基線,本文所提方法相對CE在數(shù)據(jù)集CIFAR10-LT和CIFAR100-LT上的分類準(zhǔn)確率分別提升了11.71%、10.21%、3.22%、9.98%,9.4%,8.15%。其中BKD、DiVE、SSD與我們一樣使用了知識(shí)蒸餾技術(shù)訓(xùn)練模型,可以看出本文所提方法相對他們在數(shù)據(jù)集CIFAR10-LT、CIFAR100-LT上的分類準(zhǔn)確率有較大的提升。
從圖4中可以看出,所提方法是對頭部類、中部類和尾部類進(jìn)行全面的改進(jìn),對比與CE方法,所提方法可以在不損失頭部類準(zhǔn)確度的情況下大幅度提升中部類和尾部類的準(zhǔn)確度。
圖5研究了不同的溫度參數(shù)τ對于學(xué)生網(wǎng)絡(luò)性能的影響,可以看出當(dāng)溫度很高時(shí)(τ =5)會(huì)導(dǎo)致學(xué)生性能的下降,原因是因?yàn)楦邷貢?huì)增加非正確類的Logit從而影響學(xué)生網(wǎng)絡(luò)預(yù)測的正確性。
我們還研究了自監(jiān)督預(yù)訓(xùn)練學(xué)生模型的有效性。在不平衡比率IR為100的CIFAR100-LT數(shù)據(jù)集上評估結(jié)果,具體而言,根據(jù)訓(xùn)練樣本數(shù)將測試集劃分為3個(gè)部分:head(訓(xùn)練樣本數(shù)≥100)、medium(20<訓(xùn)練樣本數(shù)<100)和tail(訓(xùn)練樣本數(shù)≤20)用于研究自監(jiān)督預(yù)訓(xùn)練方法對不同部分的改進(jìn)效果。結(jié)果如圖6所示。
從圖6中可以看出使用自監(jiān)督旋轉(zhuǎn)預(yù)訓(xùn)練(ssp)可以使學(xué)生模型整體性能提升1.85%,并且對于不同部分均有明顯的提升,如對于head部分有1.89%的改進(jìn),medium部分有1.94%的改進(jìn),tail部分有2.02%的改進(jìn),說明自監(jiān)督預(yù)訓(xùn)練能幫助學(xué)生網(wǎng)絡(luò)更好的識(shí)別不同類的語義信息,并且有助于學(xué)生模型更好地吸收教師傳遞過來的知識(shí)2m2ApfMsPV/2ot1T+zRTuQ==。
4 結(jié) 論
本文針對長尾視覺識(shí)別中尾部類不能被很好識(shí)別的問題,提出了一種基于分類器分級(jí)蒸餾的長尾視覺識(shí)別方法。首先提出一種基于多分類器的教師模型訓(xùn)練方法,可以有效增強(qiáng)教師的教學(xué)能力,然后采用了自監(jiān)督技術(shù)對網(wǎng)絡(luò)進(jìn)行預(yù)訓(xùn)練,最后通過分級(jí)知識(shí)蒸餾將教師模型中有用的信息傳遞給學(xué)生模型,實(shí)驗(yàn)結(jié)果表明,本文所提方法可以有效地提高長尾視覺識(shí)別任務(wù)的準(zhǔn)確性。
參考文獻(xiàn):
[1] HE H B,GARCIA E A. Learning from Imbalanced Data [J].IEEE Transactions on Knowledge & Data Engineering,2009,21(9):1263-1284.
[2] KONG S,RAMANAN D. OpenGAN: Open-Set Recognition via Open Data Generation [C]//2021 IEEE/CVF International Conference on Computer Vision.Montreal:IEEE,2021:793-802.
[3] HAN H,WANG W Y,MAO B H. Borderline-SMOTE: A New Over-Sampling Method in Imbalanced Data Sets Learning [C]//International Conference on Intelligent Computing,ICIC 2005.Hefei:Springer,2005:878-887.
[4] DRUMNOND C,HOLTE R C. Class Imbalance and Cost Sensitivity: Why Under-sampling beats OverSampling [EB/OL].[2024-01-08].https://www.docin.com/p-871518697.html.
[5] CHU P,BIAN X,LIU S P,et al. Feature Space Augmentation for Long-Tailed Data [C]//16th European Conference on Computer Vision.Glasgow:Springer,2020:694-710.
[6] SHEN L,LIN Z C,HUANG Q M. Relay backpropagation for effective learning of deep convolutional neural networks [C]//14th European conference on computer vision.Amsterdam:Springer,2016:467-482.
[7] KHAN S H,HAYAT M,BENNAMOUN M,et al. Cost-Sensitive Learning of Deep Feature Representations From Imbalanced Data [J].IEEE Transactions on Neural Networks and Learning Systems,2018,29(8):3573-3587.
[8] WANG Y X,RAMANAN D,HEBERT M. Learning to Model the Tail [C]//NIPS'17:Proceedings of the 31st International Conference on Neural Information Processing Systems,2017:7032-7042.
[9] KANG B Y,XIE S N,ROHRBACH M,et al. Decoupling Representation and Classifier for Long-Tailed Recognition [J/OL].arXiv:1910.09217[cs.CV].[2024-01-08].https://arxiv.org/abs/1910.09217?context=cs.CV.
[10] ZHOU B Y,CUI Q,WEI X S,et al. Bbn: Bilateral-Branch Network with Cumulative Learning for Long-tailed Visual Recognition [C]//2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition.Seattle:IEEE,2020:9716-9725.
[11] CAO K D,WEI C,GAIDON A,et al. Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss [J/OL].arXiv:1906.07413 [cs.LG].[2024-01-09].https://arxiv.org/abs/1906.07413.
[12] HINTON G,VINYALS O,DEAN J. Distilling the Knowledge in a Neural Network [J/OL].arXiv:1503.02531[stat.ML].[2024-01-09].https://arxiv.org/abs/1503.02531.
[13] GIDARIS S,SINGH P,KOMODAKIS N. Unsupervised Representation Learning by Predicting Image Rotations [J/OL].arXiv:1803.07728 [cs.CV].[2024-01-09].https://arxiv.org/abs/1803.07728v1.
[14] YANG Y Z,XU Z. Rethinking the Value of Labels for Improving Class-Imbalanced Learning [J/OL].arXiv:2006.07529[cs.LG].[2024-01-09].https://arxiv.org/abs/2006.07529?amp=1.