论文代码阅读——手打版
论文链接:
github仓库链接:
https://github.com/qingzhenduyu/TAMER
这里会更多地关注该项目用到的模型架构和训练流程。
从 tamer
文件夹中的 lit_tamer.py
开始阅读。
在 class LitTAMER(pl.LightningModule)
中,定义了模型为 self.tamer_model = TAMER(...)
,并定义了 self.exprate_recorder = ExpRateRecorder()
。
TAMER 模型(tamer.py
)由 Encoder(encoder.py)
和 Decoder(decoder.py)
组成。在定义的 beam_search
函数中,除了 self.encoder
,还调用了 self.decoder.beam_search
。
关于 beam_search
,下面是 GPT 对它的介绍:
点击查看 Beam Search 的详细介绍
Beam Search 的原理与方法
1. Beam Search 是什么?
Beam Search 是一种序列解码算法,通常用于自然语言生成任务(如机器翻译、文本摘要、图像字幕生成等)。它在逐步生成序列时,不是只选择一个最优候选,而是维护一个固定大小的候选集合,从中寻找全局最优的生成结果。
2. 核心思想
扩展搜索树:
将序列生成问题看作树的扩展,每次扩展节点时,基于模型预测的概率,选择前几个(如 beam size = 3)概率最高的候选。权衡搜索空间和效率:
相比贪婪搜索(一次只保留最优路径),Beam Search 能探索更多路径,避免局部最优;相比暴力穷举(尝试所有可能路径),它限制了搜索宽度,提升效率。
3. 工作流程
以下是逐步生成序列的过程:
初始化:
- 起始标记(如
<SOS>
)作为初始序列,分数为 0。 - 初始化候选集合(称为 Beam),大小为
beam_size
。
- 起始标记(如
扩展 Beam:
- 对当前 Beam 中的每个序列,用模型预测下一个词的概率分布。
- 将每个候选序列与其可能的扩展(即词)组合,计算总得分。
- 保留总得分最高的
beam_size
个候选序列。
终止条件:
- 所有候选序列都生成了终止标记
<EOS>
。 - 或生成的序列长度达到预设最大值
max_len
。
- 所有候选序列都生成了终止标记
输出结果:
- 从最终的候选集合中,选择分数最高的序列作为输出。
4. 关键公式
总得分计算:
[
\text{Score} = \frac{\sum_{t=1}^L \log P(y_t | y_{1:t-1})}{L^\alpha}
]
- ( L ):序列长度。
- ( \alpha ):长度惩罚系数(( \alpha > 0 ) 时,惩罚短序列)。
长度惩罚的作用:
- 防止模型过于偏向短序列(短序列的得分累积更高)。
- 调整模型对不同长度的偏好。
5. 优势与不足
优势:
- 在可接受的计算开销下,生成质量较高的序列。
- 能发现更优解,避免陷入局部最优。
不足:
- 生成的序列仍可能不是全局最优(近似搜索)。
- Beam Size 过大会显著增加计算量。
6. 与 TAMER 的关系
在 TAMER 中:
Encoder:
提取图像特征,作为解码器的上下文信息。Decoder 的 Beam Search:
- 使用解码器
beam_search
,根据编码器输出的特征,逐步生成序列。 - 每个时间步中扩展候选序列,保留
beam_size
个得分最高的序列,直到生成完成。
- 使用解码器
7. Beam Search 的改进方向
- 加速优化:采用动态 Beam Size 或并行化实现,减少计算负担。
- 多样性增强:引入随机性(如 Top-K 或采样)来生成更具多样性的候选序列。
- 混合方法:结合贪婪搜索和 Beam Search,平衡速度与质量。
Encoder中,主要使用了DenseNet-B
的架构。给出的can项目代码的encoder也使用到了 DenseNet-B
。
笔记图片
Decoder中,使用到了Transformer 解码器
以及结构相似性计算模块(StructSim
)等组件。其中对Transformer相关部分的定义调用了Transformer
文件夹中封装的encoder 、decoder等。
点击查看 GPT对 Decoder.py 的详细介绍
decoder.py
的模块和主要原理
decoder.py
定义了 Decoder 模块,该模块实现了目标序列生成的核心功能,使用了 Transformer 解码器、结构相似性计算模块(StructSim
)等组件。以下是各模块及其原理的详细解释:
1. 主要组件和功能
(1) LBR (Linear + BatchNorm + ReLU)
- 模块定义:
- 线性变换 + 层归一化(LayerNorm)+ 激活函数(ReLU)。
- 功能:作为一种轻量的预处理单元,对输入进行简单变换,增强非线性表达能力。
(2) StructSim 和 StructSimOneDir
- 模块定义:
StructSim
包含两个方向的相似性计算模块:从左到右(l2r)和从右到左(r2l)。- 每个方向使用
TransformerEncoder
提取表示,并通过q + k
的方式计算结构相似性。
- 功能:
- 结构相似性计算:用来评估序列在目标方向上的一致性,确保生成序列结构合理。
- 输出一个矩阵,表示目标序列内部的相似性分布。
(3) _build_transformer_decoder
- 模块定义:
- 基于
nn.TransformerDecoder
,结合AttentionRefinementModule
(ARM)进一步增强解码器性能。
- 基于
- 功能:
- ARM 模块:通过自监督(self-coverage)或交叉覆盖(cross-coverage)的方式,改善注意力权重的分布,从而生成更加准确的序列。
- Transformer 解码器层:实现目标序列的生成逻辑。
(4) Decoder 类
- 主要模块:
- **词嵌入 (word_embed)**:将目标序列的单词 ID 映射为向量表示。
- **位置编码 (pos_enc)**:通过
WordPosEnc
为嵌入添加位置信息。 - **Transformer 解码器 (model)**:通过多层 Transformer 解码序列。
- **输出投影 (proj)**:将解码结果映射回词汇表空间,输出每个单词的概率分布。
- **结构相似性模块 (struct_sim)**:通过
StructSim
评估解码器的输出序列。
2. 主要原理
Decoder 模块采用 Transformer 解码器 的架构,结合目标任务需求,增加了结构相似性和注意力优化模块。以下是核心原理:
(1) Transformer 解码器
- 解码器由多层堆叠的解码器层组成,每一层包括:
- **自注意力机制 (Self-Attention)**:关注目标序列已生成部分的内部关系。
- **交叉注意力机制 (Cross-Attention)**:关注输入特征(图片编码)与目标序列之间的关系。
- **前馈网络 (Feedforward Neural Network)**:对注意力结果进一步提取非线性特征。
- 解码器通过 mask 限制模型只能看到当前时间步之前的输出。
(2) ARM (Attention Refinement Module)
- 通过调整注意力权重的分布,解决:
- 自监督问题(自注意力的权重集中问题)。
- 交叉覆盖问题(输入特征未充分利用问题)。
(3) 结构相似性
StructSim
模块引入两种相似性度量:- 从左到右生成的结构是否合理。
- 从右到左生成的结构是否一致。
- 通过 Transformer Encoder 学习表示,并通过加和查询 (
q
) 和键 (k
) 的方式计算相似性。
(4) 目标序列生成过程
- 词嵌入和位置编码:目标序列转换为向量,并加入位置信息。
- Transformer 解码:
- 自注意力机制理解序列上下文。
- 交叉注意力与图片特征交互。
- 输出预测:通过全连接层输出词汇表上的概率分布。
- 相似性评估:计算生成序列的结构一致性,提供额外的监督信号。
3. 函数解析
以下是模块主要函数及其功能:
Decoder.forward
- 输入:图片特征(
src
)、目标序列(tgt
)。 - 输出:词汇表概率分布(
out
)和结构相似性(sim
)。 - 过程:
- 对目标序列
tgt
进行词嵌入和位置编码。 - 构造目标序列的掩码(
tgt_mask
和tgt_pad_mask
)。 - 使用 Transformer 解码器处理
tgt
和src
。 - 计算目标序列的结构相似性。
- 投影到词汇表空间输出概率分布。
- 对目标序列
Decoder._build_attention_mask
- 构造一个因果掩码,确保生成序列时只能看到当前时间步之前的内容。
Decoder.transform
- 封装
forward
,专用于生成任务的调用,接收图片特征和输入 ID。
4. 总结
decoder.py
实现了一个面向图文序列生成任务的解码器,核心特点包括:
- Transformer 架构:高效处理序列生成任务。
- ARM 和结构相似性模块:解决注意力分布不均和生成结构不一致问题。
- 多功能设计:支持训练和推理过程(如
beam search
)。
某人偷懒没修改的pos_enc.py功能简介
这段代码定义了几种**位置编码(Positional Encoding)**方法,主要用于为输入特征添加位置信息,从而在模型中保留顺序或空间信息。这些编码方式被应用在不同类型的数据上,如序列数据(单词)或二维图像数据。功能分类
基于正弦和余弦的绝对位置编码:
- **
WordPosEnc
**:- 实现了用于序列输入(如文本)的标准正弦和余弦位置编码。
- 被用于给序列特征(形状
[b, l, d]
)添加位置感知信息。
- **
ImgPosEnc
**:- 将绝对位置编码扩展到二维图像特征,处理形状
[b, h, w, d]
的输入。 - 通过累积mask反转计算生成二维位置坐标,并添加正弦和余弦的编码值。
- 将绝对位置编码扩展到二维图像特征,处理形状
- **
基于旋转变换的相对位置编码(Rotary Positional Embedding):
- **
WordRotaryEmbed
**:- 应用于序列输入,结合正弦和余弦,通过旋转操作实现相对位置编码。
- 提供更灵活的位置感知能力,适合更复杂的上下文关联建模。
- **
ImageRotaryEmbed
**:- 处理二维图像特征,支持旋转的二维位置编码。
- 类似于
WordRotaryEmbed
,但扩展到二维输入(形状[b, h, w, d]
)。
- **
工具函数:
- **
rotate_every_two
**:实现二维旋转,用于旋转嵌入的偶数和奇数维度的值。
- **
在前述代码中的调用
这些位置编码模块被用在以下地方:
- **
Decoder
类中的self.pos_enc
**:WordPosEnc
用于对目标序列(tgt
)添加位置编码,以增强 Transformer 解码器对序列顺序的感知。
- 可能的其他调用场景(未直接展示):
- 对图像或序列数据进行位置编码(如
ImgPosEnc
或ImageRotaryEmbed
)可在编码器或解码器的预处理阶段使用,尤其是在 CAN 项目中对图像-文本跨模态任务可能需要的场景。
- 对图像或序列数据进行位置编码(如
总结:这段代码提供了多种灵活的位置编码方案,可根据输入数据的维度和模型需求灵活选择合适的方法,以增强输入的位置信息表达能力。
对于transformer文件夹,里面包含了支持AttentionRefinementModule(注意力精细化模块)的transformer解码器,也包含MaskBatchNorm2d(自定义掩码批标准化)、AttentionRefinementModule(注意力精细化模块)、MultiheadAttention(多头注意力模块)的具体实现。
utils
文件夹是 “utilities”(工具)的缩写,通常在其中存放封装好的各个辅助功能或工具函数。