田金簫
(西南交通大學(xué) 計算機(jī)與人工智能學(xué)院,成都 611756)
近年來,隨著人工智能技術(shù)的快速發(fā)展和廣泛應(yīng)用,數(shù)據(jù)隱私保護(hù)也得到了密切關(guān)注.歐盟出臺了首個關(guān)于數(shù)據(jù)隱私保護(hù)的法案《通用數(shù)據(jù)保護(hù)條例》(General Data Protection Regulation,GDPR)[1],明確了對數(shù)據(jù)隱私保護(hù)的若干規(guī)定.中國自2017年起實(shí)施的《中華人民共和國網(wǎng)絡(luò)安全法》和《中華人民共和國民法總則》中也對用戶隱私數(shù)據(jù)的使用做出了明確的規(guī)定.在機(jī)器學(xué)習(xí)中,模型的好壞很大程度上依托于建模的數(shù)據(jù).但由于相關(guān)法律法規(guī)的限制,數(shù)據(jù)孤島問題變得十分普遍,導(dǎo)致企業(yè)很難獲取訓(xùn)練數(shù)據(jù).為此,谷歌在2016年提出了聯(lián)邦學(xué)習(xí)的概念.聯(lián)邦學(xué)習(xí)是一種基于分布式機(jī)器學(xué)習(xí)的框架,在這種框架中,多個客戶端在中央服務(wù)器的協(xié)調(diào)下共同訓(xùn)練模型,并保證訓(xùn)練數(shù)據(jù)可以保留在本地,不需要像傳統(tǒng)的機(jī)器學(xué)習(xí)方法一樣將數(shù)據(jù)上傳至中央服務(wù)器[2],從而保護(hù)了用戶隱私.
構(gòu)建一個高性能的聯(lián)邦模型通常需要多輪通信,同時規(guī)模龐大的神經(jīng)網(wǎng)絡(luò)模型,往往包含數(shù)百萬個參數(shù)[3],這導(dǎo)致了巨大的通信開銷.此外,相較于傳統(tǒng)的分布式機(jī)器學(xué)習(xí),聯(lián)邦學(xué)習(xí)還面臨如下問題:
1)客戶端數(shù)據(jù)非獨(dú)立同分布: 在傳統(tǒng)分布式機(jī)器學(xué)習(xí)中的訓(xùn)練數(shù)據(jù)隨機(jī)均勻地分布在客戶端上[4],即遵循獨(dú)立同分布(independent and identically distributed,IID).這在聯(lián)邦學(xué)習(xí)中通常是不成立的,由于用戶的喜好不同,客戶端的數(shù)據(jù)通常是非獨(dú)立同分布(non-IID)的.即客戶端擁有的局部數(shù)據(jù)集不能代表整體數(shù)據(jù)的分布,不同客戶端之間的數(shù)據(jù)分布也不同.
2)數(shù)據(jù)不平衡: 不同的客戶端可能擁有不同的數(shù)據(jù)量.
3)客戶端數(shù)量龐大且不可靠: 參與訓(xùn)練的客戶端為大量的移動設(shè)備,通常大部分客戶端經(jīng)常離線或者處于不可靠的連接上,因此無法確保客戶端參與每一輪的訓(xùn)練.
本文主要研究聯(lián)邦學(xué)習(xí)中的通信效率問題,利用梯度稀疏化的思想減少客戶端與服務(wù)器之間通信的參數(shù)量,并在服務(wù)器聚合時使用投影的方式緩解非獨(dú)立同分布數(shù)據(jù)帶來的影響.經(jīng)過在MNIST 和CIFAR10數(shù)據(jù)集上的實(shí)驗(yàn)證明,本文提出的算法能夠在聯(lián)邦學(xué)習(xí)的約束條件下高效訓(xùn)練模型.
一般來說,減少聯(lián)邦學(xué)習(xí)中的通信開銷有兩種策略,一種是減少訓(xùn)練過程中的通信輪次,另一種是減少每輪傳遞的通信量.減少通信輪次的經(jīng)典方案是聯(lián)邦學(xué)習(xí)中最常用的FedAvg 算法[2],即令客戶端在本地執(zhí)行多輪本地更新,服務(wù)器再進(jìn)行全局聚合,來減少通信輪數(shù).FedAvg 在每次通信中,客戶端需要上傳或下載整個模型,由于聯(lián)邦客戶端通常運(yùn)行在緩慢且不可靠的網(wǎng)絡(luò)連接上,這一要求使得使用FedAvg 訓(xùn)練大型模型變得困難.在實(shí)際應(yīng)用中,FedAvg 算法可以較好地處理非凸問題,但該算法不能很好處理聯(lián)邦學(xué)習(xí)中數(shù)據(jù)non-IID 的情況,在此應(yīng)用場景很可能導(dǎo)致模型不收斂[5].因此針對non-IID 場景,Briggs 等[6]在FedAvg的基礎(chǔ)上引入層次聚類技術(shù),根據(jù)局部更新與全局模型的相似度對客戶端進(jìn)行聚類和分離,以減少總通信輪數(shù).此外Karimireddy 等[7]通過估計服務(wù)器與客戶端更新方向的差異來修正客戶端本地更新的方向,有效地克服了non-IID 問題,能在較少的通信輪次達(dá)到收斂.
另一類方法的核心思想在于減少傳輸?shù)臄?shù)據(jù)量,主要通過量化、稀疏化等一系列方法對模型參數(shù)或者梯度進(jìn)行壓縮.量化通過將元素低精度表示或者映射到預(yù)定義的一組碼字來減少梯度張量中每個元素的位數(shù),例如Dettmers[8]將梯度的32 位浮點(diǎn)數(shù)量化至8 位,SignSGD[9-11]則只保留梯度的符號來更新模型,將負(fù)梯度量化為-1,其余量化為1,實(shí)現(xiàn)了32 倍的壓縮.稀疏化方法通過只上傳部分重要的梯度來進(jìn)行全局模型的更新,如何選擇這些梯度成為該方法的關(guān)鍵.Strom[12]提出使用梯度的大小來衡量其重要性,通過預(yù)先設(shè)立閾值,當(dāng)梯度大于該閾值時對其進(jìn)行上傳.然而在實(shí)際情況中,由于不同的網(wǎng)絡(luò)結(jié)構(gòu)參數(shù)分布差異較大,導(dǎo)致我們無法選擇合適的閾值.因此目前稀疏化方法通常使用Aji 等[13]提出的固定稀疏率,每次傳遞一定比例的最大梯度或每次傳遞前k個最大梯度的Topk 方法[14].上述工作有效地解決了分布式機(jī)器學(xué)習(xí)中的通信開銷問題,針對聯(lián)邦學(xué)習(xí)的訓(xùn)練環(huán)境,Rothchild 等[15]使用了一種特殊的數(shù)據(jù)結(jié)構(gòu)計數(shù)草圖(count sketch)對客戶端梯度進(jìn)行壓縮.Chen 等[16]將神經(jīng)網(wǎng)絡(luò)的不同層分為淺層和深層,并認(rèn)為深層參數(shù)更新頻率低于淺層參數(shù),因此提出了異步更新策略,有效減少了每輪傳遞的參數(shù)量.Haddadpour 等[17]在FedAvg 的基礎(chǔ)上對每輪傳遞的參數(shù)進(jìn)行壓縮,并針對non-IID 場景采用梯度跟蹤技術(shù)對客戶端梯度方向進(jìn)行修正,在收斂速度和準(zhǔn)確率上都取得了較好的效果.
Sattler 等[18]也針對聯(lián)邦學(xué)習(xí)的訓(xùn)練環(huán)境提出了稀疏三元壓縮(sparse ternary compression,STC),該方法在Topk 梯度稀疏化的基礎(chǔ)上進(jìn)行了量化進(jìn)一步減少了通信量,并利用錯誤反饋機(jī)制實(shí)現(xiàn)了客戶端與服務(wù)器之間的雙向壓縮,在聯(lián)邦學(xué)習(xí)場景中表現(xiàn)出了良好的效果.該方法考慮了聯(lián)邦學(xué)習(xí)中客戶端non-IID數(shù)據(jù)的場景,通過利用稀疏的特性以及減少本地訓(xùn)練次數(shù)與服務(wù)器端頻繁通信去減輕non-IID 數(shù)據(jù)帶來的問題,但該方法對non-IID 數(shù)據(jù)的優(yōu)化能力有限.因此本文將在稀疏三元壓縮算法的基礎(chǔ)上,關(guān)注non-IID下的聯(lián)邦場景,提升聯(lián)邦學(xué)習(xí)的通信效率.
常規(guī)的Topk 稀疏方法以全精度傳遞稀疏元素,Sattler 等[19]證明了當(dāng)稀疏化與非零元素的量化相結(jié)合時,可以獲得更高的壓縮增益.如算法1 所示,當(dāng)獲得Topk 稀疏元素Tmasked后,會將其量化為稀疏元素的平均值,因此最后只需要傳遞一個包含值{-μ,0,μ}的三元張量.如果將每一層的梯度看做一個矩陣,那么使用Topk 和稀疏三元壓縮后得到的結(jié)果如圖1 所示,原始梯度是一個稠密矩陣,顏色深淺代表值的大小,通過Topk 方法會得到一個保留較大值的稀疏矩陣,值較小的則置為0,而稀疏三元壓縮則在Topk 的基礎(chǔ)上做了量化,進(jìn)一步提升了壓縮率.
圖1 梯度壓縮效果
算法1.STC[18]: 稀疏三元壓縮算法T∈Rn輸入: 張量,稀疏率p 1.v←topk(|T|)k←max(np,1)2.mask←(|T|≥v)∈{0,1}n 3.Tmasked←mask⊙T 4.μ←1∑ni=1|Tmaskedi|5.T*←μ×sign(Tmasked)6.輸出k
Sattler 等[18]在聯(lián)邦學(xué)習(xí)中使用稀疏三元壓縮對客戶端和服務(wù)器之間通信的梯度進(jìn)行雙向壓縮,并結(jié)合錯誤反饋機(jī)制[20]在客戶端和服務(wù)器保留壓縮前后的誤差累加至下一輪訓(xùn)練過程.
其中,gti為第i個客戶端第t輪訓(xùn)練得到的原始梯度,為壓縮后的梯度,errort為壓縮前后的誤差.該方法取得了與非壓縮算法相似的收斂速度并大大減少了每一輪的通信量,因此本文也將使用稀疏三元壓縮方法進(jìn)行梯度壓縮.
目前在聯(lián)邦學(xué)習(xí)中,我們通常采用平均各個客戶端梯度的方法計算全局模型.當(dāng)不同客戶端數(shù)據(jù)滿足IID 條件時,各客戶端梯度更新方向相近,且聚合后梯度與基于傳統(tǒng)的集中式學(xué)習(xí)獲得的梯度相似性較高.故此方法能獲得全局目標(biāo)函數(shù)的最優(yōu)解.若客戶端數(shù)據(jù)non-IID 且數(shù)據(jù)量差異較大,各客戶端梯度差異性較大,存在相互干擾的情況,導(dǎo)致全局模型收斂速率降低.同時,簡單平均各方梯度易使數(shù)據(jù)量多的客戶端占主導(dǎo)作用,使得全局模型無法較好地處理數(shù)據(jù)量較少的客戶端,最終導(dǎo)致全局模型整體性能低下.
Wang 等[21]提出使用梯度投影處理non-IID 數(shù)據(jù)的問題,服務(wù)器端在進(jìn)行梯度平均之前,通過修改梯度方向減輕non-IID 數(shù)據(jù)帶來的影響.該方法首先對客戶端之間的梯度沖突做出定義,當(dāng)客戶端i的梯度gi和客戶端j的梯度gj滿足gi·gj<0時,則稱為客戶端i和客戶端j之間存在梯度沖突.當(dāng)客戶端之間存在梯度沖突時,梯度方向差異性較大,這時可以通過將一個客戶端的梯度投影到另一個有沖突的客戶端梯度平面上,使用原梯度減去投影來縮小客戶端之間的梯度差異,如式(3)所示:
此外,該方法定義了內(nèi)部沖突和外部沖突,分別對其進(jìn)行投影處理.將參與訓(xùn)練的客戶端之間的梯度沖突定義為內(nèi)部沖突,將客戶端梯度按照訓(xùn)練損失從小到大排序得到并引入?yún)?shù) α來控制每輪參與投影的客戶端數(shù)目.從POt中選擇損失較小的客戶端Sαt迭代的判斷與其他客戶端之間的梯度沖突,并進(jìn)行投影修改梯度方向以緩解內(nèi)部沖突.對于未選擇的損失較大的客戶端則保持原有的梯度,此后進(jìn)行梯度平均得到聚合后的梯度gt,如算法2 所示.
在實(shí)際聯(lián)邦場景中,客戶端non-IID 程度較大,在每輪聚合中,若對所有客戶端統(tǒng)一采用投影方案,則導(dǎo)致訓(xùn)練損失大的客戶端的梯度方向不斷靠近損失小的客戶端.這將導(dǎo)致聚合模型無法學(xué)習(xí)到所有客戶端的信息.但通過調(diào)整參數(shù) α,自適應(yīng)地讓部分訓(xùn)練損失較大的客戶端直接參與最終的聚合階段,有效地緩解了上述問題.
算法2.MitigateInternalConflict[21]: 緩解內(nèi)部沖突算法輸入: 客戶端梯度投影順序,參數(shù)POtα POtS1-αtα 1.服務(wù)器從選擇損失較小的客戶集合參與投影,保留 比例損失較大的客戶端梯度k∈S1-α t 2.for each client in parallel do gpc k ←gtk 3.gti∈POti=1,···,m 4.for each ,do k ·gti<0k≠i 5.if and then gPC||gti||2 gti 6.投影修正客戶端梯度:gPC k ←gPCk -(gti)·gPC k 7.end if 8.end for 9.end for ∑mk=1 gPCk 10.計算聚合梯度:gt←1 m 11.返回聚合梯度gt
由于聯(lián)邦學(xué)習(xí)中客戶端的部分參與和不可靠連接,在第t輪未被選中參與訓(xùn)練的客戶端可能會遭受被全局模型遺忘的風(fēng)險, 因此可以在服務(wù)端保留其最近一次參與訓(xùn)練的梯度根據(jù)它們的近鄰歷史梯度來估計真實(shí)梯度以避免客戶端被遺忘, 如算法3 第6 步所示.第t輪未被選中客戶端的估計梯度gcon與參與更新的客戶端平均后的梯度gt之間的沖突稱為外部沖突, 通過將gt迭代的投影到不同輪次的估計梯度gcon的法平面以緩解外部沖突, 通過參數(shù)τ控制投影的輪次. 具體步驟如算法3 所示.
算法3.MitigateExternalConflict[21]: 緩解外部沖突算法gtGHτ輸入:聚合梯度 ,所有客戶端近鄰歷史梯度,參數(shù)1.for round do gcon←0 t-i,i=τ,τ-1,···,1 2.初始化估計梯度:k=1,2,···,K 3.for each client do tk=t-i 4.if then gt·gtkk <0 5.if then gcon←gcon+gtkk 6.計算未被選中客戶端的估計梯度:7.end if 8.end if 9.end for gt·gcon<0 10.if then 11.對聚合梯度投影修正:12.end if 13.end for gt 14.返回聚合梯度gt←gt- gt·gcon||gcon||2 gcon
鑒于投影能夠有效地處理聯(lián)邦學(xué)習(xí)中的non-IID數(shù)據(jù)問題,因此本文將在稀疏三元壓縮的基礎(chǔ)上,在服務(wù)器端使用投影聚合的方式,進(jìn)一步提高模型的正確率與收斂速度,具體步驟如算法4 所示.
服務(wù)器端接收到客戶端梯度與訓(xùn)練損失后,首先在算法第14 行更新每個客戶端最近一次參與訓(xùn)練的梯度以便在緩解外部沖突時使用,其中K是所有客戶端個數(shù),tK是客戶端最近一次參與訓(xùn)練的輪次.之后在第15 行根據(jù)訓(xùn)練損失的大小對本輪參與訓(xùn)練的客戶端梯度進(jìn)行排序得到其中m是本輪參與訓(xùn)練的客戶端個數(shù).然后依次根據(jù)算法2 中的緩解內(nèi)部沖突算法和算法3 中的緩解外部沖突算法得到聚合梯度gt.算法2 和算法3 的主要作用是對聚合梯度gt的方向進(jìn)行修正以緩解non-IID 問題,因此在第20 行中,保留修正后的聚合梯度gt的方向與原始聚合梯度的大小得到最終的聚合梯度.最后使用與客戶端相同的STC 壓縮算法壓縮聚合梯度并發(fā)送至客戶端.
算法4.基于投影聚合的稀疏三元壓縮算法輸入: 初始化模型w 1.for do 2.服務(wù)器從K 個客戶端隨機(jī)選取m 個客戶端參與訓(xùn)練i=1,···,m t=1,···,T 3.for in parallel do Ci 4.客戶端 :5.從服務(wù)器端下載聚合梯度wti←wt-1i -gˉg 6.)-wti 7.gti←S TC(gti+errort-1,p)8.errort=gti-?gti 9.?gtilti 10.上傳客戶端梯度 和訓(xùn)練損失至服務(wù)器11.end for 12.服務(wù)器器端:?gtilti 13.接收參與訓(xùn)練的客戶端梯度 和訓(xùn)練損失gti←SGD(wti,Datai
GH={?gt11 ,?gt22 ,···,?gtKK 14.更新所有客戶端近鄰歷史梯度信息:POt={?gt1,?gt2,···,?gtm}15.根據(jù)客戶端訓(xùn)練損失對梯度排序:gt←MitigateInternalCon flict(POt,α)16.緩解內(nèi)部沖突:t≥τ}17.if then gt←MitigateExternalCon flict(gt,GH,τ)18.緩解外部沖突:19.end if gt=gt/||gt||*|| 1∑mi ?gti||20.m g=S TC(gt+error,p)21.22.error=gt-g 23.發(fā)送聚合梯度 至客戶端24.end for g
算法4 中的步驟可簡化為圖2,在客戶端,首先接收聚合梯度,然后根據(jù)模型和客戶端數(shù)據(jù)進(jìn)行本地訓(xùn)練得到客戶端梯度,本地訓(xùn)練完成后使用STC 算法壓縮梯度上傳至服務(wù)器,并計算壓縮誤差存儲在本地,在下一輪被選中訓(xùn)練時進(jìn)行梯度修正.
圖2 基于投影聚合的稀疏三元壓縮算法流程
服務(wù)端接收到所有參與訓(xùn)練的客戶端發(fā)送的梯度后判斷客戶端梯度之間是否存在梯度沖突,并依次通過緩解內(nèi)部沖突和外部沖突的算法對梯度方向進(jìn)行修正.最終聚合投影后的梯度生成全局梯度gt,采用STC 算法壓縮全局梯度gt得到發(fā)送至客戶端.該算法實(shí)現(xiàn)了客戶端與服務(wù)器之間的雙向壓縮,并且在服務(wù)器端進(jìn)行投影緩解數(shù)據(jù)異構(gòu)的問題.
本文的實(shí)驗(yàn)使用了MNIST 和CIFAR10 數(shù)據(jù)集.MNIST 數(shù)據(jù)集包含60 000 張訓(xùn)練圖片,10 000 張測試圖片,每張圖片是2 828 的灰度手寫數(shù)字圖像,實(shí)驗(yàn)使用帶有3 個卷積層的CNN 模型對MNIST 進(jìn)行訓(xùn)練.CIFAR10 數(shù)據(jù)集包含50 000 張訓(xùn)練圖片,10 000 張測試圖片,每張圖片是3 232 的RGB 圖像,使用文獻(xiàn)[18]中簡化的VGG11 網(wǎng)絡(luò)進(jìn)行訓(xùn)練.客戶端數(shù)據(jù)集劃分參照文獻(xiàn)[2],首先按照數(shù)據(jù)集的類別進(jìn)行排序,然后將數(shù)據(jù)集劃分為200 個分片,每個客戶端隨機(jī)選擇兩個不會替換的分片來模擬客戶端數(shù)據(jù)非獨(dú)立同分布的場景.實(shí)驗(yàn)中部分參數(shù)設(shè)置如表1 所示.
表1 參數(shù)設(shè)置
我們將本文提出的算法與FedAvg 以及稀疏三元壓縮算法進(jìn)行了對比,圖3 和圖4 是在MNIST 數(shù)據(jù)集上的結(jié)果,圖3 是全局模型在所有客戶端上的平均測試準(zhǔn)確率,圖4 為測試準(zhǔn)確率的方差,其中稀疏三元壓縮以及本文提出的算法在實(shí)驗(yàn)中設(shè)置了0.1 的稀疏率,也就是每輪傳遞10%的參數(shù)進(jìn)行訓(xùn)練,根據(jù)圖1 的實(shí)驗(yàn)結(jié)果可以看到本文提出的算法相較于其他算法收斂速度和收斂精度都略有提升,特別是相較于STC 算法,在相同壓縮率的條件下本文提出的算法大約在第75 輪收斂,而STC 算法在訓(xùn)練過程非常震蕩,并且在大約100 輪才收斂.
圖3 MNISTS 數(shù)據(jù)集測試正確率
圖4 MNISTS 數(shù)據(jù)集測試方差
圖5 和圖6 是在CIFAR10 數(shù)據(jù)集上的測試準(zhǔn)確率和測試方差,稀疏率同樣為0.1,與MNIST 數(shù)據(jù)集相比,在CIFAR10 數(shù)據(jù)集上的訓(xùn)練過程更加震蕩,但是本文提出的算法相較其他算法收斂速度和收斂精度都有大幅度提升,并且訓(xùn)練過程中的震蕩幅度遠(yuǎn)遠(yuǎn)小于FedAvg 和STC 算法,這說明本文的算法是非常有效的.
圖5 CIFAR10 數(shù)據(jù)集平均測試正確率
圖6 CIFAR10 數(shù)據(jù)集測試方差
表2 中記錄了客戶端與服務(wù)器之間每輪通信的參數(shù)大小,通信輪次是達(dá)到固定正確率(MNIST 95%CIFAR10 50% )大約所用的通信輪數(shù),以FedAvg 作為基線算法,本文提出的算法在上傳和下載時都進(jìn)行了壓縮,在MNIST 數(shù)據(jù)集上相較于FedAvg 每輪的通信量減少了45 倍,并且本文的算法在第100 輪時就達(dá)到了指定的正確率,相較于FedAvg 和STC 分別減少了97 和57 個通信輪次,在CIFAR10 數(shù)據(jù)集上每輪的通信量更是減少了47 倍,通信輪次相較于FedAvg 和STC 減少了295 輪和300 輪.
表2 通信開銷計算
本文提出了基于投影聚合的稀疏三元壓縮算法,提升聯(lián)邦學(xué)習(xí)的通信效率.該算法在客戶端和服務(wù)端采用稀疏三元壓縮減少客戶端在每一輪訓(xùn)練過程中上傳和下載的通信量,同時在服務(wù)器端利用梯度投影的方式緩解了由于客戶端數(shù)據(jù)異構(gòu)以及部分參與導(dǎo)致的梯度沖突問題.通過在MNIST 和CIFAR10 數(shù)據(jù)集上的實(shí)驗(yàn)驗(yàn)證,本文提出的算法在通信量、收斂速度和正確率3 個方面都要由于傳統(tǒng)的FedAvg 算法和稀疏三元壓縮算法.由于梯度壓縮會略微改變原始梯度的方向,在未來我們將針對不同的壓縮方法對投影聚合的方式做進(jìn)一步的研究,進(jìn)一步提高算法的有效性.