Browse Source

Pre Merge pull request !335 from ZhidanLiu/ut

pull/335/MERGE
ZhidanLiu Gitee 3 years ago
parent
commit
b992911891
1 changed files with 48 additions and 0 deletions
  1. +48
    -0
      tests/ut/python/privacy/sup_privacy/test_model_train.py

+ 48
- 0
tests/ut/python/privacy/sup_privacy/test_model_train.py View File

@@ -85,3 +85,51 @@ def test_suppress_model_with_pynative_mode():

model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)


@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_card
@pytest.mark.component_mindarmour
def test_suppress_model_with_graph_mode():
"""
Feature: suppress model
Description:graph mode
Expectation: none.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
networks_l5 = LeNet5()
epochs = 5
batch_num = 10
mask_times = 10
lr = 0.01
masklayers_lenet5 = []
masklayers_lenet5.append(MaskLayerDes("conv1.weight", 0, False, False, -1))
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5,
masklayers_lenet5,
policy="local_train",
end_epoch=epochs,
batch_num=batch_num,
start_epoch=1,
mask_times=mask_times,
lr=lr,
sparse_end=0.50,
sparse_start=0.0)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.SGD(networks_l5.trainable_params(), lr)
model_instance = SuppressModel(
network=networks_l5,
loss_fn=net_loss,
optimizer=net_opt,
metrics={"Accuracy": Accuracy()})
model_instance.link_suppress_ctrl(suppress_ctrl_instance)
suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance)
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory="./trained_ckpt_file/",
config=config_ck)
ds_train = ds.GeneratorDataset(dataset_generator, ['data', 'label'])

model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker],
dataset_sink_mode=False)

Loading…
Cancel
Save