# 需求

如图所示,想通过 Pytorch 加载自己的数据集,每一张图片后面有三个标签。该如何实现呢?看下面

# 代码

from    PIL import Image
import  torch
from    torchvision import transforms
import  numpy as np
from    matplotlib import pyplot as plt
# 创建自己的类:MyDataset, 这个类是继承的 torch.utils.data.Dataset
class MyDataset(torch.utils.data.Dataset): 
    def __init__(self, filepath, transform=None, target_transform=None):
        super(MyDataset, self).__init__()
        # 按照传入的路径打开这个文本,并读取内容
        fh = open(filepath, 'r') 
        imgs = []
        for line in fh:                # 按行循环 txt 文本中的内容
            line = line.rstrip()       # 删除 本行 string 字符串末尾的指定字符,这个方法的详细介绍自己查询 python
            words = line.split()       # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            # 三分类问题,所以 label 有三个
            imgs.append((words[0], [float(words[1]), float(words[2]), float(words[3])]))
            # imgs.append((words[0], float(words[1])))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
    # 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
    def __getitem__(self, index):
        # 读取文件路径和标签
        fn, label = self.imgs[index]
        # 读取图片信息
        img = Image.open(fn).convert('RGB')
        # list 转 numpy
        label = np.array(label)
        # 是否进行 transform
        if self.transform is not None:
            img = self.transform(img)
        # return 很关键,return 回哪些内容,那么我们在训练时循环读取每个 batch 时,就能获得哪些内容
        return img, torch.from_numpy(label)
    # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和 loader 的长度作区分
    def __len__(self): 
        return len(self.imgs)
def main():
    
    train_data = MyDataset(filepath='./data/Train.txt', transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=2, shuffle=True)
    x, label = iter(train_loader).next()
    print('x:', x.shape, 'label:', label.shape)
    # 显示一张图片
    img = x[0]
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))
    plt.imshow(img)
    plt.show()
if __name__ == "__main__":
    main()

# 输出结果

因为我 batch_size 设为 2,所以一个 iterator 随机取出两张图片,所以输出结果的第一个维度是 batch_size,大小为 2

# [batch_size, channel, h, w]   [batch_size, label]
x: torch.Size([2, 3, 320, 480]) label: torch.Size([2, 3])
更新于 阅读次数

请我喝[茶]~( ̄▽ ̄)~*

宇凌喵 微信支付

微信支付

宇凌喵 支付宝

支付宝