From be377fb165b32f765bc52de3c4650e8c0b70debd Mon Sep 17 00:00:00 2001 From: zhenghuanhuan Date: Thu, 28 May 2020 10:48:15 +0800 Subject: [PATCH] 1IKCU fix [MA][diff_privacy][Doc] the tutorials of diff_privacy has problem --- example/mnist_demo/lenet5_dp_model_train.py | 6 +++--- example/mnist_demo/mnist_train.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/example/mnist_demo/lenet5_dp_model_train.py b/example/mnist_demo/lenet5_dp_model_train.py index 61a359a..089c23f 100644 --- a/example/mnist_demo/lenet5_dp_model_train.py +++ b/example/mnist_demo/lenet5_dp_model_train.py @@ -92,15 +92,15 @@ if __name__ == "__main__": parser.add_argument('--data_path', type=str, default="./MNIST_unzip", help='path where the dataset is saved') parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') - parser.add_argument('--micro_batches', type=float, default=None, + parser.add_argument('--micro_batches', type=int, default=None, help='optional, if use differential privacy, need to set micro_batches') - parser.add_argument('--l2_norm_bound', type=float, default=1, + parser.add_argument('--l2_norm_bound', type=float, default=0.1, help='optional, if use differential privacy, need to set l2_norm_bound') parser.add_argument('--initial_noise_multiplier', type=float, default=0.001, help='optional, if use differential privacy, need to set initial_noise_multiplier') args = parser.parse_args() - context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False) + context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target) network = LeNet5() net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") diff --git a/example/mnist_demo/mnist_train.py b/example/mnist_demo/mnist_train.py index 81daa8b..7e595b4 100644 --- a/example/mnist_demo/mnist_train.py +++ b/example/mnist_demo/mnist_train.py @@ -61,6 +61,5 @@ def mnist_train(epoch_size, batch_size, lr, momentum): if __name__ == '__main__': - context.set_context(mode=context.GRAPH_MODE, device_target="CPU", - enable_mem_reuse=False) + context.set_context(mode=context.GRAPH_MODE, device_target="CPU") mnist_train(10, 32, 0.01, 0.9)