1 函数目的
def get_batch(data: np.array, batch_size: int, sequence_length: int, device: str) -> torch.Tensor:
该函数的核心功能是从连续序列数据 data 中随机采样若干长度为 sequence_length 的片段,组成一个batch张量并传输到指定设备(通常是GPU)。这是语言模型、RNN、Transformer等模型训练中最常见的mini-batch采样逻辑。
2 随机采样机制
2.1 生成随机起始位置
start_indices = torch.randint(len(data) - sequence_length, (batch_size,))
这个代码:
- 在
[0, len(data) - sequence_length)范围内随机选取batch_size个起始位置; - 每个起点代表从数据中抽取一段连续的 token 子序列。
示例说明:
data = np.arange(100)
batch_size = 3
sequence_length = 5
# 可能采样到的 start_indices = [10, 27, 51]
那么每个样本的序列为:
data[10:15], data[27:32], data[51:56]
2.2 构建 batch 张量
x = torch.tensor([data[start:start + sequence_length] for start in start_indices])
这一步将多个样本拼成一个张量:
- 外层长度是 batch size;
- 每个样本长度是 sequence length。
最终形状为:
torch.Size([batch_size, sequence_length])
张量结构示例:
(batch_size=3, sequence_length=5)
[[data[10:15]],
[data[27:32]],
[data[51:56]]]
3 内存优化 pinned memory
3.1 默认内存类型的限制
“By default, CPU tensors are in paged memory.”
在操作系统中,CPU 内存默认是 分页内存(paged memory), 这意味着数据页可以被换出(swap)或重新映射。 但 GPU 无法直接访问这种内存,它只能从 锁页内存(pinned memory) 中快速 DMA 复制数据。
3.2 锁页内存(Pinned Memory)作用
if torch.cuda.is_available():
x = x.pin_memory()
这一步的作用是:
将 CPU 内存中的 tensor “锁定”到物理内存中,使 GPU 可以直接、异步读取。
这样做的好处是:
- 避免了 CPU → GPU 数据拷贝的系统开销;
- 可以实现 异步拷贝(non-blocking transfer), 实现non-blocking数据传输,提升并行效率
4 异步拷贝到 GPU
4.1 非阻塞拷贝
x = x.to(device, non_blocking=True)
这一步把数据从 CPU(pinned memory)传到 GPU。
non_blocking=True 表示:
- 不等待传输完成;
- 可以同时让 GPU 继续执行前一个 batch 的计算;
- CPU 可以同时加载下一个 batch。
4.2 并行流水线的概念
“This allows us to do two things in parallel (not done here): Fetch the next batch of data into CPU Process x on the GPU.”
该机制实现了CPU数据加载与GPU计算的并行流水线:
含义是:
- 当 GPU 正在训练当前 batch 时;
- CPU 可以 并行加载下一个 batch 的数据;
- 下一步训练时直接异步传给 GPU,无需等待。
这是 PyTorch DataLoader + pinned memory 的核心性能优化思路。
时间轴:
|--- GPU处理Batch N ---|--- GPU处理Batch N+1 ---|
|--- CPU加载Batch N+1 ---|--- CPU加载Batch N+2 ---|
5 性能优化
根据NVIDIA技术博客和PyTorch社区的实践经验:
- 内存分配策略
- 避免过度分配锁页内存(推荐不超过系统总内存的50%)
- 使用
torch.utils.data.DataLoader的pin_memory=True参数自动管理
- 批量传输优化
- 将多个小数据传输合并为单次大传输
- 对于二维数组,使用
cudaMemcpy2D()提升效率
- 性能监控工具
- 使用
nvprof或nsight systems分析数据传输瓶颈 - 通过
torch.cuda.Event精确测量传输耗时
- 使用
参考资料
- PyTorch GitHub Gist: Tricks to Speed Up Data Loading
- NVIDIA Technical Blog: How to Optimize Data Transfers in CUDA C/C++