From a389ed6d29eff5bc3db150a399d56c1708708c4a Mon Sep 17 00:00:00 2001 From: ZhidanLiu Date: Thu, 17 Mar 2022 11:29:18 +0800 Subject: [PATCH] add a graph model sup privacy ut --- .../privacy/sup_privacy/test_model_train.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/ut/python/privacy/sup_privacy/test_model_train.py b/tests/ut/python/privacy/sup_privacy/test_model_train.py index fdde47f..f334fe0 100644 --- a/tests/ut/python/privacy/sup_privacy/test_model_train.py +++ b/tests/ut/python/privacy/sup_privacy/test_model_train.py @@ -85,3 +85,51 @@ def test_suppress_model_with_pynative_mode(): model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], dataset_sink_mode=False) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_card +@pytest.mark.component_mindarmour +def test_suppress_model_with_graph_mode(): + """ + Feature: suppress model + Description:graph mode + Expectation: none. + """ + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + networks_l5 = LeNet5() + epochs = 5 + batch_num = 10 + mask_times = 10 + lr = 0.01 + masklayers_lenet5 = [] + masklayers_lenet5.append(MaskLayerDes("conv1.weight", 0, False, False, -1)) + suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5, + masklayers_lenet5, + policy="local_train", + end_epoch=epochs, + batch_num=batch_num, + start_epoch=1, + mask_times=mask_times, + lr=lr, + sparse_end=0.50, + sparse_start=0.0) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.SGD(networks_l5.trainable_params(), lr) + model_instance = SuppressModel( + network=networks_l5, + loss_fn=net_loss, + optimizer=net_opt, + metrics={"Accuracy": Accuracy()}) + model_instance.link_suppress_ctrl(suppress_ctrl_instance) + suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory="./trained_ckpt_file/", + config=config_ck) + ds_train = ds.GeneratorDataset(dataset_generator, ['data', 'label']) + + model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], + dataset_sink_mode=False)