当前位置:首页> AI教程> 长短期记忆网络(LSTM):解决RNN梯度问题的方法

长短期记忆网络(LSTM):解决RNN梯度问题的方法

释放双眼,带上耳机,听听看~!
本文介绍了长短期记忆网络(LSTM)的原理和公式,以及其在解决循环神经网络(RNN)梯度问题中的应用。

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第6天,点击查看活动详情

Long Short-Term Memory | MIT Press Journals & Magazine | IEEE Xplore

长短期存储器(long short-term memory, LSTM) 是为了解决RNN梯度爆炸梯度消失问题提出来的,因为RNN每一步都会保留上一步的一些东西,随着时间步逐渐变长,离得远的那些信息占比就很小了,所以提出了诸如LSTM、GRU等方法来解决这些问题。两者的主要思想是对于前边时间步的内容有选择地保留,直观可以理解为有用的信息多留一点,没用的适当丢弃。

先来看一下公式

长短期记忆网络(LSTM):解决RNN梯度问题的方法

  • 每一步计算都需要三部分内容:

    • 上一个时间步传递过来的记忆单元

    • 上一个时间步步传递过来的隐状态

    • 本时间步的输入

  • 每一步计算的输出都有两部分内容:

    • 本时间步的记忆单元

    • 本时间步的隐藏状态

  • 浅蓝色圆圈表示神经网络使用的激活函数

  • 深蓝色圆圈表示运算过程

三门

LSTM有三个门,它分别是输入门Itmathbf{I}_t、忘记门Ftmathbf{F}_t和输出门Otmathbf{O}_t

假设有 hh 个隐藏单元,批量大小为 nn,输入数为 dd

公式如下:

It=σ(XtWxi+Ht−1Whi+bi),Ft=σ(XtWxf+Ht−1Whf+bf),Ot=σ(XtWxo+Ht−1Who+bo),begin{aligned}
&mathbf{I}_t = sigma(mathbf{X}_t mathbf{W}_{xi} + mathbf{H}_{t-1} mathbf{W}_{hi} + mathbf{b}_i),
&mathbf{F}_t = sigma(mathbf{X}_t mathbf{W}_{xf} + mathbf{H}_{t-1} mathbf{W}_{hf} + mathbf{b}_f),
&mathbf{O}_t = sigma(mathbf{X}_t mathbf{W}_{xo} + mathbf{H}_{t-1} mathbf{W}_{ho} + mathbf{b}_o),
end{aligned}

  • 其中输入 Xt∈Rn×dmathbf{X}_t in mathbb{R}^{n times d}
  • 前一时间步的隐藏状态为 Ht−1∈Rn×hmathbf{H}_{t-1} in mathbb{R}^{n times h}
  • tt时间步时, 输入门It∈Rn×hmathbf{I}_t in mathbb{R}^{n times h},遗忘门Ft∈Rn×hmathbf{F}_t in mathbb{R}^{n times h},输出门Ot∈Rn×hmathbf{O}_t in mathbb{R}^{n times h}
  • Wxi,Wxf,Wxo∈Rd×hmathbf{W}_{xi}, mathbf{W}_{xf}, mathbf{W}_{xo} in mathbb{R}^{d times h}Whi,Whf,Who∈Rh×hmathbf{W}_{hi}, mathbf{W}_{hf}, mathbf{W}_{ho} in mathbb{R}^{h times h} 是权重参数
  • bi,bf,bo∈R1×hmathbf{b}_i, mathbf{b}_f, mathbf{b}_o in mathbb{R}^{1 times h} 是偏置参数。
  • 激活函数依旧使用sigmoid

当然也可以合并起来写:

It=σ([Xt,Ht−1]Wi+bi),Ft=σ([Xt,Ht−1]Wf+bf),Ot=σ([Xt,Ht−1]Wo+bo),begin{aligned}
&mathbf{I}_t = sigma([mathbf{X}_t ,mathbf{H}_{t-1}] mathbf{W}_{i} + mathbf{b}_i),
&mathbf{F}_t = sigma([mathbf{X}_t ,mathbf{H}_{t-1}] mathbf{W}_{f} + mathbf{b}_f),
&mathbf{O}_t = sigma([mathbf{X}_t ,mathbf{H}_{t-1}] mathbf{W}_{o} + mathbf{b}_o),
end{aligned}

候选记忆单元

长短期记忆网络引入了存储记忆单元(memory cell)。

C~t=tanh(XtWxc+Ht−1Whc+bc)tilde{mathbf{C}}_t = text{tanh}(mathbf{X}_t mathbf{W}_{xc} + mathbf{H}_{t-1} mathbf{W}_{hc} + mathbf{b}_c)

  • Wxc∈Rd×hmathbf{W}_{xc} in mathbb{R}^{d times h}Whc∈Rh×hmathbf{W}_{hc} in mathbb{R}^{h times h} 是权重参数。
  • bc∈R1×hmathbf{b}_c in mathbb{R}^{1 times h} 是偏置参数。
  • 候选记忆单元使用的激活函数是tanh。

记忆单元

输入门 Itmathbf{I}_t 控制采用多少来自 C~ttilde{mathbf{C}}_t 的新数据,而遗忘门 Ftmathbf{F}_t 控制保留了多少旧记忆单元 Ct−1∈Rn×hmathbf{C}_{t-1} in mathbb{R}^{n times h} 的内容。最后计算结果存储在记忆单元Ctmathbf{C}_t 中。

Ct=Ft⊙Ct−1+It⊙C~tmathbf{C}_t = mathbf{F}_t odot mathbf{C}_{t-1} + mathbf{I}_t odot tilde{mathbf{C}}_t

  • ⊙odot在这里的意思是按矩阵元素位置相乘,不是做普通的矩阵运算。

因为输入门、忘记门他们都使用的sigmoid作为激活函数。因此它们两个的值都是趋近于0或者近于1的。

  • 如果遗忘门为 11 且输入门为 00,则过去的记忆单元 Ct−1mathbf{C}_{t-1} 将随时间被保存并传递到当前时间步。
  • 如果遗忘门为 00 且输入门为 11,则过去的记忆单元 Ct−1mathbf{C}_{t-1} 被丢弃掉,仅使用当前的候选记忆单元C~ttilde{mathbf{C}}_t

引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。

隐藏单元

输入门遗忘门都介绍了,输出门的作用就在 隐藏单元Htmathbf{H}_t 计算这一步。

公式如下:

Ht=Ot⊙tanh⁡(Ct)mathbf{H}_t = mathbf{O}_t odot tanh(mathbf{C}_t)

  • 输出门接近 11,我们就能够把我们的记忆单元信息传递下去。
  • 输出门接近 00,我们只保留存储单元内的所有信息。

代码

之前我还会写手动实现,就是实现以下计算过程,然而实际上其实就是用代码堆出来计算公式,也没什么意思,以后就不搞了,直接写怎么用pytorch实现。

import torch
from torch import nn
from d2l import torch as d2l
from torch.nn import functional as F

导包啊导包,这个不用解释了吧。

train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
batch_size, num_steps = 32, 35

这里直接使用d2l加载数据集。

设定批量大小batch_size和时间步的长度num_steps,时间步的长度就是每次LSTM处理的一个序列的长度。

class LSTMModel(nn.Module):
    def __init__(self, lstm_layer, vocab_size, **kwargs):
        super(LSTMModel, self).__init__(**kwargs)
        self.lstm = lstm_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.lstm.hidden_size
        if not self.lstm.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.lstm(X, state)
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, device, batch_size=1):
        return (torch.zeros((
            self.num_directions * self.lstm.num_layers,
            batch_size, self.num_hiddens), device=device),
                torch.zeros((
                    self.num_directions * self.lstm.num_layers,
                    batch_size, self.num_hiddens), device=device))
  • __init__初始化这个类,这个类是继承了nn.Module的。

    • self.lstm设定计算层是LSTM层

    • self.vocab_size设定字典的大小,这里大小是28,因为为了方便演示,我们这里使用的是字母进行分词,所以其中只有a~z26个字母外加 <unk>(空格和unknown)。

    • self.num_hiddens设置隐藏层的大小。

      可能这里会导致迷惑,我刚开始看的时候也有一瞬间的迷惑。普通rnn的隐藏层,不是说其他的就是隐状态了吗?
      不是这样的,普通rnn是普通隐藏层,现在的GRU。LSTM是含有隐状态的隐藏层。所以还是隐藏层。

    • if-else语句是设定LSTM单双向的,毕竟还有双向LSTM这种东西的存在。

  • forward定义前向传播网络。

    也就是描述计算过程。

    • 这里首先是将输入转化为对应的one-hot向量,再将其类型转化为float。

    • Ystate是计算隐状态的,这里Y是输出全部的隐状态,state是输出最后一个时间步的隐状态。注意 在这里Y不是 输出。

    • output是用于存储输出的。

  • begin_state是进行初始化。

    这里是return了好长一个句子。可以拆解开看一下子。

    长短期记忆网络(LSTM):解决RNN梯度问题的方法

    这里是初始化为0张量,初始化位置device=device由你传入的位置决定是CPU还是GPU。这里和普通RNN的区别在于普通RNN和GRU不同,二者只需要返回一个张量即可,但LSTM这里是一个元组里两个张量。

这段代码看似是写了个LSTM的类,其实是换汤不换药的,就是之前手动简洁实现RNN那个文章里的RNN类改了一下子。不论是RNN还是GRU、LSTM,都是在那个类的基础上改的。那个RNN的类写的更齐全,详细的可以看→洁实现RNN循环神经网络](简洁实现RNN循环神经网络)。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 100, 1
num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = LSTMModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

然后就是我们的训练和预测过程。

  • 第一句和第二句分别是设置词典长度、隐藏层大小、运行在CPU还是GPU上、训练epoch数量、学习率。

  • 设置lstm使用pytorch自带的nn.LSTM

  • 模型使用我们的那个类,并且将其放到对应的设备上运行。

  • 最后一句就是预测训练的过程。

下图是100个epoch之后的结果,还没降到底,我这组训练数据要跑400多才能差不多平稳。结果会输出困惑度以及前缀为“time traveler”和“traveler”的预测结果。

长短期记忆网络(LSTM):解决RNN梯度问题的方法

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

基于深度学习的高精度鸡蛋检测识别系统

2023-12-17 12:21:14

AI教程

Windows10下YOLOv8 TensorRT CUDA加速部署

2023-12-17 12:31:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索