趙法川,徐曉輝,宋 濤,郝淼淼,汪 曙,朱偉龍
(河北工業(yè)大學(xué) 電子信息工程學(xué)院,天津 300401)
農(nóng)作物病蟲害種類多、密度大,極易造成作物大量減產(chǎn),嚴(yán)重制約農(nóng)業(yè)生產(chǎn),而快速高效地識別病蟲害是防治的關(guān)鍵。隨著智慧農(nóng)業(yè)的興起與發(fā)展,利用深度學(xué)習(xí)技術(shù)對病蟲害進(jìn)行智能識別以輔助農(nóng)業(yè)生產(chǎn),減少不必要的農(nóng)藥噴施,對保護(hù)生態(tài)環(huán)境、提高農(nóng)作物的品質(zhì),有著十分重要的作用。
隨著數(shù)據(jù)量的增長和算力的提升,深度學(xué)習(xí)發(fā)展迅猛,誕生出CNN、Transformer等特征提取器,催生出一系列模型。比較經(jīng)典的卷積網(wǎng)絡(luò)如VGGNets[1]、ResNets[2]具有識別準(zhǔn)確率高的優(yōu)點,但也存在參數(shù)量大、性能差、難以廣泛應(yīng)用于移動端的問題,因此越來越多的學(xué)者將目光轉(zhuǎn)向輕量級網(wǎng)絡(luò)的研究。李靜等[3]通過遷移學(xué)習(xí)對GoogLeNet的Inception-v4網(wǎng)絡(luò)結(jié)構(gòu)進(jìn)行優(yōu)化,對玉米螟蟲害識別任務(wù)達(dá)到96.44%的準(zhǔn)確率。劉洋等[4]對輕量級網(wǎng)絡(luò)MobileNet和Inception V3進(jìn)行優(yōu)化,在PlantVillage數(shù)據(jù)集上分別達(dá)到95.02%和95.62%的識別準(zhǔn)確率。陸健強(qiáng)等[5]提出一種基于Mixup算法和卷積神經(jīng)網(wǎng)絡(luò)的柑橘黃龍病果實識別模型,對柑橘黃龍病數(shù)據(jù)集的識別準(zhǔn)確率達(dá)到94.29%。邱文杰等[6]通過知識蒸餾得到壓縮模型Distilled-MobileNet,該模型在38種常見病害中達(dá)到了97.62%的分類準(zhǔn)確率,且模型僅為19.83M。輕量級卷積網(wǎng)絡(luò)在作物病蟲害識別中的應(yīng)用研究已經(jīng)頗有成效,但其在模型參數(shù)和性能方面仍有繼續(xù)提升的空間,也許多頭注意力機(jī)制將是一個突破點。
近兩年,在自然語言處理領(lǐng)域大火的Transformer也被成功應(yīng)用到計算機(jī)視覺領(lǐng)域。Dosovitskiy等[7]提出了直接應(yīng)用于圖像塊序列的視覺Transformer (Vision transformer,ViT),在ImageNet-1K上取得了88.55%的準(zhǔn)確率,刷新了該榜單紀(jì)錄。相較于CNN(如ResNet),ViT依靠多頭注意力機(jī)制捕獲圖像塊之間的長距離依賴關(guān)系,因此擁有更大的感受野,能獲取全局信息,但長程的多頭注意力也使得ViT很容易忽略圖像的局部性質(zhì),而剛好CNN能彌補(bǔ)這一點。相較于ViT網(wǎng)絡(luò),CNN的卷積核大多尺寸較小,具有局部特征提取能力,且在現(xiàn)實的工業(yè)部署場景中,執(zhí)行CNN比大多數(shù)現(xiàn)有的ViT都要高效。由此可見,將卷積和多頭注意力混合設(shè)計,有效結(jié)合CNN和ViT的優(yōu)點,可進(jìn)一步提升輕量級作物病蟲害識別模型的性能。
為了將卷積與多頭注意力有效結(jié)合,設(shè)計出高效的作物病蟲害識別方法,本研究提出了一個全新的架構(gòu)M2CNet (Multi-head attention to convolutional neural network)。M2CNet基于層級金字塔結(jié)構(gòu),并引入深度可分離卷積和循環(huán)全連接層進(jìn)行局部特征提取,同時設(shè)計輕量級的全局特征捕捉塊,既提高了性能,也節(jié)省了計算開銷,以期為病蟲害精準(zhǔn)識別提供新的思路,為后續(xù)的邊緣平臺部署和作物病害檢測系統(tǒng)的開發(fā)提供新的解決思路和方案。
1.1.1 CIFAR100數(shù)據(jù)集 CIFAR100由Krizhevsky等[8]收集,圖片主要來自Google和各類搜索引擎。CIFAR100數(shù)據(jù)集有100個類別,每個類別有600張大小為32像素 × 32像素的彩色圖像,其中500張作為訓(xùn)練集,100張作為測試集。這100類被分為20個超類,每個圖像都帶有一個“精細(xì)”標(biāo)簽(它所屬的類)和一個“粗略”標(biāo)簽(它所屬的超類)。將M2CNet應(yīng)用于CIFAR100數(shù)據(jù)集,并與其他模型進(jìn)行效果比較。
1.1.2 PlantVillage數(shù)據(jù)集 PlantVillage由Hughes等[9]創(chuàng)建,在植物病理學(xué)專家的輔助下完成標(biāo)注,目的是幫助解決傳染病導(dǎo)致的作物產(chǎn)量損失問題。該數(shù)據(jù)集包含54 309張圖像,涵蓋南瓜白粉病、桃細(xì)菌性斑點病、櫻桃白粉病、柑橘黃龍病、玉米枯葉病、玉米灰斑病、玉米銹病、番茄二斑葉螨病、番茄葉霉病、番茄斑枯病、番茄早疫病、番茄晚疫病、番茄細(xì)菌性斑點病、番茄花葉病、番茄輪斑病、番茄黃曲葉病、蘋果瘡痂病、蘋果銹病、蘋果黑腐病、蘋果葉焦病、葡萄葉枯病、葡萄黑痘病、葡萄黑腐病、辣椒細(xì)菌性斑點病、馬鈴薯早疫病、馬鈴薯晚疫病共計26種作物疾病。試驗中按數(shù)量8∶2的比例劃分訓(xùn)練集和測試集,PlantVillage將用于檢驗M2CNet在作物病害識別任務(wù)中的表現(xiàn)。
1.1.3 IP102數(shù)據(jù)集 IP102是用于作物害蟲識別的野外構(gòu)建的大規(guī)模數(shù)據(jù)集[10],共有75 222張圖像,涵蓋了102種常見的害蟲,平均每種害蟲737個樣本,這些圖像呈現(xiàn)出自然的長尾分布。病蟲害生命周期有不同階段,例如稻縱卷葉螟在幼蟲時期呈現(xiàn)翠綠色的長條節(jié)狀,而在成蟲時期呈現(xiàn)棕灰色的飛蛾形態(tài),與水稻二化螟類似,因此IP102呈現(xiàn)出類間差異小和類內(nèi)差異大的特點。試驗中同樣按數(shù)量8∶2的比例來劃分訓(xùn)練集和測試集,各類具體害蟲的訓(xùn)練集、測試集包含的圖像數(shù)量匯總?cè)绫?所示。IP102用于檢驗M2CNet在作物蟲害識別任務(wù)中的表現(xiàn)。
表1 IP102 數(shù)據(jù)集害蟲分級分類體系Table 1 Taxonomy of the IP102 dataset on different class levels
本文構(gòu)建了一種識別作物病蟲害的輕量模型-M2CNet,該模型采用金字塔結(jié)構(gòu),降低空間分辨率的同時能夠在不同階段擴(kuò)展通道數(shù)。M2CNet主要開發(fā)了2個重要組件,首先構(gòu)建了局部捕獲塊(Local capture block,LCB),該組件主要由深度可分離卷積和多層循環(huán)全連接構(gòu)成,用來捕捉病蟲害圖片的短距離和細(xì)粒度信息;其次構(gòu)建了輕量級全局捕獲塊(Lightweight global capture block,LGCB),該組件由全局子采樣注意力(Global subsampling attention,GSA)和輕量級前饋網(wǎng)絡(luò)(Lightweight feedforward network)構(gòu)成,用來捕捉病蟲害圖片的長距離和高維信息。模型總體組成如圖1所示,下面將分別介紹局部捕獲塊和輕量級全局捕獲塊。
圖1 M2CNet網(wǎng)絡(luò)總體組成Fig.1 Overall structure of the M2CNet network
1.2.1 局部捕獲塊 局部捕捉塊的結(jié)構(gòu)如圖2所示,其中引入了殘差學(xué)習(xí)[11]的思想,主要由2個連續(xù)的深度可分離卷積[12]和1個多層循環(huán)全連接[13]構(gòu)成。深度可分離卷積由1個深度卷積和1個逐點卷積構(gòu)成,每層卷積后跟隨一個批規(guī)范化[14],由于頻繁地做非線性投影會有害于模型特征的信息傳遞[15],因此這里減少了激活層。深度可分離卷積先從空間維度獲取局部信息,再將獲取的局部信息向通道維度傳遞,最大程度地降低特征的損失;多層循環(huán)全連接由2個偽核為 1×3 和 3×1 的循環(huán)全連接層構(gòu)成,其中也使用了殘差學(xué)習(xí)以避免模型加深時出現(xiàn)的退化現(xiàn)象。多層循環(huán)全連接層通過階梯狀采樣來增大其感受野以更好地集成上下文特征,相比通道全連接有著一步操作就可以同時提取局部信息和融合通道信息的優(yōu)勢。
圖2 局部捕捉塊結(jié)構(gòu)圖Fig.2 Structure diagram of a local snap block
深度可分離卷積和多層循環(huán)全連接的感受野基本相當(dāng),都可以關(guān)注局部信息,但深度可分離卷積更側(cè)重于空間維度,多層循環(huán)全連接更側(cè)重于通道維度。由于圖片數(shù)據(jù)的紋理在空間維度表現(xiàn)更加明顯,因此在局部捕捉塊中采取先空間后通道的思想,深度可分離卷積在前,多層循環(huán)全連接在后,避免特征提取過程中圖片紋理被過度壓縮。
1.2.2 輕量級全局捕獲塊 輕量級全局捕獲塊由多個輕量結(jié)構(gòu)組成,旨在通過更少的參數(shù)來學(xué)習(xí)更魯棒的表征。LGCB最核心的部分是一種特殊的多頭注意力:全局子采樣注意力[16]。圖3是標(biāo)準(zhǔn)多頭注意力與全局子采樣注意力的對比,可以看到全局子采樣注意力多出一個次采樣(Subsampling)結(jié)構(gòu),該結(jié)構(gòu)把特征圖分為多個不重疊的子窗口(s×s),在子窗口上提取代表鍵(K)和值(V),但由于查詢(Q)是全局的,因此注意力仍可以恢復(fù)到全局,這種做法顯著減少了計算量。
圖3 標(biāo)準(zhǔn)多頭注意力(a)與全局子采樣注意力(b)的對比Fig.3 Comparison of standard multi-head attention and global subsampling attention
輕量級全局捕獲塊的整體結(jié)構(gòu)如圖4所示,LGCB首先對輸入特征圖做條件位置編碼[17](Conditional position encoding,CPE),將輸入向量H×W×di映射到高維空間,然后在空間維度展平成向量HW×di,過程中得到了輸入特征圖的位置信息。在全局子采樣注意力階段,輸入特征尺寸為HW×di,次采樣的輸出尺寸為HW/s2×di,其中di為通道維數(shù),s為子窗口的大小,得到Q=HW×di,K=V=HW/s2×di/h,h為多頭注意力頭的數(shù)量,將QKV共同送入多頭注意力。最后經(jīng)過輕量級前饋網(wǎng)絡(luò)[18]將輸入從di降維到di/r,再從di/r升維到di,其中r為降維因子,通常取r=4,該操作用于提升模型容量。簡單地,輕量級全局捕獲塊可以表述如下:
圖4 輕量級全局捕獲塊Fig.4 Lightweight global capture block
式中,Xin表示輸入張量,Norm是層歸一化操作,CPE是條件位置編碼,GSA是全局子采樣注意力,Lightweight FFN是輕量級前饋網(wǎng)絡(luò)。所有這些操作都可以在標(biāo)準(zhǔn)深度學(xué)習(xí)平臺通過常用和高度優(yōu)化的操作來實現(xiàn)。
1.2.3 M2CNet模型架構(gòu) 為滿足不同的邊緣部署需求,本研究提出了3個典型的變體,即M2CNet-S/B/L。架構(gòu)規(guī)范如表2所示,對于歸一化,在局部捕捉塊中使用批歸一化,在輕量級全局捕獲塊中使用層歸一化,對于激活函數(shù)均使用ReLU。
本研究在Ubuntu 20.04系統(tǒng)展開,該系統(tǒng)搭載GeForce RTX 3 090圖形處理器并通過并行計算架構(gòu)CUDA 11.4和CUDNN 8.2.4驅(qū)動,深度學(xué)習(xí)框架選擇PyTorch 1.10.1,編程語言為Python 3.8.5。訓(xùn)練時CIFAR100和IP102的迭代次數(shù)設(shè)為300,PlantVillage的迭代次數(shù)設(shè)為60,批次均為64。學(xué)習(xí)率選擇余弦衰減[19]策略,PlantVillage和CIFAR-100的初始學(xué)習(xí)率設(shè)為0.000 5,IP102的初始學(xué)習(xí)率設(shè)為0.005,前10個迭代次數(shù)學(xué)習(xí)率均使用線性啟動。優(yōu)化器選擇Adamw[20],并將權(quán)重衰減設(shè)置為0.05,在訓(xùn)練中還使用了標(biāo)簽平滑[21]和Mixup[22]數(shù)據(jù)增強(qiáng)來進(jìn)一步探索模型性能。訓(xùn)練時圖像使用224像素×244像素的隨機(jī)裁剪,測試時使用224像素×244像素的中心裁剪。
評價指標(biāo)采用Top1準(zhǔn)確率、Top5準(zhǔn)確率和損失值,Top1準(zhǔn)確率指預(yù)測概率排名第1的類別與實際結(jié)果相符的準(zhǔn)確率,Top5 準(zhǔn)確率是指預(yù)測概率排名前5的類別與實際結(jié)果相符的準(zhǔn)確率。準(zhǔn)確率(Accuracy)和損失值(Loss)的計算公式如下:
式中,TP為真正類,TN為真負(fù)類,F(xiàn)P假正類,F(xiàn)N假負(fù)類;p(xi)代表真實的標(biāo)簽,q(xi)代表預(yù)測的概率。
本研究將M2CNet應(yīng)用于CIFAR100,并與多種模型進(jìn)行了比較,包括許多經(jīng)典的計算量(Floating point operations,F(xiàn)LOPs)小于1G的輕量級卷積網(wǎng)絡(luò),例如ShuffleNets[23-24]、SqueezeNet[25]、MobileNetV2[26]、MobileNetV3[27]、MnasNet[28]、EfficientNet[29],還包括ViT模型MobileViT[30]和大型模型VGG,M2CNet-S/B/L的訓(xùn)練過程見圖5,從圖5可以直觀地看到隨著300次迭代的收斂,M2CNet-S/B/L在訓(xùn)練集和測試集的損失值逐漸降低,直至趨于平穩(wěn)。
圖5 M2CNet-S/B/L在CIFAR100數(shù)據(jù)集的訓(xùn)練過程Fig.5 M2CNet-S/B/L training process in the CIFAR100 dataset
表3是對比結(jié)果,在參數(shù)量和計算量相似的情況下,M2CNet-S/B/L占據(jù)一定優(yōu)勢,且M2CNet-L參數(shù)量和準(zhǔn)確率最優(yōu)。與ShuffleNet系列相比,本研究的M2CNet-S/B比ShuffleNet-V2 1.5/2.0分別在Top1準(zhǔn)確率上實現(xiàn)了4.53、2.53個百分點的提升。與MnasNet系列相比,M2CNet-S/B/L比MnasNet 0.75/1.0/1.3分別在Top1的準(zhǔn)確率上實現(xiàn)了1.89、2.62和1.75個百分點的提升。與MobileNet系列相比,M2CNet-S/B/L分別在Top1準(zhǔn)確率上實現(xiàn)了9.35、6.16和5.12個百分點的提升。由此可見將多頭注意力機(jī)制與卷積結(jié)合可以有效提升卷積模型的性能,例如M2CNet-S與MobileNet-V2、MobileNet-V3-Large參數(shù)量和計算量相似,但其識別精度卻更優(yōu)。在與MobileViT系列的對比中,M2CNet-S/B/L同樣在識別精度上展現(xiàn)出明顯優(yōu)勢。本研究也將M2CNet與大型模型做對比,可以看到M2CNet-L比VGG系列、ResNet 18準(zhǔn)確率更高,而參數(shù)量僅為ResNet 18的一半,是VGG系列的1/20。由此可見M2CNet可以在模型參數(shù)量和準(zhǔn)確率之間保持平衡。
表3 CIFAR100數(shù)據(jù)集模型對比結(jié)果Table 3 Comparison results of CIFAR100 dataset model
為了更好地比較M2CNet-S/B/L的效果,本研究針對每一種變體找到了在參數(shù)量和計算量上相似的對照,即M2CNet-S對應(yīng)MobileViT-XS、MobileViT-XXS、MnasNet 0.75、MobileNet-V2;M2CNet-B對應(yīng)MobileNet-V3-Large、EfficientNet B0、MnasNet 1.0;M2CNet-L對應(yīng)EfficientNet B1、MobileViT-S、MnasNet 1.3。將以上網(wǎng)絡(luò)分別在PlantVillage病害數(shù)據(jù)集和IP102蟲害數(shù)據(jù)集上展開試驗,試驗結(jié)果見圖6。
圖6 病蟲害數(shù)據(jù)集識別結(jié)果Fig.6 Identification results of pest data sets
圖6a是PlantVillage數(shù)據(jù)集識別結(jié)果,可以看到在各組對照中M2CNet-S/B/L分別取得了95.92%、96.82%、97.15%的最大Top1識別準(zhǔn)確率,在參數(shù)量相似的情況下取得了最優(yōu)的結(jié)果。圖6b是IP102數(shù)據(jù)集識別結(jié)果,M2CNet-S/L依然延續(xù)了在PlantVillage上的表現(xiàn),分別取得了67.08%、71.0%的最大Top1準(zhǔn)確率和88.49%、90.50%的最大Top5準(zhǔn)確率。在M2CNet-B對照中MnasNet 1.0取得了69.46%的最大Top1準(zhǔn)確率,超出M2CNet-B 0.47個百分點,不過從整體來看,M2CNet變體在作物病蟲害識別任務(wù)中依然表現(xiàn)出色。M2CNet變體能在對照試驗中取得比其他輕量級網(wǎng)絡(luò)更有競爭力的結(jié)果,分析原因在于融合多頭注意力的M2CNet不僅關(guān)注局部信息,也關(guān)注全局信息,因此能夠靈活應(yīng)對不同特征尺度的變化。
為了進(jìn)一步解釋融合多頭注意力后M2CNet-S/B/L關(guān)注的區(qū)域,這里使用Grad-CAM[31]方法在病害和蟲害的部分?jǐn)?shù)據(jù)集上抽取特征圖進(jìn)行可視化,其可視化結(jié)果如圖7所示??梢钥吹?,由M2CNet-S到M2CNet-B再到M2CNet-L,模型對于分類識別任務(wù)中更有判別性的特征區(qū)域給予了更高的關(guān)注,在一定程度上降低了背景特征的干擾,進(jìn)而提升了模型識別精度。
圖7 網(wǎng)絡(luò)關(guān)注區(qū)域熱力圖Fig.7 Thermal map of the network focus area
本研究為設(shè)計出輕量級作物病蟲害識別方法,將多頭注意力機(jī)制捕捉長距離依賴關(guān)系的能力與卷積神經(jīng)網(wǎng)絡(luò)的局部特征提取能力相結(jié)合,設(shè)計出滿足不同邊緣部署需求的3個變體:M2CNet-S/B/L。
為了驗證M2CNet-S/B/L的特征提取能力,在CIFAR100數(shù)據(jù)集上將其與其他輕量級網(wǎng)絡(luò)展開對比,在參數(shù)量和計算量相似的情況下,M2CNet 3個變體均表現(xiàn)出良好的性能。在與經(jīng)典的輕量級網(wǎng)絡(luò)MobileNet系列比較中,M2CNet-S/B/L在Top1準(zhǔn)確率上分別實現(xiàn)了9.35、6.16和5.12個百分點的增益。
在作物病蟲害數(shù)據(jù)集的試驗中,M2CNet-S/B/L在PlantVillage病害數(shù)據(jù)集上取得了大于99.70%的Top5準(zhǔn)確率和大于95.92%的Top1準(zhǔn)確率,在IP102蟲害數(shù)據(jù)集上取得了大于88.4%的Top5準(zhǔn)確率和大于67.0%的Top1準(zhǔn)確率,且在同級別網(wǎng)絡(luò)的對比中均占有優(yōu)勢,證明M2CNet能夠勝任作物病蟲害識別任務(wù)。
M2CNet網(wǎng)絡(luò)有著參數(shù)量少的優(yōu)點,以M2CNet-S為例,其參數(shù)內(nèi)存僅占用1.8M,對硬件性能(FLOPs)要求僅為0.23G,這極大降低了對硬件平臺的要求,有利于后續(xù)的邊緣平臺部署和作物病害檢測系統(tǒng)的開發(fā)和普及。