241122论文阅读


论文代码阅读——手打版

论文链接:

github仓库链接:
https://github.com/qingzhenduyu/TAMER

这里会更多地关注该项目用到的模型架构和训练流程。

tamer 文件夹中的 lit_tamer.py 开始阅读。

class LitTAMER(pl.LightningModule) 中,定义了模型为 self.tamer_model = TAMER(...),并定义了 self.exprate_recorder = ExpRateRecorder()

TAMER 模型(tamer.py)由 Encoder(encoder.py)Decoder(decoder.py) 组成。在定义的 beam_search 函数中,除了 self.encoder,还调用了 self.decoder.beam_search

关于 beam_search,下面是 GPT 对它的介绍:

点击查看 Beam Search 的详细介绍

Beam Search 的原理与方法

1. Beam Search 是什么?

Beam Search 是一种序列解码算法,通常用于自然语言生成任务(如机器翻译、文本摘要、图像字幕生成等)。它在逐步生成序列时,不是只选择一个最优候选,而是维护一个固定大小的候选集合,从中寻找全局最优的生成结果。


2. 核心思想

  • 扩展搜索树
    将序列生成问题看作树的扩展,每次扩展节点时,基于模型预测的概率,选择前几个(如 beam size = 3)概率最高的候选。

  • 权衡搜索空间和效率
    相比贪婪搜索(一次只保留最优路径),Beam Search 能探索更多路径,避免局部最优;相比暴力穷举(尝试所有可能路径),它限制了搜索宽度,提升效率。


3. 工作流程

以下是逐步生成序列的过程:

  1. 初始化

    • 起始标记(如 <SOS>)作为初始序列,分数为 0。
    • 初始化候选集合(称为 Beam),大小为 beam_size
  2. 扩展 Beam

    • 对当前 Beam 中的每个序列,用模型预测下一个词的概率分布。
    • 将每个候选序列与其可能的扩展(即词)组合,计算总得分。
    • 保留总得分最高的 beam_size 个候选序列。
  3. 终止条件

    • 所有候选序列都生成了终止标记 <EOS>
    • 或生成的序列长度达到预设最大值 max_len
  4. 输出结果

    • 从最终的候选集合中,选择分数最高的序列作为输出。

4. 关键公式

总得分计算
[
\text{Score} = \frac{\sum_{t=1}^L \log P(y_t | y_{1:t-1})}{L^\alpha}
]

  • ( L ):序列长度。
  • ( \alpha ):长度惩罚系数(( \alpha > 0 ) 时,惩罚短序列)。

长度惩罚的作用

  • 防止模型过于偏向短序列(短序列的得分累积更高)。
  • 调整模型对不同长度的偏好。

5. 优势与不足

  • 优势

    • 在可接受的计算开销下,生成质量较高的序列。
    • 能发现更优解,避免陷入局部最优。
  • 不足

    • 生成的序列仍可能不是全局最优(近似搜索)。
    • Beam Size 过大会显著增加计算量。

6. 与 TAMER 的关系

在 TAMER 中:

  • Encoder
    提取图像特征,作为解码器的上下文信息。

  • Decoder 的 Beam Search

    • 使用解码器 beam_search,根据编码器输出的特征,逐步生成序列。
    • 每个时间步中扩展候选序列,保留 beam_size 个得分最高的序列,直到生成完成。

7. Beam Search 的改进方向

  • 加速优化:采用动态 Beam Size 或并行化实现,减少计算负担。
  • 多样性增强:引入随机性(如 Top-K 或采样)来生成更具多样性的候选序列。
  • 混合方法:结合贪婪搜索和 Beam Search,平衡速度与质量。


Encoder中,主要使用了DenseNet-B的架构。给出的can项目代码的encoder也使用到了 DenseNet-B

笔记图片

Decoder中,使用到了Transformer 解码器以及结构相似性计算模块(StructSim)等组件。其中对Transformer相关部分的定义调用了Transformer文件夹中封装的encoder 、decoder等。

点击查看 GPT对 Decoder.py 的详细介绍

decoder.py 的模块和主要原理

decoder.py 定义了 Decoder 模块,该模块实现了目标序列生成的核心功能,使用了 Transformer 解码器、结构相似性计算模块(StructSim)等组件。以下是各模块及其原理的详细解释:


1. 主要组件和功能

(1) LBR (Linear + BatchNorm + ReLU)
  • 模块定义:
    • 线性变换 + 层归一化(LayerNorm)+ 激活函数(ReLU)。
  • 功能:作为一种轻量的预处理单元,对输入进行简单变换,增强非线性表达能力。
(2) StructSim 和 StructSimOneDir
  • 模块定义:
    • StructSim 包含两个方向的相似性计算模块:从左到右(l2r)从右到左(r2l)
    • 每个方向使用 TransformerEncoder 提取表示,并通过 q + k 的方式计算结构相似性。
  • 功能:
    • 结构相似性计算:用来评估序列在目标方向上的一致性,确保生成序列结构合理。
    • 输出一个矩阵,表示目标序列内部的相似性分布。
(3) _build_transformer_decoder
  • 模块定义:
    • 基于 nn.TransformerDecoder,结合 AttentionRefinementModule(ARM)进一步增强解码器性能。
  • 功能:
    • ARM 模块:通过自监督(self-coverage)或交叉覆盖(cross-coverage)的方式,改善注意力权重的分布,从而生成更加准确的序列。
    • Transformer 解码器层:实现目标序列的生成逻辑。
(4) Decoder 类
  • 主要模块:
    • **词嵌入 (word_embed)**:将目标序列的单词 ID 映射为向量表示。
    • **位置编码 (pos_enc)**:通过 WordPosEnc 为嵌入添加位置信息。
    • **Transformer 解码器 (model)**:通过多层 Transformer 解码序列。
    • **输出投影 (proj)**:将解码结果映射回词汇表空间,输出每个单词的概率分布。
    • **结构相似性模块 (struct_sim)**:通过 StructSim 评估解码器的输出序列。

2. 主要原理

Decoder 模块采用 Transformer 解码器 的架构,结合目标任务需求,增加了结构相似性和注意力优化模块。以下是核心原理:

(1) Transformer 解码器
  • 解码器由多层堆叠的解码器层组成,每一层包括:
    • **自注意力机制 (Self-Attention)**:关注目标序列已生成部分的内部关系。
    • **交叉注意力机制 (Cross-Attention)**:关注输入特征(图片编码)与目标序列之间的关系。
    • **前馈网络 (Feedforward Neural Network)**:对注意力结果进一步提取非线性特征。
  • 解码器通过 mask 限制模型只能看到当前时间步之前的输出。
(2) ARM (Attention Refinement Module)
  • 通过调整注意力权重的分布,解决:
    • 自监督问题(自注意力的权重集中问题)。
    • 交叉覆盖问题(输入特征未充分利用问题)。
(3) 结构相似性
  • StructSim 模块引入两种相似性度量:
    • 从左到右生成的结构是否合理。
    • 从右到左生成的结构是否一致。
  • 通过 Transformer Encoder 学习表示,并通过加和查询 (q) 和键 (k) 的方式计算相似性。
(4) 目标序列生成过程
  1. 词嵌入和位置编码:目标序列转换为向量,并加入位置信息。
  2. Transformer 解码
    • 自注意力机制理解序列上下文。
    • 交叉注意力与图片特征交互。
  3. 输出预测:通过全连接层输出词汇表上的概率分布。
  4. 相似性评估:计算生成序列的结构一致性,提供额外的监督信号。

3. 函数解析

以下是模块主要函数及其功能:

Decoder.forward
  • 输入:图片特征(src)、目标序列(tgt)。
  • 输出:词汇表概率分布(out)和结构相似性(sim)。
  • 过程:
    1. 对目标序列 tgt 进行词嵌入和位置编码。
    2. 构造目标序列的掩码(tgt_masktgt_pad_mask)。
    3. 使用 Transformer 解码器处理 tgtsrc
    4. 计算目标序列的结构相似性。
    5. 投影到词汇表空间输出概率分布。
Decoder._build_attention_mask
  • 构造一个因果掩码,确保生成序列时只能看到当前时间步之前的内容。
Decoder.transform
  • 封装 forward,专用于生成任务的调用,接收图片特征和输入 ID。

4. 总结

decoder.py 实现了一个面向图文序列生成任务的解码器,核心特点包括:

  1. Transformer 架构:高效处理序列生成任务。
  2. ARM 和结构相似性模块:解决注意力分布不均和生成结构不一致问题。
  3. 多功能设计:支持训练和推理过程(如 beam search)。

某人偷懒没修改的pos_enc.py功能简介 这段代码定义了几种**位置编码(Positional Encoding)**方法,主要用于为输入特征添加位置信息,从而在模型中保留顺序或空间信息。这些编码方式被应用在不同类型的数据上,如序列数据(单词)或二维图像数据。

功能分类

  1. 基于正弦和余弦的绝对位置编码

    • **WordPosEnc**:
      • 实现了用于序列输入(如文本)的标准正弦和余弦位置编码。
      • 被用于给序列特征(形状 [b, l, d])添加位置感知信息。
    • **ImgPosEnc**:
      • 将绝对位置编码扩展到二维图像特征,处理形状 [b, h, w, d] 的输入。
      • 通过累积mask反转计算生成二维位置坐标,并添加正弦和余弦的编码值。
  2. 基于旋转变换的相对位置编码(Rotary Positional Embedding)

    • **WordRotaryEmbed**:
      • 应用于序列输入,结合正弦和余弦,通过旋转操作实现相对位置编码。
      • 提供更灵活的位置感知能力,适合更复杂的上下文关联建模。
    • **ImageRotaryEmbed**:
      • 处理二维图像特征,支持旋转的二维位置编码。
      • 类似于WordRotaryEmbed,但扩展到二维输入(形状 [b, h, w, d])。
  3. 工具函数

    • **rotate_every_two**:实现二维旋转,用于旋转嵌入的偶数和奇数维度的值。

在前述代码中的调用

这些位置编码模块被用在以下地方:

  • **Decoder 类中的 self.pos_enc**:
    • WordPosEnc 用于对目标序列(tgt)添加位置编码,以增强 Transformer 解码器对序列顺序的感知。
  • 可能的其他调用场景(未直接展示):
    • 对图像或序列数据进行位置编码(如ImgPosEncImageRotaryEmbed)可在编码器或解码器的预处理阶段使用,尤其是在 CAN 项目中对图像-文本跨模态任务可能需要的场景。

总结:这段代码提供了多种灵活的位置编码方案,可根据输入数据的维度和模型需求灵活选择合适的方法,以增强输入的位置信息表达能力。


对于transformer文件夹,里面包含了支持AttentionRefinementModule(注意力精细化模块)的transformer解码器,也包含MaskBatchNorm2d(自定义掩码批标准化)、AttentionRefinementModule(注意力精细化模块)、MultiheadAttention(多头注意力模块)的具体实现。

utils 文件夹是 “utilities”(工具)的缩写,通常在其中存放封装好的各个辅助功能或工具函数。

某人偷懒没修改的pos_enc.py功能简介


文章作者: Qijia Huang
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Qijia Huang !
  目录