李小夏,李孝安
(西北工業(yè)大學 計算機學院,陜西 西安 710129)
神經(jīng)網(wǎng)絡剪枝算法就是在神經(jīng)網(wǎng)絡訓練好之后,以一定的標準或者允許誤差,刪除不重要的網(wǎng)絡節(jié)點以及網(wǎng)絡中部分節(jié)點連接的算法[1]。通過剪枝算法,能夠使得神經(jīng)網(wǎng)絡的結構得到簡化,計算性能(如計算時間)得到優(yōu)化,同時為神經(jīng)網(wǎng)絡的后期處理,比如神經(jīng)網(wǎng)絡的知識提取等要求網(wǎng)絡功能齊全,結構簡單的操作打下基礎,以防止由于網(wǎng)絡結構復雜,造成計算出現(xiàn)組合爆炸的問題[2]。
當前的剪枝算法大致分為3類:1)權衰減法[3]:權衰減法屬于正則化方法,它通過在網(wǎng)絡目標函數(shù)中引入表示結構復雜性的正則化項來達到降低網(wǎng)絡結構復雜性的目的。2)靈敏度計算方法[4]:靈敏度計算方法是指在網(wǎng)絡進行訓練時或在網(wǎng)絡訓練結束后,計算節(jié)點(輸入節(jié)點以及隱層節(jié)點)或連接權對網(wǎng)絡誤差的貢獻(靈敏度),刪除那些貢獻最小的節(jié)點或權。3)相關性剪枝方法[5]:根據(jù)節(jié)點間相關性或相互作用進行剪枝,也是一種很重要的剪枝方法,最常見的做法是先判斷隱節(jié)點輸出之間的相關性,然后合并具有較大相關性的隱節(jié)點。
文中主要是對相關性剪枝算法進行研究,首先介紹相關性剪枝算法的思想和計算方法,然后提出新的基于誤差傳遞的改進方案,最后通過實驗建立神經(jīng)網(wǎng)絡,并對網(wǎng)絡進行剪枝。將新的剪枝算法獲得的網(wǎng)絡與標準算法剪枝得到的網(wǎng)絡進行對比,新算法下的網(wǎng)絡相比于標準算法的網(wǎng)絡的精度得到提高。
相關性剪枝算法的核心就是查找隱層節(jié)點之間的相關性,合并那些相關性比較高的節(jié)點,刪除那些方差較小的節(jié)點,也就是輸出值幾乎為固定值的節(jié)點,并將去掉這些節(jié)點后引起的改變傳遞到下一層的權值連接以及輸出節(jié)點的偏置值中。
相關性在剪枝過程中實際上就是隱層節(jié)點輸出向量之間的線性相關性。當某兩個節(jié)點之間的線性相關性大于某一標準的時候,實際上可以認為兩個隱層節(jié)點的輸出是線性相關的,比如說第i個節(jié)點和第j節(jié)點,存在:
而下一層第k個節(jié)點的輸出:
M為隱層節(jié)點總數(shù),vm對第m個隱層節(jié)點的輸出,Ok為第k個輸出節(jié)點的輸出,wbk為第k個節(jié)點的偏置值。明顯可以用vi代替vj,重新提取vi系數(shù),代替wki,相應的也可以找出新的wbk。從而可以使用vi代替vj,即表示改變第i個節(jié)點與下一層節(jié)點的連接權值以及下一層節(jié)點的偏置值,就可以刪除掉第j個節(jié)點。
相關性算法要求先算出隱層節(jié)點的輸出結果,即隱層輸出數(shù)值矩陣,選取tansig作為隱層激活函數(shù)時,每個值位于[-1 1]區(qū)間上,然后算出每個隱層節(jié)點輸出之間的相關度以及各節(jié)點的輸出的方差。
第i個隱層節(jié)點和第j個隱層節(jié)點的相關度公式為:
實際上就是計算向量的線性相關度的公式。
各節(jié)點的輸出的方差公式如下:
其中P是隱層輸出的組數(shù),vip是第i個隱層節(jié)點的第p組輸出,vi為第i個隱層節(jié)點的輸出均值。
按以上公式計算后輸出序列呈以下5種情況:
1)兩個隱節(jié)點輸出序列高度正相關;
2)兩個隱節(jié)點輸出序列高度負相關;
3)兩個隱節(jié)點輸出序列相關性不高;
4)某隱節(jié)點輸出序列方差較?。?/p>
5)某隱節(jié)點輸出序列方差較大。
可以處理的是線性相關度高的以及方差較小的隱層節(jié)點。
1)當兩個節(jié)點線性相關度高時,對于第k個輸出節(jié)點:
最后得到:
其中, a,b 的值由式(1)得:
wbk為第k個輸出節(jié)點的偏置值。
2)方差較小的隱層節(jié)點,可以認為輸出為一個固定值,設第i個節(jié)點方差較小則直接用計算式可以得到:
綜上兩種方法,可以將剪枝后的相關改變傳遞到網(wǎng)絡權值以及下一層的偏置值中。
但在算法中,對于Rij和取值在哪個范圍才對之刪除并沒有明確的給出,有的材料上[5-6],給出了:Rij>θ1且>θ2,>θ2時刪除兩個相關節(jié)點中的一個;當<θ2,則認為該節(jié)點的輸出為固定值,可刪除,其中θ1和θ2是預設的閾值,但是到底這兩個值如何給出并沒有給出理論依據(jù),在網(wǎng)絡上也搜索的類似的剪枝算法的課件,同樣在θ1和θ2的取值上也沒有給出明確的方法。
平生沒有求過人的父親,將給我攢好的下學期的所有費用,換成名牌酒和茶葉,趁夜色帶我去校長家。父親拖著那條在車禍中被撞瘸的右腿,走起路來很是艱難。
實際計算時,由于隱層節(jié)點的方差計算完之后會有很多的方差值接近0而不等于0,選取那些接近0的方差,刪除對應的隱層節(jié)點能夠有效的簡化網(wǎng)絡結構,但是如何選擇方差允許的上限θ2就是實際刪除操作中要注意的問題??紤]到如果每次網(wǎng)絡結果允許存在一定的誤差且隱層的實際輸出在這個誤差允許范圍內便認為其值是合理的,那么以誤差傳遞作為基本思想,給出每個隱層節(jié)點的實際值與節(jié)點輸出的均值的差,如果差值在一定的范圍內則認為是合理的、沒有誤差的,并在計算總結所有節(jié)點誤差后,刪除沒有誤差的隱層節(jié)點。于是怎么計算由輸出誤差傳遞到隱層節(jié)點允許的誤差就是至關重要的。
按照誤差傳遞的思想,假設神經(jīng)網(wǎng)絡[7]在當前的精度下,允許誤差下降率為δ,則對于第k個輸出節(jié)點,如期望值為Ak,則此處允許誤差為δ*Ak,相應的每個隱層節(jié)點允許的誤差為 xi,則:
此時可做兩種假設,如:假設每一個隱層節(jié)點的輸出誤差都相等,即:xi=xj,此時重新計算時:
則對于所有的輸出節(jié)點,允許誤差:
此處的x對所有隱層節(jié)點都是相同的。
另一種假設為:每個隱層節(jié)點的輸出乘以權值后誤差相等,即wkixi=wkjxj,此時重新計算時:
則對于所有的輸出節(jié)點,允許誤差:
此處的xi分別對應到第i個隱層節(jié)點。
以兩種方式計算的xi,即為希望求得的各節(jié)點的允許誤差,那么在重新計算隱層節(jié)點輸出方差時,若:
此處vi是第i個節(jié)點輸出的均值。
則令vip-vi=0,計入到方差,計算公式(4)的值,這樣各個節(jié)點的誤差值都能求出。最后將所有方差為0的節(jié)點刪除掉,并根據(jù)公式(9),計算固定值輸出節(jié)點的刪除下的網(wǎng)絡變化,修正權值和偏置值,以簡化網(wǎng)絡的結構。
選取數(shù)據(jù)是UCI機器學習數(shù)據(jù)中的Car Evaluation Data-base,是根據(jù)車的價格和車身配置進行選取的分類問題,將其中 1 728 個樣例的 501~600,1 001~1 100,1 601~1 700 樣例抽取出來作為測試樣例,其余1 428個作為網(wǎng)絡的訓練樣例。
網(wǎng)絡最開始選擇時,從5~40個隱層節(jié)點進行遍歷,網(wǎng)絡權值初始化全為0,偏置值全為1,找到17個隱層節(jié)點時網(wǎng)絡的誤差精度相對不錯,達到0.013。而后開始使用相關性剪枝。 計算時發(fā)現(xiàn)[11 12]節(jié)點的方差為 0,[2,17][4,7,13][6,8]幾個節(jié)點的相關性達到0.94以上。如果單獨刪除11,12節(jié)點,網(wǎng)絡的誤差精度不變,為 0.013;單獨刪除[2,17][4,7,13][6,8]中17,7,13,8節(jié)點,網(wǎng)絡的誤差精度為0.030;使用改進方法,A 設為 1,δ設為 0.01,使 xi=xj,則可刪除節(jié)點還有 9,14,在刪除9,11,12,14時網(wǎng)絡的誤差精度為 0.019,上升0.06,相對于以線性相關刪除 17,7,13,8節(jié)點后的網(wǎng)絡的誤差0.030,新剪枝方法獲得的網(wǎng)絡性能還是不錯的。
以下為5個網(wǎng)絡的重新訓練后,訓練以及測試結果:net_0 原網(wǎng)絡(未剪枝網(wǎng)絡);net_1 刪除 9,11,12,14 節(jié)點的網(wǎng)絡(新算法網(wǎng)絡);net_2 刪除 7,8,13,17 節(jié)點的網(wǎng)絡(刪除線性相關節(jié)點生成的網(wǎng)絡,此處的θ1=0.94);net_3刪除7,11,12,13,14,17 節(jié)點的網(wǎng)絡(標準相關性算法剪枝網(wǎng)絡,此處的 θ1=0.94,θ2=0.94);net_4是直接以前述初始化權值訓練生成的13個節(jié)點的網(wǎng)絡,故而訓練前的精度計算沒有意義。
表1 5種神經(jīng)網(wǎng)絡參數(shù)對比Tab.1 Comparison of five artificial neural networks
在訓練前,對比 net_0 與 net_1、net_2、net_3,新算法下的net_1相較于原網(wǎng)絡,在減少了4個節(jié)點的情況下訓練前的訓練精度以及測試精度都稍微下降,但是比之刪去線性相關節(jié)點的net_2,同樣減少4個節(jié)點的情況下,網(wǎng)絡的樣本精度和測試精度高出近4%,足見采用新的剪切算法對于網(wǎng)絡精度的影響要小得多。而相對于標準剪枝算法下的net_3,net_1多出兩個節(jié)點,訓練前的樣本精度和測試精度都也要高出近4%。
訓練之后,對比 net_1、net_2、net_3、net_4,樣本精度和測試精度最高的是net_2,即刪去相關性節(jié)點的網(wǎng)絡,但是新算法下的net_1與之相差不大,在考慮訓練樣本與測試樣本的選取因素下,兩種方法的訓練后的結果幾乎相差無幾。而net_1與標準算法net_3和直接生成神經(jīng)網(wǎng)絡的net_4相比,網(wǎng)絡無論是樣本精度還是測試精度都要高出3-4%,存在較大優(yōu)勢。
時間消耗上,對比可見,網(wǎng)絡隱層節(jié)點數(shù)越多,訓練的時間開銷越大。而在同樣多隱層節(jié)點數(shù)的幾個網(wǎng)絡中,net_1的時間消耗最少,當然因為機器運行時的實際情況不同,也會對網(wǎng)絡的時間開銷有影響,此條僅參考。
綜上實驗對比,新算法在神經(jīng)網(wǎng)絡直接剪枝后,無論是樣本精度還是測試精度都高于其它的基于相關性剪枝算法所獲得的精度,剪枝效果明顯不錯。重新對網(wǎng)進行訓練后,網(wǎng)絡的樣本精度和測試精度略低于刪去線性相關節(jié)點的網(wǎng)絡,但是相差不大。而比之標準相關性網(wǎng)絡和直接訓練13個節(jié)點的網(wǎng)絡,精度要高得多。訓練時間開銷上,新算法生成的網(wǎng)絡略少。
文中關于相關性算法的改進是研究神經(jīng)網(wǎng)絡知識規(guī)則提取方法過程中發(fā)現(xiàn)的問題并進行研究和改進。首先通過不斷增長隱層節(jié)點的方法選取節(jié)點數(shù),而后通過相關性算法和改進算法對網(wǎng)絡進行了剪枝,以達到簡化網(wǎng)絡結構的目的。實驗表明新算法的網(wǎng)絡剪枝方面不僅能夠盡量減少節(jié)點數(shù),還能保證對網(wǎng)絡精度影響較小,對比于其他方法下的網(wǎng)絡有一定的優(yōu)勢。目前給出的方法思路和計算相對簡單,許多地方有待改進,今后將進一步對之進行研究。
[1]Reed R.Pruning algorithms-a survey[J].IEEE Trans Neural Networks,1993,4(5):740-747.
[2]周志華,陳世福.神經(jīng)網(wǎng)絡規(guī)則抽取[J].計算機研究與發(fā)展,2002:39(4):398-405.
ZHOU Zhi-hua,CHEN Shi-fu.Rule extraction from neural networks[J].Journal of Computer Research and Development,2002,39(4):398-405.
[3]雷素娟.一種改進的RBF神經(jīng)網(wǎng)絡及其在股市中的應用[D].泉州:華僑大學,2009.
[4]朱萬富.基于粗集神經(jīng)網(wǎng)絡的故障診斷專家系統(tǒng)研究[D].青島:中國石油大學(華東),2007.
[5]宋清昆,郝敏.基于改進相關性剪枝算法的B P神經(jīng)網(wǎng)絡的結構優(yōu)化[J].控制理論與應用,2006,25(12):4-6
SONG Qing-kun,HE Min.Structural optimization of BP neural network based on correlation pruning algorithm[J].Control Theory and Applications, 2006, 25(12):4-6.
[6]趙壽玲.BP神經(jīng)網(wǎng)絡結構優(yōu)化方法的研究與應用[D].蘇州:蘇州大學,2010.
[7]杜雙育,袁紅波,王先培.基于BP神經(jīng)網(wǎng)絡和氣象統(tǒng)計的絕緣子污閃預警研究[J].陜西電力,2012(11):8-11,33.
DU Shuang-yu,YUAN Hong-bo,WANG Xian-pei.Early warning of insulator pollution flashover based on BP Network and meteorological statistics[J].Shaanxi Electric Power,2012(11):8-11,33.