You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 1.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import os
  2. import torch
  3. import numpy as np
  4. import cv2 as cv
  5. from os.path import join
  6. from torch.utils.data.dataset import Dataset
  7. class DataTrain(Dataset):
  8. def __init__(self, data_path, transforms=None):
  9. self.data_dir = data_path
  10. self.image_list = os.listdir(join(data_path, 'train_ori_images'))
  11. files_len = len(self.image_list)
  12. try:
  13. imgs = np.zeros(shape=(files_len, 256, 256, 3), dtype=np.uint8)
  14. labels = np.zeros(shape=(files_len, 256, 256), dtype=np.uint8)
  15. for idx, file in enumerate(self.image_list):
  16. fname = file.split('.')[0]
  17. img = cv.imread(join(self.data_dir, 'train_ori_images', fname + '.tif'))
  18. img = np.asarray(img, dtype=np.uint8)
  19. label = cv.imread(
  20. join(self.data_dir, 'train_pupil_images', fname + '.png'),
  21. cv.IMREAD_UNCHANGED)
  22. label = np.asarray(label, dtype=np.uint8) % 100
  23. imgs[idx, :, :, :] = img
  24. labels[idx, :, :] = label
  25. self.images = imgs
  26. self.labels = labels
  27. self.transforms = transforms
  28. except Exception:
  29. raise Exception('read error')
  30. def __getitem__(self, index):
  31. img = self.images[index]
  32. label = self.labels[index]
  33. label[label > 0.5] = 1
  34. tx_sample = self.transforms({'img': img, 'label': label})
  35. img = tx_sample['img']
  36. label = tx_sample['label']
  37. return img, label
  38. def __len__(self):
  39. return len(self.image_list)

网络代码复现