|
|
@@ -92,15 +92,15 @@ if __name__ == "__main__": |
|
|
|
parser.add_argument('--data_path', type=str, default="./MNIST_unzip", |
|
|
|
help='path where the dataset is saved') |
|
|
|
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') |
|
|
|
parser.add_argument('--micro_batches', type=float, default=None, |
|
|
|
parser.add_argument('--micro_batches', type=int, default=None, |
|
|
|
help='optional, if use differential privacy, need to set micro_batches') |
|
|
|
parser.add_argument('--l2_norm_bound', type=float, default=1, |
|
|
|
parser.add_argument('--l2_norm_bound', type=float, default=0.1, |
|
|
|
help='optional, if use differential privacy, need to set l2_norm_bound') |
|
|
|
parser.add_argument('--initial_noise_multiplier', type=float, default=0.001, |
|
|
|
help='optional, if use differential privacy, need to set initial_noise_multiplier') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False) |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) |
|
|
|
|
|
|
|
network = LeNet5() |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") |
|
|
|