数据处理

版本1
#数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
#定义自己的数据集合
class DogCat(data.Dataset):
def __init__(self,root):
#所有图片的绝对路径
imgs=os.listdir(root)
self.imgs=[os.path.join(root,k) for k in imgs]
def __getitem__(self, index):
img_path=self.imgs[index]
#dog-> 1 cat ->0
label=1 if 'dog' in img_path.split('/')[-1] else 0
pil_img=Image.open(img_path)
array=np.asarray(pil_img)
data=torch.from_numpy(array)
return data,label
def __len__(self):
return len(self.imgs)
dataSet=DogCat('./data/dogcat')
print(dataSet[0])