yield数据集生成器用法学习
方法一,简单yield方法。
假如我都数据集是一个列表[0,1,2,3,4,5,6,7,8,9]。
需求:
- 打乱顺序,相当于随机取数。
- 一轮取完,重新打乱顺序,继续取数。无穷无尽。
def test_yeild():
np.random.shuffle(a)
for j in range(10):
yield a[j]
if __name__ == '__main__':
a = list(range(10))
while True:
for i in test_yeild():
print(i)
print("finish***")
下面是pytorch的自带的图像数据集cifar10的处理方法。
import numpy as np
import _pickle as pickle
def unpickle(file):
fo = open(file, 'rb')
dict = pickle.load(fo, encoding='bytes')
fo.close()
return dict[b'data']
def cifar_generator(filenames, batch_size, data_dir):
all_data = []
for filename in filenames:
all_data.append(unpickle(data_dir + '/' + filename))
images = np.concatenate(all_data, axis=0)
def get_epoch():
np.random.shuffle(images)
for i in range(int(len(images) / batch_size)):
yield np.copy(images[i*batch_size:(i+1)*batch_size])
return get_epoch
def load(batch_size, data_dir):
return (
cifar_generator(['data_batch_1','data_batch_2','data_batch_3','data_batch_4','data_batch_5'], batch_size, data_dir),
cifar_generator(['test_batch'], batch_size, data_dir)
)
解释:
- 代码看起来很复杂,其实核心是返回get_epoch数据集生成器(注意,没有括号’()')。
- load() 方法是返回训练集、测试集的两个生成器的方法。
- cifar_generator() 方法是核心处理步骤。
- unpickle() 是解析数据集的文件,因为自带的cifar10是以二进制的格式存储的。return dict[b’data’] 返回的是n3072的数据矩阵,n是数据的个数,3072=3232*3,也就是32像素的彩色图像。
方法二,ImageFolder+DataLoader+yield方法
from torchvision import datasets
datasets.ImageFolder
torch.utils.data.DataLoader
下面是我来曾用过的一个数据集加载方法。
class DataLoaderMy(datasets.ImageFolder):
def __init__(self, baseDir):
self.baseDir = baseDir
self.img_lists = self.getImgList(self.baseDir)
self.transform = transforms.Compose([
transforms.ToPILImage(),
transforms.ToTensor()
])
def getImgList(self, baseDir):
img_list = []
list_dir = os.walk(baseDir)
for root, dirs, files in list_dir:
for f in files:
file_name = f.split('.')
if file_name[-1] in ['png', 'jpg', 'JPG', 'PNG', 'JPG', 'jpeg', 'JPEG', 'bmp']:
img_file = root + '/' + f
img_list.append(img_file)
return img_list
def __getitem__(self, idx):
img_name = self.img_lists[idx]
img_path = os.path.join(self.baseDir,img_name)
img = cv2.imread(img_path)
# img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
# img = torch.from_numpy(img).cuda().float()
img = self.transform(img)
return img.cuda()
def __len__(self):
return len(self.img_lists)
如何加载?看下面代码
imgLoaderG = DataLoaderMy(data_dir)
data_loader_G = torch.utils.data.DataLoader(imgLoaderG, batch_size=batch_size, shuffle=True)
while True:
for images in data_loader:
yield images