|
|
@@ -85,3 +85,51 @@ def test_suppress_model_with_pynative_mode(): |
|
|
|
|
|
|
|
model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], |
|
|
|
dataset_sink_mode=False) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.level0 |
|
|
|
@pytest.mark.platform_arm_ascend_training |
|
|
|
@pytest.mark.platform_x86_ascend_training |
|
|
|
@pytest.mark.env_card |
|
|
|
@pytest.mark.component_mindarmour |
|
|
|
def test_suppress_model_with_graph_mode(): |
|
|
|
""" |
|
|
|
Feature: suppress model |
|
|
|
Description:graph mode |
|
|
|
Expectation: none. |
|
|
|
""" |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") |
|
|
|
networks_l5 = LeNet5() |
|
|
|
epochs = 5 |
|
|
|
batch_num = 10 |
|
|
|
mask_times = 10 |
|
|
|
lr = 0.01 |
|
|
|
masklayers_lenet5 = [] |
|
|
|
masklayers_lenet5.append(MaskLayerDes("conv1.weight", 0, False, False, -1)) |
|
|
|
suppress_ctrl_instance = SuppressPrivacyFactory().create(networks_l5, |
|
|
|
masklayers_lenet5, |
|
|
|
policy="local_train", |
|
|
|
end_epoch=epochs, |
|
|
|
batch_num=batch_num, |
|
|
|
start_epoch=1, |
|
|
|
mask_times=mask_times, |
|
|
|
lr=lr, |
|
|
|
sparse_end=0.50, |
|
|
|
sparse_start=0.0) |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
net_opt = nn.SGD(networks_l5.trainable_params(), lr) |
|
|
|
model_instance = SuppressModel( |
|
|
|
network=networks_l5, |
|
|
|
loss_fn=net_loss, |
|
|
|
optimizer=net_opt, |
|
|
|
metrics={"Accuracy": Accuracy()}) |
|
|
|
model_instance.link_suppress_ctrl(suppress_ctrl_instance) |
|
|
|
suppress_masker = SuppressMasker(model=model_instance, suppress_ctrl=suppress_ctrl_instance) |
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=10) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", |
|
|
|
directory="./trained_ckpt_file/", |
|
|
|
config=config_ck) |
|
|
|
ds_train = ds.GeneratorDataset(dataset_generator, ['data', 'label']) |
|
|
|
|
|
|
|
model_instance.train(epochs, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], |
|
|
|
dataset_sink_mode=False) |