摘要: 針對邊緣計算環(huán)境中參與聯(lián)邦學(xué)習(xí)的客戶端數(shù)據(jù)資源的有限性,同時局限于使用硬標簽知識訓(xùn)練模型的邊緣聯(lián)邦學(xué)習(xí)算法難以進一步提高模型精度的問題,提出了基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法。利用知識蒸餾對軟標簽信息的提取能夠有效提升模型性能的特點,將知識蒸餾技術(shù)引入聯(lián)邦學(xué)習(xí)的模型訓(xùn)練中。在每一輪的聯(lián)邦學(xué)習(xí)模型訓(xùn)練過程中,客戶端將模型參數(shù)和樣本邏輯值一起上傳到邊緣服務(wù)器,服務(wù)器端聚合生成全局模型和全局軟標簽,并一起發(fā)送給客戶端進行下一輪的學(xué)習(xí),使得客戶端在進行本地訓(xùn)練時也能夠得到全局軟標簽知識的指導(dǎo)。同時在模型訓(xùn)練中對利用軟標簽知識和硬標簽知識的占比設(shè)計了動態(tài)調(diào)整機制,使得在聯(lián)邦學(xué)習(xí)中能夠較為合理地利用兩者的知識指導(dǎo)模型訓(xùn)練,實驗結(jié)果驗證了提出的基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法能夠有效地提升模型的精度。
關(guān)鍵詞: 邊緣計算; 知識蒸餾; 客戶端; 軟標簽; 硬標簽
中圖分類號: TP391.1
文獻標志碼: A
文章編號: 1671-6841(2025)02-0044-07
DOI: 10.13705/j.issn.1671-6841.2023158
An Edge Federated Learning Algorithm Based on Knowledge Distillation
SHI Ling1,2, HE Changle1,2, CHANG Baofang1,2, WANG Yali1,2, YUAN Peiyan1,2
(1.College of Computer and Information Engineering, Henan Normal University, Xinxiang 453007, China;
2.Engineering Laboratory of Intellectual Business and Internet of Things Technologies, Xinxiang 453007, China)
Abstract: In view of the clients′ limited data resources involved in federated learning in edge computing environment, and the problem that it was difficult to further improve the accuracy of edge federated learning algorithm which used hard label knowledge to train the model, an edge federated learning algorithm based on knowledge distillation was proposed. The extraction of soft label information by knowledge distillation could effectively improve the performance of the model, so the knowledge distillation technology was introduced into the model training of federated learning. In each round of federated learning model training process, the client uploaded the model parameters and samples logic values to the edge server, and the server generated the global model and global soft label together and sent them to the client for the next round of learning, so that the client could also get the guidance of global soft label knowledge during local training. At the same time, a dynamic adjustment mechanism was designed for the proportion of soft label knowledge and hard label knowledge in model training, so that the knowledge of both could be reasonably used to guide model training in federated learning. The experimental results verified that the proposed edge federated learning algorithm based on knowledge distillation could effectively improve the accuracy of the model.
Key words: edge computing; knowledge distillation; client; soft label; hard label
0 引言
機器學(xué)習(xí)發(fā)展十分迅速,在很多方面都取得了較好的效果[1-2]。但是近年來,互聯(lián)網(wǎng)用戶對數(shù)據(jù)隱私安全越來越重視,不同組織機構(gòu)出于種種原因不愿意共享數(shù)據(jù),導(dǎo)致數(shù)據(jù)碎片化、孤島化等相關(guān)問題,嚴重阻礙了人工智能進一步發(fā)展,因而能夠解決這些問題的聯(lián)邦學(xué)習(xí)技術(shù)應(yīng)運而生[3]。聯(lián)邦學(xué)習(xí)只需要用戶使用私有數(shù)據(jù)訓(xùn)練本地模型,然后通過聚合服務(wù)器接收來自各個用戶的模型參數(shù)來更新全局模型,從而協(xié)作訓(xùn)練出性能更優(yōu)的全局模型。聯(lián)邦學(xué)習(xí)由于其獨特的模型訓(xùn)練過程,在最近幾年引起相關(guān)研究人員廣泛的關(guān)注[4]。
在聯(lián)邦學(xué)習(xí)的模型聚合階段進行相關(guān)優(yōu)化是提高聯(lián)邦學(xué)習(xí)效率的一種常用方案[5-6]。文獻[7]在聯(lián)邦學(xué)習(xí)的客戶端本地訓(xùn)練結(jié)束后,全局模型采用異步聚合的方式,客戶端無需額外的等待時間直接上傳本地模型,減少了時間消耗。文獻[8]提出了基于最新系統(tǒng)狀態(tài)自適應(yīng)選擇最優(yōu)全局聚合頻率的控制算法,這是一種云-邊-端協(xié)同的架構(gòu),能有效減少聯(lián)邦學(xué)習(xí)過程的能耗。在聯(lián)邦學(xué)習(xí)中客戶端訓(xùn)練的本地模型對全局模型的性能起到?jīng)Q定性的作用,因而文獻[9]通過排除聯(lián)邦學(xué)習(xí)中一些不相關(guān)的本地模型,減少了其對全局模型聚合的影響,能夠有效地提升全局模型的最終精度。在聯(lián)邦學(xué)習(xí)過程中,對目標函數(shù)重新加權(quán),可以為損失較大的客戶端分配較高的權(quán)重,這在一定程度上增加了權(quán)重分配的公平合理性,文獻[10]采用此種方式,提升了聯(lián)邦學(xué)習(xí)的模型性能。聯(lián)邦學(xué)習(xí)中客戶端本地模型訓(xùn)練的更新速度不同,文獻[11]根據(jù)聚合服務(wù)器上接收到的本地模型參數(shù),依據(jù)其時效性采取了一種更加合理的自適應(yīng)加權(quán)聚合的方式,使模型精度有了進一步的提升。
聯(lián)邦學(xué)習(xí)雖然具有保護用戶數(shù)據(jù)隱私以及解決數(shù)據(jù)孤島等問題的優(yōu)點,但與集中式機器學(xué)習(xí)訓(xùn)練的模型相比,在精度上有一定差距[12]。移動邊緣計算中有很多方法能有效緩解網(wǎng)絡(luò)擁塞,更好地進行聯(lián)邦學(xué)習(xí)[13-14]。但當(dāng)面對邊緣計算環(huán)境時,客戶端上的數(shù)據(jù)會由于不同用戶的使用習(xí)慣以及興趣愛好等,呈現(xiàn)出不平衡的非獨立同分布的狀態(tài),使本就采用分布式學(xué)習(xí)的聯(lián)邦學(xué)習(xí)在精度上進一步下降[15-17]。當(dāng)下提升聯(lián)邦學(xué)習(xí)的性能是研究在邊緣計算環(huán)境部署聯(lián)邦學(xué)習(xí)的當(dāng)務(wù)之急。
在邊緣計算環(huán)境中,客戶端擁有的本地數(shù)據(jù)資源都是有限的,機器學(xué)習(xí)模型的性能也取決于能夠參與訓(xùn)練的數(shù)據(jù)量,在不能獲取更多數(shù)據(jù)資源的時候,可以考慮對現(xiàn)有模型資源的充分利用。知識蒸餾能夠?qū)崿F(xiàn)對現(xiàn)有模型資源的重復(fù)充分利用,將模型中的軟標簽知識蒸餾出來,并用于指導(dǎo)模型進行新的訓(xùn)練,以此提升模型的性能。本文主要研究基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí),主要貢獻包括兩個方面:1) 提出了基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法,利用知識蒸餾能從軟標簽中獲取知識的特點,提升邊緣聯(lián)邦學(xué)習(xí)模型的性能。2) 設(shè)計了客戶端分組參與聯(lián)邦學(xué)習(xí)的模型訓(xùn)練方式,同時依據(jù)聯(lián)邦學(xué)習(xí)訓(xùn)練階段對知識蒸餾過程中硬標簽的損失與軟標簽的損失在總損失中的比例設(shè)計了一個動態(tài)調(diào)整機制,提升了模型訓(xùn)練的效率。
1 基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法
1.1 基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)
在聯(lián)邦學(xué)習(xí)中,客戶端利用從聚合服務(wù)器下載的全局模型以及本地數(shù)據(jù)集訓(xùn)練本地模型。假設(shè)有K個客戶端參與聯(lián)邦學(xué)習(xí),每個客戶端都有本地數(shù)據(jù)集(xk,yk),客戶端的模型參數(shù)為ω(k),則每個客戶端損失函數(shù)定義為fk(ω)=l(ω;xk,yk),在學(xué)習(xí)率η下,本地模型采用梯度下降法進行訓(xùn)練,Δfk(ω(k))表示客戶端k的模型ω(k)的梯度,則訓(xùn)練過程可表示為ω(k)=ω(k)-ηΔ
fk(ω(k)),然后,服務(wù)器根據(jù)客戶端數(shù)據(jù)量對接收到的參數(shù)進行加權(quán)平均,則在第r輪通信中全局模型參數(shù)為ωr=DkD∑Kkω(k)r,其中:D為數(shù)據(jù)樣本總量;Dk為第k個客戶端的本地數(shù)據(jù)大小。
盡管通過對客戶端進行選擇可以有效地提升模型性能,但是邊緣聯(lián)邦學(xué)習(xí)采用的是基于樣本硬標簽的訓(xùn)練方式,知識映射到模型的渠道單一、效率不高,這些因素阻礙著聯(lián)邦學(xué)習(xí)模型性能的進一步提升。而知識蒸餾技術(shù)在模型訓(xùn)練中引入軟標簽知識來增加模型的知識,提高了模型在訓(xùn)練中獲取知識的效率,從而達到提升模型準確率的目的。
知識蒸餾的訓(xùn)練結(jié)構(gòu)采用的是教師-學(xué)生(teacher-student, T-S)結(jié)構(gòu),簡單來說就是學(xué)生模型可以通過蒸餾獲取到教師模型的知識,神經(jīng)網(wǎng)絡(luò)中輸出層產(chǎn)生的類概率可以表示為Φi(zi)=exp(zi)/∑jexp(zj),其中:zi是第i類邏輯單元值;Φi(zi)是第i類概率。
當(dāng)使用類概率表示知識時,類概率層的負標簽信息就是軟目標知識,而數(shù)據(jù)標簽則稱為硬目標知識。在神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程中負標簽會被Softmax函數(shù)壓扁接近于零,因此會使軟目標知識的部分信息丟失。為了達到對軟標簽知識的利用,設(shè)溫度系數(shù)為T,控制輸出概率的軟化程度,則
Φi(zi,T)=exp(zi/T)/∑jexp(zi/T)。
在教師-學(xué)生結(jié)構(gòu)模型中,知識蒸餾(knowledge distillation)損失定義為LKD(Φ(tea,T),p(stu,T))=∑i(-Φi(teai,T)log(Φi(stui,T))),
其中:LKD表示知識蒸餾損失;stu和tea分別是學(xué)生模型和教師模型輸出的邏輯單元。
若用y表示硬標簽向量,則學(xué)生模型的損失可以定義為
Lstu(y,p(stu,T))=∑i(-yilog(Φi(stui,T)))。
所以對知識蒸餾的總損失可以定義為Ltotal=λ·LKD+(1-λ)·Lstu,其中λ是超參數(shù)。
如圖1所示,本文把參與聯(lián)邦學(xué)習(xí)的客戶端分為組1和組2,按響應(yīng)時間先選取出組1的客戶端,并保證其性能,從而為整體的聯(lián)邦學(xué)習(xí)效率奠定基礎(chǔ)。接著從剩余客戶端中隨機選取與組1相同數(shù)量的客戶端構(gòu)成組2。組1中的客戶端是聯(lián)邦學(xué)習(xí)的主要參與者,對模型的最終性能起主導(dǎo)作用,在本地訓(xùn)練結(jié)束時,需要上傳本地模型參數(shù)以及樣本logits。組1中的客戶端進行同步訓(xùn)練,即在每一個學(xué)習(xí)輪次中,邊緣服務(wù)器只有在接收到組1所有更新的本地模型參數(shù)和樣本logits之后,才會進行聚合操作[18]。組2的客戶端作為輔助參與者,采用的是異步訓(xùn)練方式,在本地訓(xùn)練結(jié)束時,只需要上傳樣本logits。雖然邊緣計算環(huán)境下客戶端的資源狀態(tài)是動態(tài)的,但總體而言,組2中客戶端的質(zhì)量低于組1,所以在組1中的客戶端全部完成本地訓(xùn)練時,組2中的客戶端可能只有部分完成本地訓(xùn)練,在本文方法中,聯(lián)邦學(xué)習(xí)每輪訓(xùn)練的截止時間以組1中所有客戶端上傳完本地參數(shù)和樣本logits為準。到達截止時間時,邊緣服務(wù)器聚合當(dāng)前所接收的來自組2上傳的樣本logits。組2中客戶端的參與,使得全局軟標簽涉及的范圍更廣,有效提高了整體的訓(xùn)練效率。服務(wù)器對全局模型的聚合為
ωr+1=∑Kk=1(
Dk×ω(k)r+1)/∑Kk=1Dk,
其中:K表示組1中客戶端的數(shù)量;ω(k)r+1表示設(shè)備k在完成第r輪本地訓(xùn)練后上傳的模型參數(shù);Dk表示第k個客戶端本地數(shù)據(jù)集的大小。服務(wù)器端全局軟標簽的聚合為
yr+1=∑Kk=1(Dk×ykr+1)/(∑Kk=1Dk)+Δy+1,
其中:ykr+1表示在組1中,客戶端k在完成第r輪本地訓(xùn)練后上傳的樣本logits的平均值;Δy+1表示在第r輪結(jié)束時來自組2中的客戶端生成的局部軟標簽梯度,用以對全局軟標簽信息的補充修正。
在組2中,假設(shè)時間閾值Te截止時,在第r-1輪接收到m個客戶端上傳的樣本logits,在第r輪接收到n個客戶端上傳的樣本logits,則
Δy+1=∑mk=1Dk×ykr
∑mk=1Dk-
∑nk=1Dk×ykr+1∑nk=1Dk。
客戶端使用梯度下降法更新權(quán)重并生成樣本logits,所用公式為ω(k)r+1=ω(k)r-ηgk,其中:η是本地模型訓(xùn)練的學(xué)習(xí)率;gk=Δfk(ω(k))表示當(dāng)前模型參數(shù)ω(k)的梯度。若客戶端k在第r輪本地訓(xùn)練中產(chǎn)生了D個樣本logits,則最終把樣本logits的平均值ykr+1上傳到邊緣服務(wù)器,其中ykr+1=∑Di=1yki/D。
本文所提算法模型訓(xùn)練的基本思想是使預(yù)測模型既要逼近硬標簽,也要逼近相應(yīng)的軟標簽,則參與聯(lián)邦學(xué)習(xí)的客戶端的本地訓(xùn)練中損失函數(shù)為L(ω)=λF(Ylab,y)+(1-λ)GKL(Ysf‖y),其中:y表示模型預(yù)測;λ是超參數(shù),且λ∈(0,1);Ysf為軟標簽;Ylab為硬標簽;F(Ylab,y)表示交叉熵損失函數(shù);GKL(Ysf‖y)為KL散度損失函數(shù),若客戶端k的本地數(shù)據(jù)集的數(shù)量為nk,則
F(Ylab,y)=-∑nki=1Ylablogexp(yi)∑jexp(yj),
GKL(Ysf‖y)=-∑nki=1Ysflogexp(yi)∑jexp(Yjsf)exp(Yisf)∑jexp(yj)。
因此本算法優(yōu)化函數(shù)可定義為
minω{Φ(ω)∑Kk=1
pk(λFk(ω)+(1-λ)GkL(ω))},
其中:pk=Dk/(∑Kk=1Dk)。從優(yōu)化函數(shù)中可以看出超參數(shù)
λ決定著知識蒸餾過程中硬標簽的損失與軟標簽的損失在總損失中的比例。在模型訓(xùn)練的初始階段以及前期階段,軟標簽中蘊含的知識較少,硬標簽的信息對模型性能提升占據(jù)絕對主導(dǎo)地位,在模型訓(xùn)練的后期,模型已充分利用了硬標簽的知識信息后逐漸開始收斂,此時軟標簽知識的出現(xiàn)將進一步提升模型性能。基于上述考慮,為超參數(shù)λ設(shè)計一個模型訓(xùn)練前期取值較大,之后逐漸減小到最低閾值的動態(tài)取值方法,
λ=max(φ,(R-r)/R),(1)
其中:r是聯(lián)邦學(xué)習(xí)當(dāng)前的通信輪數(shù);R是總通信輪數(shù);φ是設(shè)定的最低比例閾值,用來保證聯(lián)邦學(xué)習(xí)中必要比例的硬標簽知識信息的輸入。
1.2 算法實現(xiàn)
基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)機制的每一輪次學(xué)習(xí)過程如下:① 邊緣服務(wù)器選取當(dāng)前輪次中參與聯(lián)邦學(xué)習(xí)的客戶端,并給這些客戶端下發(fā)當(dāng)前全局模型和當(dāng)前全局軟標簽;② 客戶端在本地數(shù)據(jù)集上利用從邊緣服務(wù)器下載的當(dāng)前全局軟標簽和全局模型采用知識蒸餾方式訓(xùn)練更新本地模型,同時也生成新的樣本logits;③ 客戶端將訓(xùn)練好的本地模型以及新生成的樣本logits上傳到邊緣服務(wù)器端;④ 服務(wù)器對客戶端模型參數(shù)以及樣本logits分別進行聚合操作,從而更新全局模型和全局軟標簽,具體算法如下。
算法1 基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法
輸入:客戶端總數(shù)N,每輪時間閾值Te,客戶端批處理大小B,本地訓(xùn)練迭代輪數(shù)E,訓(xùn)練總輪數(shù)R,學(xué)習(xí)率η,最低比例閾值φ,控制參與聯(lián)邦學(xué)習(xí)客戶端數(shù)量的比例系數(shù),客戶端具有硬標簽Ylab的數(shù)據(jù)集D={D1,D2,…,DN}。
輸出:全局模型ωr+1,全局軟標簽yr+1。
選擇客戶端進程(client selection, CS)。
1) 客戶端組1:Kset1←「N×依據(jù)客戶端響應(yīng)時間選?。?/p>
2) 客戶端組2: Kset2←剩余客戶端中隨機選取數(shù)量等于Kset1的客戶端。
服務(wù)器進程(server process, SP)。
1)K←客戶端組1數(shù)量;
2) 初始化全局模型ω以及全局軟標簽y,并將其發(fā)送給選取的客戶端;
3) for each round r=1, 2, … do
4) for 客戶端 Ck∈K do
5) (ωkr+1,ykr+1)←ClientUpdate(ωr,yr)
6) end for
7) ωr+1=ωr+∑Kk=1(Dk×ωkr+1)/∑Kk=1Dk;
8) yr+1=∑Kk=1(Dk×ykr+1)/∑Kk=1Dk+Δy+1;
9)end for
客戶端進程(worker process, WP)。
1)λ←maxφ,R-rR;
2) for epoch i in 1 to E do
3) for batch b in 1 to B do
4) L(ω)=λF(yω,Ylab)+(1-λ)GKL(yω,Ysf)
5) ωr=ω-ηΔL(ω)
6) end for
7) end for
8)yr←Prediction(ω,D)
9)return (ωr,yr)。
2 實驗與結(jié)果分析
2.1 實驗設(shè)置
為了評估本文提出的算法,分別在兩種常見的數(shù)據(jù)集MNIST和CIFAR-10上進行實驗,其中MNIST數(shù)據(jù)集為70 000張被標準化、像素大小為28×28的手寫數(shù)字0~9構(gòu)成的灰度圖像,類別數(shù)為10,訓(xùn)練集和測試集分別包括60 000張和10 000張圖像。CIFAR-10數(shù)據(jù)集為60 000張像素大小32×32的RGB圖像,類別數(shù)為10,訓(xùn)練集和測試集分別包括50 000張和10 000張圖像。在實驗中模擬生成的客戶端數(shù)量為100,同時考慮客戶端數(shù)據(jù)獨立同分布(independent and identically distributed, IID)和非獨立同分布(non-independent and identically distributed, Non-IID)兩種情形。在獨立同分布的實驗狀態(tài)下,分別對MNIST和CIFAR-10的訓(xùn)練數(shù)據(jù)集進行均勻置亂操作,然后分配給每個客戶端,實現(xiàn)對客戶端數(shù)據(jù)集獨立同分布的劃分。在非獨立同分布的實驗狀態(tài)下,分別對MNIST和CIFAR-10數(shù)據(jù)集按照標簽進行排序,然后將這兩個數(shù)據(jù)集分別以300個和250個樣本劃分為一個片區(qū),共分成200個片區(qū),最后為每個客戶端分配兩個不同類的片區(qū),實現(xiàn)客戶端在數(shù)據(jù)集上的非獨立同分布。在實驗中對MNIST數(shù)據(jù)集使用的模型有兩個卷積層,卷積核都為3×3,輸出通道均為8,填充和步長都為1,卷積層后均使用2×2的最大池化層,最后連接一個輸入為392維、輸出為10維的全連接層網(wǎng)絡(luò)。對CIFAR-10數(shù)據(jù)集使用的模型有三個卷積層,卷積核都為5×5,輸出通道分別為16、16和32,填充都為2,步長都為1,卷積層后均使用2×2的最大池化層,最后連接一個輸入為512維、輸出為10維的全連接層網(wǎng)絡(luò),本文算法在與其他聯(lián)邦學(xué)習(xí)算法的對比實驗中,最低比例閾值φ取值0.6。
本文提出的算法用Our algorithm表示,同時為了驗證算法對模型性能的提升,選取了聯(lián)邦學(xué)習(xí)領(lǐng)域中比較經(jīng)典的FedAvg[18]和FedProx[19]算法做參照實驗。FedKD是在FedAvg架構(gòu)下單純地將知識蒸餾技術(shù)與聯(lián)邦學(xué)習(xí)結(jié)合,即在客戶端增加了生成樣本logits的操作,在服務(wù)器端增加了全局軟標簽聚合的操作,其與本文算法相比,沒有設(shè)計動態(tài)取值超參數(shù)λ,并且沒有對客戶端采取分組差異性的全局聚合算法。
2.2 實驗結(jié)果分析
圖2和圖3分別是在獨立同分布和非獨立同分布狀態(tài)下MNIST和CIFAR-10數(shù)據(jù)集對FedAvg、FedProx、FedKD和本文算法訓(xùn)練的模型性能對比。實驗中對于評價模型的性能均采用模型的訓(xùn)練準確率與通信輪數(shù)之間的變化趨勢來反映,其中在獨立同分布狀態(tài)下,MNIST和CIFAR-10數(shù)據(jù)集分別進行了100輪和200輪通信;在非獨立同分布狀態(tài)下,MNIST和CIFAR-10數(shù)據(jù)集分別進行了200輪和300輪通信。
在獨立同分布實驗狀態(tài)下,對于MNIST數(shù)據(jù)集,本文提出的算法最終精度為99.05%,而FedAvg、FedProx和FedKD算法的最終精度分別為97.37%、97.54%和98.17%,與之相比,本文方法在模型最終精度上分別提升了1.7%、1.5%和0.9%。對于CIFAR-10數(shù)據(jù)集,本文算法最終精度為60.31%,而FedAvg、FedProx和FedKD算法的最終精度分別為52.51%、55.93%和59.17%,與之相比本文方法在模型最終精度上分別提升了14.9%、7.8%和1.9%。
在非獨立同分布實驗狀態(tài)下,對于MNIST數(shù)據(jù)集,本文提出的算法最終精度為94.94%,而FedAvg、FedProx和FedKD算法的最終精度分別為92.64%、92.86%和94.56%,與之相比,本文方法在模型最終精度上分別提升了1.2%、1.0%和0.4%。對于CIFAR-10數(shù)據(jù)集,本文算法最終精度為53.99%,而FedAvg、FedProx和FedKD算法的最終精度分別為45.69%、49.05%和53.18%,與之相比本文方法在模型最終精度上分別提升了18.1%、10.1%和1.5%。
對于獨立同分布實驗狀態(tài)和非獨立同分布實驗狀態(tài),與FedAvg、FedProx以及FedKD算法相比,本文提出的算法模型收斂速度更快,能以更少的通信輪次率先達到模型的收斂,同時在模型最終訓(xùn)練的精度上也更高。
為了弄清楚知識蒸餾過程中,硬標簽損失和軟標簽損失對模型最終性能的影響,接下來,我們將研究集中在控制知識蒸餾過程中硬標簽的損失與軟標簽的損失在總損失中所占比例的超參數(shù)λ上。本文在式(1)中對超參數(shù)λ設(shè)計了一個隨著訓(xùn)練輪數(shù)進行動態(tài)調(diào)整,且同時具有最低比例閾值φ的取值機制,因而在接下來的實驗中研究不同最低比例閾值對于模型最終精度的影響。本文設(shè)計的基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法中分別選取的最低比例閾值φ從0.1到0.9進行了實驗,得到如下的結(jié)果。
在利用知識蒸餾訓(xùn)練機器學(xué)習(xí)模型的過程中,模型的準確率并不是與最低比例閾值φ成絕對正相關(guān)的。使用MNIST數(shù)據(jù)集進行實驗時,在IID和Non-IID兩種狀態(tài)下,當(dāng)最低比例閾值為0.6時,模型的精度是最高的。這也表明了在模型訓(xùn)練過程中,硬標簽知識對于訓(xùn)練模型而言是不可或缺的,同時設(shè)置超參數(shù)λ的最低閾值是十分有必要的。由于超參數(shù)λ控制著在訓(xùn)練過程中基于軟標簽的損失和基于樣本硬標簽的損失在總損失函數(shù)中所占的比例,所以λ的取值決定著模型性能。由知識蒸餾中軟標簽知識的來源可知,模型的準確率是與軟標簽中具有的知識量成正相關(guān)的。在聯(lián)邦學(xué)習(xí)初期,模型本身的準確率就較低,因而生成的軟標簽知識量也是較低的,此時如果用軟標簽知識指導(dǎo)模型訓(xùn)練,不僅會效率低下,甚至可能會誤導(dǎo)模型的優(yōu)化方向。隨著聯(lián)邦學(xué)習(xí)訓(xùn)練輪次的增加,模型的準確率也得到了提升,此時生成的軟標簽中的知識量也相對豐富了,在模型無法從硬標簽中獲得更多的知識指導(dǎo)時,軟標簽的參與將會進一步促進模型性能的提升。在聯(lián)邦學(xué)習(xí)過程中,隨著訓(xùn)練輪數(shù)的增加,生成的軟標簽知識量也在逐漸豐富,因此對超參數(shù)依據(jù)聯(lián)邦學(xué)習(xí)訓(xùn)練的輪數(shù)采取動態(tài)取值是較為合理的。但是在模型訓(xùn)練中軟標簽知識是無法起到主導(dǎo)作用的,對軟標簽知識在模型訓(xùn)練中所占比例設(shè)置上限是有必要的,當(dāng)硬標簽占比低于0.6,即軟標簽知識占比超過0.4時,模型的準確率下降。因而在聯(lián)邦學(xué)習(xí)訓(xùn)練中合理地選取最低比例閾值,更合理地利用硬標簽和軟標簽知識能夠使得模型取得最佳訓(xùn)練效果。同樣的,使用CIFAR-10數(shù)據(jù)集進行實驗時,在IID和Non-IID兩種狀態(tài)下,當(dāng)最低比例閾值為0.6時,模型的準確率也是最高的。
通過上述實驗分析可知,客戶端在進行本地模型訓(xùn)練時,由于有全局軟標簽的參與,本地模型性能在一定程度上得到了提升,最終提升了全局模型的性能。
3 結(jié)論
本文提出了基于知識蒸餾的邊緣聯(lián)邦學(xué)習(xí)算法,拓展了傳統(tǒng)聯(lián)邦學(xué)習(xí)算法中只能從樣本硬標簽獲取知識訓(xùn)練模型的單一渠道,使得在聯(lián)邦學(xué)習(xí)模型訓(xùn)練的過程中,能夠利用軟標簽中所涵蓋的知識,從而達到進一步提升聯(lián)邦學(xué)習(xí)在邊緣計算環(huán)境中模型性能的目的。與聯(lián)邦學(xué)習(xí)中其他較為經(jīng)典算法的對比實驗表明,本文提出的方法能夠有效地提升模型的性能。在未來的工作中,需要在追求模型性能的同時,兼顧學(xué)習(xí)過程中的總能耗,合理控制聯(lián)邦學(xué)習(xí)的成本。
參考文獻:
[1] 魏明軍, 閆旭文, 紀占林, 等. 基于CNN與LightGBM的入侵檢測研究[J]. 鄭州大學(xué)學(xué)報(理學(xué)版), 2023, 55(6): 35-40.
WEI M J, YAN X W, JI Z L, et al. Research on intrusion detection based on CNN and LightGBM[J]. Journal of Zhengzhou university (natural science edition), 2023, 55(6): 35-40.
[2] 吳宇鑫, 陳知明, 李建軍. 基于半監(jiān)督深度學(xué)習(xí)網(wǎng)絡(luò)的水體分割方法[J]. 鄭州大學(xué)學(xué)報(理學(xué)版), 2023, 55(6): 29-34.
WU Y X, CHEN Z M, LI J J. Semi-supervised deep learning network based water body segmentation method[J]. Journal of Zhengzhou university (natural science edition), 2023, 55(6): 29-34.
[3] BONAWITZ K, EICHNER H, GRIESKAMP W, et al. Towards federated learning at scale: system design [EB/OL]. (2019-03-22) [2023-04-28]. https:∥arxiv.org/abs/1902.01046.
[4] WANG X F, HAN Y W, WANG C Y, et al. In-edge AI: intelligentizing mobile edge computing, caching and communication by federated learning[J]. IEEE network, 2019, 33(5): 156-165.
[5] ZHANG C, XIE Y, BAI H, et al. A survey on federated learning[J]. Knowledge-based systems, 2021, 216: 106775.
[6] YU B, MAO W J, LV Y H, et al. A survey on federated learning in data mining[J]. WIREs data mining and knowledge discovery, 2022, 12(1): e1443.
[7] SPRAGUE M R, JALALIRAD A, SCAVUZZO M, et al. Asynchronous federated learning for geospatial applications[C]∥Joint European Conference on Machine Learning and Knowledge Discovery in Databases. Cham: Springer International Publishing, 2019: 21-28.
[8] WANG S Q, TUOR T, SALONIDIS T, et al. Adaptive federated learning in resource constrained edge computing systems[J]. IEEE journal on selected areas in communications, 2019, 37(6): 1205-1221.
[9] WANG L P, WANG W, LI B. CMFL: mitigating communication overhead for federated learning[C]∥2019 IEEE 39th International Conference on Distributed Computing Systems. Piscataway: IEEE Press, 2019: 954-964.
[10]LI T, SANJABI M, SMITH V. Fair resource allocation in federated learning [EB/OL]. (2020-02-14) [2023-04-28]. https:∥arxiv.org/abs/1905.10497v1.
[11]YOSHIDA N, NISHIO T, MORIKURA M, et al. Hybrid-FL for wireless networks: cooperative learning mechanism using non-IID data[C].(2019-05-17)[2023-04-28]. https:∥arxiv.org/pdf/1905.07210V2.pdf.
[12]ZHAO Y, LI M, LAI L Z, et al. Federated learning with non-iid data [EB/OL]. (2022-07-21) [2023-04-28]. https:∥arxiv.org/abs/1806.00582.
[13]YUAN P Y, ZHAO X Y, CHANG B F, et al. COPO: a context aware and posterior caching scheme in mobile edge computing[C]∥2019 IEEE International Conference on Signal Processing, Communications and Computing. Piscataway: IEEE Press, 2019: 1-5.
[14]YUAN P Y, CAI Y Y. Contact ratio aware mobile edge computing for content offloading[C]∥IEEE International Conference on Parallel and Distributed Systems. Piscataway: IEEE Press, 2019: 520-524.
[15]LIU J, HUANG J Z, ZHOU Y, et al. From distributed machine learning to federated learning: a survey[J]. Knowledge and information systems, 2022, 64(4): 885-917.
[16]HUANG W K, YE M, DU B. Learn from others and be yourself in heterogeneous federated learning[C]∥Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. Piscataway: IEEE Press, 2022: 10143-10153.
[17]CHAI Z, FAYYAZ H, FAYYAZ Z, et al. Towards taming the resource and data heterogeneity in federated learning[C]∥2019 USENIX Conference on Operational Machine Learning. Berkeley: USENIX Association Press, 2019: 19-21.
[18]MCMAHAN H B, MOORE E, RAMAGE D, et al. Communication-efficient learning of deep networks from decentralized data[EB/OL]. (2016-02-17) [2023-04-28]. https:∥arxiv.org/abs/1602.05629.
[19]LI T, SAHU A K, ZAHEER M, et al. Federated optimization in heterogeneous networks[EB/OL]. (2019-04-21) [2023-04-28]. https:∥arxiv.org/abs/1812. 06127.