Merge pull request !321 from 张澍坤/mastertags/v1.8.0
@@ -111,6 +111,31 @@ class ImageInversionAttack: | |||||
TypeError: If the type of network is not Cell. | TypeError: If the type of network is not Cell. | ||||
ValueError: If any value of input_shape is not positive int. | ValueError: If any value of input_shape is not positive int. | ||||
ValueError: If any value of loss_weights is not positive value. | ValueError: If any value of loss_weights is not positive value. | ||||
Examples: | |||||
>>> import mindspore.ops.operations as P | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindarmour.privacy.evaluation.inversion_attack import ImageInversionAttack | |||||
>>> class Net(Cell): | |||||
... def __init__(self): | |||||
... super(Net, self).__init__() | |||||
... self._softmax = P.Softmax() | |||||
... self._reduce = P.ReduceSum() | |||||
... self._squeeze = P.Squeeze(1) | |||||
... def construct(self, inputs): | |||||
... out = self._softmax(inputs) | |||||
... out = self._reduce(out, 2) | |||||
... return self._squeeze(out) | |||||
>>> net = Net() | |||||
>>> original_images = np.random.random((2,1,10,10)).astype(np.float32) | |||||
>>> target_features = np.random.random((2,10)).astype(np.float32) | |||||
>>> inversion_attack = ImageInversionAttack(net, | |||||
... input_shape=(1, 10, 10), | |||||
... input_bound=(0, 1), | |||||
... loss_weights=[1, 0.2, 5]) | |||||
>>> inversion_images = inversion_attack.generate(target_features, iters=10) | |||||
>>> evaluate_result = inversion_attack.evaluate(original_images, inversion_images) | |||||
>>> print(evaluate_result) | |||||
""" | """ | ||||
def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)): | def __init__(self, network, input_shape, input_bound, loss_weights=(1, 0.2, 5)): | ||||
self._network = check_param_type('network', network, Cell) | self._network = check_param_type('network', network, Cell) | ||||
@@ -144,15 +169,6 @@ class ImageInversionAttack: | |||||
Raises: | Raises: | ||||
TypeError: If the type of target_features is not numpy.ndarray. | TypeError: If the type of target_features is not numpy.ndarray. | ||||
ValueError: If any value of iters is not positive int.Z | ValueError: If any value of iters is not positive int.Z | ||||
Examples: | |||||
>>> net = LeNet5() | |||||
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||||
>>> loss_weights=[1, 0.2, 5]) | |||||
>>> features = np.random.random((2, 10)).astype(np.float32) | |||||
>>> images = inversion_attack.generate(features, iters=10) | |||||
>>> print(images.shape) | |||||
(2, 1, 32, 32) | |||||
""" | """ | ||||
target_features = check_numpy_param('target_features', target_features) | target_features = check_numpy_param('target_features', target_features) | ||||
iters = check_int_positive('iters', iters) | iters = check_int_positive('iters', iters) | ||||
@@ -203,16 +219,6 @@ class ImageInversionAttack: | |||||
- float, average ssim value. | - float, average ssim value. | ||||
- Union[float, None], average confidence. It would be None if labels or new_network is None. | - Union[float, None], average confidence. It would be None if labels or new_network is None. | ||||
Examples: | |||||
>>> net = LeNet5() | |||||
>>> inversion_attack = ImageInversionAttack(net, input_shape=(1, 32, 32), input_bound=(0, 1), | |||||
>>> loss_weights=[1, 0.2, 5]) | |||||
>>> features = np.random.random((2, 10)).astype(np.float32) | |||||
>>> inver_images = inversion_attack.generate(features, iters=10) | |||||
>>> ori_images = np.random.random((2, 1, 32, 32)) | |||||
>>> result = inversion_attack.evaluate(ori_images, inver_images) | |||||
>>> print(len(result)) | |||||
""" | """ | ||||
check_numpy_param('original_images', original_images) | check_numpy_param('original_images', original_images) | ||||
check_numpy_param('inversion_images', inversion_images) | check_numpy_param('inversion_images', inversion_images) | ||||
@@ -106,21 +106,52 @@ class MembershipInference: | |||||
n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | n_jobs (int): Number of jobs run in parallel. -1 means using all processors, | ||||
otherwise the value of n_jobs must be a positive integer. | otherwise the value of n_jobs must be a positive integer. | ||||
Examples: | |||||
>>> # train_1, train_2 are non-overlapping datasets from training dataset of target model. | |||||
>>> # test_1, test_2 are non-overlapping datasets from test dataset of target model. | |||||
>>> # We use train_1, test_1 to train attack model, and use train_2, test_2 to evaluate attack model. | |||||
>>> model = Model(network=net, loss_fn=loss, optimizer=opt, metrics={'acc', 'loss'}) | |||||
>>> attack_model = MembershipInference(model, n_jobs=-1) | |||||
>>> config = [{"method": "KNN", "params": {"n_neighbors": [3, 5, 7]}}] | |||||
>>> attack_model.train(train_1, test_1, config) | |||||
>>> metrics = ["precision", "recall", "accuracy"] | |||||
>>> result = attack_model.eval(train_2, test_2, metrics) | |||||
Raises: | Raises: | ||||
TypeError: If type of model is not mindspore.train.Model. | TypeError: If type of model is not mindspore.train.Model. | ||||
TypeError: If type of n_jobs is not int. | TypeError: If type of n_jobs is not int. | ||||
ValueError: The value of n_jobs is neither -1 nor a positive integer. | ValueError: The value of n_jobs is neither -1 nor a positive integer. | ||||
Examples: | |||||
>>> import mindspore.ops.operations as P | |||||
>>> from mindspore.nn import Cell | |||||
>>> from mindspore import Model | |||||
>>> from mindarmour.privacy.evaluation import MembershipInference | |||||
>>> def dataset_generator(): | |||||
... batch_size = 16 | |||||
... batches = 1 | |||||
... data = np.random.randn(batches * batch_size,1,10).astype(np.float32) | |||||
... label = np.random.randint(0,10, batches * batch_size).astype(np.int32) | |||||
... for i in range(batches): | |||||
... yield data[i*batch_size:(i+1)*batch_size],\ | |||||
... label[i*batch_size:(i+1)*batch_size] | |||||
>>> class Net(Cell): | |||||
... def __init__(self): | |||||
... super(Net, self).__init__() | |||||
... self._softmax = P.Softmax() | |||||
... self._Dense = nn.Dense(10,10) | |||||
... self._squeeze = P.Squeeze(1) | |||||
... def construct(self, inputs): | |||||
... out = self._softmax(inputs) | |||||
... out = self._Dense(out) | |||||
... return self._squeeze(out) | |||||
>>> net = Net() | |||||
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True) | |||||
>>> opt = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
>>> model = Model(network=net, loss_fn=loss, optimizer=opt) | |||||
>>> inference_model = MembershipInference(model, 2) | |||||
>>> config = [{ | |||||
... "method": "KNN", | |||||
... "params": {"n_neighbors": [3, 5, 7],} | |||||
}] | |||||
>>> ds_train = ds.GeneratorDataset(dataset_generator, ["image", "label"]) | |||||
>>> ds_test = ds.GeneratorDataset(dataset_generator, ["image", "label"]) | |||||
>>> inference_model.train(ds_train, ds_test, config) | |||||
>>> metrics = ["precision", "accuracy", "recall"] | |||||
>>> eval_train = ds.GeneratorDataset(dataset_generator, ["image", "label"]) | |||||
>>> eval_test = ds.GeneratorDataset(dataset_generator, ["image", "label"]) | |||||
>>> result = inference_model.eval(eval_train. eval_test, metrics) | |||||
>>> print(result) | |||||
""" | """ | ||||
def __init__(self, model, n_jobs=-1): | def __init__(self, model, n_jobs=-1): | ||||
@@ -25,39 +25,54 @@ TAG = 'suppress masker' | |||||
class SuppressMasker(Callback): | class SuppressMasker(Callback): | ||||
""" | """ | ||||
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_ | |||||
Args: | Args: | ||||
model (SuppressModel): SuppressModel instance. | model (SuppressModel): SuppressModel instance. | ||||
suppress_ctrl (SuppressCtrl): SuppressCtrl instance. | suppress_ctrl (SuppressCtrl): SuppressCtrl instance. | ||||
Examples: | Examples: | ||||
>>> networks_l5 = LeNet5() | |||||
>>> import mindspore.nn as nn | |||||
>>> import mindspore.dataset as ds | |||||
>>> import mindspore.ops.operations as P | |||||
>>> from mindspore import context | |||||
>>> from mindspore.nn import Accuracy | |||||
>>> from mindarmour.privacy.sup_privacy import SuppressModel | |||||
>>> from mindarmour.privacy.sup_privacy import SuppressMasker | |||||
>>> from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory | |||||
>>> from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
>>> class Net(nn.Cell): | |||||
... def __init__(self): | |||||
... super(Net, self).__init__() | |||||
... self._softmax = P.Softmax() | |||||
... self._Dense = nn.Dense(10,10) | |||||
... self._squeeze = P.Squeeze(1) | |||||
... def construct(self, inputs): | |||||
... out = self._softmax(inputs) | |||||
... out = self._Dense(out) | |||||
... return self._squeeze(out) | |||||
>>> context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
>>> network = Net() | |||||
>>> masklayers = [] | >>> masklayers = [] | ||||
>>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, | |||||
... mask_layers=masklayers, | |||||
... policy="local_train", | |||||
... end_epoch=10, | |||||
... batch_num=(int)(10000/cfg.batch_size), | |||||
... start_epoch=3, | |||||
... mask_times=1000, | |||||
... lr=lr, | |||||
... sparse_end=0.90, | |||||
... sparse_start=0.0) | |||||
>>> masklayers.append(MaskLayerDes("_Dense.weight", 0, False, True, 10)) | |||||
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=network, | |||||
... mask_layers=masklayers, | |||||
... policy="local_train", | |||||
... end_epoch=10, | |||||
... batch_num=1, | |||||
... start_epoch=3, | |||||
... mask_times=10, | |||||
... lr=0.05, | |||||
... sparse_end=0.95, | |||||
... sparse_start=0.0) | |||||
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
>>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
>>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
>>> model_instance = SuppressModel(network=networks_l5, | |||||
>>> net_opt = nn.SGD(network.trainable_params(), 0.05) | |||||
>>> model_instance = SuppressModel(network=network, | |||||
... loss_fn=net_loss, | ... loss_fn=net_loss, | ||||
... optimizer=net_opt, | ... optimizer=net_opt, | ||||
... metrics={"Accuracy": Accuracy()}) | ... metrics={"Accuracy": Accuracy()}) | ||||
>>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) | >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) | ||||
>>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
... batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
>>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
... directory="./trained_ckpt_file/", | |||||
... config=config_ck) | |||||
>>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
... dataset_sink_mode=False) | |||||
>>> masker_instance = SuppressMasker(model_instance, suppress_ctrl_instance) | |||||
""" | """ | ||||
def __init__(self, model, suppress_ctrl): | def __init__(self, model, suppress_ctrl): | ||||
@@ -38,6 +38,8 @@ class SuppressPrivacyFactory: | |||||
def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, | def create(networks, mask_layers, policy="local_train", end_epoch=10, batch_num=20, start_epoch=3, | ||||
mask_times=1000, lr=0.05, sparse_end=0.90, sparse_start=0.0): | mask_times=1000, lr=0.05, sparse_end=0.90, sparse_start=0.0): | ||||
""" | """ | ||||
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_ | |||||
Args: | Args: | ||||
networks (Cell): The training network. | networks (Cell): The training network. | ||||
This networks parameter should be same as 'network' parameter of SuppressModel(). | This networks parameter should be same as 'network' parameter of SuppressModel(). | ||||
@@ -57,35 +59,45 @@ class SuppressPrivacyFactory: | |||||
SuppressCtrl, class of Suppress Privavy Mechanism. | SuppressCtrl, class of Suppress Privavy Mechanism. | ||||
Examples: | Examples: | ||||
>>> networks_l5 = LeNet5() | |||||
>>> mask_layers = [] | |||||
>>> mask_layers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, | |||||
... mask_layers=mask_layers, | |||||
... policy="local_train", | |||||
... end_epoch=10, | |||||
... batch_num=(int)(10000/cfg.batch_size), | |||||
... start_epoch=3, | |||||
... mask_times=1000, | |||||
... lr=lr, | |||||
... sparse_end=0.90, | |||||
... sparse_start=0.0) | |||||
>>> import mindspore.nn as nn | |||||
>>> import mindspore.dataset as ds | |||||
>>> import mindspore.ops.operations as P | |||||
>>> from mindspore import context | |||||
>>> from mindspore.nn import Accuracy | |||||
>>> from mindarmour.privacy.sup_privacy import SuppressPrivacyFactory | |||||
>>> from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
>>> from mindarmour.privacy.sup_privacy import SuppressModel | |||||
>>> class Net(nn.Cell): | |||||
... def __init__(self): | |||||
... super(Net, self).__init__() | |||||
... self._softmax = P.Softmax() | |||||
... self._Dense = nn.Dense(10,10) | |||||
... self._squeeze = P.Squeeze(1) | |||||
... def construct(self, inputs): | |||||
... out = self._softmax(inputs) | |||||
... out = self._Dense(out) | |||||
... return self._squeeze(out) | |||||
>>> context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") | |||||
>>> network = Net() | |||||
>>> masklayers = [] | |||||
>>> masklayers.append(MaskLayerDes("_Dense.weight", 0, False, True, 10)) | |||||
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=network, | |||||
... mask_layers=masklayers, | |||||
... policy="local_train", | |||||
... end_epoch=10, | |||||
... batch_num=1, | |||||
... start_epoch=3, | |||||
... mask_times=10, | |||||
... lr=0.05, | |||||
... sparse_end=0.95, | |||||
... sparse_start=0.0) | |||||
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
>>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
>>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), | |||||
... keep_checkpoint_max=10) | |||||
>>> model_instance = SuppressModel(network=networks_l5, | |||||
>>> net_opt = nn.SGD(network.trainable_params(), 0.05) | |||||
>>> model_instance = SuppressModel(network=network, | |||||
... loss_fn=net_loss, | ... loss_fn=net_loss, | ||||
... optimizer=net_opt, | ... optimizer=net_opt, | ||||
... metrics={"Accuracy": Accuracy()}) | ... metrics={"Accuracy": Accuracy()}) | ||||
>>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) | >>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) | ||||
>>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
... batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
>>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
... directory="./trained_ckpt_file/", | |||||
... config=config_ck) | |||||
>>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
... dataset_sink_mode=False) | |||||
""" | """ | ||||
check_param_type('policy', policy, str) | check_param_type('policy', policy, str) | ||||
if policy == "local_train": | if policy == "local_train": | ||||
@@ -97,6 +109,8 @@ class SuppressPrivacyFactory: | |||||
class SuppressCtrl(Cell): | class SuppressCtrl(Cell): | ||||
""" | """ | ||||
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html#%E5%BC%95%E5%85%A5%E6%8A%91%E5%88%B6%E9%9A%90%E7%A7%81%E8%AE%AD%E7%BB%83>`_ | |||||
Args: | Args: | ||||
networks (Cell): The training network. | networks (Cell): The training network. | ||||
mask_layers (list): Description of those layers that need to be suppressed. | mask_layers (list): Description of those layers that need to be suppressed. | ||||
@@ -107,37 +121,6 @@ class SuppressCtrl(Cell): | |||||
lr (Union[float, int]): Learning rate. | lr (Union[float, int]): Learning rate. | ||||
sparse_end (float): The sparsity to reach. | sparse_end (float): The sparsity to reach. | ||||
sparse_start (Union[float, int]): The sparsity to start. | sparse_start (Union[float, int]): The sparsity to start. | ||||
Examples: | |||||
>>> networks_l5 = LeNet5() | |||||
>>> masklayers = [] | |||||
>>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, | |||||
... mask_layers=masklayers, | |||||
... policy="local_train", | |||||
... end_epoch=10, | |||||
... batch_num=(int)(10000/cfg.batch_size), | |||||
... start_epoch=3, | |||||
... mask_times=1000, | |||||
... lr=lr, | |||||
... sparse_end=0.90, | |||||
... sparse_start=0.0) | |||||
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
>>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
>>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), | |||||
... keep_checkpoint_max=10) | |||||
>>> model_instance = SuppressModel(network=networks_l5, | |||||
... loss_fn=net_loss, | |||||
... optimizer=net_opt, | |||||
... metrics={"Accuracy": Accuracy()}) | |||||
>>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
>>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
... batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
>>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
... directory="./trained_ckpt_file/", | |||||
... config=config_ck) | |||||
>>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
... dataset_sink_mode=False) | |||||
""" | """ | ||||
def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, | def __init__(self, networks, mask_layers, end_epoch, batch_num, start_epoch, mask_times, lr, | ||||
sparse_end, sparse_start): | sparse_end, sparse_start): | ||||
@@ -776,6 +759,7 @@ class MaskLayerDes: | |||||
If parameter num is greater than 100000, upper_bound has not effect. | If parameter num is greater than 100000, upper_bound has not effect. | ||||
Examples: | Examples: | ||||
>>> from mindarmour.privacy.sup_privacy import MaskLayerDes | |||||
>>> masklayers = [] | >>> masklayers = [] | ||||
>>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | >>> masklayers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | ||||
""" | """ | ||||
@@ -57,42 +57,13 @@ def tensor_grad_scale(scale, grad): | |||||
class SuppressModel(Model): | class SuppressModel(Model): | ||||
""" | """ | ||||
This class is overload mindspore.train.model.Model. | This class is overload mindspore.train.model.Model. | ||||
For details, please check `Tutorial <https://mindspore.cn/mindarmour/docs/zh-CN/master/protect_user_privacy_with_suppress_privacy.html>`_ | |||||
Args: | Args: | ||||
network (Cell): The training network. | network (Cell): The training network. | ||||
loss_fn (Cell): Computes softmax cross entropy between logits and labels. | loss_fn (Cell): Computes softmax cross entropy between logits and labels. | ||||
optimizer (Optimizer): optimizer instance. | optimizer (Optimizer): optimizer instance. | ||||
kwargs: Keyword parameters used for creating a suppress model. | kwargs: Keyword parameters used for creating a suppress model. | ||||
Examples: | |||||
>>> networks_l5 = LeNet5() | |||||
>>> mask_layers = [] | |||||
>>> mask_layers.append(MaskLayerDes("conv1.weight", 0, False, True, 10)) | |||||
>>> suppress_ctrl_instance = SuppressPrivacyFactory().create(networks=networks_l5, | |||||
... mask_layers=mask_layers, | |||||
... policy="local_train", | |||||
... end_epoch=10, | |||||
... batch_num=(int)(10000/cfg.batch_size), | |||||
... start_epoch=3, | |||||
... mask_times=1000, | |||||
... lr=lr, | |||||
... sparse_end=0.90, | |||||
... sparse_start=0.0) | |||||
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | |||||
>>> net_opt = nn.Momentum(params=networks_l5.trainable_params(), learning_rate=lr, momentum=0.0) | |||||
>>> config_ck = CheckpointConfig(save_checkpoint_steps=(int)(samples/cfg.batch_size), keep_checkpoint_max=10) | |||||
>>> model_instance = SuppressModel(network=networks_l5, | |||||
... loss_fn=net_loss, | |||||
... optimizer=net_opt, | |||||
... metrics={"Accuracy": Accuracy()}) | |||||
>>> model_instance.link_suppress_ctrl(suppress_ctrl_instance) | |||||
>>> ds_train = generate_mnist_dataset("./MNIST_unzip/train", | |||||
... batch_size=cfg.batch_size, repeat_size=1, samples=samples) | |||||
>>> ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", | |||||
... directory="./trained_ckpt_file/", | |||||
... config=config_ck) | |||||
>>> model_instance.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(), suppress_masker], | |||||
... dataset_sink_mode=False) | |||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||