# 需求
如图所示,想通过 Pytorch 加载自己的数据集,每一张图片后面有三个标签。该如何实现呢?看下面
# 代码
from PIL import Image | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
from matplotlib import pyplot as plt | |
# 创建自己的类:MyDataset, 这个类是继承的 torch.utils.data.Dataset | |
class MyDataset(torch.utils.data.Dataset): | |
def __init__(self, filepath, transform=None, target_transform=None): | |
super(MyDataset, self).__init__() | |
# 按照传入的路径打开这个文本,并读取内容 | |
fh = open(filepath, 'r') | |
imgs = [] | |
for line in fh: # 按行循环 txt 文本中的内容 | |
line = line.rstrip() # 删除 本行 string 字符串末尾的指定字符,这个方法的详细介绍自己查询 python | |
words = line.split() # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等 | |
# 三分类问题,所以 label 有三个 | |
imgs.append((words[0], [float(words[1]), float(words[2]), float(words[3])])) | |
# imgs.append((words[0], float(words[1]))) | |
self.imgs = imgs | |
self.transform = transform | |
self.target_transform = target_transform | |
# 这个方法是必须要有的,用于按照索引读取每个元素的具体内容 | |
def __getitem__(self, index): | |
# 读取文件路径和标签 | |
fn, label = self.imgs[index] | |
# 读取图片信息 | |
img = Image.open(fn).convert('RGB') | |
# list 转 numpy | |
label = np.array(label) | |
# 是否进行 transform | |
if self.transform is not None: | |
img = self.transform(img) | |
# return 很关键,return 回哪些内容,那么我们在训练时循环读取每个 batch 时,就能获得哪些内容 | |
return img, torch.from_numpy(label) | |
# 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和 loader 的长度作区分 | |
def __len__(self): | |
return len(self.imgs) | |
def main(): | |
train_data = MyDataset(filepath='./data/Train.txt', transform=transforms.ToTensor()) | |
train_loader = torch.utils.data.DataLoader(train_data, batch_size=2, shuffle=True) | |
x, label = iter(train_loader).next() | |
print('x:', x.shape, 'label:', label.shape) | |
# 显示一张图片 | |
img = x[0] | |
img = img.numpy() | |
img = np.transpose(img, (1, 2, 0)) | |
plt.imshow(img) | |
plt.show() | |
if __name__ == "__main__": | |
main() |
# 输出结果
因为我 batch_size 设为 2,所以一个 iterator 随机取出两张图片,所以输出结果的第一个维度是 batch_size,大小为 2
# [batch_size, channel, h, w] [batch_size, label] | |
x: torch.Size([2, 3, 320, 480]) label: torch.Size([2, 3]) |