\providecommand\N \providecommand\Q \providecommand\R \providecommand\Z \providecommand\pa \providecommand\br Extra open brace or missing close brace \providecommand\abs \providecommand\norm \providecommand\floor \providecommand\ceil \providecommand\eval \providecommand\pd \providecommand\Lim \providecommand\Prod \providecommand\Sum \providecommand\algoProc \providecommand\algoEndProc \providecommand\algoIf \providecommand\algoEndIf \providecommand\algoEq \providecommand\algoFor \providecommand\algoEndFor \providecommand\algoWhile \providecommand\algoEndWhile \providecommand\algoReturn \providecommand\hash

目標 提出 RNN 使用 BPTT 進行最佳化時遇到的問題,並提出 LSTM 架構進行修正
作者 Sepp Hochreiter, Jürgen Schmidhuber
期刊/會議名稱 Neural Computation
發表時間 1997
論文連結 https://ieeexplore.ieee.org/abstract/document/6795963
書本連結 https://link.springer.com/chapter/10.1007/978-3-642-24797-2_4

\providecommand\opnet \providecommand\opin \providecommand\opout \providecommand\ophid \providecommand\opblk \providecommand\opig \providecommand\opog \providecommand\opseq \providecommand\Loss \providecommand\loss \providecommand\net \providecommand\fnet \providecommand\dfnet \providecommand\din \providecommand\dout \providecommand\dhid \providecommand\dblk \providecommand\nblk \providecommand\tp \providecommand\tf \providecommand\dv \providecommand\blk \providecommand\wig \providecommand\wog \providecommand\whid \providecommand\wblk \providecommand\wout \providecommand\netig \providecommand\fnetig \providecommand\dfnetig \providecommand\netog \providecommand\fnetog \providecommand\dfnetog \providecommand\nethid \providecommand\fnethid \providecommand\dfnethid \providecommand\netout \providecommand\fnetout \providecommand\dfnetout \providecommand\netcell \providecommand\aptr

重點

  • 此篇論文LSTM-2000 都寫錯自己的數學公式,但我的筆記內容主要以正確版本為主
  • 計算 RNN 梯度反向傳播的演算法包含 BPTTRTRL
    • BPTT 全名為 Back-Propagation Throught Time
    • RTRL 全名為 Real Time Recurrent Learning
  • 不論使用 BPTT 或 RTRL,RNN 的梯度都會面臨爆炸消失的問題
    • 梯度爆炸造成神經網路的權重劇烈振盪
    • 梯度消失造成訓練時間慢長
    • 無法解決時間差較長的問題
  • 論文提出 LSTM + RTRL 能夠解決上述問題
    • Backward pass 演算法時間複雜度O(w)w 代表權重
    • Backward pass 演算法空間複雜度也為 O(w),因此沒有輸入長度的限制
    • 此結論必須依靠丟棄部份梯度並使用 RTRL 才能以有效率的辦法解決梯度爆炸消失
  • 使用乘法閘門Multiplicative Gate)學習開啟 / 關閉模型記憶寫入 / 讀取機制
  • LSTM 的閘門單元參數應該讓偏差項(bias term)初始化成負數
    • 閘門偏差項初始化成負數能夠解決內部狀態偏差行為Internal State Drift
    • 閘門偏差項初始化成負數能夠避免模型濫用記憶單元初始值訓練初期梯度過大
    • 如果沒有輸出閘門,則收斂速度會變慢
  • 根據實驗 LSTM 能夠達成以下任務
    • 擁有處理短時間差Short Time Lag)任務的能力
    • 擁有處理長時間差Long Time Lag)任務的能力
    • 能夠處理最長時間差長達 1000 個單位的任務
    • 輸入訊號含有雜訊時也能處理
  • LSTM 的缺點
    • 仍然無法解決 delayed XOR 問題
      • 改成 BPTT 可能可以解決,但計算複雜度變高
      • CEC 在使用 BPTT 後有可能無效,但根據實驗使用 BPTT 時誤差傳遞的過程中很快就消失
    • 在部份任務上無法比 random weight guessing 最佳化速度還要快
      • 例如 500-bit parity
      • 使用 CEC 才導致此後果
      • 但計算效率高,最佳化過程也比較穩定
    • 無法精確的判斷重要訊號的輸入時間
      • 所有使用梯度下降作為最佳演算法的模型都有相同問題
      • 如果精確判斷是很重要的功能,則作者認為需要幫模型引入計數器的功能
  • 當單一字元的出現次數期望值增加時,學習速度會下降
    • 作者認為是常見字詞的出現導致參數開始振盪
  • PyTorch 實作的 LSTM 完全不同
    • 本篇論文的架構定義更為廣義
    • 本篇論文只有輸入閘門Input Gate)跟輸出閘門Output Gate),並沒有使用失憶閘門Forget Gate

傳統的 RNN

計算定義

一個 RNN 模型在 t 時間點的輸入來源共有兩種:

  • 外部輸入External Inputx(t)
    • 輸入維度為 din
    • 使用下標 xj(t) 代表不同的輸入訊號,j{1,,din}
  • 總輸出Total Outputy(t)
    • 輸出維度為 dout
    • 使用下標 yj(t) 代表不同的輸入訊號,j{1,,dout}
    • 注意這裡是使用 t 不是 t1
  • 總共計算 T 個時間點
    • 時間為離散狀態,t 的起始值為 0,結束值為 T1,每次遞增 1
    • 初始化 y(0)=0,輸入為 x(0),,x(T1),輸出為 y(1),,y(T)

令 RNN 模型的參數wRdout×(din+dout),如果我們已經取得 t 時間點的外部輸入 x(t)總輸出 y(t),則我們可以定義 t+1 時間點的計算狀態

(1)neti(t+1)=j=1doutwi,jyj(t)+j=1dinwi,dout+jxj(t)=j=1dout+dinwi,j(y(t)x(t))jnet(t+1)=w(y(t)x(t))
  • neti(t+1) 代表第 t+1 時間的模型內部節點 i 所收到的淨輸入(total input)
    • t 時間點的輸入訊號計算 t+1 時間點的輸出結果
    • 這是早年常見的 RNN 公式表達法
  • wi,j 代表輸入節點 j模型內部節點 i 所連接的權重
    • 輸入節點可以是外部輸入 xj(t) 或是總輸出 yj(t)
    • 總共有 din+dout 個輸入節點,因此 1jdin+dout
    • 總共有 dout 個內部節點,因此 1idout

令模型使用的啟發函數Activation Function)為 f:RdoutRdout,並且內部節點之間無相互連接(Element-wise Activation Function),則我們可以得到 t+1 時間的輸出

(2)yi(t+1)=fi(neti(t+1))y(t+1)=f(net(t+1))
  • 使用下標 fi 是因為每個維度所使用的啟發函數可以不同
  • f 必須要可以微分,當時與 RNN 有關的論文幾乎都是令 fi 為 sigmoid 函數 σ(x)=1/(1+ex)
  • 後續論文分析都是採用 sigmoid 函數,因此我們直接以 σ 表達 fi

計算誤差

如果 t+1 時間點的輸出目標y^(t+1)Rdout,則目標函數最小平方差(Mean Square Error):

(3)lossi(t+1)=12(yi(t+1)y^i(t+1))2loss(t+1)=i=1doutlossi(t+1)

梯度計算

根據 (3) 我們知道 yi(t+1)loss(t+1) 所得梯度為

(4)loss(t+1)yi(t+1)=loss(t+1)lossi(t+1)lossi(t+1)yi(t+1)=1(yi(t+1)y^i(t+1))=yi(t+1)y^i(t+1)

根據 (4) 我們可以推得 neti(t+1)loss(t+1) 所得梯度

(5)loss(t+1)neti(t+1)=loss(t+1)yi(t+1)yi(t+1)neti(t+1)=σ(neti(t+1))(yi(t+1)y^i(t+1))

式子 (5) 就是論文 3.1.1 節的第一條公式。

根據 (5) 我們可以推得 yj(t)loss(t+1) 所得梯度為

(6)loss(t+1)yj(t)=i=1dout[loss(t+1)neti(t+1)neti(t+1)yj(t)]=i=1dout[σ(neti(t+1))(yi(t+1)y^i(t+1))wi,j]

由於第 t 時間點的輸出 y(t) 的計算是由 net(t) 而來(請見 (2)),所以我們也利用 (6) 計算 netj(t)loss(t+1) 所得梯度(注意是 t 不是 t+1

(7)loss(t+1)netj(t)=loss(t+1)yj(t)yj(t)netj(t)=i=1dout[loss(t+1)neti(t+1)wi,jσ(netj(t+1))]=σ(netj(t+1))i=1dout[wi,jloss(t+1)neti(t+1)]

式子 (7) 就是論文 3.1.1 節的最後一條公式。

模型參數 wi,j 對於 loss(t+1) 所得梯度為

(8)loss(t+1)wi,j=k=1doutloss(t+1)netk(t+1)netk(t+1)wi,j=k=1doutloss(t+1)netk(t+1)[j=1dout+din(wk,jwi,j(y(t)x(t))j+wk,j(y(t)x(t))jwi,j)]=k=1doutloss(t+1)netk(t+1)[(y(t)x(t))j+j=1doutwk,jσ(netj(t))netj(t)wi,j]

而在時間點 t+1 進行參數更新的方法為

(9)wi,jwi,jαloss(t+1)wi,j

(9) 就是最常用來最佳化神經網路的梯度下降演算法(Gradient Descent),α 代表學習率(Learning Rate)。

梯度爆炸 / 消失

(7) 式我們可以進一步推得 t 時間點造成的梯度與前次時間點 (t1,t2,) 所得的梯度變化關係。 注意這裡的變化關係指的是梯度與梯度之間的變化率,意即用時間點 t1 的梯度對時間點 t 的梯度算微分。

為了方便計算,我們定義新的符號

(10)ϑktfuture[tpast]=loss(tfuture)netk(tpast)

意思是在過去時間點 tpast 的第 k模型內部節點 netk(tpast) 對於未來時間點 tfuture 貢獻的總誤差 loss(tfuture) 計算所得之梯度

  • 注意是貢獻總誤差所得之梯度
  • 根據時間的限制我們有不等式 0tpasttfuture
  • 節點 k 的數值範圍為 k{1,,dout},見式子 (1)

因此

(11)ϑk0t[t]=loss(t)netk0(t);ϑk1t[t1]=loss(t)netk1(t1)=σ(netk1(t1))(k0=1doutwk0,k1ϑk0t[t]);ϑk2t[t2]=loss(t)netk2(t2)=k1=1dout[loss(t)netk1(t1)netk1(t1)yk2(t2)yk2(t2)netk2(t2)]=k1=1dout[ϑk1t[t1]wk1,k2σ(netk2(t2))]=k1=1dout[σ(netk1(t1))(k0=1doutwk0,k1ϑk0t[t])wk1,k2σ(netk2(t2))]=k1=1doutk0=1dout[wk0,k1wk1,k2σ(netk1(t1))σ(netk2(t2))ϑk0t[t]];ϑk3t[t3]=k2=1dout[loss(t)netk2(t2)netk2(t2)yk3(t3)yk3(t3)netk3(t3)]=k2=1dout[ϑk2t[t2]wk2,k3σ(netk3(t3))]=k2=1dout[k1=1doutk0=1dout[wk0,k1wk1,k2σ(netk1(t1))σ(netk2(t2))ϑk0t[t]]wk2,k3σ(netk3(t3))]=k2=1doutk1=1doutk0=1dout[wk0,k1wk1,k2wk2,k3σ(netk1(t1))σ(netk2(t2))σ(netk3(t3))ϑk0t[t]]=k2=1doutk1=1doutk0=1dout[[q=13wkq1,kqσ(netkq(tq))]ϑk0t[t]]

(11) 我們可以歸納得出 n1 時的公式

(12)ϑknt[tn]=kn1=1doutk0=1dout[[q=1nwkq1,kqσ(netkq(tq))]ϑk0t[t]]

(12) 我們可以看出 ϑknt[tn] 都與 ϑk0t[t] 相關,因此我們將 ϑknt[tn] 想成由 ϑk0t[t] 構成的函數。

現在讓我們固定 k0{1,,dout},我們可以計算 ϑk0t[t] 對於 ϑknt[tn] 的微分,分析梯度在進行反向傳遞過程中的變化率

  • n=1 時,根據 (11) 我們可以推得論文中的 (3.1) 式

    (13)ϑknt[tn]ϑk0t[t]=wk0,k1σ(netk1(t1))
  • n>1 時,根據 (12) 我們可以推得論文中的 (3.2) 式

    (14)ϑknt[tn]ϑk0t[t]=kn1=1doutk1=1doutk0{k0}[q=1nwkq1,kqσ(netkq(tq))]

注意錯誤:論文中的 (3.2) 式不小心把 wlm1lm 寫成 wlmlm1

因此根據 (14),共有 (dout)n1 個連乘積項次進行加總。

根據 (13)(14),如果

(15)|wkq1,kqσ(netkq(tq))|>1.0q=1,,n

梯度變化率成指數 n 增長,直接導致梯度爆炸,參數會進行劇烈的振盪,無法進行順利更新。

而如果

(16)|wkq1,kqσ(netkq(tq))|<1.0q=1,,n

梯度變化率成指數 n 縮小,直接導致梯度消失,誤差收斂速度會變得非常緩慢

(17) 我們知道 σ 最大值為 0.25

(17)σ(x)=11+exσ(x)=ex(1+ex)2=11+exex1+ex=11+ex1+ex11+ex=σ(x)(1σ(x))σ(R)=(0,1)maxxRσ(x)=σ(0)×(1σ(0))=0.5×0.5=0.25

因此當 |wkq1,kq|<4.0 時我們可以發現

(18)|wkq1,kqσ(netkq(tq))|<4.00.25=1.0

所以 (18)(16) 的結論相輔相成:當 wkq1,kq 的絕對值小於 4.0 會造成梯度消失

|wkq1,kq| 我們可以使用 (17) 得到

(19)|netkq1(tq+1)|{σ(netkq1(tq+1))1if netkq1(tq+1)σ(netkq1(tq+1))0if netkq1(tq+1)|σ(netkq1(tq+1))|0|q=1nwkq1,kqσ(netkq(tq))|=|wk0,k1q=2n[σ(netkq1(tq+1))wkq1,kq]σ(netkn(tn))|0

最後一個推論的原理是指數函數的收斂速度比線性函數快

注意錯誤:論文中的推論

|wkq1,kqfkq(netkq(tq))|0

錯誤的,理由是 wkq1,kq 無法對 netkq(tq) 造成影響,作者不小心把時間順序寫反了,但是最後的邏輯仍然正確,理由如 (19) 所示。

注意錯誤:論文中進行了以下函數最大值的推論

flm(netlm(tm)))wlmlm1=σ(netlm(tm))(1σ(netlm(tm)))wlmlml

最大值發生於微分值為 0 的點,即我們想求出滿足以下式子的 wlmlm1

[σ(netlm(tm))(1σ(netlm(tm)))wlmlml]wlmlm1=0

拆解微分式可得

[σ(netlm(tm))(1σ(netlm(tm)))wlmlml]wlmlm1=σ(netlm(tm))netlm(tm)netlm(tm)wlmlm1(1σ(netlm(tm)))wlmlml+σ(netlm(tm))(1σ(netlm(tm)))netlm(tm)netlm(tm)wlmlm1wlmlml+σ(netlm(tm))(1σ(netlm(tm)))wlmlm1wlmlm1=σ(netlm(tm))(1σ(netlm(tm)))2ylm1(tm1)wlmlm1(σ(netlm(tm)))2(1σ(netlm(tm)))ylm1(tm1)wlmlm1+σ(netlm(tm))(1σ(netlm(tm)))=[2(σ(netlm(tm)))33(σ(netlm(tm)))2+σ(netlm(tm))]ylm1(tm1)wlmlm1+σ(netlm(tm))(1σ(netlm(tm)))=σ(netlm(tm))(2σ(netlm(tm))1)(σ(netlm(tm))1)ylm1(tm1)wlmlm1+σ(netlm(tm))(1σ(netlm(tm)))=0

移項後可以得到

σ(netlm(tm))(2σ(netlm(tm))1)(1σ(netlm(tm)))ylm1(tm1)wlmlm1=σ(netlm(tm))(1σ(netlm(tm)))(2σ(netlm(tm))1)ylm1(tm1)wlmlm1=1wlmlm1=1ylm1(tm1)12σ(netlm(tm))1wlmlm1=1ylm1(tm1)coth(netlm(tm)2)

註:推論中使用了以下公式

tanh(x)=2σ(2x)1tanh(x2)=2σ(x)1coth(x2)=1tanh(x2)=12σ(x)1

但公式的前提不對,理由是 wlmlm1 根本不存在,應該改為 wlm1lm(同 (14))。

接著我們可以計算 t 時間點 dout不同節點 netk0(t) 對於同一個 tn 時間點的 netkn(tn) 節點所貢獻的梯度變化總和

(20)k0=1doutϑknt[tn]ϑk0t[t]

由於每個項次都能遭遇梯度消失,因此總和也會遭遇梯度消失

問題觀察

情境 1:模型輸出與內部節點 1-1 對應

假設模型沒有任何輸入,啟發函數 fj 為未知且 t1 時間點的輸出節點 yj(t1) 只與 netj(t) 相連,即

(21)netj(t)=wj,jyj(t1)

則根據式子 (11) 我們可以推得

(22)ϑjt[t1]=wj,jfj(netj(t1))ϑjt[t]

為了不讓梯度 ϑjt[t] 在傳遞的過程消失,作者認為需要強制達成梯度常數(Constant Error Flow)

(23)wj,jfj(netj(t1))=1.0

透過 (23) 的想法讓 (12) 中梯度變化率的連乘積項1.0,因此

  • 不會像 (15) 導致梯度爆炸
  • 不會像 (16) 導致梯度消失

如果 (23) 能夠達成,則積分 (23) 可以得到

(24)wj,jfj(netj(t1))d[netj(t1)]=1.0d[netj(t1)]wj,jfj(netj(t1))=netj(t1)yj(t1)=fj(netj(t1))=netj(t1)wj,j

觀察 (24) 我們可以發現

  • 輸入 netj(t1) 與輸出 fj(netj(t1)) 之間的關係是乘上一個常數項 wj,j
  • 代表函數 fj 其實是一個線性函數

若採用 (24) 的架構設計,我們可以發現每個時間點輸出必須完全相同

yj(t)=fj(netj(t))=fj(wj,jyj(t1))(25)=fj(wj,jnetj(t1)wj,j)=fj(netj(t1))=yj(t1)

這個現象稱為 Constant Error Carousel(簡稱 CEC),而作者設計的 LSTM 架構會完全基於 CEC 進行設計,但我覺得概念比較像 ResNet 的 residual connection。

情境 2:增加外部輸入

(21) 的假設改成每個模型內部節點可以額外接收外部輸入

(26)netj(t)=wj,jyj(t1)+i=1dinwj,ixi(t1)

由於 yj(t1) 的設計功能是保留過去計算所擁有的資訊,在 (26) 的假設中唯一能夠更新資訊的方法只有透過 xi(t1) 配合 wj,i 將新資訊合併進入 netj(t)

但作者認為,在計算的過程中,部份時間點的輸入資訊 xi() 可能是雜訊,因此可以(甚至必須)被忽略。 但這代表與外部輸入相接的參數 wj,i 需要同時達成兩種任務:

  • 加入新資訊:代表 |wj,i|0
  • 忽略新資訊:代表 |wj,i|0

因此無法只靠一個 wj,i 決定輸入的影響,必須有額外能夠理解當前內容 (context-sensitive) 的功能模組幫忙決定是否寫入 xi()

情境 3:輸出回饋到多個節點

(21)(26) 的假設改回正常的模型架構

(27)netj(t)=i=1doutwj,iyi(t1)+i=1dinwj,dout+ixi(t1)

由於 yj(t1) 的設計功能是保留過去計算所擁有的資訊,在 (27) 的假設中唯一能夠讓過去資訊影響未來計算結果的方法只有透過 yi(t1) 配合 wj,din+i 將新資訊合併進入 netj(t)

但作者認為,在計算的過程中,部份時間點的輸出資訊 yi() 可能對預測沒有幫助,因此可以(甚至必須)被忽略。 但這代表與輸出相接的參數 wj,din+i 需要同時達成兩種任務:

  • 保留過去資訊:代表 |wj,din+i|0
  • 忽略過去資訊:代表 |wj,din+i|0

因此無法只靠一個 wj,din+i 決定輸出的影響,必須有額外能夠理解當前內容 (context-sensitive) 的功能模組幫忙決定是否讀取 yi()

LSTM 架構

圖 1:記憶單元內部架構。 符號對應請見下個小節。 圖片來源:論文

圖 1

圖 2:LSTM 全連接架構範例。 線條真的多到讓人看不懂,看我整理過的公式比較好理解。 圖片來源:論文

圖 2

為了解決梯度爆炸 / 消失問題,作者決定以 Constant Error Carousel 為出發點(見 (25)),提出 3 個主要的機制,並將這些機制的合體稱為記憶單元區塊(Memory Cell Blocks)(見圖 1):

  • 乘法輸入閘門(Multiplicative Input Gate):用於決定是否更新記憶單元的內部狀態
  • 乘法輸出閘門(Multiplicative Output Gate):用於決定是否輸出記憶單元的計算結果
  • 自連接線性單元(Central Linear Unit with Fixed Self-connection):概念來自於 CEC(見 (25)),藉此保障梯度不會消失

初始狀態

我們將 (1) 中的計算重新定義,並新增幾個符號:

符號 意義 數值範圍
dhid 隱藏單元的個數 N
dblock 每個記憶單元區塊中記憶單元的個數 Z+
nblock 記憶單元區塊的個數 Z+
  • 因為論文 4.3 節有提到可以完全沒有隱藏單元,因此允許 dhid=0
    • 此論文的後續研究似乎都沒有使用隱藏單元
    • 例如更新 LSTM 架構的主要研究 LSTM-2000LSTM-2002 都沒有使用隱藏單元
  • 根據論文 4.4 節,可以同時擁有 nblock 個不同的記憶單元區塊,因此允許 nblock1

接著我們定義 t 時間點的模型計算狀態:

符號 意義 數值範圍
yhid(t) 隱藏單元(Hidden Units) Rdhid
yig(t) 輸入閘門單元(Input Gate Units) Rnblock
yog(t) 輸出閘門單元(Output Gate Units) Rnblock
yblockk(t) 記憶單元區塊 k輸出 Rdblock
sblockk(t) 記憶單元區塊 k內部狀態 Rdblock
y(t) 模型總輸出 Rdout
  • 以上所有向量全部都初始化成各自維度的零向量,也就是 t=0 時模型所有節點(除了輸入)都是 0
  • 根據論文 4.4 節,可以同時擁有 nblock 個不同的記憶單元
    • 圖 2 模型共有 2 個不同的記憶單元
    • 記憶單元區塊上標 k 的數值範圍為 k{1,,nblock}
  • 同一個記憶單元區塊共享閘門單元,因此 yig(t),yog(t) 的維度為 nblock
  • 根據論文 4.3 節,記憶單元閘門單元隱藏單元都算是隱藏層(Hidden Layer)的一部份
    • 外部輸入會與隱藏層總輸出連接
    • 隱藏層會與總輸出連接(但閘門不會)

All units (except for gate units) in all layers have directed connections (serve as input) to all units in the layer above (or to all higher layers; see experiments 2a and 2b)

計算定義

當我們得到 t 時間點的外部輸入 x(t) 時,我們可以進行以下計算得到 t+1 時間點的總輸出 y(t+1)

(28)D=din+dhid+nblock(2+dblock)(29)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))RD(30)k{1,,nblock}(31)yhid(t+1)=fhid(nethid(t+1))=fhid(whidx~(t))(32)yig(t+1)=fig(netig(t+1))=fig(wigx~(t))(33)yog(t+1)=fog(netog(t+1))=fog(wogx~(t))(34)sblockk(t+1)=sblockk(t)+ykig(t+1)g(netblockk(t+1))=sblockk(t)+ykig(t+1)g(wblockkx~(t))(35)yblockk(t+1)=ykog(t+1)h(sblockk(t+1))(36)y(t+1)=fout(netout(t+1))=fout(wout(x(t)yhid(t+1)yblock1(t+1)yblocknblock(t+1)))

以上就是 LSTM(1997 版本)的計算流程。

  • fhid,fig,fog,fout,g,h 都是 differentiable element-wise activation function,大部份都是 sigmoid 或是 sigmoid 的變形
  • fig,fog 的數值範圍(range)必須限制在 [0,1],才能達成閘門的功能
  • fout 的數值範圍只跟任務有關
  • 論文並沒有給 fhid,g,h 任何數值範圍的限制

論文 4.3 節有提到可以完全沒有隱藏單元,而後續的研究(例如 LSTM-2000LSTM-2002)也完全沒有使用隱藏單元,因此 (31) 可以完全不存在。

  • (29) 中的 yhid(t) 必須去除
  • (36) 中的 yhid(t+1) 必須去除
  • 隱藏單元的設計等同於保留 (1)(2) 的架構,是個不好的設計,因此論文後續在最佳化的過程中動了手腳

根據 (32)(34),在計算完 t+1 時間點的輸入閘門 yig(t+1) 後便可以更新 t+1 時間點的記憶單元內部狀態 sblockk(t+1)

  • 記憶單元淨輸入會與輸入閘門進行相乘,因此稱為乘法輸入閘門
  • 由於 t+1 時間點的資訊有加上 t 時間點的資訊,因此稱為自連接線性單元
  • 同一個記憶單元區塊會共享同一個輸入閘門,因此 (34) 中的乘法是純量乘上向量,這也是 yig(t+1)Rnblock 的理由
  • 當模型認為輸入訊號不重要時,模型應該要關閉輸入閘門,即 ykig(t+1)0
    • 丟棄當前輸入訊號,只以過去資訊進行決策
    • 在此狀態下 t+1 時間點的記憶單元內部狀態t 時間點完全相同,達成 (23)(25),藉此保障梯度不會消失
  • 當模型認為輸入訊號重要時,模型應該要開啟輸入閘門,即 ykig(t+1)1
  • 不論輸入訊號 g(netblockk(t+1)) 的大小,只要 ykig(t+1)0,則輸入訊號完全無法影響接下來的所有計算,LSTM 以此設計避免 (26) 所遇到的困境

根據 (33)(35),在計算完 t+1 時間點的輸出閘門 yog(t+1)記憶單元內部狀態 sblockk(t+1) 後便可以得到 t+1 時間點的記憶單元輸出 yblockk(t+1)

  • 記憶單元啟發值會與輸出閘門進行相乘,因此稱為乘法輸出閘門
  • 同一個記憶單元區塊會共享同一個輸出閘門,因此 (35) 中的乘法是純量乘上向量,這也是 yog(t+1)Rnblock 的理由
  • 當模型認為輸出訊號會導致當前計算錯誤時,模型應該關閉輸出閘門,即 ykog(t+1)0
    • 輸入閘門開啟的狀況下,關閉輸出閘門代表不讓現在時間點的資訊影響當前計算
    • 輸入閘門關閉的狀況下,關閉輸出閘門代表不讓過去時間點的資訊影響當前計算
  • 當模型認為輸出訊號包含重要資訊時,模型應該要開啟輸出閘門,即 ykog(t+1)1
    • 輸入閘門開啟的狀況下,開啟輸出閘門代表讓現在時間點的資訊影響當前計算
    • 輸入閘門關閉的狀況下,開啟輸出閘門代表不讓過去時間點的資訊影響當前計算
  • 不論輸出訊號 h(sblockk(t+1)) 的大小,只要 ykog(t+1)0,則輸出訊號完全無法影響接下來的所有計算,LSTM 以此設計避免 (26)(27) 所遇到的困境
  • PyTorch 實作的 LSTMh(t) 表達的意思是記憶單元輸出 yblockk(t)

根據 (36),得到 t+1 時間點的記憶單元輸出 yblockk(t+1) 後就可以計算 t+1 時間點的模型總輸出 y(t+1)

  • 注意在計算 (36) 時並沒有使用閘門單元,與 (29) 的計算不同
  • 注意 y(t+1)yog 不同
    • y(t+1)總輸出,我的 y(t+1) 是論文中的 yk(t+1)
    • yog(t+1)記憶單元輸出閘門,我的 yog(t+1) 是論文中的 youti(t+1)

根據論文 A.7 式下方的描述,t+1 時間點的總輸出只與 t 時間點的模型狀態不含閘門與總輸出)有關係,所以 (31)(32)(33)(35) 的計算都只是在幫助 t+2 時間點的計算狀態鋪陳

我不確定這是否為作者的筆誤,畢竟附錄中所有分析的數學式都寫的蠻正確的,我認為這裡是筆誤的理由如下:

  • 同個實驗室後續的研究(例如 LSTM-2002)寫的式子不同
  • 至少要傳播兩個時間點才能得到輸出,代表第 1 個時間點的輸出完全無法利用到記憶單元的知識
  • 後續的實驗架構設計中沒有將外部輸入連接到輸出,代表第 1 個時間點的輸出完全依賴模型的初始狀態(常數),非常不合理

因此我決定改用我認為是正確的版本撰寫後續的筆記,即 t+1 時間點的總輸出t 時間點的外部輸入t+1 時間點的計算狀態有關。

注意 (32)(33) 沒有使用偏差項(bias term),但後續的分析會提到可以使用偏差項進行計算缺陷的修正。

參數結構

參數 意義 輸出維度 輸入維度
whid 產生隱藏單元的全連接參數 dhid din+dhid+nblock(2+dblock)
wig 產生輸入閘門的全連接參數 nblock din+dhid+nblock(2+dblock)
wog 產生輸出閘門的全連接參數 nblock din+dhid+nblock(2+dblock)
wblockk 產生第 k記憶單元區塊淨輸入的全連接參數 dblock din+dhid+nblock(2+dblock)
wout 產生輸出的全連接參數 dblock din+dhid+nblockdblock

丟棄部份模型單元的梯度

過去的論文中提出以修改最佳化過程避免 RNN 訓練遇到梯度爆炸 / 消失的問題(例如 Truncated BPTT)。

論文 4.5 節提到最佳化 LSTM 的方法為 RTRL 的變種,主要精神如下:

  • 最佳化的核心思想是確保能夠達成 CEC (見 (25)
  • 使用的手段是要求所有梯度反向傳播的過程在經過記憶單元區塊隱藏單元後便停止傳播
  • 停止傳播導致在完成 t+1 時間點的 forward pass 後梯度可以馬上計算完成(real time 的精神便是來自於此)

首先我們定義新的符號 tr,代表計算梯度的過程會有部份梯度故意被丟棄(設定為 0),並以丟棄結果近似真正的全微分

(37)netia(t+1)yjb(t)tr0where a,b{hid,ig,og,block1,,blocknblock}

所有與隱藏單元淨輸入 netihid(t+1)輸入閘門淨輸入 netiig(t+1)輸出閘門淨輸入 netiog(t+1)記憶單元淨輸入 netiblockk(t+1) 直接相連t 時間點的單元,一律丟棄梯度

  • 注意論文在 A.1.2 節的開頭只提到輸入閘門輸出閘門記憶單元丟棄梯度
  • 但論文在 A.9 式描述可以將隱藏單元的梯度一起丟棄,害我白白推敲公式好幾天

Here it would be possible to use the full gradient without affecting constant error flow through internal states of memory cells.

根據 (37) 我們可以進一步推得

(38)a{hid,ig,og}b{hid,ig,og,block1,,blocknblock}yia(t+1)yjb(t)=yia(t+1)netia(t+1)netia(t+1)yjb(t)0tr0k{1,2,,nblock}yiblockk(t+1)yjb(t)=yiblockk(t+1)ykig(t+1)ykig(t+1)yjb(t)0+yiblockk(t+1)netiblockk(t+1)netiblockk(t+1)yjb(t)0+yiblockk(t+1)ykog(t+1)ykog(t+1)yjb(t)0tr0

由於 yig(t+1),yog(t+1),netblockk(t+1) 並不是直接透過 whid 產生,因此 whid 只能透過參與 t 時間點以前的計算間接t+1 時間點的計算造成影響(見 (31)),這也代表在 (38) 作用的情況下 whid 無法yig(t+1),yog(t+1),netblockk(t+1) 收到任何的梯度

(39)a{ig,og,block1,,blocknblock}b{hid,ig,og,block1,,blocknblock}yia(t+1)wp,qhid=j=din+1din+dhid+nblock(2+dblock)[yia(t+1)yjb(t)0yjb(t)wp,qhid]tr0

相對於總輸出所得剩餘梯度

我們將論文的 A.8 式拆解成 (41)(42)(43)(44)

總輸出參數

δa,bKronecker delta,i.e.,

(40)δa,b={1if a=b0otherwise

由於總輸出 y(t+1) 不會像是 (1)(2) 的方式回饋到模型的計算狀態中,因此總輸出參數 wout總輸出 y(t+1) 計算所得的梯度

(41)i,p{1,,dout}q{1,,din+dhid+nblockdblock}yi(t+1)wp,qout=yi(t+1)netiout(t+1)netiout(t+1)wp,qout=fiout(netiout(t+1))δi,p(x(t)yhid(t+1)yblock1(t+1)yblocknblock(t+1))q
  • (41) 就是論文中 A.8 式的第一個 case
  • 由於 p 可以是任意的輸出節點,因此在 ipwp,qout 對於 yi(t+1) 的梯度為 0

隱藏單元參數

(37)(38)(39) 的作用下,我們可以求得隱藏單元參數 whid丟棄部份梯度後對於總輸出 y(t+1) 計算所得的剩餘梯度

(42)D=din+dhid+nblockdblockx~(t+1)=(x(t)yhid(t+1)yblock1(t+1)yblocknblock(t+1))RDi{1,,dout}p{1,,dhid}q{1,,D}yi(t+1)wp,qhid=yi(t+1)netiout(t+1)netiout(t+1)wp,qhid=fiout(netiout(t+1))j=1D[netiout(t+1)x~j(t+1)x~j(t+1)wp,qhidtr]trfiout(netiout(t+1))wi,poutyphid(t+1)wp,qhid

(42) 就是論文中 A.8 式的最後一個 case。

閘門單元參數

(42),我們可以計算閘門單元參數 wig,wog總輸出 y(t+1) 計算所得的剩餘梯度

(43)D=din+dhid+nblockdblockx~(t+1)=(x(t)yhid(t+1)yblock1(t+1)yblocknblock(t+1))RDi{1,,dout}k{1,,nblock}q{1,,din+dhid+nblock(2+dblock)}yi(t+1)wk,qog=yi(t+1)netiout(t+1)netiout(t+1)wk,qog=fiout(netiout(t+1))j=1D[netiout(t+1)x~j(t+1)x~j(t+1)wk,qogtr]trfiout(netiout(t+1))j=1dblock[wi,din+dhid+(k1)dblock+joutyjblockk(t+1)wk,qog]yi(t+1)wk,qigtrfiout(netiout(t+1))j=1dblock[wi,din+dhid+(k1)dblock+joutyjblockk(t+1)wk,qig]

(43) 就是論文中 A.8 式的第三個 case。

記憶單元淨輸入參數

記憶單元淨輸入參數 wblockk總輸出 y(t+1) 計算所得的剩餘梯度(43) 幾乎相同

(44)D=din+dhid+nblockdblockx~(t+1)=(x(t)yhid(t+1)yblock1(t+1)yblocknblock(t+1))RDi{1,,dout}k{1,,nblock}p{1,,dblock}q{1,,din+dhid+nblock(2+dblock)}yi(t+1)wp,qblockk=yi(t+1)netiout(t+1)netiout(t+1)wp,qblockk=fiout(netiout(t+1))j=1D[netiout(t+1)x~j(t+1)x~j(t+1)wp,qblockktr]trfiout(netiout(t+1))wi,din+dhid+(k1)dblock+poutypblockk(t+1)wp,qblockk

(44) 就是論文中 A.8 式的第二個 case。

相對於隱藏單元所得剩餘梯度

我們將論文的 A.9 式拆解成 (45)(46)(47)

隱藏單元參數

根據 (37)(38) 我們可以得到隱藏單元參數 whid 對於隱藏單元 yhid(t+1) 計算所得剩餘梯度

(45)i,p{1,,dhid}q{1,,din+dhid+nblock(2+dblock)}yihid(t+1)wp,qhid=yihid(t+1)netihid(t+1)netihid(t+1)wp,qhidtrtrfihid(netihid(t+1))δi,p(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))q

閘門單元參數

由於隱藏單元 yhid(t+1) 並不是直接透過閘門參數 wig,wog 產生,因此根據 (37) 我們可以推得 wig,wog 對於 yhid(t+1) 剩餘梯度0

(46)D=din+dhid+nblock(2+dblock)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))RDi{1,,dhid}p{1,,nblock}q{1,,D}yihid(t+1)wp,qog=yihid(t+1)netihid(t+1)j=1D[netihid(t+1)x~j(t)0x~j(t)wp,qog]tr0yihid(t+1)wp,qigtr0

記憶單元淨輸入參數

(46),由於隱藏單元 yhid(t+1) 並不是直接透過記憶單元淨輸入參數 wblockk 產生,因此根據 (37) 我們可以推得 wblockk 對於 yhid(t+1) 剩餘梯度0

(47)D=din+dhid+nblock(2+dblock)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))RDi{1,,dhid}k{1,,nblock}p{1,,dblock}q{1,,D}yihid(t+1)wp,qblockk=yihid(t+1)netihid(t+1)j=1D[netihid(t+1)x~j(t)0x~j(t)wp,qblockk]tr0

相對於記憶單元輸出所得剩餘梯度

我們將論文的 A.13 式拆解成 (48)(49)(50)

閘門單元參數

根據 (37) 我們可以推得閘門單元參數 wig,wog 對於記憶單元輸出 yblockk(t+1) 計算所得剩餘梯度

i{1,,dblock}k,p{1,,nblock}q{1,,din+dhid+nblock(2+dblock)}yiblockk(t+1)wp,qog=yiblockk(t+1)ykog(t+1)ykog(t+1)wp,qog+yiblockk(t+1)siblockk(t+1)siblockk(t+1)wp,qog0(48)trhi(siblockk(t+1))δk,pykog(t+1)wk,qogyiblockk(t+1)wp,qig=yiblockk(t+1)ykog(t+1)ykog(t+1)wp,qig0+yiblockk(t+1)siblockk(t+1)siblockk(t+1)wp,qig(49)trykog(t+1)hi(siblockk(t+1))δk,psiblockk(t+1)wk,qig

記憶單元淨輸入參數

(49),使用 (37) 推得記憶單元淨輸入參數 wblockk 對於記憶單元輸出 yblockk(t+1) 計算所得剩餘梯度(注意 k 可以不等於 k

(50)i,p{1,,dblock}k,k{1,,nblock}q{1,,din+dhid+nblock(2+dblock)}yiblockk(t+1)wp,qblockk=yiblockk(t+1)ykog(t+1)ykog(t+1)wp,qblockk0+yiblockk(t+1)siblockk(t+1)siblockk(t+1)wp,qblockktrykog(t+1)hi(siblockk(t+1))δk,kδi,psiblockk(t+1)wi,qblockk

注意錯誤:論文 A.13 式最後使用加法 δinjl+δcjvl,可能會導致梯度乘上常數 2,因此應該修正成乘法 δinjlδcjvl

相對於閘門單元所得剩餘梯度

我們將論文的 A.10, A.11 式拆解成 (51)(52)

閘門單元參數

根據 (37)(38) 我們可以得到閘門單元參數 wig,wog 對於閘門單元 yig(t+1),yog(t+1) 計算所得剩餘梯度

(51)D=din+dhid+nblock(2+dblock)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock)RDk,p{1,,nblock}q{1,,D}ykig(t+1)[wig;wog]p,q=ykig(t+1)netkig(t+1)netkig(t+1)[wig;wog]p,qtrtrfkig(netkig(t+1))δk,px~q(t)ykog(t+1)[wig;wog]p,qtrδk,pfkog(netkog(t+1))x~q(t)

記憶單元淨輸入參數

由於閘門單元 yig(t+1),yog(t+1) 並不是直接透過記憶單元淨輸入參數 wblockk 產生,因此根據 (37) 我們可以推得 wblockk 對於 yig(t+1),yog(t+1) 剩餘梯度0

(52)D=din+dhid+nblock(2+dblock)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock)RDk{1,,nblock}p{1,,dblock}q{1,,D}ykig(t+1)wp,qblockk=ykig(t+1)netkig(t+1)j=1D[netkig(t+1)x~j(t)0x~j(t)wp,qblockk]tr0ykog(t+1)wp,qblockktr0

相對於記憶單元內部狀態所得剩餘梯度

我們將論文的 A.12 式拆解成 (53)(54)(55)

閘門單元參數

(37) 結合 (51) 我們可以推得閘門單元參數 wig,wog 對於記憶單元內部狀態 sblockk(t+1) 計算所得剩餘梯度

D=din+dhid+nblock(2+dblock)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))RDi{1,,dblock}k,p{1,,nblock}q{1,,D}siblockk(t+1)wp,qog=siblockk(t+1)siblockk(t)siblockk(t)wp,qog0+siblockk(t+1)ykig(t+1)ykig(t+1)wp,qog0+siblockk(t+1)netiblockk(t+1)netiblockk(t+1)wp,qog0(53)tr0siblockk(t+1)wp,qig=siblockk(t+1)siblockk(t)siblockk(t)wp,qig+siblockk(t+1)ykig(t+1)ykig(t+1)wp,qig+siblockk(t+1)netiblockk(t+1)netiblockk(t+1)wp,qig0tr1δk,psiblockk(t)wk,qig+gi(netiblockk(t+1))δk,pykig(t+1)wk,qigtr(54)trδk,p[siblockk(t)wk,qig+gi(netiblockk(t+1))fkig(netkig(t+1))x~q(t)]

記憶單元淨輸入參數

使用 (37) 推得記憶單元淨輸入參數 wblockk 對於記憶單元內部狀態 sblockk(t+1) 計算所得剩餘梯度(注意 k 可以不等於 k

(55)D=din+dhid+nblock(2+dblock)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))RDi,p{1,,dblock}k,k{1,,nblock}q{1,,D}siblockk(t+1)wp,qblockk=siblockk(t+1)siblockk(t)siblockk(t)wp,qblockk+siblockk(t+1)ykig(t+1)ykig(t+1)wp,qblockk0+siblockk(t+1)netiblockk(t+1)netiblockk(t+1)wp,qblockktrδk,kδi,p1siblockk(t)wi,qblockk+δk,kδi,pykig(t+1)gi(netiblockk(t+1))x~q(t)=δk,kδi,p[siblockk(t)wi,qblockk+ykig(t+1)gi(netiblockk(t+1))x~q(t)]

注意錯誤:論文 A.12 式最後使用加法 δinjl+δcjvl,可能會導致梯度乘上常數 2,因此應該修正成乘法 δinjlδcjvl

更新模型參數

總輸出參數

(4) 我們可以觀察出以下結論

(56)x~(t+1)=(x(t)yhid(t+1)yblock1(t+1)yblocknblock(t+1))i{1,,dout}j{1,,din+dhid+nblockdblock}loss(t+1)wi,jout=loss(t+1)lossi(t+1)lossi(t+1)yi(t+1)yi(t+1)wi,jout=(yi(t+1)y^i(t+1))yi(t+1)wi,jout=(yi(t+1)y^i(t+1))fiout(netiout(t+1))x~j(t+1)

隱藏單元參數

(4)(39)(42)(45) 我們可以觀察出以下結論

(57)p{1,,dhid}q{1,,din+dhid+nblock(2+dblock)}loss(t+1)wp,qhid=i=1dout[loss(t+1)lossi(t+1)lossi(t+1)yi(t+1)yi(t+1)wp,qhid]tri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,poutyphid(t+1)wp,qhid]=i=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,pout]yphid(t+1)wp,qhidtri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,pout]fphid(netphid(t+1))(x(t)yhid(t)yblock1(t)yblocknblock(t))j

輸出閘門單元參數

(4)(43)(48)(51)(53) 我們可以觀察出以下結論

(58)k{1,,nblock}q{1,,din+dhid+nblock(2+dblock)}loss(t+1)wk,qog=i=1dout[loss(t+1)lossi(t+1)lossi(t+1)yi(t+1)yi(t+1)wk,qog]tri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))j=1dblock(wi,din+dhid+(k1)dblock+joutyjblockk(t+1)wk,qog)]tri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))j=1dblock(wi,din+dhid+(k1)dblock+jouthj(sjblockk(t+1))ykog(t+1)wk,qog)]=[i=1dout(yi(t+1)y^i(t+1))fiout(netiout(t+1))(j=1dblockwi,din+dhid+(k1)dblock+jouthj(sjblockk(t+1)))]ykog(t+1)wk,qogtr[i=1dout(yi(t+1)y^i(t+1))fiout(netiout(t+1))(j=1dblockwi,din+dhid+(k1)dblock+jouthj(sjblockk(t+1)))]fkog(netkog(t+1))(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))q

輸入閘門單元參數

(4)(43)(49)(51)(54) 我們可以觀察出以下結論

(59)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))k{1,,nblock}q{1,,din+dhid+nblock(2+dblock)}loss(t+1)wk,qig=i=1dout[loss(t+1)lossi(t+1)lossi(t+1)yi(t+1)yi(t+1)wk,qig]tri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))j=1dblock(wi,din+dhid+(k1)dblock+joutyjblockk(t+1)wk,qig)]tri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))j=1dblock(wi,din+dhid+(k1)dblock+joutykog(t+1)hj(sjblockk(t+1))sjblockk(t+1)wk,qig)]=(i=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))j=1dblock(wi,din+dhid+(k1)dblock+jouthj(sjblockk(t+1))sjblockk(t+1)wk,qig)])ykog(t+1)tr(i=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))j=1dblock(wi,din+dhid+(k1)dblock+jouthj(sjblockk(t+1))[sjblockk(t)wk,qig+gj(netjblockk(t+1))fkig(netkig(t+1))x~q(t)])])ykog(t+1)

記憶單元淨輸入參數

(4)(44)(47)(50)(52)(55) 我們可以觀察出以下結論

(60)x~(t)=(x(t)yhid(t)yig(t)yog(t)yblock1(t)yblocknblock(t))k{1,,nblock}p{1,,dblock}q{1,,din+dhid+nblock(2+dblock)}loss(t+1)wp,qblockk=i=1dout[loss(t+1)lossi(t+1)lossi(t+1)yi(t+1)yi(t+1)wp,qblockk]tri=1dout[(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,din+dhid+(k1)dblock+poutypblockk(t+1)wp,qblockk]=[i=1dout(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,din+dhid+(k1)dblock+pout]ypblockk(t+1)wp,qblockktr[i=1dout(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,din+dhid+(k1)dblock+pout]ykog(t+1)hp(spblockk(t+1))spblockk(t+1)wp,qblockk]tr[i=1dout(yi(t+1)y^i(t+1))fiout(netiout(t+1))wi,din+dhid+(k1)dblock+pout]ykog(t+1)hp(spblockk(t+1))[spblockk(t)wp,qblockk+ykig(t+1)gp(netpblockk(t+1))x~q(t)]

架構分析

時間複雜度

假設 t+1 時間點的 forward pass 已經執行完成,則更新 t+1 時間點所有參數時間複雜度

(61)O(dim(whid)+dim(wog)+dim(wig)+nblockdim(wblock1)+dim(wout))
  • (61) 就是論文中的 A.27 式
  • t+1 時間點參數更新需要考慮 t 時間點的計算狀態,請見 (57)(58)(59)(60)
  • 沒有如同 (14)連乘積項,因此不會有梯度消失問題
  • 整個計算過程需要額外紀錄的梯度項次只有 (59)(60) 中的 sjblockk(t)wk,qig,spblockk(t)wp,qblockk
    • 紀錄讓 LSTM 可以隨著 forward pass 的過程即時更新
    • 不需要等到 T 時間點的計算結束,因此不是採用 BPTT 的演算法
    • 即時更新(意思是 t+1 時間點的 forward pass 完成後便可計算 t+1 時間點的誤差梯度)是 RTRL 的主要精神

總共會執行 T+1forward pass,因此更新所有參數所需的總時間複雜度

(62)O(T[dim(whid)+dim(wog)+dim(wig)+nblockdim(wblock1)+dim(wout)])

空間複雜度

我們也可以推得在 t+1 時間點更新所有參數所需的空間複雜度

(63)O(dim(whid)+dim(wog)+dim(wig)+nblockdim(wblock1)+dim(wout))

總共會執行 Tforward pass,但更新所需的總空間複雜度仍然同 (63)

  • 依照時間順序計算梯度,計算完 t+1 時間點的梯度時 t 的資訊便可丟棄
  • 這就是 RTRL 的最大優點

達成梯度常數

根據 (37)(38) 我們可以推得

(64)i{1,,dblock}k{1,,nblock}siblockk(t+1)siblockk(t)=siblockk(t)siblockk(t)+ykig(t+1)siblockk(t)0gi(netiblockk(t+1))+ykig(t+1)gi(netiblockk(t+1))siblockk(t)0tr1

由於丟棄部份梯度的作用,sblockk梯度是模型中唯一進行遞迴(跨過多個時間點)的計算節點。 透過丟棄部份梯度我們從 (64) 可以看出 LSTM 達成 (23) 所設想的情況。

內部狀態偏差行為

觀察 (54)(59),當 h 是 sigmoid 函數時,我們可以發現

  • 如果 sblockk(t+1) 是一個非常大正數,則 hj(sjblockk(t+1)) 會變得非常小
  • 如果 sblockk(t+1) 是一個非常小負數,則 hj(sjblockk(t+1)) 也會變得非常小
  • sblockk(t+1) 極正或極負的情況下,輸入閘門參數 wig梯度消失
  • 此現象稱為內部狀態偏差行為Internal State Drift
  • 同樣的現象也會發生在記憶單元淨輸入參數 wblock1,wblocknblock 身上,請見 (60)
  • 此分析就是論文的 A.39 式改寫而來

解決 Internal State Drift

作者提出可以在 netig 加上偏差項,並在訓練初期將偏差項弄成很小的負數,邏輯如下

(65)big0netig(1)0yig(1)0swblockk(1)=swblockk(0)+yig(1)g(netwblockk(1))=yig(1)g(netwblockk(1))0{swblockk(t+1)≪̸0swblockk(t+1)≫̸0t=0,,T1

根據 (65) 我們就不會得到 sblockk(t) 極正或極負的情況,也就不會出現 Internal State Drift。

雖然這種作法是種模型偏差Model Bias)而且會導致 yig()fkig(netkig()) 變小,但作者認為這些影響比起 Internal State Drift 一點都不重要。

輸出閘門初始化

論文 4.7 節表示,在訓練的初期模型有可能濫用記憶單元的初始值作為計算的常數項(細節請見 (41)),導致模型在訓練的過程中學會完全不紀錄資訊

因此可以將輸出閘門加上偏差項,並初始化成較小的負數(理由類似於 (65)),讓記憶單元在計算初期輸出值為 0,迫使模型只在需要時指派記憶單元進行記憶

如果有多個記憶單元,則可以給予不同的負數,讓模型能夠按照需要依大小順序取得記憶單元(愈大的負數愈容易被取得)。

輸出閘門的優點

在訓練的初期誤差通常比較,導致梯度跟著變,使得模型在訓練初期的參數劇烈振盪。

由於輸出閘門所使用的啟發函數 fog 是 sigmoid,數值範圍是 (0,1),我們可以發現 (59)(60) 的梯度乘積包含 yog,可以避免過大誤差造成的梯度變大

但這些說法並沒有辦法真的保證一定會實現,算是這篇論文說服力比較薄弱的點。

實驗

實驗設計

  • 要測試較長的時間差
    • 資料集不可以出現短時間差
  • 任務要夠難
    • 不可以只靠 random weight guessing 解決
    • 需要比較多的參數或是高計算精度 (sparse in weight space)

控制變因

  • 使用 Online Learning 進行最佳化
    • 意思就是 batch size 為 1
    • 不要被 Online 這個字誤導
  • 使用 sigmoid 作為啟發函數
    • 包含 fout,fhid,fig,fog
  • 資料隨機性
    • 資料生成為隨機
    • 訓練順序為隨機
  • 在每個時間點 t 的計算順序為
    1. 將外部輸入 x(t) 丟入模型
    2. 計算輸入閘門、輸出閘門、記憶單元、隱藏單元
    3. 計算總輸出
  • 訓練初期只使用一個記憶單元,即 nblock=1
    • 如果訓練中發現最佳化做的不好,開始增加記憶單元,即 nblock=nblock+1
    • 一旦記憶單元增加,輸入閘門與輸出閘門也需要跟著增加
    • 這個概念稱為 Sequential Network Construction
  • hblockkgblockk 函數如果沒有特別提及,就是使用 (66)(67) 的定義

hblockk:R[1,1] 函數的定義為

(66)hblockk(x)=21+exp(x)1=2σ(x)1

gblockk:R[2,2] 函數的定義為

(67)gblockk(x)=41+exp(x)2=4σ(x)2

實驗 1:Embedded Reber Grammar

圖 3:Reber Grammar。 一個簡單的有限狀態機,能夠生成的字母包含 BEPSTVX。 圖片來源:論文

圖 3

圖 4:Embedded Reber Grammar。 一個簡單的有限狀態機,包含兩個完全相同的 Reber Grammar,開頭跟結尾只能是 BT…TE 與 BP…PE。 圖片來源:論文

圖 4

任務定義

  • Embedded Reber Grammar 是實驗 RNN 短時間差(Short Time Lag)的基準測試資料集
    • 圖 3 只是 Reber Grammar,真正的生成資料是使用圖 4 的 Embedded Reber Grammar
    • Embedded Reber Grammar 時間差最短只有 9 個單位
    • 傳統 RNN 在此資料集上仍然表現不錯
    • 資料生成為隨機,任何一個分支都有 0.5 的機率被生成
  • 根據圖 3 的架構,生成的第一個字為 B,接著是 T 或 P
    • 因此前兩個字生成 BT 或 BP 的機率各為 0.5
    • 能夠生成的字母包含 BEPSTVX
    • 生成直到產生 E 結束,結尾一定是 SE 或 VE
    • 由於有限狀態機中有 Loop,因此 Reber Grammar 有可能產生任意長度的文字
  • 根據圖 4 的架構,生成的開頭為 BT 或 BP
    • 前兩個字生成 BT 或 BP 的機率各為 0.5
    • 如果生成 BT,則結尾一定要是 TE
    • 如果生成 BP,則結尾一定要是 PE
    • 因此 RNN 模型必須學會記住開頭的 T / P 與結尾搭配,判斷一個文字序列是否由 Embedded Reber Grammar 生成
  • 模型會在每個時間點 t 收到一個字元,並輸出下一個時間點 t+1 會收到的字元
    • 輸入與輸出都是 one-hot vector,維度為 7,每個維度各自代表 BEPSTVX 中的一個字元,取數值最大的維度作為預測結果
    • 模型必須根據 0,1,t1,t 時間點收到的字元預測 t+1 時間點輸出的字元
    • 概念就是 Language Model
  • 資料數
    • 訓練集:256 筆
    • 測試集:256 筆
    • 總共產生 3 組不同的訓練測試集
    • 每組資料集都跑 10 次實驗,每次實驗模型都隨機初始化
    • 總共執行 30 次實驗取平均
  • 評估方法
    • Accuracy

LSTM 架構

參數 數值(或範圍) 備註
din 7  
dhid 0 沒有隱藏單元
(nblock,dblock) {(3,2),(4,1)} 至少有 3 個記憶單元
dout 7  
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+nblock(2+dblock)] 全連接隱藏層
dim(wig) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wog) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wout) dout×[nblockdblock] 外部輸入沒有直接連接到總輸出
參數初始化範圍 [0.2,0.2]  
輸出閘門偏差項初始化範圍 {1,2,3,4} 由大到小依序初始化不同記憶單元對應輸出閘門偏差項
Learning rate {0.1,0.2,0.5}  
總參數量 {264,276}  

實驗結果

表格 1:Embedded Reber Grammar 實驗結果。 表格來源:論文

表 1

  • LSTM + 丟棄梯度 + RTRL 在不同的實驗架構中都能解決任務
    • RNN + RTRL 無法完成
    • Elman Net + ELM 無法完成
  • LSTM 收斂速度比其他模型都還要快
  • LSTM 使用的參數數量並沒有比其他的模型多太多
  • 驗證輸出閘門的有效性
    • 當 LSTM 模型記住第二個輸入是 T / P 之後,輸出閘門就會讓後續運算的啟發值接近 0,不讓記憶單元內部狀態影響模型學習簡單的 Reber Grammar
    • 如果沒有輸出閘門,則收斂速度會變慢

實驗 2a:無雜訊長時間差任務

任務定義

定義 p+1 種不同的字元,標記為 V={α,β,c1,c2,,cp1}

定義 2 種長度為 p+1 不同的序列 seq1,seq2,分別為

seq1=α,c1,c2,,cp2,cp1,αseq2=β,c1,c2,,cp2,cp1,β

seq{seq1,seq2},令 seqt 個時間點的字元為 seq(t)V

當給予模型 seq(t) 時,模型要能夠根據 seq(0),seq(1),seq(t1),seq(t) 預測 seq(t+1)

  • 模型需要記住 c1,,cp1 的順序
  • 模型也需要記住開頭的 seq(0)α 還是 β,並利用 seq(0) 的資訊預測 seq(p+1)
  • 根據 p 的大小這個任務可以是時間差或時間差
  • 訓練資料
    • 每次以各 0.5 的機率抽出 seq1,seq2 作為輸入
    • 總共執行 5000000 次抽樣與更新
  • 測試資料
    • 每次以各 0.5 的機率抽出 seq1,seq2 作為輸入
    • 每次錯誤率在 0.25 以下就是成功,反之失敗
    • 總共執行 10000 次成功與失敗的判斷

LSTM 架構

參數 數值(或範圍) 備註
din p+1  
dhid 0 沒有隱藏單元
dblock dout 總輸出就是記憶單元的輸出
nblock 1 當誤差停止下降時,增加記憶單元
dout p+1  
g g(x)=σ(x) Sigmoid 函數
h h(x)=x  
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+(1+nblock)dblock] 全連接隱藏層
dim(wig) nblock×[din+(1+nblock)dblock] 全連接隱藏層
dim(wog) 0 沒有輸出閘門
dim(wout) 0 總輸出就是記憶單元的輸出
參數初始化範圍 [0.2,0.2]  
Learning rate 1  
最大更新次數 5000000  

實驗結果

表格 2:無雜訊長時間差任務實驗結果。 表格來源:論文

表 2

  • p=4 時使用 RNN + RTRL 時部份實驗能夠預測序列
    • 序列很短時 RNN 還是有能力完成任務
  • p10 時使用 RNN + RTRL 時直接失敗
  • p=100 時只剩 LSTM 能夠完全完成任務
  • LSTM 收斂速度最快

實驗 2b:有雜訊長時間差任務

實驗設計和 LSTM 的架構與實驗 2a 完全相同,只是序列 seq1,seq2 中除了頭尾之外的字元可以替換成 V 中任意的文字,總長度維持 p+1

  • 此設計目的是為了確保實驗 2a 中的順序性無法被順利壓縮
  • 先創造訓練資料,測試使用與訓練完全相同的資料
  • 仍然只有 LSTM 能夠完全完成任務
  • LSTM 的誤差仍然很快就收斂
    • p=100 時只需要 5680 次更新就能完成任務
    • 代表 LSTM 能夠在有雜訊的情況下正常運作

實驗 2c:有雜訊超長時間差任務

任務定義

實驗設計和 LSTM 的架構與實驗 2a 概念相同,只是 V 增加了兩個字元 b,e,而序列長度可以不同。

生成一個序列的概念如下:

  1. 固定一個正整數 q,代表序列基本長度
  2. c1,,cp1 中隨機抽樣生成長度為 q 的序列 seq
  3. 在序列的開頭補上 bαbβ(機率各為 0.5),讓序列長度變成 q+2
  4. 接著以 0.9 的機率從 c1,,cp1 中挑一個字補在序列 seq 的尾巴,或是以 0.1 的機率補上 e
  5. 如果生成 e 就再補上 αβ(與開頭第二個字元相同)並結束
  6. 如果不是生成 e 則重複步驟 4

假設步驟 4 執行了 k+1 次,則序列長度為 2+q+(k+1)+1=q+k+4。 序列的最短長度為 q+4,長度的期望值為

4+k=0110(910)k(q+k)=4+q10[k=0(910)k]+110[k=0(910)kk]=4+q1010+110100=q+14

其中

[k=0nkxk]x[k=0nkxk]=(0x0+1x1+2x2+3x3++nxn)(0x1+1x2+2x3+3x4++nxn+1)=0x0+1x1+1x2+1x3++1xnnxn+1=[k=0nxk]nxn+1=1xn+11xnxn+1=1xn+1nxn+1+nxn+21x

因此

[k=0nkxk]x[k=0nkxk]=1xn+1nxn+1+nxn+21xk=0nkxk=1xn+1nxn+1+nxn+2(1x)2k=0kxk=1(1x)2 when 0x<1

利用二項式分佈的期望值公式我們可以推得 ciV 出現次數的期望值

k=0110(910)k[i=0q+k(q+ki)(1p1)i(11p1)q+ki]=k=0110(910)kq+kp1=q10(p1)[k=0(910)k]+110(p1)[k=0(910)kk]=qp1+10p1qp1 when q0

訓練誤差只考慮最後一個時間點 seq(2+q+k+2) 的預測結果,必須要跟第 seq(1) 個時間點的輸入相同(概念同實驗 2a)。

測試時會連續執行 10000 次的實驗,預測誤差必須要永遠小於 0.2。 會以 20 次的測試結果取平均。

LSTM 架構

參數 數值(或範圍) 備註
din p+4  
dhid 0 沒有隱藏單元
dblock 1  
nblock 2 作者認為其實只要一個記憶單元就夠了
dout 2 只考慮最後一個時間點的預測誤差,並且預測的可能結果只有 2 種(αβ
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+nblock(2+dblock)] 全連接隱藏層
dim(wig) nblock×[din+nblock(2+dblock)] 全連接隱藏層
dim(wog) nblock×[din+nblock(2+dblock)] 全連接隱藏層
dim(wout) dout×[nblockdblock] 外部輸入沒有直接連接到總輸出
參數初始化範圍 [0.2,0.2]  
Learning rate 0.01  

實驗結果

表格 3:有雜訊超長時間差任務實驗結果。 表格來源:論文

表 3

  • 其他方法沒有辦法完成任務,因此不列入表格比較
  • 輸入序列長度可到達 1000
  • 當輸入字元種類與輸入長度一起增加時,訓練時間只會緩慢增加
  • 當單一字元的出現次數期望值增加時,學習速度會下降
    • 作者認為是常見字詞的出現導致參數開始振盪

實驗 3a:Two-Sequence Problem

任務定義

給予一個實數序列 seq,該序列可能隸屬於兩種類別 C1,C2,隸屬機率分別是 0.5

如果 seqC1,則該序列的前 N 個數字都是 1.0,序列的最後一個數字為 1.0。 如果 seqC2,則該序列的前 N 個數字都是 1.0,序列的最後一個數字為 0.0

給定一個常數 T,並從 [T,T+T10] 的區間中隨機挑選一個整數作為序列 seq 的長度 L

LN 時,任何在 seq(N+1),seq(L1) 中的數字都是由常態分佈隨機產生,常態分佈的平均為 0 變異數為 0.2

  • 此任務由 Bengio 提出
  • 作者發現只要用隨機權重猜測(Random Weight Guessing)就能解決,因此在實驗 3c 提出任務的改進版本
  • 訓練分成兩個階段
    • ST1:事先隨機抽取的 256 筆測試資料完全分類正確
    • ST2:達成 ST1 後在 2560 筆測試資料上平均錯誤低於 0.01
  • 實驗結果是執行 10 次實驗的平均值

LSTM 架構

參數 數值(或範圍) 備註
din 1  
dhid 0 沒有隱藏單元
dblock 1  
nblock 3  
dout 1  
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wig) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wog) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wout) dout×[nblockdblock] 外部輸入沒有直接連接到總輸出
參數初始化範圍 [0.1,0.1]  
輸入閘門偏差項初始化範圍 {1,3,5} 由大到小依序初始化不同記憶單元對應輸入閘門偏差項
輸出閘門偏差項初始化範圍 {2,4,6} 由大到小依序初始化不同記憶單元對應輸出閘門偏差項
Learning rate 1  

實驗結果

表格 4:Two-Sequence Problem 實驗結果。 表格來源:論文

表 4

  • 偏差項初始化的數值其實不需要這麼準確
  • LSTM 能夠快速解決任務
  • LSTM 在輸入有雜訊(高斯分佈)時仍然能夠正常表現

實驗 3b:Two-Sequence Problem + 雜訊

表格 5:Two-Sequence Problem + 雜訊實驗結果。 表格來源:論文

表 5

實驗設計與 LSTM 完全與實驗 3a 相同,但對於序列 seqN 個實數加上雜訊(與實驗 2a 相同的高斯分佈)。

  • 兩階段訓練稍微做點修改
    • ST1:事先隨機抽取的 256 筆測試資料少於 6 筆資料分類錯誤
    • ST2:達成 ST1 後在 2560 筆測試資料上平均錯誤低於 0.04
  • 結論
    • 增加雜訊導致誤差收斂時間變長
    • 相較於實驗 3a,雖然分類錯誤率上升,但 LSTM 仍然能夠保持較低的分類錯誤率

實驗 3c:強化版 Two-Sequence Problem

表格 6:強化版 Two-Sequence Problem 實驗結果。 表格來源:論文

表 6

實驗設計與 LSTM 完全與實驗 3b 相同,但進行以下修改

  • C1 類別必須輸出 0.2C2 類別必須輸出 0.8
  • 高斯分佈變異數改為 0.1
  • 預測結果與答案絕對誤差大於 0.1 就算分類錯誤
  • 任務目標是所有的預測絕對誤差平均值小於 0.015
  • 兩階段訓練改為一階段
    • 事先隨機抽取的 256 筆測試資料完全分類正確
    • 2560 筆測試資料上絕對誤差平均值小於 0.015
  • Learning rate 改成 0.1
  • 結論
    • 任務變困難導致收斂時間變更長
    • 相較於實驗 3a,雖然分類錯誤率上升,但 LSTM 仍然能夠保持較低的分類錯誤率

實驗 4:加法任務

任務定義

定義一個序列 seq,序列的每個元素都是由兩個實數組合而成,具體的數值範圍如下

seq(t)[1,1]×{1,0,1}t=0,,T

每個時間點的元素的第一個數值都是隨機從 [1,1] 中取出,第二個數值只能是 1,0,1 三個數值的其中一個。

T 為序列的最小長度,則序列 seq 的長度 L 將會落在 [T,T+T/10] 之間。

決定每個時間點的元素的第二個數值的方法如下:

  1. 首先將所有元素的第二個數值初始化成 0
  2. t=0t=L 的第二個數值初始化成 1
  3. t=0,,9 隨機挑選一個時間點,並將該時間點的第二個數值加上 1
  4. 如果前一個步驟剛好挑到 t=0,則 t=0 的第二個數值將會是 0,否則為 1
  5. t=0,,T/21 隨機挑選一個時間點,並只挑選第二個數值仍為 0 的時間點,挑選後將該時間點的第二個數值設為 1

透過上述步驟 seq 最少會包含一個元素其第二個數值為 1,最多會包含二個元素其第二個數值為 1

模型在 L+1 時間點必須輸出所有元素中第二個數值為 1 的元素,其第一個數值的總和,並轉換到 [0,1] 區間的數值,即

y^(L+1)=0.5+14t=0L[1(seq1(t)=1)seq2(t)]

只考慮 L+1 時間點的誤差,誤差必須要低於 0.04 才算預測正確。

  • 模型必須要學會長時間關閉輸入閘門
  • 在實驗中故意對所有參數加上偏差項,實驗內部狀態偏差行為造成的影響
  • 當連續 2000 次的誤差第於 0.04,且平均絕對誤差低於 0.01 時停止訓練
  • 測試資料集包含 2560 筆資料

LSTM 架構

參數 數值(或範圍) 備註
din 2  
dhid 0 沒有隱藏單元
dblock 2  
nblock 2  
dout 1  
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wig) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wog) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wout) dout×[nblockdblock+1] 外部輸入沒有直接連接到總輸出,有額外使用偏差項
參數初始化範圍 [0.1,0.1]  
輸入閘門偏差項初始化範圍 {3,6} 由大到小依序初始化不同記憶單元對應輸入閘門偏差項
Learning rate 0.5  

實驗結果

表格 7:加法任務實驗結果。 表格來源:論文

表 7

  • LSTM 能夠達成任務目標
    • 不超過 3 筆以上預測錯誤的資料
  • LSTM 能夠摹擬加法器,具有作為 distributed representation 的能力
  • 能夠儲存時間差至少有 T/2 以上的資訊,因此不會被內部狀態偏差行為影響

實驗 5:乘法任務

任務定義

從 LSTM 的架構上來看實驗 4 的加法任務可以透過 (39) 輕鬆完成,因此實驗 5 的目標是確認模型是否能夠從加法上延伸出乘法的概念,確保實驗 4 並不只是單純因模型架構而解決。

概念與實驗 4 的任務幾乎相同,只做以下修改:

  • 每個時間點的元素第一個數值改為 [0,1] 之間的隨機值
  • L+1 時間點的輸出目標改成
y^(L+1)=0.5+14t=0L[1(seq1(t)=1)seq2(t)]
  • 當連續 2000 筆訓練資料中,不超過 nseq 筆資料的絕對誤差小於 0.04 就停止訓練
  • nseq{13,140}
    • 選擇 140 的理由是模型已經有能力記住資訊,但計算結果不夠精確
    • 選擇 13 的理由是模型能夠精確達成任務

LSTM 架構

與實驗 4 完全相同,只做以下修改:

  • 輸入閘門偏差項改成隨機初始化
  • Learning rate 改為 0.1

實驗結果

表格 8:乘法任務實驗結果。 表格來源:論文

表 8

  • LSTM 能夠達成任務目標
    • nseq=140 時不超過 170 筆以上預測錯誤的資料
    • nseq=13 時不超過 15 筆以上預測錯誤的資料
  • 如果額外使用隱藏單元,則收斂速度會更快
  • LSTM 能夠摹擬乘法器,具有作為 distributed representation 的能力
  • 能夠儲存時間差至少有 T/2 以上的資訊,因此不會被內部狀態偏差行為影響

實驗 6a:Temporal Order with 4 Classes

任務定義

給予一個序列 seq,其長度 L 會落在 [100,110] 之間,序列中的所有元素都來自於集合 V={a,b,c,d,B,E,X,Y}

序列 seq 的開頭必定為 B,最後為 E,剩餘所有的元素都是 a,b,c,d,除了兩個時間點 t1,t2

t1,t2 時間點只能出現 XYt1 時間點會落在 [10,20]t2 時間點會落在 [50,60]

因此根據 X,Y 出現的次數順序共有 4 種不同的類別

C1=XXC2=XYC3=YXC4=YY

模型必須要在 L+1 時間點進行類別預測,誤差只會出現在 L+1 時間點。

  • t1,t2 的最少時間差為 30
  • 模型必須要記住資訊與出現順序
  • 當模型成功預測連續 2000 筆資料,並且預測平均誤差低於 0.1 時便停止訓練
  • 測試資料共有 2560

LSTM 架構

參數 數值(或範圍) 備註
din 8  
dhid 0 沒有隱藏單元
dblock 2  
nblock 2  
dout 4  
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wig) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wog) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wout) dout×[nblockdblock+1] 外部輸入沒有直接連接到總輸出,有額外使用偏差項
參數初始化範圍 [0.1,0.1]  
輸入閘門偏差項初始化範圍 {2,4} 由大到小依序初始化不同記憶單元對應輸入閘門偏差項
Learning rate 0.5  

實驗結果

表格 9:Temporal Order with 4 Classes 任務實驗結果。 表格來源:論文

表 9

  • LSTM 的平均誤差低於 0.1

    • 沒有超過 3 筆以上的預測錯誤
  • LSTM 可能使用以下的方法進行解答

    • 擁有 2 個記憶單元時,依照順序記住出現的資訊
    • 只有 1 個記憶單元時,LSTM 可以改成記憶狀態的轉移

實驗 6b:Temporal Order with 8 Classes

任務定義

與實驗 6a 完全相同,只是多了一個 t3 時間點可以出現 X,Y

  • t2 時間點改成落在 [33,43]
  • t3 時間點落在 [66,76]
  • 類別變成 8

LSTM 架構

參數 數值(或範圍) 備註
din 8  
dhid 0 沒有隱藏單元
dblock 2  
nblock 3  
dout 8  
dim(whid) 0 沒有隱藏單元
dim(wblockk) dblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wig) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wog) nblock×[din+nblock(2+dblock)+1] 全連接隱藏層,有額外使用偏差項
dim(wout) dout×[nblockdblock+1] 外部輸入沒有直接連接到總輸出,有額外使用偏差項
參數初始化範圍 [0.1,0.1]  
輸入閘門偏差項初始化範圍 {2,4,6} 由大到小依序初始化不同記憶單元對應輸入閘門偏差項
Learning rate 0.1  

實驗結果

表格 9