$\newcommand{\field}[1]{\mathbb{#1}}$ $\providecommand{\N}{}$ $\renewcommand{\N}{\field{N}}$ $\providecommand{\Q}{}$ $\renewcommand{\Q}{\field{Q}}$ $\providecommand{\R}{}$ $\renewcommand{\R}{\field{R}}$ $\providecommand{\Z}{}$ $\renewcommand{\Z}{\field{Z}}$ $\providecommand{\pa}{}$ $\renewcommand{\pa}[1]{\left\lparen #1 \right\rparen}$ $\providecommand{\br}{}$ $\renewcommand{\br}[1]{\left\lbrack #1 \right\rbrack}$ $\providecommand{\set}{}$ $\renewcommand{\set}[1]{\left\lbrace #1 \right\rbrace}$ $\providecommand{\abs}{}$ $\renewcommand{\abs}[1]{\left\lvert #1 \right\rvert}$ $\providecommand{\norm}{}$ $\renewcommand{\norm}[1]{\left\lVert #1 \right\rVert}$ $\providecommand{\floor}{}$ $\renewcommand{\floor}[1]{\left\lfloor #1 \right\rfloor}$ $\providecommand{\ceil}{}$ $\renewcommand{\ceil}[1]{\left\lceil #1 \right\rceil}$ $\providecommand{\eval}{}$ $\renewcommand{\eval}[1]{\left. #1 \right\rvert}$ $\providecommand{\pd}{}$ $\renewcommand{\pd}[2]{\dfrac{\partial #1}{\partial #2}}$ $\DeclareMathOperator{\sign}{sign}$ $\DeclareMathOperator{\diag}{diag}$ $\DeclareMathOperator*{\argmax}{argmax}$ $\DeclareMathOperator*{\argmin}{argmin}$ $\providecommand{\Lim}{}$ $\renewcommand{\Lim}{\lim\limits}$ $\providecommand{\Prod}{}$ $\renewcommand{\Prod}{\prod\limits}$ $\providecommand{\Sum}{}$ $\renewcommand{\Sum}{\sum\limits}$ $\DeclareMathOperator{\softmax}{softmax}$ $\DeclareMathOperator{\cat}{concatenate}$ $\providecommand{\algoProc}{}$ $\renewcommand{\algoProc}[1]{\textbf{procedure}\text{ #1}}$ $\providecommand{\algoEndProc}{}$ $\renewcommand{\algoEndProc}{\textbf{end procedure}}$ $\providecommand{\algoIf}{}$ $\renewcommand{\algoIf}[1]{\textbf{if } #1 \textbf{ do}}$ $\providecommand{\algoEndIf}{}$ $\renewcommand{\algoEndIf}{\textbf{end if}}$ $\providecommand{\algoEq}{}$ $\renewcommand{\algoEq}{\leftarrow}$ $\providecommand{\algoFor}{}$ $\renewcommand{\algoFor}[1]{\textbf{for } #1 \textbf{ do}}$ $\providecommand{\algoEndFor}{}$ $\renewcommand{\algoEndFor}{\textbf{end for}}$ $\providecommand{\algoWhile}{}$ $\renewcommand{\algoWhile}[1]{\textbf{while } #1 \textbf{ do}}$ $\providecommand{\algoEndWhile}{}$ $\renewcommand{\algoEndWhile}{\textbf{end while}}$ $\providecommand{\algoReturn}{}$ $\renewcommand{\algoReturn}{\textbf{return }}$ $\providecommand{\hash}{}$ $\renewcommand{\hash}{\unicode{35}}$

目標 提出在 LSTM 上增加 forget gate
作者 Felix A. Gers, Jürgen Schmidhuber, Fred Cummins
期刊/會議名稱 Neural Computation
發表時間 2000
論文連結 https://direct.mit.edu/neco/article-abstract/12/10/2451/6415/Learning-to-Forget-Continual-Prediction-with-LSTM

$\providecommand{\opnet}{}$ $\renewcommand{\opnet}{\operatorname{net}}$ $\providecommand{\opin}{}$ $\renewcommand{\opin}{\operatorname{in}}$ $\providecommand{\opout}{}$ $\renewcommand{\opout}{\operatorname{out}}$ $\providecommand{\opblk}{}$ $\renewcommand{\opblk}{\operatorname{block}}$ $\providecommand{\opfg}{}$ $\renewcommand{\opfg}{\operatorname{fg}}$ $\providecommand{\opig}{}$ $\renewcommand{\opig}{\operatorname{ig}}$ $\providecommand{\opog}{}$ $\renewcommand{\opog}{\operatorname{og}}$ $\providecommand{\opseq}{}$ $\renewcommand{\opseq}{\operatorname{seq}}$ $\providecommand{\oploss}{}$ $\renewcommand{\oploss}{\operatorname{loss}}$ $\providecommand{\net}{}$ $\renewcommand{\net}[2]{\opnet_{#1}(#2)}$ $\providecommand{\fnet}{}$ $\renewcommand{\fnet}[2]{f_{#1}\big(\net{#1}{#2}\big)}$ $\providecommand{\dfnet}{}$ $\renewcommand{\dfnet}[2]{f_{#1}'\big(\net{#1}{#2}\big)}$ $\providecommand{\din}{}$ $\renewcommand{\din}{d_{\opin}}$ $\providecommand{\dout}{}$ $\renewcommand{\dout}{d_{\opout}}$ $\providecommand{\dblk}{}$ $\renewcommand{\dblk}{d_{\opblk}}$ $\providecommand{\nblk}{}$ $\renewcommand{\nblk}{n_{\opblk}}$ $\providecommand{\blk}{}$ $\renewcommand{\blk}[1]{\opblk^{#1}}$ $\providecommand{\wfg}{}$ $\renewcommand{\wfg}{w^{\opfg}}$ $\providecommand{\wig}{}$ $\renewcommand{\wig}{w^{\opig}}$ $\providecommand{\wog}{}$ $\renewcommand{\wog}{w^{\opog}}$ $\providecommand{\wblk}{}$ $\renewcommand{\wblk}[1]{w^{\blk{#1}}}$ $\providecommand{\wout}{}$ $\renewcommand{\wout}{w^{\opout}}$ $\providecommand{\netfg}{}$ $\renewcommand{\netfg}[2]{\opnet_{#1}^{\opfg}(#2)}$ $\providecommand{\fnetfg}{}$ $\renewcommand{\fnetfg}[2]{f_{#1}^{\opfg}\big(\netfg{#1}{#2}\big)}$ $\providecommand{\dfnetfg}{}$ $\renewcommand{\dfnetfg}[2]{f_{#1}^{\opfg}{'}\big(\netfg{#1}{#2}\big)}$ $\providecommand{\netig}{}$ $\renewcommand{\netig}[2]{\opnet_{#1}^{\opig}(#2)}$ $\providecommand{\fnetig}{}$ $\renewcommand{\fnetig}[2]{f_{#1}^{\opig}\big(\netig{#1}{#2}\big)}$ $\providecommand{\dfnetig}{}$ $\renewcommand{\dfnetig}[2]{f_{#1}^{\opig}{'}\big(\netig{#1}{#2}\big)}$ $\providecommand{\netog}{}$ $\renewcommand{\netog}[2]{\opnet_{#1}^{\opog}(#2)}$ $\providecommand{\fnetog}{}$ $\renewcommand{\fnetog}[2]{f_{#1}^{\opog}\big(\netog{#1}{#2}\big)}$ $\providecommand{\dfnetog}{}$ $\renewcommand{\dfnetog}[2]{f_{#1}^{\opog}{'}\big(\netog{#1}{#2}\big)}$ $\providecommand{\netout}{}$ $\renewcommand{\netout}[2]{\opnet_{#1}^{\opout}(#2)}$ $\providecommand{\fnetout}{}$ $\renewcommand{\fnetout}[2]{f_{#1}^{\opout}\big(\netout{#1}{#2}\big)}$ $\providecommand{\dfnetout}{}$ $\renewcommand{\dfnetout}[2]{f_{#1}^{\opout}{'}\big(\netout{#1}{#2}\big)}$ $\providecommand{\netblk}{}$ $\renewcommand{\netblk}[3]{\opnet_{#1}^{\blk{#2}}(#3)}$ $\providecommand{\gnetblk}{}$ $\renewcommand{\gnetblk}[3]{g_{#1}\big(\netblk{#1}{#2}{#3}\big)}$ $\providecommand{\dgnetblk}{}$ $\renewcommand{\dgnetblk}[3]{g_{#1}'\big(\netblk{#1}{#2}{#3}\big)}$ $\providecommand{\hblk}{}$ $\renewcommand{\hblk}[3]{h_{#1}\big(s_{#1}^{\blk{#2}}(#3)\big)}$ $\providecommand{\dhblk}{}$ $\renewcommand{\dhblk}[3]{h_{#1}'\big(s_{#1}^{\blk{#2}}(#3)\big)}$ $\providecommand{\aptr}{}$ $\renewcommand{\aptr}{\approx_{\operatorname{tr}}}$

  • 此篇論文原版 LSTM 都寫錯自己的數學公式,但我的筆記內容主要以正確版本為主,原版 LSTM 可以參考我的筆記
  • 原版 LSTM 沒有遺忘閘門,現今常用的 LSTM 都有遺忘閘門,概念由這篇論文提出
  • 包含多個子序列的連續輸入會讓 LSTM 的記憶單元內部狀態沒有上下界
    • 現實中的大多數資料並不存在好的分割序列演算法,導致輸入給模型的資料通常都包含多個子序列
    • 根據實驗 1 的分析發現記憶單元內部狀態的累積導致預測結果完全錯誤
  • 使用遺忘閘門讓模型學會適當的忘記已經處理過的子序列資訊
    • 當遺忘閘門的偏差項初始化為正數時會保持記憶單元內部狀態,等同於使用原版的 LSTM
    • 因此使用遺忘閘門的 LSTM 能夠達成原版 LSTM 的功能,並額外擁有自動重設記憶單元的機制
  • 這篇模型的理論背景較少,實驗為主的描述居多

原始 LSTM

模型架構

根據原始論文提出的架構如下(這篇論文不使用額外的隱藏單元,因此我們也完全不列出隱藏單元相關的公式)(細節可以參考我的筆記

符號 意義 備註
$\din$ 輸入層的維度 數值範圍為 $\Z^+$
$\dblk$ 記憶單元區塊的維度 數值範圍為 $\Z^+$
$\nblk$ 記憶單元區塊的個數 數值範圍為 $\Z^+$
$\dout$ 輸出層的維度 數值範圍為 $\Z^+$
$T$ 輸入序列的長度 數值範圍為 $\Z^+$

以下所有符號的時間 $t$ 範圍為 $t \in \set{0, \dots, T - 1}$

符號 意義 維度 備註
$x(t)$ 第 $t$ 個時間點的輸入 $\din$  
$y^{\opig}(t)$ 第 $t$ 個時間點的輸入閘門 $\nblk$ $y^{\opig}(0) = 0$,同一個記憶單元區塊共享輸入閘門
$y^{\opog}(t)$ 第 $t$ 個時間點的輸出閘門 $\nblk$ $y^{\opog}(0) = 0$,同一個記憶單元區塊共享輸出閘門
$s^{\blk{k}}(t)$ 第 $t$ 個時間點的第 $k$ 個記憶單元區塊內部狀態 $\dblk$ $s^{\blk{k}}(0) = 0$ 且 $k \in \set{1, \dots, \nblk}$
$y^{\blk{k}}(t)$ 第 $t$ 個時間點的第 $k$ 個記憶單元區塊輸出 $\dblk$ $y^{\blk{k}}(0) = 0$ 且 $k \in \set{1, \dots, \nblk}$
$y(t + 1)$ 第 $t + 1$ 個時間點的輸出 $\dout$ 由 $t$ 時間點的輸入記憶單元輸出透過全連接產生,因此沒有 $y(0)$
$\hat{y}(t + 1)$ 第 $t + 1$ 個時間點的預測目標 $\dout$  
符號 意義 下標範圍
$x_j(t)$ 第 $t$ 個時間點的第 $j$ 個輸入 $j \in \set{1, \dots, \din}$
$y_k^{\opig}(t)$ 第 $t$ 個時間點第 $k$ 個記憶單元區塊的輸入閘門 $k \in \set{1, \dots, \nblk}$
$y_k^{\opog}(t)$ 第 $t$ 個時間點第 $k$ 個記憶單元區塊的輸出閘門 $k \in \set{1, \dots, \nblk}$
$s_i^{\blk{k}}(t)$ 第 $t$ 個時間點的第 $k$ 個記憶單元區塊的第 $i$ 個記憶單元內部狀態 $i \in \set{1, \dots, \dblk}$
$y_i^{\blk{k}}(t)$ 第 $t$ 個時間點的第 $k$ 個記憶單元區塊的第 $i$ 個記憶單元輸出 $i \in \set{1, \dots, \dblk}$
$y_i(t + 1)$ 第 $t + 1$ 個時間點的第 $i$ 個輸出 $i \in \set{1, \dots, \dout}$
$\hat{y}_i(t + 1)$ 第 $t + 1$ 個時間點的第 $i$ 個預測目標 $i \in \set{1, \dots, \dout}$
參數 意義 輸出維度 輸入維度
$\wig$ 產生輸入閘門的全連接參數 $\nblk$ $\din + \nblk \cdot (2 + \dblk)$
$\wog$ 產生輸出閘門的全連接參數 $\nblk$ $\din + \nblk \cdot (2 + \dblk)$
$\wblk{k}$ 產生第 $k$ 個記憶單元區塊淨輸入的全連接參數 $\dblk$ $\din + \nblk \cdot (2 + \dblk)$
$\wout$ 產生輸出的全連接參數 $\dblk$ $\din + \nblk \cdot \dblk$

定義 $\sigma$ 為 sigmoid 函數 $\sigma(x) = \frac{1}{1 + e^{-x}}$

函數 意義 公式 range
$f_k^{\opig}$ 第 $k$ 個輸入閘門的啟發函數 $\sigma$ $[0, 1]$
$f_k^{\opog}$ 第 $k$ 個輸出閘門的啟發函數 $\sigma$ $[0, 1]$
$g_i^{\blk{k}}$ 第 $k$ 個記憶單元區塊中第 $i$ 個記憶單元內部狀態的啟發函數 $4\sigma - 2$ $[-2, 2]$
$h_i^{\blk{k}}$ 第 $k$ 個記憶單元區塊中第 $i$ 個記憶單元輸出的啟發函數 $2\sigma - 1$ $[-1, 1]$
$f_i^{\opout}$ 第 $i$ 個輸出的啟發函數 $\sigma$ $[0, 1]$

在 $t$ 時間點時得到輸入 $x(t)$,產生 $t + 1$ 時間點輸入閘門 $y^{\opig}(t + 1)$ 與輸出閘門 $y^{\opog}(t + 1)$ 的方法如下

\[\begin{align*} \tilde{x}(t) & = \begin{pmatrix} x(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix} \\ y^{\opig}(t + 1) & = f^{\opig}\pa{\opnet^{\opig}(t + 1)} = f^{\opig}\pa{\wig \cdot \tilde{x}(t)} \\ y^{\opog}(t + 1) & = f^{\opog}\pa{\opnet^{\opog}(t + 1)} = f^{\opog}\pa{\wog \cdot \tilde{x}(t)} \end{align*} \tag{1}\label{1}\]

利用 $\eqref{1}$ 產生 $t + 1$ 時間點的記憶單元內部狀態 $s^{\blk{k}}(t + 1)$ 方法如下

\[\begin{align*} \tilde{x}(t) & = \begin{pmatrix} x(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix} \\ k & \in \set{1, \dots, \nblk} \\ \opnet^{\blk{k}}(t + 1) & = \wblk{k} \cdot \tilde{x}(t) \\ s^{\blk{k}}(t + 1) & = s^{\blk{k}}(t) + y_k^{\opig}(t + 1) \cdot g(\opnet^{\blk{k}}(t + 1)) \end{align*} \tag{2}\label{2}\]

注意第 $k$ 個記憶單元區塊內部狀態共享輸入閘門 $y_k^{\opig}(t + 1)$。

利用 $\eqref{1}\eqref{2}$ 產生 $t + 1$ 時間點的記憶單元輸出 $y^{\blk{k}}(t + 1)$ 方法如下

\[\begin{align*} k & \in \set{1, \dots, \nblk} \\ y^{\blk{k}}(t + 1) & = y_k^{\opog}(t + 1) \cdot h\pa{s^{\blk{k}}(t + 1)} \end{align*} \tag{3}\label{3}\]

注意第 $k$ 個記憶單元區塊輸出共享輸出閘門 $y_k^{\opog}(t + 1)$。

產生 $t + 1$ 時間點的輸出是透過 $t$ 時間點的輸入與 $t + 1$ 時間點的記憶單元輸出(見 $\eqref{3}$)而得

\[\begin{align*} \tilde{x}(t + 1) & = \begin{pmatrix} x(t) \\ y^{\blk{1}}(t + 1) \\ \vdots \\ y^{\blk{\nblk}}(t + 1) \end{pmatrix} \\ y(t + 1) & = f^{\opout}(\opnet^{\opout}(t + 1)) = f^{\opout}\pa{\wout \cdot \tilde{x}(t + 1)} \end{align*} \tag{4}\label{4}\]

這篇論文原版 LSTM 的論文 都不小心寫成 $t$ 時間點的記憶單元輸出,在 LSTM-2002 才終於寫對。

最佳化

原始 LSTM 提出與 truncated BPTT 相似的概念,透過 RTRL 進行參數更新,並故意丟棄流出記憶單元的所有梯度,避免梯度爆炸或梯度消失的問題,同時節省更新所需的空間與時間(local in time and space)。(細節可見我的筆記

令 $t \in \set{0, \dots, T - 1}$,最佳化的目標為每個時間點 $t + 1$ 所產生的平方誤差總和最小化

\[\begin{align*} \oploss(t + 1) & = \sum_{i = 1}^{\dout} \oploss_i(t + 1) \\ & = \sum_{i = 1}^{\dout} \frac{1}{2} \big(y_i(t + 1) - \hat{y}_i(t + 1)\big)^2 \end{align*} \tag{5}\label{5}\]

以下我們使用 $\aptr$ 代表丟棄部份梯度後的剩餘梯度

輸出參數的剩餘梯度為

\[\begin{align*} \pd{\oploss(t + 1)}{\wout_{i, j}} & = \pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \pd{\netout{i}{t + 1}}{\wout_{i, j}} \\ & = \big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\blk{1}}(t + 1) \\ \vdots \\ y^{\blk{\nblk}}(t + 1) \end{pmatrix}_j \end{align*} \tag{6}\label{6}\]

其中 $1 \leq i \leq \dout$ 且 $1 \leq j \leq \din + \nblk \cdot \dblk$。

輸出閘門參數的剩餘梯度為

\[\begin{align*} & \pd{\oploss(t + 1)}{\wog_{k, q}} \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \\ & \quad \pa{\sum_{j = 1}^{\dblk} \pd{\netout{i}{t + 1}}{y_j^{\blk{k}}(t + 1)} \cdot \pd{y_j^{\blk{k}}(t + 1)}{y_k^{\opog}(t + 1)}} \cdot \pd{y_k^{\opog}(t + 1)}{\netog{k}{t + 1}} \cdot \pd{\netog{k}{t + 1}}{\wog_{k, q}}\Bigg] \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \\ & \quad \pa{\sum_{j = 1}^{\dblk} \wout_{i, \din + (k - 1) \cdot \dblk + j} \cdot \hblk{j}{k}{t + 1}} \cdot \dfnetog{k}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix}_q\Bigg] \end{align*} \tag{7}\label{7}\]

其中 $1 \leq k \leq \nblk$ 且 $1 \leq q \leq \din + \nblk \cdot (2 + \dblk)$。

輸入閘門參數的剩餘梯度為

\[\begin{align*} & \pd{\oploss(t + 1)}{\wig_{k, q}} \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \\ & \quad \pa{\sum_{j = 1}^{\dblk} \pd{\netout{i}{t + 1}}{y_j^{\blk{k}}(t + 1)} \cdot \pd{y_j^{\blk{k}}(t + 1)}{s_j^{\blk{k}}(t + 1)} \cdot \pd{s_j^{\blk{k}}(t + 1)}{\wig_{k, q}}}\Bigg] \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \Bigg(\sum_{j = 1}^{\dblk} \pd{\netout{i}{t + 1}}{y_j^{\blk{k}}(t + 1)} \cdot \pd{y_j^{\blk{k}}(t + 1)}{s_j^{\blk{k}}(t + 1)} \cdot \\ & \quad \quad \br{\pd{s_j^{\blk{k}}(t)}{\wig_{k, q}} + \gnetblk{j}{k}{t + 1} \cdot \pd{y_k^{\opig}(t + 1)}{\netig{k}{t + 1}} \cdot \pd{\netig{k}{t + 1}}{\wig_{k, q}}}\Bigg)\Bigg] \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \\ & \quad \Bigg(\sum_{j = 1}^{\dblk} \wout_{i, \din + (k - 1) \cdot \dblk + j} \cdot y_k^{\opog}(t + 1) \cdot \dhblk{j}{k}{t + 1} \cdot \\ & \quad \quad \br{\pd{s_j^{\blk{k}}(t)}{\wig_{k, q}} + \gnetblk{j}{k}{t + 1} \cdot \dfnetig{k}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix}_q}\Bigg)\Bigg] \end{align*} \tag{8}\label{8}\]

其中 $1 \leq k \leq \nblk$ 且 $1 \leq q \leq \din + \nblk \cdot (2 + \dblk)$。

記憶單元淨輸入參數的剩餘梯度為

\[\begin{align*} & \pd{\oploss(t + 1)}{\wblk{k}_{p, q}} \\ & \aptr \sum_{i = 1}^{\dout} \br{\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \pd{\netout{i}{t + 1}}{y_p^{\blk{k}}(t + 1)} \cdot \pd{y_p^{\blk{k}}(t + 1)}{s_p^{\blk{k}}(t + 1)} \cdot \pd{s_p^{\blk{k}}(t + 1)}{\wblk{k}_{p, q}}} \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \pd{\netout{i}{t + 1}}{y_p^{\blk{k}}(t + 1)} \cdot \pd{y_p^{\blk{k}}(t + 1)}{s_p^{\blk{k}}(t + 1)} \cdot \\ & \quad \quad \pa{\pd{s_p^{\blk{k}}(t)}{\wblk{k}_{p, q}} + y_k^{\opig}(t + 1) \cdot \pd{\gnetblk{j}{k}{t + 1}}{\netblk{j}{k}{t + 1}} \cdot \pd{\netblk{j}{k}{t + 1}}{\wblk{k}_{p, q}}}\Bigg] \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \wout_{i, \din + (k - 1) \cdot \dblk + j} \cdot \\ & \quad y_k^{\opog}(t + 1) \cdot \dhblk{j}{k}{t + 1} \cdot \\ & \quad \br{\pd{s_p^{\blk{k}}(t)}{\wblk{k}_{p, q}} + y_k^{\opig}(t + 1) \cdot \dgnetblk{p}{k}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix}_q}\Bigg] \end{align*} \tag{9}\label{9}\]

其中 $1 \leq k \leq \nblk$, $1 \leq p \leq \dblk$ 且 $1 \leq q \leq \din + \nblk \cdot (2 + \dblk)$。

計算完上述所有參數後使用梯度下降(gradient descent)進行參數更新

\[\begin{align*} \wout_{i, j} & \leftarrow \wout_{i, j} - \alpha \cdot \pd{\oploss(t + 1)}{\wout_{i, j}} \\ \wog_{k, q} & \leftarrow \wog_{k, q} - \alpha \cdot \pd{\oploss(t + 1)}{\wog_{k, q}} \\ \wig_{k, q} & \leftarrow \wig_{k, q} - \alpha \cdot \pd{\oploss(t + 1)}{\wig_{k, q}} \\ \wblk{k}_{p, q} & \leftarrow \wblk{k}_{p, q} - \alpha \cdot \pd{\oploss(t + 1)}{\wblk{k}_{p, q}} \end{align*} \tag{10}\label{10}\]

其中 $\alpha$ 為學習率learning rate)。

由於使用基於 RTRL 的最佳化演算法,因此每個時間點 $t + 1$ 計算完誤差後就可以更新參數。

問題

當一個輸入序列中包含多個獨立的子序列(例如一個文章段落有多個句子),則模型無法知道不同獨立子序列的起始點在哪裡(除非有明確的切斷序列演算法,但實際上不一定存在)。

原始 LSTM 架構假設任意輸入序列都是由單一獨立序列組成,不會包含多個獨立的序列,因此會在每次序列輸入時重設模型的計算狀態 $y^{\opig}(0), y^{\opog}(0), s^{\blk{k}}(0), y^{\blk{k}}(0)$,沒有需要在計算過程中重設計算狀態的需求

但當輸入包含多個獨立的子序列時,且沒有明確的方法辨識不同獨立子序列的起始點時,LSTM 模型就必須要擁有能夠在任意時間點 $t$ 重設計算狀態 $y^{\opig}(t), y^{\opog}(t), s^{\blk{k}}(t), y^{\blk{k}}(t)$ 的功能。

遺忘閘門

模型架構

圖 1:在原始 LSTM 架構上增加遺忘閘門。 圖片來源:論文

圖 1

作者提出在模型中加入遺忘閘門forget gate),概念是讓記憶單元內部狀態能夠進行重設。

首先需要計算遺忘閘門 $y^{\opfg}(t)$,定義如下

\[\begin{align*} \tilde{x}(t) & = \begin{pmatrix} x(t) \\ y^{\opfg}(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix} \\ y^{\opfg}(0) & = 0 \\ y^{\opfg}(t + 1) & = f^{\opfg}\pa{\opnet^{\opfg}(t + 1) = f^{\opfg}\pa{\wfg \cdot \tilde{x}(t)}} \end{align*} \tag{11}\label{11}\]

計算方法與輸入閘門和輸出閘門相同。

而計算過程需要做以下修改

  • $\eqref{1}\eqref{2}$ 中的淨輸入需要加上 $y^{\opfg}(t)$
  • 參數 $\wig, \wog, \wblk{k}$ 的輸入維度都改成 $\din + \nblk \cdot (3 + \dblk)$
  • $\wfg$ 的維度與 $\wig$ 完全相同
  • $f^{\opfg}$ 與 $f^{\opig}$ 的定義完全相同

所謂的遺忘並不是直接設定成 $0$,而是以乘法閘門的形式進行數值重設,因此 $\eqref{2}$ 的計算改成

\[s^{\blk{k}}(t + 1) = y_k^{\opfg}(t + 1) \cdot s^{\blk{k}}(t) + y_k^{\opig}(t + 1) \cdot g(\opnet^{\blk{k}}(t + 1)) \tag{12}\label{12}\]

偏差項

如同原始 LSTM輸入閘門輸出閘門可以使用偏差項(bias term),將偏差項初始化成負數可以讓輸入閘門與輸出閘門在需要的時候才被啟用(細節可以看我的筆記)。

遺忘閘門也可以使用偏差項,但初始化的數值應該為正數,理由是在模型計算前期應該要讓遺忘閘門開啟($y^{\opfg} \approx 1$),讓記憶單元內部狀態的數值能夠進行改變。

注意遺忘閘門只有在關閉($y^{\opfg} \approx 0$)時才能進行遺忘,這個名字取得不是很好。

最佳化

基於原始 LSTM 的最佳化演算法,將流出遺忘閘門的梯度也一起丟棄

\[\begin{align*} \pd{\netfg{k}{t + 1}}{y_{k^{\star}}^{\opfg}(t)} & \aptr 0 && k = 1, \dots, \nblk \\ \pd{\netfg{k}{t + 1}}{y_{k^{\star}}^{\opig}(t)} & \aptr 0 && k^{\star} = 1, \dots, \nblk \\ \pd{\netfg{k}{t + 1}}{y_{k^{\star}}^{\opog}(t)} & \aptr 0 \\ \pd{\netfg{k}{t + 1}}{y_i^{\blk{k^{\star}}}(t)} & \aptr 0 && i = 1, \dots, \dblk \end{align*} \tag{13}\label{13}\]

因此遺忘閘門的參數剩餘梯度為

\[\begin{align*} & \pd{\oploss(t + 1)}{\wfg_{k, q}} \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \\ & \quad \pa{\sum_{j = 1}^{\dblk} \pd{\netout{i}{t + 1}}{y_j^{\blk{k}}(t + 1)} \cdot \pd{y_j^{\blk{k}}(t + 1)}{s_j^{\blk{k}}(t + 1)} \cdot \pd{s_j^{\blk{k}}(t + 1)}{\wfg_{k, q}}}\Bigg] \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\pd{\oploss(t + 1)}{y_i(t + 1)} \cdot \pd{y_i(t + 1)}{\netout{i}{t + 1}} \cdot \Bigg(\sum_{j = 1}^{\dblk} \pd{\netout{i}{t + 1}}{y_j^{\blk{k}}(t + 1)} \cdot \pd{y_j^{\blk{k}}(t + 1)}{s_j^{\blk{k}}(t + 1)} \cdot \\ & \quad \quad \br{y_k^{\opfg}(t + 1) \cdot \pd{s_j^{\blk{k}}(t)}{\wfg_{k, q}} + s_j^{\blk{k}}(t) \cdot \pd{y_k^{\opfg}(t + 1)}{\netfg{k}{t + 1}} \cdot \pd{\netfg{k}{t + 1}}{\wfg_{k, q}}}\Bigg)\Bigg] \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \\ & \quad \Bigg(\sum_{j = 1}^{\dblk} \wout_{i, \din + (k - 1) \cdot \dblk + j} \cdot y_k^{\opog}(t + 1) \cdot \dhblk{j}{k}{t + 1} \cdot \\ & \quad \quad \br{y_k^{\opfg}(t + 1) \cdot \pd{s_j^{\blk{k}}(t)}{\wfg_{k, q}} + s_j^{\blk{k}}(t) \cdot \dfnetog{k}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\opfg}(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix}_q}\Bigg)\Bigg] \end{align*} \tag{14}\label{14}\]

$\eqref{14}$ 式就是論文的 3.12 式,其中 $1 \leq k \leq \nblk$ 且 $1 \leq q \leq \din + \nblk \cdot (3 + \dblk)$。

由於 $\eqref{12}$ 的修改,$\eqref{9} \eqref{10}$ 最佳化的過程也需要跟著修改。

輸入閘門的參數剩餘梯度改為

\[\begin{align*} & \pd{\oploss(t + 1)}{\wig_{k, q}} \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \\ & \quad \Bigg(\sum_{j = 1}^{\dblk} \wout_{i, \din + (k - 1) \cdot \dblk + j} \cdot y_k^{\opog}(t + 1) \cdot \dhblk{j}{k}{t + 1} \cdot \\ & \quad \quad \br{y_k^{\opfg}(t + 1) \cdot \pd{s_j^{\blk{k}}(t)}{\wig_{k, q}} + \gnetblk{j}{k}{t + 1} \cdot \dfnetig{k}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\opfg}(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix}_q}\Bigg)\Bigg] \end{align*} \tag{15}\label{15}\]

$\eqref{14}$ 式就是論文的 3.11 式,其中 $1 \leq k \leq \nblk$ 且 $1 \leq q \leq \din + \nblk \cdot (3 + \dblk)$。

記憶單元淨輸入參數的剩餘梯度改為

\[\begin{align*} & \pd{\oploss(t + 1)}{\wblk{k}_{p, q}} \\ & \aptr \sum_{i = 1}^{\dout} \Bigg[\big(y_i(t + 1) - \hat{y}_i(t + 1)\big) \cdot \dfnetout{i}{t + 1} \cdot \wout_{i, \din + (k - 1) \cdot \dblk + j} \cdot \\ & \quad y_k^{\opog}(t + 1) \cdot \dhblk{j}{k}{t + 1} \cdot \\ & \quad \br{y_k^{\opfg}(t + 1) \cdot \pd{s_p^{\blk{k}}(t)}{\wblk{k}_{p, q}} + y_k^{\opig}(t + 1) \cdot \dgnetblk{p}{k}{t + 1} \cdot \begin{pmatrix} x(t) \\ y^{\opfg}(t) \\ y^{\opig}(t) \\ y^{\opog}(t) \\ y^{\blk{1}}(t) \\ \vdots \\ y^{\blk{\nblk}}(t) \end{pmatrix}_q}\Bigg] \end{align*} \tag{16}\label{16}\]

$\eqref{14}$ 式就是論文的 3.10 式,其中 $1 \leq k \leq \nblk$, $1 \leq p \leq \dblk$ 且 $1 \leq q \leq \din + \nblk \cdot (3 + \dblk)$。

注意錯誤:根據論文中的 3.4 式,論文 2.5 式的 $t - 1$ 應該改成 $t$。

根據 $\eqref{14}\eqref{15}\eqref{16}$,當遺忘閘門 $y_k^{\opfg}(t + 1) \approx 0$ (關閉)時,不只記憶單元 $s^{\blk{k}}(t + 1)$ 會重設,與其相關的梯度也會重設,因此更新時需要額外紀錄以下的項次

\[\pd{s_i^{\blk{k}}(t + 1)}{\wfg_{k, q}}, \pd{s_i^{\blk{k}}(t + 1)}{\wig_{k, q}}, \pd{s_i^{\blk{k}}(t + 1)}{\wblk{k}_{p, q}}\]

同樣的概念在原始 LSTM 中也有出現,細節可以看我的筆記

實驗 1:Continual Embedded Reber Grammar

圖 2:Continual Embedded Reber Grammar。 圖片來源:論文

圖 2

任務定義

  • 根據原始 LSTM 論文中的實驗 1(Embedded Reber Grammar)進行修改,輸入為連續序列,連續序列的定義是由多個 Embedded Reber Grammar 產生的序列組合而成(細節可以看我的筆記
  • 每個分支的生成機率值為 $0.5$
  • 當所有輸出單元的平方誤差低於 $0.49$ 時就當成預測正確
  • 在一次的訓練過程中,給予模型的輸入只會在以下兩種狀況之一發生時停止
    • 當模型產生一次的預測錯誤
    • 模型連續接收 $10^6$ 個輸入
  • 每次訓練停止就進行一次測試
    • 一次測試會執行 $10$ 次的連續輸入
    • 評估結果是 $10$ 次連續輸入的平均值
  • 每輸入一個訊號就進行更新(RTRL)
  • 訓練最多執行 $30000$ 次,實驗結果由 $100$ 個訓練模型實驗進行平均

LSTM 架構

圖 3:LSTM 架構。 圖片來源:論文

圖 3

參數 數值(或範圍) 備註
$\din$ $7$  
$\nblk$ $4$  
$\dblk$ $2$  
$\dout$ $7$  
$\dim(\wblk{k})$ $\dblk \times [\din + \nblk \cdot \dblk]$ 訊號來源為外部輸入與記憶單元
$\dim(\wfg)$ $\nblk \times [\din + \nblk \cdot \dblk + 1]$ 訊號來源為外部輸入與記憶單元,有額外使用偏差項
$\dim(\wig)$ $\nblk \times [\din + \nblk \cdot \dblk + 1]$ 訊號來源為外部輸入與記憶單元,有額外使用偏差項
$\dim(\wog)$ $\nblk \times [\din + \nblk \cdot \dblk + 1]$ 訊號來源為外部輸入與記憶單元,有額外使用偏差項
$\dim(\wout)$ $\dout \times [\din + \nblk \cdot \dblk + 1]$ 訊號來源為外部輸入與記憶單元,有額外使用偏差項
總參數量 $424$  
參數初始化 $[-0.2, 0.2]$ 平均分佈
輸入閘門偏差項初始化 $\set{-0.5, -1.0, -1.5, -2.0}$ 依序初始化成不同數值
輸出閘門偏差項初始化 $\set{-0.5, -1.0, -1.5, -2.0}$ 依序初始化成不同數值
遺忘閘門偏差項初始化 $\set{0.5, 1.0, 1.5, 2.0}$ 依序初始化成不同數值
Learning rate $\alpha$ $0.5$ 訓練過程可以固定 $\alpha$,或是以 $0.99$ 的 decay factor 在每次更新後進行衰減

實驗結果

圖 4:Continual Embedded Reber Grammar 實驗結果。 圖片來源:論文

圖 4

  • 原始 LSTM 在有手動進行計算狀態的重置時表現非常好,但當沒有手動重置時完全無法執行任務
    • 就算讓記憶單元內部狀態進行 decay 也無濟於事
  • 使用遺忘閘門的 LSTM 不需要手動重置計算狀態也能達成完美預測
    • 完美預測指的是連續 $10^6$ 輸入都預測正確
  • 有嘗試使用 $\alpha / t$ 或 $\alpha / \sqrt{T}$ 作為 learning rate,實驗發現不論是哪種最佳化的方法使用遺忘閘門的 LSTM 都表現的不錯
    • 在其他模型架構上(包含原版 LSTM)就算使用這些最佳化演算法也無法解決任務
  • 額外實驗在將 Embedded Reber Grammar 開頭的 B 與結尾的 E 去除的狀態下,使用遺忘閘門的 LSTM 仍然表現不錯

分析

圖 5:原版 LSTM 記憶單元內部狀態的累加值。 圖片來源:論文

圖 5

圖 6:LSTM 加上遺忘閘門後第三個記憶單元內部狀態。 圖片來源:論文

圖 6

圖 7:LSTM 加上遺忘閘門後第一個記憶單元內部狀態。 圖片來源:論文

圖 7

  • 觀察原版 LSTM 的記憶單元內部狀態,可以發現在不進行手動重設的狀態下,記憶單元內部狀態的數值只會不斷的累加(朝向極正或極負前進)
  • 觀察架上遺忘閘門後 LSTM 的記憶單元內部狀態,可以發現模型學會自動重設
    • 在第三個記憶單元中展現了長期記憶重設的能力
    • 在第一個記憶單元中展現了短期記憶重設的能力

實驗 2:Noisy Temporal Order Problem

任務定義

  • 就是原始 LSTM 論文中的實驗 6b,細節可以看我的筆記
  • 由於此任務需要讓記憶維持一段不短的時間,因此遺忘資訊對於這個任務可能有害,透過這個任務想要驗證是否有任務是只能使用原版 LSTM 可以解決但增加遺忘閘門後不能解決

LSTM 架構

與實驗 1 大致相同,只做以下修改

  • $\din = \dout = 8$
  • 將遺忘閘門的偏差項初始化成較大的正數(論文使用 $5$),讓遺忘閘門很難被關閉,藉此達到跟原本 LSTM 幾乎相同的計算能力

實驗結果

  • 使用遺忘閘門的 LSTM 仍然能夠解決 Noisy Temporal Order Problem
    • 當偏差項初始化成較大的正數(例如 $5$)時,收斂速度與原版 LSTM 一樣快
    • 當偏差項初始化成較小的正數(例如 $1$)時,收斂速度約為原版 LSTM 的 $3$ 倍
  • 因此根據實驗沒有什麼任務是原版 LSTM 可以解決但加上遺忘閘門後不能解決的

實驗 3:Continual Noisy Temporal Order Problem

任務定義

  • 根據原始 LSTM 論文中的實驗 6b 進行修改,輸入為連續序列,連續序列的定義是由 $100$ 筆 Noisy Temporal Order 序列所組成
  • 在一次的訓練過程中,給予模型的輸入只會在以下兩種狀況之一發生時停止
    • 當模型產生一次的預測錯誤
    • 模型連續接收 $100$ 個 Noisy Temporal Order 序列
  • 每次訓練停止就進行一次測試
    • 一次測試會執行 $10$ 次的連續輸入
    • 評估結果是 $10$ 次連續輸入中預測正確的序列個數平均值
  • 論文沒有講怎麼計算誤差與更新,我猜變成每個非預測時間點必須輸出 $0$,預測時間點時輸出預測結果
  • 訓練最多執行 $10^5$ 次,實驗結果由 $100$ 個訓練模型實驗進行平均

LSTM 架構

與實驗 2 相同。

實驗結果

圖 8:Continual Noisy Temporal Order Problem 實驗結果。 圖片來源:論文

圖 8

  • 圖 8 中的註解 a 應該寫錯了,應該改為 correct classification of 100 successive NTO sequences
  • 實驗再次驗證原版 LSTM 無法解決連續輸入,但使用輸入閘門後就能夠解決問題
  • 將 learning rate 使用 decay factor $0.9$ 逐漸下降可以讓模型表現變更好,但作者認為這不重要