跳转至

Knowledge Distillation: A Survey

约 2786 个字 预计阅读时间 14 分钟

一篇关于知识蒸馏领域的综述。

主要内容概览

知识蒸馏指的是将一个大型的、复杂的模型(教师模型)中提取的知识转移到一个较小的、简单的模型(学生模型)中。知识蒸馏的主要目标是提高学生模型的性能,使其在推理时更快、更高效。

  1. 按照教师模型中可提取的知识类型对知识蒸馏进行分类

    1. Response-based:直接使用教师模型的输出作为学生模型的目标。这里的输出包括了软标签(soft labels)和硬标签(hard labels)。软标签是指教师模型的输出概率分布,而硬标签是指真实标签。软标签包含了更多的信息,因此通常比硬标签更有效。

      • Response-based知识蒸馏是最基础的知识迁移方式
      • 通常使用KL散度或交叉熵损失来衡量学生模型输出与教师模型输出之间的差异
      • 引入温度参数(Temperature)对教师模型输出的概率分布进行"软化",使其包含更多信息
      • 例如:在分类任务中,教师可能对某个图像有90%确信是猫,8%是狗,2%是其他动物,这种分布比简单的"这是猫"标签包含更丰富的知识
    2. Feature-based:使用教师模型中间层的特征作为学生模型的目标。这里的特征可以是中间层的激活值,也可以是中间层的输出。

      • 假设教师模型通过多层处理提取了丰富的数据表示,这些中间特征包含有价值的知识
      • 实现方式包括在学生模型中添加额外的映射层,使其特征空间与教师模型对齐
      • 通常通过最小化学生特征与教师特征之间的距离(如L2距离、余弦相似度等)
      • 可以选择性地关注特定层或多层特征,不同层次的特征可能携带不同的抽象级别信息
    3. Relation-based:使用教师模型中间层的关系作为学生模型的目标。这里的关系可以是中间层之间的关系,也可以是中间层与输入之间的关系。具体来讲:

      • 关注的是模型内部结构化关系的迁移,而非单点特征或输出
      • 主要形式包括:
        • 样本关系迁移:保持批次内不同样本激活值之间的相似性关系
        • 特征关系迁移:保持同一层不同通道/神经元之间的关系结构
        • 层间关系迁移:捕获不同层之间的依赖关系
      • 常用方法包括相关系数矩阵迁移、注意力映射迁移、图结构知识迁移等
  2. 蒸馏方案:教师模型与学生模型有很多训练方式

    1. 离线蒸馏(Offline Distillation):教师模型和学生模型在不同的时间训练。教师模型先训练好,然后用其输出作为学生模型的训练目标。

      • 优点:可以充分利用教师模型的知识
      • 缺点:需要额外的存储空间来保存教师模型的输出
    2. 在线蒸馏(Online Distillation):教师模型和学生模型在同一时间训练。教师模型和学生模型共享参数,或者在同一批次中训练。

      • 优点:可以减少存储空间的需求
      • 缺点:需要更多的计算资源
    3. 自蒸馏(Self-Distillation):学生模型在训练过程中使用自己的输出作为目标。即学生模型在训练时既是教师模型也是学生模型。

  3. 蒸馏算法:有许多算法来实现在教师与学生中传递知识

    1. Adversarial Distillation:对抗蒸馏.利用GAN的思想,教师模型生成的知识加上噪声,学生模型通过对抗训练来学习教师模型的知识.

    2. Multi-Teacher Distillation:多教师蒸馏.使用多个教师模型来生成知识,学生模型通过融合多个教师模型的知识来学习.

    3. Cross-Modal Distillation:交叉模态蒸馏.在不同模态之间进行知识蒸馏,例如从图像到文本,或从文本到图像.

    4. Graph-Based Distillation:图蒸馏.使用图神经网络来进行知识蒸馏,通过图结构来表示教师模型和学生模型之间的关系.

    5. Attention-Based Distillation:注意力蒸馏.使用注意力机制来进行知识蒸馏,通过对教师模型的注意力权重进行蒸馏来提高学生模型的性能.

    6. Data-Free Distillation:无数据蒸馏.在没有训练数据的情况下进行知识蒸馏,通过生成合成数据来进行蒸馏.

    7. Quantized Distillation:量化蒸馏.对教师模型进行量化,即将模型的浮点数参数(通常是32位浮点数)转换为低精度表示(如8位整数或二值),以减少模型大小和推理时间。量化蒸馏将量化技术与知识蒸馏相结合:

      • 通过知识蒸馏来缓解量化过程中引入的精度损失
      • 教师模型通常是全精度模型,而学生模型为量化后的低精度版本
    8. Lifelong Distillation:终身蒸馏.在终身学习场景中进行知识蒸馏,教师模型在不断更新,学生模型需要不断适应新的知识.

    9. NAS-Based Distillation:基于神经架构搜索的蒸馏.使用神经架构搜索来自动化地设计教师模型和学生模型,通过搜索最优的蒸馏策略来提高学生模型的性能.

Response-Based Knowledge Distillation

在知识蒸馏中,传统的损失函数无法精确描述学生学习教师模型的程度,因此需要设计新的损失函数来更好地捕捉教师模型的知识。

于是,我们定义了一个新的损失函数,称为蒸馏损失函数,它结合了教师模型的输出和学生模型的输出。蒸馏损失函数可以表示为:

\[ L_{KD} = \alpha L_{hard} + (1 - \alpha) L_{soft} \]

其中,\(L_{hard}\)是学生模型的硬标签损失,\(L_{soft}\)是学生模型的软标签损失,\(\alpha\)是一个超参数,用于平衡两者的权重。

\(L_{hard}\)通常使用交叉熵损失函数来计算,而\(L_{soft}\)则使用KL散度来计算。具体来说,\(L_{hard}\)\(L_{soft}\)可以表示为:

\[ L_{hard} = -\sum_{i=1}^{N} y_i \log(p_i) \]
\[ L_{soft} = -\sum_{i=1}^{N} q_i \log(p_i) \]

其中,\(N\)是样本数量,\(y_i\)是样本\(i\)的真实标签,\(p_i\)是学生模型的输出概率分布,\(q_i\)是教师模型的输出概率分布。

对于概率分布,我们使用softmax函数来计算:

\[ p_i = \frac{e^{z_i / T}}{\sum_{j=1}^{N} e^{z_j / T}} \]

其中\(z_i\)是学生模型的logits,\(T\)是温度参数。温度参数用于控制softmax函数的平滑程度。当\(T=1\)时,softmax函数输出的是原始的概率分布;当\(T>1\)时,softmax函数输出的是平滑的概率分布;当\(T<1\)时,softmax函数输出的是尖锐的概率分布。

Feature-Based Knowledge Distillation

蒸馏损失函数为:

\[\begin{equation} L_{FeaD} \left( f_t (x), f_s(x) \right) = L_F \left( \Phi_t ( f_t (x)), \Phi_s( f_s(x)) \right) \end{equation}\]

其中:

  • \(f_t (x)\)\(f_s(x)\) 是教师和学生模型的中间层特征表示;

  • \(\Phi_t\)\(\Phi_s\) 是用于变换特征的函数,使教师和学生的特征维度匹配;

  • \(L_F (\cdot)\) 是衡量特征相似性的损失函数,如 L1、L2 或最大均值差异 (MMD) 损失。

Relation-Based Knowledge Distillation

关系蒸馏的核心思想是通过保持教师模型和学生模型之间的关系来进行知识蒸馏。我们可以使用以下损失函数来实现关系蒸馏: $$ L_{Rel} = \sum_{i=1}^{N} \left| R(f_t(x_i)) - R(f_s(x_i)) \right|_2^2 $$ 其中,\(R(\cdot)\)是一个函数,用于计算特征之间的关系。常用的关系包括欧几里得距离、余弦相似度等。

一些损失函数

Tip

\[ L_{2} = \left\| f_t(x) - f_s(x) \right\|^2 \]
\[ L_{cos} = 1 - \frac{f_t(x) \cdot f_s(x)}{\left\| f_t(x) \right\| \left\| f_s(x) \right\|} \]
\[ L_{MMD} = \left\| \frac{1}{N^2} \sum_{i=1}^{N} f_t(x_i) f_t(x_i)^T - \frac{1}{N^2} \sum_{i=1}^{N} f_s(x_i) f_s(x_i)^T \right\|^2 \]
\[ L_{KL} = \sum_{i=1}^{N} p_i \log \frac{p_i}{q_i} \]
\[ L_{CE} = -\sum_{i=1}^{N} y_i \log(p_i) \]
\[ L_{1} = \left\| f_t(x) - f_s(x) \right\| \]

知识蒸馏的一般过程

摘自另一篇论文

  1. Target Skill or Domain Steering Teacher LLM

    • 选择一个大型的、复杂的教师模型,通常是一个预训练的语言模型(LLM),如GPT-3、BERT等。

    • 该模型在特定领域或任务上表现良好,并且具有丰富的知识和能力。

    • 第一阶段涉及将教师LLM引导向特定目标技能或领域。这通过精心设计的指令或模板来实现,这些指令引导LLM的注意力。这些指令旨在引出展示LLM在特定领域的熟练程度的响应,无论是医疗保健或法律等专业领域,还是推理或语言理解等技能。

  2. Seed Knowledge as Input

    • 一旦确定了目标领域,下一步是向教师LLM提供种子知识。

    • 这种种子知识通常包括与目标技能或领域知识相关的小型数据集或特定数据线索,用于引导教师LLM生成响应。

    • 种子知识充当催化剂,促使教师LLM基于这些初始信息生成更加详细和深入的输出。

    • 种子知识至关重要,因为它提供了一个基础,教师模型可以在此基础上构建和扩展,从而创建更全面、更深入的知识示例。

  3. Generation of Distillation Knowledge

    • 在响应种子知识和引导指令后,教师LLM生成知识示例。

    • 这些示例主要以问答(QA)对话或叙述性解释的形式呈现,与LLM的自然语言处理/理解能力相符。

    • 在某些特殊情况下,输出也可能包括logits或隐藏特征

  4. Training the Student Model with a Specific Learning Objective

    • 最后一个阶段涉及利用生成的知识示例来训练学生模型。

    • 这种训练由与学习目标相符的损失函数指导。损失函数量化了学生模型在复制或适应来自教师模型知识方面的表现。

    • 通过最小化这种损失,学生模型学习模仿教师的目标技能或领域知识,从而获得类似的能力。

    • 这个过程涉及不断调整学生模型的参数,以减少其输出与教师模型输出之间的差异,确保知识的有效转移。

评论