周安眾, 謝丁峰
(湖南工業(yè)職業(yè)技術(shù)學(xué)院信息工程學(xué)院, 湖南 長沙 410208)
近年來,研究人員為了消除交通流預(yù)測中的時(shí)空因素的影響,采用了各種深度學(xué)習(xí)模型[1]。其中,圖卷積神經(jīng)網(wǎng)絡(luò)(Graph Convolutional Network,GCN)用于建模交通網(wǎng)絡(luò)這類非歐結(jié)構(gòu)數(shù)據(jù)已被證明是有效的[2]。在時(shí)間建模方面,遞歸神經(jīng)網(wǎng)絡(luò)(Recurrent Neural Network,RNN)是短時(shí)交通流預(yù)測(30 min以內(nèi))中的常用方法[3]。然而,對更長時(shí)間(45 min以上)的交通流預(yù)測效果仍不理想[4],具體原因如下。
(1)路網(wǎng)中節(jié)點(diǎn)信息不僅隨相鄰節(jié)點(diǎn)變化,還間接受到更遠(yuǎn)距離節(jié)點(diǎn)(高階節(jié)點(diǎn))的影響。傳統(tǒng)深度學(xué)習(xí)模型以局部特征為中心,需要疊加多層結(jié)構(gòu)提取全局特征,訓(xùn)練時(shí)容易出現(xiàn)梯度彌散問題,不能有效提取長距離空間依賴關(guān)系。
(2)交通流量發(fā)生周期性的變化,如日高峰時(shí)間或其他周期事件,有必要引入長期的歷史時(shí)間依賴關(guān)系捕捉周期特征。長短期記憶網(wǎng)絡(luò)和門控循環(huán)單元常用來提取長距離時(shí)間依賴關(guān)系,但在迭代訓(xùn)練時(shí)可能會(huì)導(dǎo)致誤差累積[5-7]。
綜上所述,本文提出一種結(jié)構(gòu)化圖注意力網(wǎng)絡(luò)模型(Structural Graph Attention Network,SGAN)。一方面,SGAN采用注意力機(jī)制替代了傳統(tǒng)深度學(xué)習(xí)架構(gòu)中的局部卷積核,使其在長距離約束情況下能高效提取時(shí)空特征。另一方面,SGAN在計(jì)算注意力系數(shù)時(shí)利用節(jié)點(diǎn)的結(jié)構(gòu)信息,不僅考慮了節(jié)點(diǎn)之間的特征相關(guān)性,還考慮了節(jié)點(diǎn)的連接方式。
對于交通流預(yù)測任務(wù),建立圖G=(V,ε,A),其中V是圖G中的節(jié)點(diǎn)集,ε是邊集,反映節(jié)點(diǎn)之間的連通性。A∈N×N表示G的鄰接矩陣。同時(shí),建立矩陣X(t)∈N×d作為t時(shí)刻的輸入特征矩陣,其中N表示節(jié)點(diǎn)數(shù),d表示特征的維度。給定圖G中T個(gè)時(shí)間步的輸入特征序列(X=[x(t-T+1),…,x(t)]),通過學(xué)習(xí)映射函數(shù)f預(yù)測未來T′個(gè)時(shí)間步的交通流(Y=[x(t+1),…,x(t+T′)])。G提供了圖的空間信息,T則提供了圖中交通流的時(shí)間信息,學(xué)習(xí)函數(shù)f將輸入的T個(gè)歷史時(shí)間步數(shù)據(jù)映射到未來的T′個(gè)時(shí)間步上。
本文提出的SGAN整體架構(gòu)如圖1所示,由多個(gè)時(shí)空注意力模塊組成,每個(gè)模塊包含一個(gè)空間依賴層(Spatial Dependencies Layer, SDL)和一個(gè)時(shí)間依賴層(Temporal Dependencies Layer, TDL),并加入了殘差結(jié)構(gòu)[8]。歷史周期數(shù)據(jù)經(jīng)過特征轉(zhuǎn)換后作為時(shí)空注意力模塊的輸入,并通過多個(gè)SDL和TDL生成最終輸出。SDL主要關(guān)注節(jié)點(diǎn)在圖上的結(jié)構(gòu)依賴關(guān)系,而TDL則主要關(guān)注節(jié)點(diǎn)在時(shí)間步上的順序依賴關(guān)系。
圖1 SGAN模型的整體架構(gòu)圖Fig.1 The overall architecture of the SGAN model
一般來說,輸入的T個(gè)歷史時(shí)間步與待預(yù)測的T′個(gè)時(shí)間步相鄰,為最近一個(gè)小時(shí)內(nèi)的歷史觀測數(shù)據(jù)。此外,每日和每周的長期歷史觀測數(shù)據(jù)也能幫助預(yù)測未來交通流的周期變化[9],日數(shù)據(jù)設(shè)為待預(yù)測時(shí)間前一天的相同時(shí)間步,周數(shù)據(jù)設(shè)為待預(yù)測時(shí)間前一周的相同時(shí)間步,則最近的小時(shí)數(shù)據(jù)記為Xh,日數(shù)據(jù)記為Xd,周數(shù)據(jù)記為Xw,將三個(gè)數(shù)據(jù)進(jìn)行連接作為SGAN的最終輸入數(shù)據(jù):
X=concat(Xh,Xd,Xw)
(1)
數(shù)據(jù)經(jīng)過連接后包含交通流不同的周期模式,有利于模型提取長期影響下的周期變化特征,然后通過一個(gè)全連接層對其特征進(jìn)行變換,得到一個(gè)高維向量,輸入到時(shí)空注意力模塊進(jìn)行特征提取[10]。
交通網(wǎng)絡(luò)圖中不同位置的節(jié)點(diǎn)因不同的連接方式導(dǎo)致其重要性存在差異,SDL中引入圖注意力機(jī)制自適應(yīng)地計(jì)算每個(gè)節(jié)點(diǎn)的注意力系數(shù),即節(jié)點(diǎn)的重要度[11]。對于節(jié)點(diǎn)vi,對應(yīng)特征為xi,則其輸出特征zi通過公式(2)計(jì)算:
(2)
本質(zhì)上,公式(2)是對vi的鄰居節(jié)點(diǎn)的聚合操作[12]。xj是鄰居節(jié)點(diǎn)特征,W∈dx×dz是待學(xué)習(xí)的參數(shù)矩陣,σ是激活函數(shù),αij是vi相對于vj的注意力系數(shù),由以下公式計(jì)算:
(3)
其中,n是vi的鄰居節(jié)點(diǎn)數(shù)量,eij利用公式(2)中的共享參數(shù)W對特征進(jìn)行變換,再由單層前饋神經(jīng)網(wǎng)絡(luò)函數(shù)a得到:
eij=a(Wxi,Wxj)
(4)
公式(2)通過聚合vi的鄰居節(jié)點(diǎn)特征得到zi,沒有考慮每個(gè)節(jié)點(diǎn)的連接方式。如圖2所示,假設(shè)以A為中心節(jié)點(diǎn),接收來自其鄰居節(jié)點(diǎn)的信息,圖3中D的不同邊數(shù)將導(dǎo)致傳遞到節(jié)點(diǎn)A的信息強(qiáng)度不同。因此,需要考慮中心節(jié)點(diǎn)與其相鄰節(jié)點(diǎn)之間的拓?fù)潢P(guān)系,并對拓?fù)浣Y(jié)構(gòu)進(jìn)行編碼,從而得到圖的空間依賴關(guān)系。
圖2 A節(jié)點(diǎn)接收鄰居節(jié)點(diǎn)傳遞的信息Fig.2 Node A receiving information transmitted by neighboring nodes
圖3 鄰居節(jié)點(diǎn)的結(jié)構(gòu)不同導(dǎo)致信息傳遞強(qiáng)度不同F(xiàn)ig.3 Different information intensities caused by different structures of neighboring nodes
基于以上問題,在注意力機(jī)制中引入結(jié)構(gòu)信息。首先將單位矩陣I添加到鄰接矩陣A中,為每個(gè)節(jié)點(diǎn)創(chuàng)建一個(gè)自循環(huán):
(5)
(6)
(7)
(8)
圖4 SDL提取高階鄰居節(jié)點(diǎn)的結(jié)構(gòu)信息Fig.4 SDL extracting structural information of high-order neighbor
(9)
隨著k的增加,注意力機(jī)制將應(yīng)用于更高階的節(jié)點(diǎn),即更長距離的節(jié)點(diǎn)。結(jié)合公式(2),定義vi與第k階鄰居的注意力機(jī)制如下:
(10)
將提取到的不同階特征連接起來:
(11)
圖5 TDL提取節(jié)點(diǎn)不同時(shí)間步之間的依賴關(guān)系Fig.5 TDL extracting dependencies between nodes at different time steps
(12)
(13)
(14)
為了提高特征的表達(dá)能力,采用多頭注意力[14],由n個(gè)并行的注意力機(jī)制共同輸出,并連接為最終的輸出:
(15)
模型將上述多個(gè)時(shí)空模塊進(jìn)行堆疊,形成多層結(jié)構(gòu),并采用具有ReLU激活函數(shù)的全連接層作為模型的最終輸出:
Y=ReLU(φH+b)∈T′
(16)
公式(16)不僅再次提高了特征表示的能力,還將最終的輸出變換為真實(shí)預(yù)測值的維度,其中T′是用于預(yù)測的時(shí)間步長數(shù)。
本文在以下三個(gè)真實(shí)交通數(shù)據(jù)集上驗(yàn)證SGAN的性能。
(1)METR-LA:該數(shù)據(jù)集包含從美國洛杉磯高速公路上的探測器收集的交通信息,選擇了其中207個(gè)傳感器從2012年3月1日到2012年6月30日采集的數(shù)據(jù)。
(2)PeMS-BAY:該數(shù)據(jù)集由美國加利福尼亞州運(yùn)輸機(jī)構(gòu)(CalTrans)收集,包含325個(gè)傳感器從2017年1月到2017年5月采集的數(shù)據(jù)。
(3)PeMS-S:該數(shù)據(jù)集也是由美國加利福尼亞州運(yùn)輸機(jī)構(gòu)(CalTrans)收集,由228個(gè)傳感器從2012年5月至2012年6月采集的數(shù)據(jù)。
以上所有數(shù)據(jù)集的時(shí)間步長的間隔為5 min,應(yīng)用Z-Score標(biāo)準(zhǔn)化,70%的數(shù)據(jù)用于訓(xùn)練,20%的數(shù)據(jù)用于測試,其余10%的數(shù)據(jù)用于驗(yàn)證。
本文實(shí)驗(yàn)均由谷歌TensorFlow深度學(xué)習(xí)平臺(tái)實(shí)現(xiàn),運(yùn)行在NVIDIA Titan RTX GPU平臺(tái)上,使用RMSprop優(yōu)化器,迭代訓(xùn)練50次。三個(gè)周期數(shù)據(jù)的時(shí)間步長為Th=12,Td=12,Tw=12,歷史時(shí)間步長總和T=36。預(yù)測時(shí)間步長T′=12,通過全連接神經(jīng)網(wǎng)絡(luò)將輸入的原始周期數(shù)據(jù)轉(zhuǎn)換為16維的特征向量,作為第一個(gè)時(shí)空注意力模塊的輸入。整個(gè)架構(gòu)堆疊2個(gè)時(shí)空注意力模塊,每個(gè)模塊采用8頭注意力機(jī)制,高階參數(shù)k為3。
本文采用三個(gè)度量標(biāo)準(zhǔn)用于評估SGAN的性能,包括平均絕對誤差(MAE)、均方根誤差(RMSE)和平均絕對百分比誤差(MAPE),將SGAN與以下代表性的基準(zhǔn)模型進(jìn)行比較。
(1)ARIMA:差分自回歸移動(dòng)平均模型,一種時(shí)間序列分析方法[15]。
(2)FC-LSTM:解決了LSTM只考慮時(shí)序,沒有考慮空間相關(guān)性的問題,使用了帶有全連接層的LSTM模型提取時(shí)空依賴關(guān)系[16]。
(3)STGCN:時(shí)空圖卷積模型,使用CNN卷積的方式提取時(shí)空依賴關(guān)系[17]。
(4)DCRNN:擴(kuò)散卷積循環(huán)神經(jīng)網(wǎng)絡(luò),將交通流建模為擴(kuò)散過程,采用擴(kuò)散卷積提取空間依賴關(guān)系,采用RNN提取時(shí)間依賴關(guān)系[18]。
(5)Graph WaveNet:一種圖神經(jīng)網(wǎng)絡(luò)架構(gòu),采用擴(kuò)散圖卷積提取空間依賴關(guān)系,采用帶空洞的卷積核提取時(shí)間依賴關(guān)系[19]。
(6)ASTGCN:時(shí)空注意力模型,結(jié)合注意力機(jī)制和卷積核同時(shí)捕獲時(shí)空依賴關(guān)系[9]。
(7)SLCNN:結(jié)構(gòu)學(xué)習(xí)卷積模型,將傳統(tǒng)的CNN擴(kuò)展到圖域并學(xué)習(xí)圖的結(jié)構(gòu)信息用于交通流預(yù)測[20]。
(8)NS-SGAN:在SGAN的基礎(chǔ)上,去除了SDL和TDL中的轉(zhuǎn)移矩陣和結(jié)構(gòu)化向量,評估SGAN中結(jié)構(gòu)化設(shè)計(jì)的有效性。
表1和表2分別顯示了在METR-LA和PeMS-BAY數(shù)據(jù)集上,SGAN與其他基準(zhǔn)模型在15 min、30 min和60 min時(shí)的預(yù)測結(jié)果比較,從中可以看出,基于圖結(jié)構(gòu)的STGCN、GraphWaveNet以及本文的SGAN大部分?jǐn)?shù)據(jù)皆優(yōu)于FC-LSTM等非圖結(jié)構(gòu)模型,說明交通網(wǎng)絡(luò)的結(jié)構(gòu)信息對交通預(yù)測至關(guān)重要。與其他時(shí)空模型相比,采用了注意力機(jī)制的SGAN大部分?jǐn)?shù)據(jù)優(yōu)于基于卷積的STGCN和基于遞歸的DCRNN模型,而SGAN采用注意力機(jī)制提取長距離依賴關(guān)系,在60 min的預(yù)測中優(yōu)勢更明顯,表明注意力機(jī)制在長時(shí)預(yù)測方面更有效。同時(shí),SGAN的結(jié)果明顯優(yōu)于去除結(jié)構(gòu)化信息的NS-SGAN,證明了注意力機(jī)制中的結(jié)構(gòu)信息是重要的。
表 1 SGAN和基準(zhǔn)模型在METR-LA上的性能比較
表3比較了SGAN與基準(zhǔn)模型在PeMS-S數(shù)據(jù)集上針對15 min、30 min、45 min和60 min交通流的預(yù)測結(jié)果。采用圖結(jié)構(gòu)和注意力機(jī)制的模型,其表現(xiàn)優(yōu)于包括ARIMA和FC-LSTM在內(nèi)的時(shí)間序列模型,與基于GCN和RNN結(jié)構(gòu)的STGCN、DCRNN模型相比,采用注意力機(jī)制的SGAN效果更優(yōu)。由于ASTGCN只是部分采用注意力機(jī)制,仍然結(jié)合了卷積核進(jìn)行特征提取,因此其無法有效提取長距離依賴關(guān)系,需要多個(gè)獨(dú)立模塊分別提取不同的周期特征,增加了模型的復(fù)雜性。相比Graph WaveNet,SGAN在45 min和60 min的預(yù)測上表現(xiàn)更好,證明了注意機(jī)制對更長時(shí)間交通流預(yù)測的能力;Graph WaveNet由于采用了小感受野的卷積核,因此在提取短距離依賴方面的表現(xiàn)不錯(cuò),在長距離依賴方面不如SGAN有效,特別是SGAN采用完全的注意力機(jī)制,可以直接提取長距離時(shí)空依賴性,不需要通過疊加太深的層增加感受野,從而降低了模型訓(xùn)練的難度。
表 3 在PeMS-S 數(shù)據(jù)集上的實(shí)驗(yàn)結(jié)果對比
本文還分析了高階情況下模型的性能,圖6顯示了在PeMS-BAY數(shù)據(jù)集上MAE取不同k值的結(jié)果比較??梢钥闯?k值越大,數(shù)據(jù)集上的MAE結(jié)果越小。當(dāng)模型考慮更高階的鄰居時(shí),k必須設(shè)置得足夠大,以覆蓋每個(gè)節(jié)點(diǎn)足夠遠(yuǎn)的鄰居。然而,高階信息意味著更多的參數(shù),當(dāng)參數(shù)數(shù)量增加時(shí),訓(xùn)練時(shí)過擬合問題會(huì)導(dǎo)致效果不明顯。此外,對于較大的k值,經(jīng)過多層特征提取后,圖上節(jié)點(diǎn)的信息傳遞會(huì)過于平滑,注意力系數(shù)的權(quán)重會(huì)趨向于平均分配,對節(jié)點(diǎn)失去區(qū)分性。
圖6 PeMS-BAY數(shù)據(jù)集上取不同k值的MAE結(jié)果Fig.6 MAE results of different values of k on PeMS-BAY
為了觀察注意力機(jī)制的影響,從SDL提取出學(xué)習(xí)到的注意力系數(shù)矩陣并可視化。從PeMS-S數(shù)據(jù)集中隨機(jī)選擇的16個(gè)傳感器(節(jié)點(diǎn)),其連接的結(jié)構(gòu)如圖7所示。圖8是當(dāng)k=1時(shí)的注意力系數(shù)熱力圖,點(diǎn)(x,y)處的像素值表示節(jié)點(diǎn)x對節(jié)點(diǎn)y的注意力系數(shù),像素值顏色越深,表示對應(yīng)的注意力系數(shù)權(quán)重越高。由于鄰接矩陣A的對稱性,因此系數(shù)矩陣也具有對稱的形式。當(dāng)k大于1時(shí),觀察到許多較深色的非對角塊,因?yàn)殡A數(shù)增加時(shí),在感受野中有更多的節(jié)點(diǎn),相當(dāng)于可以提取到更長距離的節(jié)點(diǎn)依賴關(guān)系。以節(jié)點(diǎn)11為例,圖8只考慮一階鄰居節(jié)點(diǎn),如節(jié)點(diǎn)2、7、8及其自身。當(dāng)k值增大時(shí),可以從圖9和圖10中看到,節(jié)點(diǎn)11與更遠(yuǎn)處的節(jié)點(diǎn)產(chǎn)生了明顯的注意力系數(shù),即考慮了長距離的空間依賴性。
圖7 隨機(jī)選取節(jié)點(diǎn)形成的結(jié)構(gòu)圖Fig.7 Structure diagram formed by randomly selecting nodes
圖8 k為1時(shí)注意力系數(shù)熱力圖Fig.8 Heat map of attention coefficient when k is 1
圖9 k為2時(shí)注意力系數(shù)熱力圖Fig.9 Heat map of attention coefficient when k is 2
圖10 k為3時(shí)注意力系數(shù)熱力圖Fig.10 Heat map of attention coefficient when k is 3
本文提出了一種基于圖注意力機(jī)制的交通流預(yù)測模型,該模型摒棄了傳統(tǒng)的卷積、遞歸單元,采用注意力機(jī)制能有效地提取交通網(wǎng)絡(luò)圖中的長距離時(shí)空依賴關(guān)系。同時(shí),在注意力機(jī)制中引入了空間結(jié)構(gòu)和時(shí)間結(jié)構(gòu)信息,使模型不僅考慮了節(jié)點(diǎn)之間的特征相關(guān)性,還考慮了節(jié)點(diǎn)的連接方式,在長距離時(shí)空上能有效區(qū)分不同的節(jié)點(diǎn)。在真實(shí)數(shù)據(jù)集上的實(shí)驗(yàn)表明,該模型的預(yù)測結(jié)果明顯優(yōu)于現(xiàn)有模型。