logologo

长短期记忆网络 (LSTM)

Jul 10

长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的门控循环神经网络,专门设计用来解决传统 RNN 的梯度消失问题和长期依赖问题。LSTM 比 GRU 结构稍微复杂,但能更精细地控制信息流动。

核心特征

  • 三门控机制:输入门、遗忘门、输出门精确控制信息流
  • 记忆细胞:独立的记忆单元,专门存储长期信息
  • 长期依赖:能够有效捕获和保持长距离的时间依赖关系
  • 梯度稳定:通过门控机制有效缓解梯度消失和爆炸问题

LSTM 结构

LSTM 包含三个核心门控机制和一个记忆细胞:

输入门 (Input Gate)

控制当前输入信息有多少流入记忆细胞:

It=σ(XtWxi+Ht1Whi+bi)

遗忘门 (Forget Gate)

控制从上一时间步记忆细胞中丢弃多少信息:

Ft=σ(XtWxf+Ht1Whf+bf)

输出门 (Output Gate)

控制记忆细胞的多少信息输出到隐状态:

Ot=σ(XtWxo+Ht1Who+bo)

候选记忆细胞

计算候选的新信息,准备添加到记忆细胞:

C~t=tanh(XtWxc+Ht1Whc+bc)

记忆细胞

结合遗忘门和输入门,更新记忆细胞:

Ct=FtCt1+ItC~t

隐状态

通过输出门控制从记忆细胞到隐状态的信息流:

Ht=Ottanh(Ct)

LSTM 工作流程

flowchart TD
    A["输入 Xt"] --> B["遗忘门 Ft"]
    A --> C["输入门 It"]
    A --> D["输出门 Ot"]
    A --> E["候选记忆 C̃t"]

    F["前一记忆 Ct-1"] --> G["记忆细胞 Ct"]
    B --> G
    C --> G
    E --> G

    G --> H["隐状态 Ht"]
    D --> H

    I["前一隐状态 Ht-1"] --> B
    I --> C
    I --> D
    I --> E

    style A fill:#FF4444,stroke:#CC0000,stroke-width:3px,color:#fff
    style B fill:#44DDDD,stroke:#00AAAA,stroke-width:3px,color:#fff
    style C fill:#44DDDD,stroke:#00AAAA,stroke-width:3px,color:#fff
    style D fill:#44DDDD,stroke:#00AAAA,stroke-width:3px,color:#fff
    style E fill:#4488FF,stroke:#0044CC,stroke-width:3px,color:#fff
    style F fill:#88FF88,stroke:#44AA44,stroke-width:3px,color:#000
    style G fill:#FFAA44,stroke:#CC6600,stroke-width:3px,color:#fff
    style H fill:#FF8844,stroke:#CC4400,stroke-width:3px,color:#fff
    style I fill:#88FF88,stroke:#44AA44,stroke-width:3px,color:#000

计算步骤说明:

  1. 门控计算:根据当前输入和前一隐状态计算三个门的值
  2. 候选记忆:计算准备添加到记忆细胞的新信息
  3. 记忆更新:遗忘门控制保留多少旧记忆,输入门控制接收多少新信息
  4. 隐状态输出:输出门控制记忆细胞的多少信息传递到隐状态

关键机制说明

遗忘门的作用

  • Ft0时:完全遗忘之前的记忆
  • Ft1时:完全保留之前的记忆

输入门的作用

  • It0时:拒绝当前的新信息
  • It1时:完全接受当前的新信息

输出门的作用

  • Ot0时:隐状态不输出记忆信息
  • Ot1时:隐状态完全基于记忆细胞

计算示例

LSTM输入门、遗忘门和输出门计算

LSTM候选记忆细胞计算

LSTM记忆细胞计算

LSTM隐状态计算

LSTM vs GRU vs 传统 RNN

特性传统 RNNGRULSTM
门控数量023
记忆机制隐状态隐状态独立记忆细胞
参数量中等
计算复杂度中等
长期依赖很好
训练速度中等

应用场景

  • 序列建模:语言模型、文本生成
  • 机器翻译:序列到序列转换
  • 语音识别:音频序列处理
  • 时间序列预测:股票、天气预测
  • 情感分析:文本情感分类
  • 对话系统:聊天机器人

优势与限制

优势

  • 有效解决梯度消失问题
  • 优秀的长期依赖捕获能力
  • 精细的信息流控制
  • 广泛的应用场景适应性

限制

  • 计算复杂度高,训练慢
  • 参数量大,容易过拟合
  • 仍然难以并行化处理
  • 对于极长序列仍有挑战

实现要点

从零开始实现(PyTorch 版本)

初始化参数

import torch
import torch.nn as nn

def get_lstm_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(*shape, device=device) * 0.01

    def three():
        return (normal((num_inputs, num_hiddens)),
                normal((num_hiddens, num_hiddens)),
                torch.zeros(num_hiddens, device=device))

    # 输入门参数
    W_xi, W_hi, b_i = three()
    # 遗忘门参数
    W_xf, W_hf, b_f = three()
    # 输出门参数
    W_xo, W_ho, b_o = three()
    # 候选记忆细胞参数
    W_xc, W_hc, b_c = three()
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)

    # 附加梯度
    params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o,
              W_xc, W_hc, b_c, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

初始化状态

def init_lstm_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device),  # 隐状态
            torch.zeros((batch_size, num_hiddens), device=device))  # 记忆细胞

LSTM 前向传播

def lstm(inputs, state, params):
    [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o,
     W_xc, W_hc, b_c, W_hq, b_q] = params
    (H, C) = state
    outputs = []

    for X in inputs:
        # 输入门
        I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)
        # 遗忘门
        F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)
        # 输出门
        O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)
        # 候选记忆细胞
        C_tilde = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)
        # 记忆细胞
        C = F * C + I * C_tilde
        # 隐状态
        H = O * torch.tanh(C)
        # 输出
        Y = H @ W_hq + b_q
        outputs.append(Y)

    return torch.cat(outputs, dim=0), (H, C)

简洁实现

# 使用PyTorch内置LSTM
lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(lstm_layer, vocab_size)

# 训练模型
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)

自定义 LSTM 类

class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # 输入门
        self.W_xi = nn.Parameter(torch.randn(input_size, hidden_size))
        self.W_hi = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_i = nn.Parameter(torch.randn(hidden_size))

        # 遗忘门
        self.W_xf = nn.Parameter(torch.randn(input_size, hidden_size))
        self.W_hf = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_f = nn.Parameter(torch.randn(hidden_size))

        # 输出门
        self.W_xo = nn.Parameter(torch.randn(input_size, hidden_size))
        self.W_ho = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_o = nn.Parameter(torch.randn(hidden_size))

        # 候选记忆细胞
        self.W_xc = nn.Parameter(torch.randn(input_size, hidden_size))
        self.W_hc = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_c = nn.Parameter(torch.randn(hidden_size))

    def forward(self, input, hidden):
        h, c = hidden

        # 计算门控
        forget_gate = torch.sigmoid(input @ self.W_xf + h @ self.W_hf + self.b_f)
        input_gate = torch.sigmoid(input @ self.W_xi + h @ self.W_hi + self.b_i)
        output_gate = torch.sigmoid(input @ self.W_xo + h @ self.W_ho + self.b_o)

        # 候选记忆
        candidate = torch.tanh(input @ self.W_xc + h @ self.W_hc + self.b_c)

        # 更新记忆细胞和隐状态
        c = forget_gate * c + input_gate * candidate
        h = output_gate * torch.tanh(c)

        return h, (h, c)

小结

LSTM 通过输入门遗忘门输出门记忆细胞的精密设计,实现了:

  • 精细控制:三个门控机制分工明确,精确控制信息流
  • 长期记忆:独立的记忆细胞专门存储长期信息
  • 梯度稳定:有效缓解梯度消失问题,支持长序列训练
  • 广泛应用:在自然语言处理、语音识别等领域表现优异

LSTM 是深度学习中处理序列数据的重要工具,虽然计算复杂度较高,但在需要精确控制信息流和处理长期依赖的任务中仍然是首选。

浙ICP备2021022773号    2022-PRESENT © ZhengKe