Table of Contents

在 PyTorch 中,数据集的划分是模型训练前的重要步骤。我们可以使用 torch.utils.data.random_splittorch.utils.data.Subset 来将总的 Dataset 划分为 训练集测试集

以下是三种常用的方法:

方法 1:使用 random_split Link to 方法 1:使用 random_split

这是最简单的方法,适用于 随机划分 数据集。

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torch.utils.data import random_split

# 假设 all_data_set 是你已经加载好的数据集
# dataset_size = len(all_data_set)

# 定义划分比例
train_ratio = 0.8  # 80% 训练集
test_ratio = 1 - train_ratio

# 计算训练集和测试集的样本数
dataset_size = len(all_data_set)
train_size = int(train_ratio * dataset_size)
test_size = dataset_size - train_size  # 确保总数不变

# 随机划分数据集
train_dataset, test_dataset = random_split(all_data_set, [train_size, test_size])

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
  • 优点:简单快捷,原生支持。
  • 缺点:每次运行都会重新随机划分。如果需要结果可复现,建议在 random_split() 之前设置随机种子 torch.manual_seed(seed)

方法 2:使用 Subset 指定索引 Link to 方法 2:使用 Subset 指定索引

适用于 自定义划分 数据集,例如需要固定前 80% 为训练集,后 20% 为测试集(非随机)。

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.utils.data import Subset

# 获取数据集总长度
dataset_size = len(all_data_set)
train_size = int(0.8 * dataset_size)  # 80% 训练集
indices = list(range(dataset_size))   # 生成所有样本的索引

# 训练集和测试集的索引
train_indices = indices[:train_size]
test_indices = indices[train_size:]

# 使用 Subset 划分数据集
train_dataset = Subset(all_data_set, train_indices)
test_dataset = Subset(all_data_set, test_indices)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
  • 优点:适用于有序数据(如时间序列),或需要手动控制样本分布的场景。
  • 缺点:如果数据本身有顺序(如按类别排序),直接截断可能导致训练集和测试集分布不一致。

方法 3:使用 train_test_split (sklearn) Link to 方法 3:使用 train_test_split (sklearn)

如果你希望 更加可控地随机 划分训练集和测试集(例如分层采样),可以使用 sklearn.model_selection.train_test_split

PYTHON
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

# 获取数据集的所有索引
dataset_size = len(all_data_set)
indices = list(range(dataset_size))

# 使用 sklearn 进行随机划分
# random_state=42 保证每次划分结果一致
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)

# 创建训练集和测试集
train_dataset = Subset(all_data_set, train_indices)
test_dataset = Subset(all_data_set, test_indices)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
  • 优点:支持 random_state 保证复现性,且支持 stratify 参数进行分层采样(保证各类别比例一致)。
  • 缺点:需要安装 scikit-learn 库。

总结 Link to 总结

方法适用场景是否随机代码简洁性备注
random_split通用随机划分✅ 是⭐⭐⭐推荐用于大多数简单任务
Subset按顺序/指定索引划分❌ 否⭐⭐⭐适合时序数据或手动控制
train_test_split需要复现或分层采样✅ 是⭐⭐功能最强大,依赖 sklearn

推荐建议:

  • 如果只是简单随机划分,首选 random_split
  • 如果希望划分结果严格可复现,推荐使用 train_test_split 配合 Subset

注意事项 Link to 注意事项

  • 随机复现性:使用 torch.manual_seed(seed)random_state 保持划分一致
  • 数据均衡:分类任务建议使用分层采样(train_test_split(..., stratify=labels)
  • 时间序列:不要打乱时间顺序,优先用 Subset 指定索引
  • 数据泄漏:确保测试集不参与任何归一化/特征选择的拟合过程
  • 交叉验证:需要更稳健评估时使用 KFold/StratifiedKFoldsklearnGroupKFold

参考资源 Link to 参考资源

Thanks for reading!

PyTorch 中划分训练集和测试集的方法

Fri Mar 28 2025
1043 words · 4 minutes