编码器-解码器架构(Encoder-Decoder Architecture) 是处理序列转换问题的核心框架,特别适用于机器翻译等输入和输出都是长度可变序列的任务。该架构包含两个主要组件:
- 编码器(Encoder):接受长度可变的序列作为输入,将其转换为具有固定形状的编码状态
- 解码器(Decoder):将固定形状的编码状态映射到长度可变的序列
架构设计原理
整体工作流程
flowchart TD
A[输入序列 X] --> B[编码器 Encoder]
B --> C[编码状态 State]
C --> D[解码器 Decoder]
D --> E[输出序列 Y]
A1[词元化] --> A
E --> E1[后处理]
subgraph "编码过程"
B --> B1[序列编码]
B1 --> B2[状态提取]
B2 --> B3[上下文表示]
end
subgraph "解码过程"
D --> D1[状态初始化]
D1 --> D2[逐步生成]
D2 --> D3[序列输出]
end
style A fill:#FF6B6B,stroke:#FF1744,stroke-width:3px,color:#fff
style B fill:#4ECDC4,stroke:#00BCD4,stroke-width:3px,color:#fff
style C fill:#FECA57,stroke:#FF9800,stroke-width:3px,color:#333
style D fill:#96CEB4,stroke:#4CAF50,stroke-width:3px,color:#fff
style E fill:#DDA0DD,stroke:#9C27B0,stroke-width:3px,color:#fff
style B1 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
style B2 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
style B3 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
style D1 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
style D2 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
style D3 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
数学表示
编码器-解码器架构的数学描述如下:
编码阶段:
解码阶段:
其中:
是输入序列 是编码状态(上下文向量) 是第 步的输出 和 分别是编码器和解码器函数
编码器实现(基于 PyTorch)
from torch import nn
#@save
class Encoder(nn.Module):
"""编码器-解码器架构的基本编码器接口"""
def __init__(self, **kwargs):
super(Encoder, self).__init__(**kwargs)
def forward(self, X, *args):
raise NotImplementedError
编码过程详解
sequenceDiagram
participant Input as 输入序列
participant Embed as 嵌入层
participant RNN as RNN层
participant Pool as 池化层
participant Output as 编码状态
Input->>Embed: 词元序列
Embed->>RNN: 嵌入向量
RNN->>RNN: 序列处理
Note over RNN: 处理每个时间步
RNN->>Pool: 隐藏状态序列
Pool->>Output: 最终编码状态
Note over Input,Output: 长度可变 → 固定形状
%% 颜色主题
%%{init: {'theme':'base', 'themeVariables': {
'primaryColor': '#FF6B6B',
'primaryTextColor': '#ffffff',
'primaryBorderColor': '#FF1744',
'lineColor': '#4ECDC4',
'sectionBkgColor': '#FECA57',
'altSectionBkgColor': '#96CEB4',
'gridColor': '#DDA0DD',
'c0': '#FF6B6B',
'c1': '#4ECDC4',
'c2': '#FECA57',
'c3': '#96CEB4',
'c4': '#DDA0DD'
}}}%%
解码器实现(基于 PyTorch)
#@save
class Decoder(nn.Module):
"""编码器-解码器架构的基本解码器接口"""
def __init__(self, **kwargs):
super(Decoder, self).__init__(**kwargs)
def init_state(self, enc_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
解码过程详解
graph TD
A[编码状态 c] --> B[初始化解码器状态]
B --> C[开始标记 <bos>]
C --> D[解码器步骤 t=1]
D --> E{生成 <eos>?}
E -->|否| F[输出词元 y_t]
F --> G[更新状态]
G --> H[解码器步骤 t+1]
H --> E
E -->|是| I[完成序列生成]
subgraph "单步解码"
D --> D1[注意力计算]
D1 --> D2[上下文融合]
D2 --> D3[输出预测]
end
style A fill:#FF6B6B,stroke:#FF1744,stroke-width:3px,color:#fff
style B fill:#4ECDC4,stroke:#00BCD4,stroke-width:3px,color:#fff
style C fill:#FECA57,stroke:#FF9800,stroke-width:3px,color:#333
style D fill:#96CEB4,stroke:#4CAF50,stroke-width:3px,color:#fff
style E fill:#DDA0DD,stroke:#9C27B0,stroke-width:3px,color:#fff
style F fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
style G fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
style H fill:#96CEB4,stroke:#4CAF50,stroke-width:3px,color:#fff
style I fill:#FFAAA5,stroke:#FF5722,stroke-width:3px,color:#fff
style D1 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333
style D2 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333
style D3 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333
编码器-解码器架构(基于 PyTorch)
#@save
class EncoderDecoder(nn.Module):
"""编码器-解码器架构的基类"""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_outputs, *args)
return self.decoder(dec_X, dec_state)
训练与推理流程对比
flowchart TD
subgraph Train["🎓 训练阶段 (Teacher Forcing)"]
direction TB
T1[源序列<br/>Source Sequence] --> T2[编码器<br/>Encoder]
T2 --> T3[编码状态<br/>Encoded State]
T4[目标序列<br/>Target Sequence] --> T5[解码器<br/>Decoder]
T3 --> T5
T5 --> T6[预测序列<br/>Predictions]
T6 --> T7[损失计算<br/>Loss Calculation]
T7 --> T8[反向传播<br/>Backpropagation]
end
subgraph Infer["🔮 推理阶段 (自回归生成)"]
direction TB
I1[源序列<br/>Source Sequence] --> I2[编码器<br/>Encoder]
I2 --> I3[编码状态<br/>Encoded State]
I4[开始标记<br/><bos>] --> I5[解码器步骤1<br/>Decoder Step 1]
I3 --> I5
I5 --> I6[词元1<br/>Token 1]
I6 --> I7[解码器步骤2<br/>Decoder Step 2]
I3 --> I7
I7 --> I8[词元2<br/>Token 2]
I8 --> I9[继续生成...<br/>Continue...]
end
%% 训练阶段样式
style T1 fill:#FF6B6B,stroke:#FF1744,stroke-width:2px,color:#fff
style T2 fill:#4ECDC4,stroke:#00BCD4,stroke-width:2px,color:#fff
style T3 fill:#FECA57,stroke:#FF9800,stroke-width:2px,color:#333
style T4 fill:#FF6B6B,stroke:#FF1744,stroke-width:2px,color:#fff
style T5 fill:#96CEB4,stroke:#4CAF50,stroke-width:2px,color:#fff
style T6 fill:#DDA0DD,stroke:#9C27B0,stroke-width:2px,color:#fff
style T7 fill:#FFE66D,stroke:#FFEB3B,stroke-width:2px,color:#333
style T8 fill:#FFAAA5,stroke:#FF5722,stroke-width:2px,color:#fff
%% 推理阶段样式
style I1 fill:#FF6B6B,stroke:#FF1744,stroke-width:2px,color:#fff
style I2 fill:#4ECDC4,stroke:#00BCD4,stroke-width:2px,color:#fff
style I3 fill:#FECA57,stroke:#FF9800,stroke-width:2px,color:#333
style I4 fill:#A8E6CF,stroke:#8BC34A,stroke-width:2px,color:#333
style I5 fill:#96CEB4,stroke:#4CAF50,stroke-width:2px,color:#fff
style I6 fill:#DDA0DD,stroke:#9C27B0,stroke-width:2px,color:#fff
style I7 fill:#96CEB4,stroke:#4CAF50,stroke-width:2px,color:#fff
style I8 fill:#DDA0DD,stroke:#9C27B0,stroke-width:2px,color:#fff
style I9 fill:#C7CEEA,stroke:#3F51B5,stroke-width:2px,color:#333
%% 子图样式
style Train fill:#f9f9f9,stroke:#333,stroke-width:3px
style Infer fill:#f0f8ff,stroke:#333,stroke-width:3px
核心差异对比
特性 | 训练阶段 (Teacher Forcing) | 推理阶段 (自回归生成) |
---|---|---|
输入方式 | 完整目标序列同时输入 | 逐个生成,前一个输出作为下一个输入 |
并行性 | 可以并行计算所有时间步 | 必须串行生成,无法并行 |
速度 | 训练速度快 | 推理速度相对较慢 |
稳定性 | 稳定,不受生成错误影响 | 容易累积错误 |
真实性 | 使用真实标签,可能存在曝光偏差 | 更接近实际使用场景 |
参考资源: