|
- import os
- import torch
- import numpy as np
- import cv2 as cv
- from os.path import join
- from torch.utils.data.dataset import Dataset
-
-
- class DataTrain(Dataset):
- def __init__(self, data_path, transforms=None):
- self.data_dir = data_path
- self.image_list = os.listdir(join(data_path, 'train_ori_images'))
- files_len = len(self.image_list)
- try:
- imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8)
- labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8)
- for idx, file in enumerate(self.image_list):
- fname = file.split('.')[0]
- img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif'))
- img = np.asarray(img, dtype=np.uint8)
- label = cv.imread(
- join(self.data_dir, 'train_pupil_images', fname + '.png'),
- cv.IMREAD_UNCHANGED)
- label = np.asarray(label, dtype=np.uint8) % 100
- imgs[idx, :, :, :] = img
- labels[idx, :, :] = label
- self.images = imgs
- self.labels = labels
- self.transforms = transforms
- except Exception:
- raise Exception('read error')
-
- def __getitem__(self, index):
- img = self.images[index]
- label = self.labels[index]
- label[label > 0.5] = 1
- tx_sample = self.transforms({'img': img, 'label': label})
- img = tx_sample['img']
- label = tx_sample['label']
-
- return img, label
-
- def __len__(self):
- return len(self.image_list)
|