1. SFT 微调概述
SFT(Supervised Fine-Tuning)指的是在大模型预训练完成后,使用人工标注的 (prompt, response) 对,进行标准监督学习训练,
目的是引导模型输出符合特定任务要求的响应。
输入/输出定义
- 输入 ( x ):Prompt(指令)
- 输出 ( y = (y_1, y_2, \ldots, y_T) ):Response 的 token 序列
2. SFT 训练流程(带公式)
2.1 目标函数(Objective)
SFT 训练的目标是:
最大化在已知输入 Prompt ( x ) 条件下,生成正确 Response ( y ) 的概率。
换句话说,最小化负对数似然(Negative Log-Likelihood, NLL)损失。
公式表示为:
[ \mathcal{L}(\theta) = -\mathbb{E}{(x, y) \sim D}\left[ \sum{t=1}^{T} \log P_\theta(y_t \mid y_{<t}, x) \right] ]
- ( D ):训练数据集
- ( \theta ):模型参数
- ( y_{<t} ):当前时刻前的所有 token
- ( P_\theta ):由当前模型参数给出的预测分布
直观理解:每一步生成正确 token 的概率越高,总 loss 越小。
2.2 输入处理(Tokenization)
将 prompt 和 response 合并成一个连续的 token 序列:
[ \text{input} = \text{tokenize}(\text{prompt} + \text{response}) ]
注意在 loss 计算时,只针对 response 部分进行监督。
2.3 Forward 过程(前向推理)
模型输出每个位置的 token 概率分布:
[ P_\theta(\cdot \mid x, y_{<t}) ]
每个位置上,根据预测分布与真实 token 对比,计算交叉熵损失(Cross Entropy Loss):
[ \text{CrossEntropy}(p, q) = - \sum_i p_i \log q_i ]
这里 ( p ) 是 one-hot 的 ground truth 分布,( q ) 是模型输出的预测分布。
2.4 Mask 策略(Loss Masking)
为了避免 prompt 部分干扰 loss 计算,使用一个 mask:
- 对 prompt token,mask=0(不计算 loss)
- 对 response token,mask=1(计算 loss)
最终实际 loss:
[ \mathcal{L}{\text{masked}} = \frac{\sum{t=1}^{T_{\text{total}}} \text{mask}t \times \text{CrossEntropy}(y_t, \hat{y}_t)}{\sum{t=1}^{T_{\text{total}}} \text{mask}_t} ]
2.5 优化与反向传播(Backward)
通过反向传播算法(Backpropagation),基于 loss 梯度更新参数 ( \theta )。
通常使用的优化器:
- AdamW
- 学习率调度器(如 Linear Warmup + Cosine Decay)
3. SFT 训练关键指标计算
训练过程中常监控以下指标:
3.1 Loss (损失)
定义:
[ \text{Loss} = \text{Masked Cross Entropy Loss} ]
代表模型在 token 预测上的整体误差。
- 训练 loss 下降 → 说明模型拟合训练数据
- 验证 loss 下降 → 说明模型具有泛化能力
3.2 Accuracy (准确率)
可以在 token 级别计算:
Token-level Accuracy:
[ \text{Accuracy} = \frac{\text{Number of Correct Tokens}}{\text{Total Number of Target Tokens}} ]
注意:通常也只对 response 部分 token 计算。
3.3 Perplexity (困惑度)
衡量语言模型对数据的拟合难度。
定义为:
[ \text{Perplexity} = \exp(\mathcal{L}) ]
其中 ( \mathcal{L} ) 是平均的 cross-entropy loss。
直观理解:
- 困惑度越低,代表模型越确定地生成正确 token;
- 理想值接近 1(实际中 10~30 左右为正常范围,取决于任务复杂度)。