logologo

编码器-解码器架构

Jul 12

编码器-解码器架构(Encoder-Decoder Architecture) 是处理序列转换问题的核心框架,特别适用于机器翻译等输入和输出都是长度可变序列的任务。该架构包含两个主要组件:

  • 编码器(Encoder):接受长度可变的序列作为输入,将其转换为具有固定形状的编码状态
  • 解码器(Decoder):将固定形状的编码状态映射到长度可变的序列

编码器-解码器架构

架构设计原理

整体工作流程

flowchart TD
    A[输入序列 X] --> B[编码器 Encoder]
    B --> C[编码状态 State]
    C --> D[解码器 Decoder]
    D --> E[输出序列 Y]

    A1[词元化] --> A
    E --> E1[后处理]

    subgraph "编码过程"
        B --> B1[序列编码]
        B1 --> B2[状态提取]
        B2 --> B3[上下文表示]
    end

    subgraph "解码过程"
        D --> D1[状态初始化]
        D1 --> D2[逐步生成]
        D2 --> D3[序列输出]
    end

    style A fill:#FF6B6B,stroke:#FF1744,stroke-width:3px,color:#fff
    style B fill:#4ECDC4,stroke:#00BCD4,stroke-width:3px,color:#fff
    style C fill:#FECA57,stroke:#FF9800,stroke-width:3px,color:#333
    style D fill:#96CEB4,stroke:#4CAF50,stroke-width:3px,color:#fff
    style E fill:#DDA0DD,stroke:#9C27B0,stroke-width:3px,color:#fff

    style B1 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
    style B2 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
    style B3 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
    style D1 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
    style D2 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
    style D3 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333

数学表示

编码器-解码器架构的数学描述如下:

编码阶段:

c=fenc(x1,x2,,xT)

解码阶段:

yt=fdec(y1,y2,,yt1,c)

其中:

  • x1,x2,,xT是输入序列
  • c是编码状态(上下文向量)
  • yt是第t步的输出
  • fencfdec分别是编码器和解码器函数

编码器实现(基于 PyTorch)

from torch import nn

#@save
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError

编码过程详解

sequenceDiagram
    participant Input as 输入序列
    participant Embed as 嵌入层
    participant RNN as RNN层
    participant Pool as 池化层
    participant Output as 编码状态

    Input->>Embed: 词元序列
    Embed->>RNN: 嵌入向量
    RNN->>RNN: 序列处理
    Note over RNN: 处理每个时间步
    RNN->>Pool: 隐藏状态序列
    Pool->>Output: 最终编码状态

    Note over Input,Output: 长度可变 → 固定形状

    %% 颜色主题
    %%{init: {'theme':'base', 'themeVariables': {
        'primaryColor': '#FF6B6B',
        'primaryTextColor': '#ffffff',
        'primaryBorderColor': '#FF1744',
        'lineColor': '#4ECDC4',
        'sectionBkgColor': '#FECA57',
        'altSectionBkgColor': '#96CEB4',
        'gridColor': '#DDA0DD',
        'c0': '#FF6B6B',
        'c1': '#4ECDC4',
        'c2': '#FECA57',
        'c3': '#96CEB4',
        'c4': '#DDA0DD'
    }}}%%

解码器实现(基于 PyTorch)

#@save
class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    def init_state(self, enc_outputs, *args):
        raise NotImplementedError

    def forward(self, X, state):
        raise NotImplementedError

解码过程详解

graph TD
    A[编码状态 c] --> B[初始化解码器状态]
    B --> C[开始标记 <bos>]
    C --> D[解码器步骤 t=1]
    D --> E{生成 <eos>?}
    E -->|否| F[输出词元 y_t]
    F --> G[更新状态]
    G --> H[解码器步骤 t+1]
    H --> E
    E -->|是| I[完成序列生成]

    subgraph "单步解码"
        D --> D1[注意力计算]
        D1 --> D2[上下文融合]
        D2 --> D3[输出预测]
    end

    style A fill:#FF6B6B,stroke:#FF1744,stroke-width:3px,color:#fff
    style B fill:#4ECDC4,stroke:#00BCD4,stroke-width:3px,color:#fff
    style C fill:#FECA57,stroke:#FF9800,stroke-width:3px,color:#333
    style D fill:#96CEB4,stroke:#4CAF50,stroke-width:3px,color:#fff
    style E fill:#DDA0DD,stroke:#9C27B0,stroke-width:3px,color:#fff
    style F fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
    style G fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
    style H fill:#96CEB4,stroke:#4CAF50,stroke-width:3px,color:#fff
    style I fill:#FFAAA5,stroke:#FF5722,stroke-width:3px,color:#fff

    style D1 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333
    style D2 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333
    style D3 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333

编码器-解码器架构(基于 PyTorch)

#@save
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args)
        return self.decoder(dec_X, dec_state)

训练与推理流程对比

flowchart TD
    subgraph Train["🎓 训练阶段 (Teacher Forcing)"]
        direction TB
        T1[源序列<br/>Source Sequence] --> T2[编码器<br/>Encoder]
        T2 --> T3[编码状态<br/>Encoded State]
        T4[目标序列<br/>Target Sequence] --> T5[解码器<br/>Decoder]
        T3 --> T5
        T5 --> T6[预测序列<br/>Predictions]
        T6 --> T7[损失计算<br/>Loss Calculation]
        T7 --> T8[反向传播<br/>Backpropagation]
    end

    subgraph Infer["🔮 推理阶段 (自回归生成)"]
        direction TB
        I1[源序列<br/>Source Sequence] --> I2[编码器<br/>Encoder]
        I2 --> I3[编码状态<br/>Encoded State]
        I4[开始标记<br/>&lt;bos&gt;] --> I5[解码器步骤1<br/>Decoder Step 1]
        I3 --> I5
        I5 --> I6[词元1<br/>Token 1]
        I6 --> I7[解码器步骤2<br/>Decoder Step 2]
        I3 --> I7
        I7 --> I8[词元2<br/>Token 2]
        I8 --> I9[继续生成...<br/>Continue...]
    end

    %% 训练阶段样式
    style T1 fill:#FF6B6B,stroke:#FF1744,stroke-width:2px,color:#fff
    style T2 fill:#4ECDC4,stroke:#00BCD4,stroke-width:2px,color:#fff
    style T3 fill:#FECA57,stroke:#FF9800,stroke-width:2px,color:#333
    style T4 fill:#FF6B6B,stroke:#FF1744,stroke-width:2px,color:#fff
    style T5 fill:#96CEB4,stroke:#4CAF50,stroke-width:2px,color:#fff
    style T6 fill:#DDA0DD,stroke:#9C27B0,stroke-width:2px,color:#fff
    style T7 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
    style T8 fill:#FFAAA5,stroke:#FF5722,stroke-width:2px,color:#fff

    %% 推理阶段样式
    style I1 fill:#FF6B6B,stroke:#FF1744,stroke-width:2px,color:#fff
    style I2 fill:#4ECDC4,stroke:#00BCD4,stroke-width:2px,color:#fff
    style I3 fill:#FECA57,stroke:#FF9800,stroke-width:2px,color:#333
    style I4 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
    style I5 fill:#96CEB4,stroke:#4CAF50,stroke-width:2px,color:#fff
    style I6 fill:#DDA0DD,stroke:#9C27B0,stroke-width:2px,color:#fff
    style I7 fill:#96CEB4,stroke:#4CAF50,stroke-width:2px,color:#fff
    style I8 fill:#DDA0DD,stroke:#9C27B0,stroke-width:2px,color:#fff
    style I9 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333

    %% 子图样式
    style Train fill:#f9f9f9,stroke:#333,stroke-width:3px
    style Infer fill:#f0f8ff,stroke:#333,stroke-width:3px

核心差异对比

特性训练阶段 (Teacher Forcing)推理阶段 (自回归生成)
输入方式完整目标序列同时输入逐个生成,前一个输出作为下一个输入
并行性可以并行计算所有时间步必须串行生成,无法并行
速度训练速度快推理速度相对较慢
稳定性稳定,不受生成错误影响容易累积错误
真实性使用真实标签,可能存在曝光偏差更接近实际使用场景

参考资源

浙ICP备2021022773号    2022-PRESENT © ZhengKe