Browse Source

1IKCU

fix [MA][diff_privacy][Doc] the tutorials of diff_privacy has problem
tags/v0.5.0-beta
zhenghuanhuan 5 years ago
parent
commit
be377fb165
2 changed files with 4 additions and 5 deletions
  1. +3
    -3
      example/mnist_demo/lenet5_dp_model_train.py
  2. +1
    -2
      example/mnist_demo/mnist_train.py

+ 3
- 3
example/mnist_demo/lenet5_dp_model_train.py View File

@@ -92,15 +92,15 @@ if __name__ == "__main__":
parser.add_argument('--data_path', type=str, default="./MNIST_unzip",
help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
parser.add_argument('--micro_batches', type=float, default=None,
parser.add_argument('--micro_batches', type=int, default=None,
help='optional, if use differential privacy, need to set micro_batches')
parser.add_argument('--l2_norm_bound', type=float, default=1,
parser.add_argument('--l2_norm_bound', type=float, default=0.1,
help='optional, if use differential privacy, need to set l2_norm_bound')
parser.add_argument('--initial_noise_multiplier', type=float, default=0.001,
help='optional, if use differential privacy, need to set initial_noise_multiplier')
args = parser.parse_args()

context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target, enable_mem_reuse=False)
context.set_context(mode=context.PYNATIVE_MODE, device_target=args.device_target)

network = LeNet5()
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")


+ 1
- 2
example/mnist_demo/mnist_train.py View File

@@ -61,6 +61,5 @@ def mnist_train(epoch_size, batch_size, lr, momentum):


if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target="CPU",
enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
mnist_train(10, 32, 0.01, 0.9)

Loading…
Cancel
Save