当我们在sft的时候发生了什么

 

1. SFT 微调概述

SFT(Supervised Fine-Tuning)指的是在大模型预训练完成后,使用人工标注的 (prompt, response) 对,进行标准监督学习训练,
目的是引导模型输出符合特定任务要求的响应

输入/输出定义

  • 输入 ( x ):Prompt(指令)
  • 输出 ( y = (y_1, y_2, \ldots, y_T) ):Response 的 token 序列

2. SFT 训练流程(带公式)

2.1 目标函数(Objective)

SFT 训练的目标是:
最大化在已知输入 Prompt ( x ) 条件下,生成正确 Response ( y ) 的概率

换句话说,最小化负对数似然(Negative Log-Likelihood, NLL)损失

公式表示为:

[ \mathcal{L}(\theta) = -\mathbb{E}{(x, y) \sim D}\left[ \sum{t=1}^{T} \log P_\theta(y_t \mid y_{<t}, x) \right] ]

  • ( D ):训练数据集
  • ( \theta ):模型参数
  • ( y_{<t} ):当前时刻前的所有 token
  • ( P_\theta ):由当前模型参数给出的预测分布

直观理解:每一步生成正确 token 的概率越高,总 loss 越小。


2.2 输入处理(Tokenization)

将 prompt 和 response 合并成一个连续的 token 序列:

[ \text{input} = \text{tokenize}(\text{prompt} + \text{response}) ]

注意在 loss 计算时,只针对 response 部分进行监督。


2.3 Forward 过程(前向推理)

模型输出每个位置的 token 概率分布:

[ P_\theta(\cdot \mid x, y_{<t}) ]

每个位置上,根据预测分布与真实 token 对比,计算交叉熵损失(Cross Entropy Loss):

[ \text{CrossEntropy}(p, q) = - \sum_i p_i \log q_i ]

这里 ( p ) 是 one-hot 的 ground truth 分布,( q ) 是模型输出的预测分布。


2.4 Mask 策略(Loss Masking)

为了避免 prompt 部分干扰 loss 计算,使用一个 mask:

  • 对 prompt token,mask=0(不计算 loss)
  • 对 response token,mask=1(计算 loss)

最终实际 loss:

[ \mathcal{L}{\text{masked}} = \frac{\sum{t=1}^{T_{\text{total}}} \text{mask}t \times \text{CrossEntropy}(y_t, \hat{y}_t)}{\sum{t=1}^{T_{\text{total}}} \text{mask}_t} ]


2.5 优化与反向传播(Backward)

通过反向传播算法(Backpropagation),基于 loss 梯度更新参数 ( \theta )。

通常使用的优化器:

  • AdamW
  • 学习率调度器(如 Linear Warmup + Cosine Decay)

3. SFT 训练关键指标计算

训练过程中常监控以下指标:

3.1 Loss (损失)

定义:

[ \text{Loss} = \text{Masked Cross Entropy Loss} ]

代表模型在 token 预测上的整体误差。

  • 训练 loss 下降 → 说明模型拟合训练数据
  • 验证 loss 下降 → 说明模型具有泛化能力

3.2 Accuracy (准确率)

可以在 token 级别计算:

Token-level Accuracy:

[ \text{Accuracy} = \frac{\text{Number of Correct Tokens}}{\text{Total Number of Target Tokens}} ]

注意:通常也只对 response 部分 token 计算。


3.3 Perplexity (困惑度)

衡量语言模型对数据的拟合难度。

定义为:

[ \text{Perplexity} = \exp(\mathcal{L}) ]

其中 ( \mathcal{L} ) 是平均的 cross-entropy loss。

直观理解:

  • 困惑度越低,代表模型越确定地生成正确 token;
  • 理想值接近 1(实际中 10~30 左右为正常范围,取决于任务复杂度)。