|
- import os
-
- import nvidia.dali.ops as ops
- import nvidia.dali.types as types
- import torch.utils.data
- from nvidia.dali.pipeline import Pipeline
- from nvidia.dali.plugin.pytorch import DALIClassificationIterator
-
-
- class HybridTrainPipe(Pipeline):
- def __init__(self, batch_size, num_threads, device_id, data_dir, crop, seed=12, local_rank=0, world_size=1,
- spos_pre=False):
- super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
- color_space_type = types.BGR if spos_pre else types.RGB
- self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size, random_shuffle=True)
- self.decode = ops.ImageDecoder(device="mixed", output_type=types.BGR) # color_space_type
- self.res = ops.RandomResizedCrop(device="gpu", size=crop,
- interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
- self.twist = ops.ColorTwist(device="gpu")
- self.jitter_rng = ops.Uniform(range=[0.6, 1.4])
- # self.cmnp = ops.CropMirrorNormalize(device="gpu",
- # dtype = types.FLOAT, # output_dtype=types.FLOAT,
- # output_layout=types.NCHW,
- # # image_type=color_space_type, # 该功能被删掉了,在ImageDecoder中即可完成
- # mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
- # std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
- self.cmnp = ops.CropMirrorNormalize(device="gpu", dtype = types.FLOAT, output_layout=types.NCHW,
- mean= [0.485 * 255, 0.456 * 255, 0.406 * 255],
- std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
- )
- self.coin = ops.CoinFlip(probability=0.5)
-
- def define_graph(self):
- rng = self.coin()
- self.jpegs, self.labels = self.input(name="Reader")
- images = self.decode(self.jpegs)
- images = self.res(images)
- images = self.twist(images, saturation=self.jitter_rng(),
- contrast=self.jitter_rng(), brightness=self.jitter_rng())
- output = self.cmnp(images, mirror=rng) # 临时删除,测试准确率为零是否是数据处理的原因
- return [output, self.labels] # output
-
-
- class HybridValPipe(Pipeline):
- def __init__(self, batch_size, num_threads, device_id, data_dir, crop, size, seed=12, local_rank=0, world_size=1,
- spos_pre=False, shuffle=False):
- super(HybridValPipe, self).__init__(batch_size, num_threads, device_id, seed=seed + device_id)
- color_space_type = types.BGR if spos_pre else types.RGB
- self.input = ops.FileReader(file_root=data_dir, shard_id=local_rank, num_shards=world_size,
- random_shuffle=shuffle)
- self.decode = ops.ImageDecoder(device="mixed", output_type=types.BGR)
- self.res = ops.Resize(device="gpu", resize_shorter=size,
- interp_type=types.INTERP_LINEAR if spos_pre else types.INTERP_TRIANGULAR)
- # self.cmnp = ops.CropMirrorNormalize(device="gpu",
- # dtype = types.FLOAT, # output_dtype=types.FLOAT,
- # output_layout=types.NCHW,
- # crop=(crop, crop),
- # # image_type=color_space_type,
- # mean=0. if spos_pre else [0.485 * 255, 0.456 * 255, 0.406 * 255],
- # std=1. if spos_pre else [0.229 * 255, 0.224 * 255, 0.225 * 255])
- self.cmnp = ops.CropMirrorNormalize(device="gpu",
- dtype = types.FLOAT, # output_dtype=types.FLOAT,
- output_layout=types.NCHW,
- crop=(crop, crop),
- # image_type=color_space_type,
- mean = [0.485 * 255, 0.456 * 255, 0.406 * 255],
- std = [0.229 * 255, 0.224 * 255, 0.225 * 255])
-
-
- def define_graph(self):
- self.jpegs, self.labels = self.input(name="Reader")
- images = self.decode(self.jpegs)
- images = self.res(images)
- output = self.cmnp(images)
- return [output, self.labels]
-
-
- class ClassificationWrapper:
- def __init__(self, loader, size):
- self.loader = loader
- self.size = size
-
- def __iter__(self):
- return self
-
- def __next__(self):
- data = next(self.loader)
- return data[0]["data"], data[0]["label"].view(-1).long().cuda(device="cuda:0", non_blocking=True) # .cuda(non_blocking=True)
-
- def __len__(self):
- return self.size
-
-
- def get_imagenet_iter_dali(split, image_dir, batch_size, num_threads, crop=224, val_size=256,
- spos_preprocessing=False, seed=12, shuffle=False, device_id=None):
- world_size, local_rank = 1, 0
- if device_id is None:
- device_id = torch.cuda.device_count() - 1 # use last gpu
- if split == "train":
- pipeline = HybridTrainPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
- data_dir=os.path.join(image_dir, "train"), seed=seed,
- crop=crop, world_size=world_size, local_rank=local_rank,
- spos_pre=spos_preprocessing)
- elif split == "val":
- pipeline = HybridValPipe(batch_size=batch_size, num_threads=num_threads, device_id=device_id,
- data_dir=os.path.join(image_dir, "val"), seed=seed,
- crop=crop, size=val_size, world_size=world_size, local_rank=local_rank,
- spos_pre=spos_preprocessing, shuffle=shuffle)
- else:
- raise AssertionError
- pipeline.build()
- num_samples = pipeline.epoch_size("Reader")
- # fill_last_batch的设置
- # 参考这里, valid和train设置为一样的策略
- # https://docs.nvidia.com/deeplearning/dali/user-guide/docs/plugins/pytorch_plugin_api.html?highlight=daliclassificationiterator#nvidia.dali.plugin.pytorch.DALIClassificationIterator
- last_batch_policy = ""
- last_batch_padded = True
- return ClassificationWrapper(
- DALIClassificationIterator(pipeline,
- # size=num_samples,
- last_batch_policy = last_batch_policy,
- # last_batch_padded = last_batch_padded,
- # fill_last_batch=split == "train", # 这个方法已经不建议使用了
- auto_reset=True),
- (num_samples + batch_size - 1) // batch_size)
|