get_batch:从采样到GPU传输

 

1 函数目的

def get_batch(data: np.array, batch_size: int, sequence_length: int, device: str) -> torch.Tensor:

该函数的核心功能是从连续序列数据 data 中随机采样若干长度为 sequence_length 的片段,组成一个batch张量并传输到指定设备(通常是GPU)。这是语言模型、RNN、Transformer等模型训练中最常见的mini-batch采样逻辑。

2 随机采样机制

2.1 生成随机起始位置

start_indices = torch.randint(len(data) - sequence_length, (batch_size,))

这个代码:

  • [0, len(data) - sequence_length) 范围内随机选取 batch_size 个起始位置;
  • 每个起点代表从数据中抽取一段连续的 token 子序列。

示例说明:

data = np.arange(100)
batch_size = 3
sequence_length = 5
# 可能采样到的 start_indices = [10, 27, 51]

那么每个样本的序列为:

data[10:15], data[27:32], data[51:56]

2.2 构建 batch 张量

x = torch.tensor([data[start:start + sequence_length] for start in start_indices])

这一步将多个样本拼成一个张量:

  • 外层长度是 batch size;
  • 每个样本长度是 sequence length。

最终形状为:

torch.Size([batch_size, sequence_length])

张量结构示例:

(batch_size=3, sequence_length=5)
[[data[10:15]],
 [data[27:32]],
 [data[51:56]]]

3 内存优化 pinned memory

3.1 默认内存类型的限制

“By default, CPU tensors are in paged memory.”

在操作系统中,CPU 内存默认是 分页内存(paged memory), 这意味着数据页可以被换出(swap)或重新映射。 但 GPU 无法直接访问这种内存,它只能从 锁页内存(pinned memory) 中快速 DMA 复制数据。

3.2 锁页内存(Pinned Memory)作用

if torch.cuda.is_available():
    x = x.pin_memory()

这一步的作用是:

将 CPU 内存中的 tensor “锁定”到物理内存中,使 GPU 可以直接、异步读取。

这样做的好处是:

  • 避免了 CPU → GPU 数据拷贝的系统开销;
  • 可以实现 异步拷贝(non-blocking transfer), 实现non-blocking数据传输,提升并行效率

4 异步拷贝到 GPU

4.1 非阻塞拷贝

x = x.to(device, non_blocking=True)

这一步把数据从 CPU(pinned memory)传到 GPU。

non_blocking=True 表示:

  • 不等待传输完成;
  • 可以同时让 GPU 继续执行前一个 batch 的计算;
  • CPU 可以同时加载下一个 batch。

4.2 并行流水线的概念

“This allows us to do two things in parallel (not done here): Fetch the next batch of data into CPU Process x on the GPU.”

该机制实现了CPU数据加载与GPU计算的并行流水线:

含义是:

  • 当 GPU 正在训练当前 batch 时;
  • CPU 可以 并行加载下一个 batch 的数据
  • 下一步训练时直接异步传给 GPU,无需等待。

这是 PyTorch DataLoader + pinned memory 的核心性能优化思路。

时间轴:
|--- GPU处理Batch N ---|--- GPU处理Batch N+1 ---|
     |--- CPU加载Batch N+1 ---|--- CPU加载Batch N+2 ---|

5 性能优化

根据NVIDIA技术博客和PyTorch社区的实践经验:

  • 内存分配策略
    • 避免过度分配锁页内存(推荐不超过系统总内存的50%)
    • 使用 torch.utils.data.DataLoaderpin_memory=True 参数自动管理
  • 批量传输优化
    • 将多个小数据传输合并为单次大传输
    • 对于二维数组,使用 cudaMemcpy2D() 提升效率
  • 性能监控工具
    • 使用 nvprofnsight systems 分析数据传输瓶颈
    • 通过 torch.cuda.Event 精确测量传输耗时

参考资料

  1. PyTorch GitHub Gist: Tricks to Speed Up Data Loading
  2. NVIDIA Technical Blog: How to Optimize Data Transfers in CUDA C/C++