束搜索(Beam Search)是一种在序列生成任务(如机器翻译、文本摘要)中常用的解码策略。它旨在平衡贪心搜索(Greedy Search)的速度和穷举搜索的准确性,通过在每一步保留多个最可能的候选序列来寻找高质量的输出。
核心思想
- 贪心搜索的局限:贪心搜索在每个时间步只选择当前概率最高的词元,这可能导致错过全局最优解。例如,一个在早期步骤中概率稍低的词元,可能在后续步骤中引出整体概率更高的序列。
- 束搜索的策略:为了弥补这一不足,束搜索在每个时间步维护一个固定大小的候选序列集合。这个集合的大小被称为束宽(Beam Size),通常用
表示。 - 在解码的每一步,我们都会考虑上一步中
个候选序列的所有可能扩展。 - 然后,从所有这些新生成的序列中,我们计算并选择出总得分最高的
个,作为当前步骤的新候选集。 - 这个过程会一直持续,直到所有候选序列都生成了结束符
<eos>
或达到了预设的最大长度。
- 在解码的每一步,我们都会考虑上一步中
算法流程
下图展示了束搜索的解码过程:
graph TD
A["开始: <bos>"] --> B["t=1: 生成 k 个最可能的词<br/>{y₁¹, y₁², ..., y₁ᵏ}"]
B --> C["对于每个候选序列 y<t"]
C --> D["计算下一个词的条件概率<br/>P(y_t | y<t, c)"]
D --> E["生成 |V|*k 个新序列"]
E --> F["计算所有新序列的得分<br/>(通常是累积对数概率)"]
F --> G["选择得分最高的 k 个序列<br/>作为新的候选集"]
G --> H{"达到 <eos> 或最大长度?"}
H -- 否 --> C
H -- 是 --> I["从 k 个候选中选择最优序列"]
style B fill:#FF5733,stroke:#333,stroke-width:2px,color:#fff
style F fill:#40E0D0,stroke:#333,stroke-width:2px,color:#fff
style G fill:#FF1493,stroke:#333,stroke-width:2px,color:#fff
style I fill:#7CFC00,stroke:#333,stroke-width:2px,color:#fff
得分计算
在时间步
其中
长度惩罚
由于对数概率是负值,序列越长,累加的得分会越低,这会导致算法偏向于选择更短的序列。为了缓解这个问题,通常会引入长度惩罚(Length Penalty):
其中
图示
下图直观地展示了束搜索的过程,其中束宽
总结
- 束搜索在每次搜索时,保存 K 歌最好的候选
- K=1 时,就是贪心搜索
- K=n 时,是穷举搜索
- 束搜索是一种启发式算法,它不保证能找到全局最优解,但在实践中通常能以可控的计算成本生成高质量的序列。