Table of Contents
在 PyTorch 中,数据集的划分是模型训练前的重要步骤。我们可以使用 torch.utils.data.random_split 或 torch.utils.data.Subset 来将总的 Dataset 划分为 训练集 和 测试集。
以下是三种常用的方法:
方法 1:使用 random_split Link to 方法 1:使用 random_split
这是最简单的方法,适用于 随机划分 数据集。
PYTHON
12345678910111213141516171819from torch.utils.data import random_split
# 假设 all_data_set 是你已经加载好的数据集
# dataset_size = len(all_data_set)
# 定义划分比例
train_ratio = 0.8 # 80% 训练集
test_ratio = 1 - train_ratio
# 计算训练集和测试集的样本数
dataset_size = len(all_data_set)
train_size = int(train_ratio * dataset_size)
test_size = dataset_size - train_size # 确保总数不变
# 随机划分数据集
train_dataset, test_dataset = random_split(all_data_set, [train_size, test_size])
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
- 优点:简单快捷,原生支持。
- 缺点:每次运行都会重新随机划分。如果需要结果可复现,建议在
random_split()之前设置随机种子torch.manual_seed(seed)。
方法 2:使用 Subset 指定索引 Link to 方法 2:使用 Subset 指定索引
适用于 自定义划分 数据集,例如需要固定前 80% 为训练集,后 20% 为测试集(非随机)。
PYTHON
1234567891011121314151617from torch.utils.data import Subset
# 获取数据集总长度
dataset_size = len(all_data_set)
train_size = int(0.8 * dataset_size) # 80% 训练集
indices = list(range(dataset_size)) # 生成所有样本的索引
# 训练集和测试集的索引
train_indices = indices[:train_size]
test_indices = indices[train_size:]
# 使用 Subset 划分数据集
train_dataset = Subset(all_data_set, train_indices)
test_dataset = Subset(all_data_set, test_indices)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
- 优点:适用于有序数据(如时间序列),或需要手动控制样本分布的场景。
- 缺点:如果数据本身有顺序(如按类别排序),直接截断可能导致训练集和测试集分布不一致。
方法 3:使用 train_test_split (sklearn) Link to 方法 3:使用 train_test_split (sklearn)
如果你希望 更加可控地随机 划分训练集和测试集(例如分层采样),可以使用 sklearn.model_selection.train_test_split。
PYTHON
1234567891011121314151617from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
# 获取数据集的所有索引
dataset_size = len(all_data_set)
indices = list(range(dataset_size))
# 使用 sklearn 进行随机划分
# random_state=42 保证每次划分结果一致
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
# 创建训练集和测试集
train_dataset = Subset(all_data_set, train_indices)
test_dataset = Subset(all_data_set, test_indices)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
- 优点:支持
random_state保证复现性,且支持stratify参数进行分层采样(保证各类别比例一致)。 - 缺点:需要安装
scikit-learn库。
总结 Link to 总结
| 方法 | 适用场景 | 是否随机 | 代码简洁性 | 备注 |
|---|---|---|---|---|
random_split | 通用随机划分 | ✅ 是 | ⭐⭐⭐ | 推荐用于大多数简单任务 |
Subset | 按顺序/指定索引划分 | ❌ 否 | ⭐⭐⭐ | 适合时序数据或手动控制 |
train_test_split | 需要复现或分层采样 | ✅ 是 | ⭐⭐ | 功能最强大,依赖 sklearn |
推荐建议:
- 如果只是简单随机划分,首选
random_split。 - 如果希望划分结果严格可复现,推荐使用
train_test_split配合Subset。
注意事项 Link to 注意事项
- 随机复现性:使用
torch.manual_seed(seed)或random_state保持划分一致 - 数据均衡:分类任务建议使用分层采样(
train_test_split(..., stratify=labels)) - 时间序列:不要打乱时间顺序,优先用
Subset指定索引 - 数据泄漏:确保测试集不参与任何归一化/特征选择的拟合过程
- 交叉验证:需要更稳健评估时使用
KFold/StratifiedKFold或sklearn的GroupKFold
参考资源 Link to 参考资源
- PyTorch DataLoader 与 Dataset: https://pytorch.org/docs/stable/data.html
- sklearn train_test_split: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
- 交叉验证指南: https://scikit-learn.org/stable/modules/cross_validation.html
Thanks for reading!
PyTorch 中划分训练集和测试集的方法
© EveSunMaple | CC BY-SA 4.0