苟利国家生死以,岂因祸福避趋之。

简单写写 RNN 和 LSTM 的基本原理。如果后续有时间就细化一下。

RNN 和 LSTM 都是用来进行序列预测的。

# RNN

# 原理

RNN 可以概括为下面两个式子:

ht=σ(Whhht1+Whxxt1+bh)(update hidden state)ot=σ(Wohht+bo)(output)\begin{aligned} & \boldsymbol{h}_t = \sigma(\boldsymbol{W}_{hh} \boldsymbol{h}_{t-1} + \boldsymbol{W}_{hx}\boldsymbol{x}_{t-1} + \boldsymbol{b}_h) && \text{(update hidden state)}\\ & \boldsymbol{o}_t = \sigma(\boldsymbol{W}_{oh} \boldsymbol{h}_t + \boldsymbol{b}_o) && \text{(output)} \end{aligned}

其中 o\boldsymbol{o} 是输出特征,h\boldsymbol{h} 是隐变量。

关于权重矩阵的下标定义:WhxW_{hx} 代表输入是 x\boldsymbol{x}, 输出是 hh, 类似地,WohW_{oh} 代表输入是 h\boldsymbol{h}, 输出是 o\boldsymbol{o}.

# LSTM

长短期记忆网络 (Long Short-Term Memory, LSTM) 包括下面几个结构:

It=σ(XtWxi+Ht1Whi+bi)输入门(input gate)Ft=σ(XtWxf+Ht1Whf+bf)忘记门(forget gate)It=σ(XtWxo+Ht1Who+bo)输出门(output gate)C~t=tanh(XtWxc+Ht1Whc)候选记忆单元(candidate memory)Ct=FtCt1+ItC~t记忆单元(memory)Ht=Ottanh(Ct)隐状态(hidden state)\begin{aligned} & \boldsymbol{I}_t = \sigma(\boldsymbol{X}_t\boldsymbol{W}_{xi}+\boldsymbol{H}_{t-1}\boldsymbol{W}_{hi}+\boldsymbol{b}_i) && \text{输入门(input gate)} \\ & \boldsymbol{F}_t = \sigma(\boldsymbol{X}_t\boldsymbol{W}_{xf}+\boldsymbol{H}_{t-1}\boldsymbol{W}_{hf}+\boldsymbol{b}_f) && \text{忘记门(forget gate)} \\ & \boldsymbol{I}_t = \sigma(\boldsymbol{X}_t\boldsymbol{W}_{xo}+\boldsymbol{H}_{t-1}\boldsymbol{W}_{ho}+\boldsymbol{b}_o) && \text{输出门(output gate)}\\ & \tilde{\boldsymbol{C}}_t = \tanh(\boldsymbol{X}_t\boldsymbol{W}_{xc}+\boldsymbol{H}_{t-1}\boldsymbol{W}_{hc}) && \text{候选记忆单元(candidate memory)} \\ & \boldsymbol{C}_t = \boldsymbol{F}_t \odot \boldsymbol{C}_{t-1} + \boldsymbol{I}_t \odot \tilde{\boldsymbol{C}}_t && \text{记忆单元(memory)} \\ & \boldsymbol{H}_t = \boldsymbol{O}_t \odot \tanh(\boldsymbol{C}_t) && \text{隐状态(hidden state)} \\ \end{aligned}

# GRU

Rt=σ(XtWxr+Ht1Whr+br)Zt=σ(XtWxz+Ht1Whz+bz)H~t=tanh(XtWxh+(RtHt1)Whh+bh)H=ZtHt1+(1Zt)H~t隐状态(hidden state)\begin{aligned} & \boldsymbol{R}_t = \sigma(\boldsymbol{X}_t\boldsymbol{W}_{xr} + \boldsymbol{H}_{t-1}\boldsymbol{W}_{hr} + \boldsymbol{b}_r) \\ & \boldsymbol{Z}_t = \sigma(\boldsymbol{X}_t\boldsymbol{W}_{xz} + \boldsymbol{H}_{t-1}\boldsymbol{W}_{hz} + \boldsymbol{b}_z) \\ & \tilde{\boldsymbol{H}}_t = \tanh(\boldsymbol{X}_t\boldsymbol{W}_{xh} + (\boldsymbol{R}_t \odot \boldsymbol{H}_{t-1})\boldsymbol{W}_{hh}+\boldsymbol{b}_h) \\ & \boldsymbol{H} = \boldsymbol{Z}_t \odot \boldsymbol{H}_{t-1} + (1-\boldsymbol{Z}_t) \odot \tilde{\boldsymbol{H}}_t && \text{隐状态(hidden state)} \end{aligned}