logologo

束搜索 (Beam Search)

Jul 12

束搜索(Beam Search)是一种在序列生成任务(如机器翻译、文本摘要)中常用的解码策略。它旨在平衡贪心搜索(Greedy Search)的速度和穷举搜索的准确性,通过在每一步保留多个最可能的候选序列来寻找高质量的输出。

核心思想

  • 贪心搜索的局限:贪心搜索在每个时间步只选择当前概率最高的词元,这可能导致错过全局最优解。例如,一个在早期步骤中概率稍低的词元,可能在后续步骤中引出整体概率更高的序列。
  • 束搜索的策略:为了弥补这一不足,束搜索在每个时间步维护一个固定大小的候选序列集合。这个集合的大小被称为束宽(Beam Size),通常用k表示。
    • 在解码的每一步,我们都会考虑上一步中k个候选序列的所有可能扩展。
    • 然后,从所有这些新生成的序列中,我们计算并选择出总得分最高的k个,作为当前步骤的新候选集。
    • 这个过程会一直持续,直到所有候选序列都生成了结束符 <eos> 或达到了预设的最大长度。

算法流程

下图展示了束搜索的解码过程:

graph TD
    A["开始: <bos>"] --> B["t=1: 生成 k 个最可能的词<br/>{y₁¹, y₁², ..., y₁ᵏ}"]
    B --> C["对于每个候选序列 y<t"]
    C --> D["计算下一个词的条件概率<br/>P(y_t | y<t, c)"]
    D --> E["生成 |V|*k 个新序列"]
    E --> F["计算所有新序列的得分<br/>(通常是累积对数概率)"]
    F --> G["选择得分最高的 k 个序列<br/>作为新的候选集"]
    G --> H{"达到 <eos> 或最大长度?"}
    H --  --> C
    H --  --> I["从 k 个候选中选择最优序列"]

    style B fill:#FF5733,stroke:#333,stroke-width:2px,color:#fff
    style F fill:#40E0D0,stroke:#333,stroke-width:2px,color:#fff
    style G fill:#FF1493,stroke:#333,stroke-width:2px,color:#fff
    style I fill:#7CFC00,stroke:#333,stroke-width:2px,color:#fff

得分计算

在时间步t,一个候选序列(y1,,yt)的得分通常是其对数似然之和:

logP(y1,,yt|c)=i=1tlogP(yi|y1,,yi1,c)

其中c是编码器输出的上下文向量。

长度惩罚

由于对数概率是负值,序列越长,累加的得分会越低,这会导致算法偏向于选择更短的序列。为了缓解这个问题,通常会引入长度惩罚(Length Penalty)

L(y1,,yt)=1tαi=1tlogP(yi|y1,,yi1,c)

其中t是当前序列的长度,α是一个超参数(通常在 0.6 到 0.75 之间),用于控制惩罚的强度。当α=0时,没有长度惩罚;当α=1时,得分是几何平均对数概率。

图示

下图直观地展示了束搜索的过程,其中束宽k=2。在每个时间步,都会保留两个得分最高的序列。

束搜索过程图示 (k=2)

总结

  • 束搜索在每次搜索时,保存 K 歌最好的候选
  • K=1 时,就是贪心搜索
  • K=n 时,是穷举搜索
  • 束搜索是一种启发式算法,它不保证能找到全局最优解,但在实践中通常能以可控的计算成本生成高质量的序列。
浙ICP备2021022773号    2022-PRESENT © ZhengKe