@@ -105,13 +105,12 @@ def vgg_create_dataset100(data_home, image_size, batch_size, rank_id=0, rank_siz | |||||
data_set = data_set.map(input_columns="label", operations=type_cast_op) | data_set = data_set.map(input_columns="label", operations=type_cast_op) | ||||
data_set = data_set.map(input_columns="image", operations=c_trans) | data_set = data_set.map(input_columns="image", operations=c_trans) | ||||
# apply repeat operations | |||||
data_set = data_set.repeat(repeat_num) | |||||
# apply shuffle operations | # apply shuffle operations | ||||
# data_set = data_set.shuffle(buffer_size=1000) | |||||
data_set = data_set.shuffle(buffer_size=1000) | |||||
# apply batch operations | # apply batch operations | ||||
data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) | data_set = data_set.batch(batch_size=batch_size, drop_remainder=True) | ||||
# apply repeat operations | |||||
data_set = data_set.repeat(repeat_num) | |||||
return data_set | return data_set |
@@ -107,7 +107,7 @@ class ImageInversionAttack: | |||||
for sub_loss_weight in loss_weights: | for sub_loss_weight in loss_weights: | ||||
check_value_positive('sub_loss_weight', sub_loss_weight) | check_value_positive('sub_loss_weight', sub_loss_weight) | ||||
self._loss = InversionLoss(self._network, loss_weights) | self._loss = InversionLoss(self._network, loss_weights) | ||||
self._input_shape = check_param_multi_types('input_shape', input_shape, [list, tuple]) | |||||
self._input_shape = check_param_type('input_shape', input_shape, tuple) | |||||
for shape_dim in input_shape: | for shape_dim in input_shape: | ||||
check_int_positive('shape_dim', shape_dim) | check_int_positive('shape_dim', shape_dim) | ||||
self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple]) | self._input_bound = check_param_multi_types('input_bound', input_bound, [list, tuple]) | ||||