import os import torch import numpy as np import cv2 as cv from os.path import join from torch.utils.data.dataset import Dataset class DataTrain(Dataset): def __init__(self, data_path, transforms=None): self.data_dir = data_path self.image_list = os.listdir(join(data_path, 'train_ori_images')) files_len = len(self.image_list) try: imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8) labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8) for idx, file in enumerate(self.image_list): fname = file.split('.')[0] img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif')) img = np.asarray(img, dtype=np.uint8) label = cv.imread( join(self.data_dir, 'train_pupil_images', fname + '.png'), cv.IMREAD_UNCHANGED) label = np.asarray(label, dtype=np.uint8) % 100 imgs[idx, :, :, :] = img labels[idx, :, :] = label self.images = imgs self.labels = labels self.transforms = transforms except Exception: raise Exception('read error') def __getitem__(self, index): img = self.images[index] label = self.labels[index] label[label > 0.5] = 1 tx_sample = self.transforms({'img': img, 'label': label}) img = tx_sample['img'] label = tx_sample['label'] return img, label def __len__(self): return len(self.image_list)