avatar

【Pytorch基础】加载数据集

回顾

 上一篇训练神经网络是用的是批梯度下降,容易陷入鞍点中。Pytorch 提供了一个数据集加载工具,使得我们可以方便地用小批量随机梯度下降来训练网络。其包含两个部分:

  • Dataset: 用于构造数据集(支持索引)
  • DataLoader: 每次拿出一个Mini-Batch用于训练更新

Epoch,Batch-Size,Iterations 概念释义

  • Epoch: 表示一个训练周期,所有样本都进行一次前馈、反馈计算
  • Batch-Size: 表示一个Mini-Batch包含的样本数量,即每次训练(一次更新)时用到的样本数量
  • Iterations: 全部样本被划分的Mini-Batch的数量,如1000个样本,Batch-Size=100,那么Iteration=10
1
2
3
4
5
# 训练循环
for epoch in range(trainning_epochs):
# 用所有Mini-Batch训练
for i in range(total_batch): # 执行Iteration次
pass

DataLoader

1
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

参数释义:

  • dataset (Dataset) – dataset from which to load the data.
  • batch_size (int, optional) – how many samples per batch to load (default: 1).
  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
  • sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified, shuffle must be False.
  • batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
  • num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  • collate_fn (callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
  • pin_memory (bool, optional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
  • drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
  • timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
  • worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

假设batch_dize=2,shuffle=True,经DataLoader过程如下:

从左至右先打乱样本顺序最终得到一个可迭代的Loader,每次迭代将(yield)产生一个Mini-Batch用于训练网络。

Dataset

 Dataset 是一个抽象类,无法被实例化。只能被其子类继承,再实例化。因此,若要实例化Dataset我们必须自己写一个类来继承自它。其结构大致为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torch.utils.data import Dataset # Dataset 是一个抽象类,不能实例化
from torch.utils.data import DataLoader

class MyDataset(Dataset):
def __init__(self, filepath):
# 加载数据集,如csv文件等,有两种方式:
# 1. All in: 将所有数据加载到内存 (适用于数据集不大的情况)
# 2. 如果数据集很大,可以分割成内存允许大小的文件,用一个列表放文件名,然后训练时用getitem函数时在将其读取到内存中
pass

def __getitem__(self,index): # 使对象支持下标操作 dataset[index]
pass

def __len__(self): # 返回数据集中的样本数
pass

实例化数据集对象:

1
2
dataset = DiabetesDataset(filepath)
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

糖尿病数据集

加载数据集:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import numpy as np
from torch.utils.data import Dataset # Dataset 是一个抽象类,不能实例化
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset):
def __init__(self, filepath):
# 数据集很小,直接加载进内存
xy = np.loadtxt(filepath,delimiter=',',dtype=np.float32)
self.len = xy.shape[0] # 取 行 数(样本数)
self.x_data = torch.from_numpy(xy[:,:-1]) # 所有行,最后一列不要
self.y_data = torch.from_numpy(xy[:,[-1]]) # 所有行,只要最后一列,-1加[]表示拿出来一个矩阵,而不是向量

def __getitem__(self,index):
return self.x_data[index], self.y_data[index] # 返回一个元组(x,y)

def __len__(self):
return self.len

dataset = DiabetesDataset('https://project-preview-1257022783.cos.ap-chengdu.myqcloud.com/diabetes.csv.gz')
train_loader = DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=2)

训练:

1
2
3
4
5
6
7
8
9
10
11
12
13
for epoch in range(100):
for i, data in enumerate(train_loader,0): # for i, (data,labels) in enumerate(train_loader,0):
# 准备数据
inputs, labels = data # 自动转换成Tensor
# 前馈计算
y_pred = model(inputs)
loss = criterion(y_pred,labels)
#print(epoch, i, loss.item())
# 反向传播
optimizer.zero_grad()
loss.backward()
# 更新参数
optimizer.step()

Pytorchvision 库内置的数据集

  • MNIST
  • Fashion-MNIST
  • EMNIST
  • COCO
  • LSUN
  • ImageFolder
  • DatasetFolder
  • Imagenet-12
  • CIFAR
  • STL10
  • PhotoTour

这些数据集都继承与 torch.utils.data.Dataset,都具有getitem和len函数的实现,可以直接用torch.utils.data.DataLoader进行加载。

引入方式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets

# Mnist
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets

# FashionMnist
train_dataset = torchvision.datasets.FashionMNIST(root='./dataset/fmnist/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='./dataset/fmnist/',train=False,transform=transforms.ToTensor(),download=True)

train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False) # 测试不需要shuffle打乱顺序,保证结果的顺序
# 训练
for epoch in range(epoch_trainning):
for batch_idx, (inputs, target) in enumerate(train_loader,0):
pass # 前馈、反馈计算
文章作者: Liam
文章链接: https://www.ccyh.xyz/p/993f.html
版权声明: 本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 Liam's Blog
ღ喜欢记得五星好评哦~
打赏
  • 微信
    微信
  • 支付寶
    支付寶

评论