Merge pull request !2555 from chenzhongming/r0.5tags/v0.5.0-beta
| @@ -593,6 +593,17 @@ def check_bool(input_param): | |||
| raise TypeError("Input type must be bool!") | |||
| def check_string(input_param, valid_values): | |||
| """String type judgment.""" | |||
| if isinstance(input_param, str) and input_param in valid_values: | |||
| return input_param | |||
| if len(valid_values) == 1: | |||
| raise ValueError(f'Input should be str and must be {valid_values[0]},' | |||
| f' but got {input_param}.') | |||
| raise ValueError(f'Input should be str and must be one of {valid_values},' | |||
| f' but got {input_param}.') | |||
| def check_input_format(input_param): | |||
| """Judge input format.""" | |||
| if input_param == "NCHW": | |||
| @@ -22,6 +22,7 @@ message Checkpoint { | |||
| required TensorProto tensor = 2; | |||
| } | |||
| repeated Value value = 1; | |||
| required string model_type = 2; | |||
| } | |||
| @@ -21,17 +21,16 @@ import time | |||
| import mindspore.context as context | |||
| from mindspore import log as logger | |||
| from mindspore._checkparam import check_bool, check_int_non_negative | |||
| from mindspore._checkparam import check_bool, check_string, check_int_non_negative | |||
| from mindspore.train._utils import _make_directory | |||
| from mindspore.train.serialization import _exec_save_checkpoint, _save_graph | |||
| from ._callback import Callback, set_cur_net | |||
| _cur_dir = os.getcwd() | |||
| _save_dir = _cur_dir | |||
| def _check_file_name_prefix(file_name_prefix): | |||
| """ | |||
| Check file name valid or not. | |||
| @@ -87,6 +86,7 @@ class CheckpointConfig: | |||
| Can't be used with keep_checkpoint_max at the same time. | |||
| integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. | |||
| Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. | |||
| model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal". | |||
| Raises: | |||
| ValueError: If the input_param is None or 0. | |||
| @@ -101,7 +101,8 @@ class CheckpointConfig: | |||
| save_checkpoint_seconds=0, | |||
| keep_checkpoint_max=5, | |||
| keep_checkpoint_per_n_minutes=0, | |||
| integrated_save=True): | |||
| integrated_save=True, | |||
| model_type="normal"): | |||
| if not save_checkpoint_steps and not save_checkpoint_seconds and \ | |||
| not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: | |||
| @@ -115,6 +116,8 @@ class CheckpointConfig: | |||
| keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) | |||
| if keep_checkpoint_per_n_minutes: | |||
| keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) | |||
| if model_type: | |||
| model_type = check_string(model_type, ["normal", "fusion", "quant"]) | |||
| self._save_checkpoint_steps = save_checkpoint_steps | |||
| self._save_checkpoint_seconds = save_checkpoint_seconds | |||
| @@ -129,6 +132,7 @@ class CheckpointConfig: | |||
| if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: | |||
| self._keep_checkpoint_max = 1 | |||
| self._model_type = model_type | |||
| self._integrated_save = check_bool(integrated_save) | |||
| @property | |||
| @@ -156,12 +160,18 @@ class CheckpointConfig: | |||
| """Get the value of _integrated_save.""" | |||
| return self._integrated_save | |||
| @property | |||
| def model_type(self): | |||
| """Get the value of model_type.""" | |||
| return self._model_type | |||
| def get_checkpoint_policy(self): | |||
| """Get the policy of checkpoint.""" | |||
| checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, | |||
| 'save_checkpoint_seconds': self._save_checkpoint_seconds, | |||
| 'keep_checkpoint_max': self._keep_checkpoint_max, | |||
| 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} | |||
| 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes, | |||
| 'model_type': self._model_type} | |||
| return checkpoint_policy | |||
| @@ -226,7 +236,7 @@ class ModelCheckpoint(Callback): | |||
| graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') | |||
| _save_graph(cb_params.train_network, graph_file_name) | |||
| self._graph_saved = True | |||
| self._save_ckpt(cb_params) | |||
| self._save_ckpt(cb_params, self._config.model_type) | |||
| def end(self, run_context): | |||
| """ | |||
| @@ -237,7 +247,7 @@ class ModelCheckpoint(Callback): | |||
| """ | |||
| cb_params = run_context.original_args() | |||
| _to_save_last_ckpt = True | |||
| self._save_ckpt(cb_params, _to_save_last_ckpt) | |||
| self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt) | |||
| from mindspore.parallel._cell_wrapper import destroy_allgather_cell | |||
| destroy_allgather_cell() | |||
| @@ -256,7 +266,7 @@ class ModelCheckpoint(Callback): | |||
| return False | |||
| def _save_ckpt(self, cb_params, force_to_save=False): | |||
| def _save_ckpt(self, cb_params, model_type, force_to_save=False): | |||
| """Save checkpoint files.""" | |||
| if cb_params.cur_step_num == self._last_triggered_step: | |||
| return | |||
| @@ -292,7 +302,7 @@ class ModelCheckpoint(Callback): | |||
| set_cur_net(cb_params.train_network) | |||
| cb_params.train_network.exec_checkpoint_graph() | |||
| _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) | |||
| _exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save) | |||
| if os.path.exists(gen_file): | |||
| shutil.move(gen_file, cur_file) | |||
| @@ -76,7 +76,7 @@ class LossMonitor(Callback): | |||
| step_loss = np.mean(step_loss.asnumpy()) | |||
| self.losses.append(step_loss) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num | |||
| cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 | |||
| if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): | |||
| raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " | |||
| @@ -87,7 +87,7 @@ class LossMonitor(Callback): | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " | |||
| "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( | |||
| cb_params.cur_epoch_num - 1, cb_params.epoch_num, | |||
| cur_step_in_epoch, cb_params.batch_num, | |||
| cb_params.cur_epoch_num, cb_params.epoch_num, | |||
| cur_step_in_epoch, int(cb_params.batch_num), | |||
| step_loss, np.mean(self.losses), | |||
| step_mseconds), flush=True) | |||
| @@ -29,6 +29,7 @@ from mindspore.common.api import _executor | |||
| from mindspore.common import dtype as mstype | |||
| from mindspore._checkparam import check_input_data | |||
| __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] | |||
| tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, | |||
| @@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin | |||
| "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, | |||
| "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} | |||
| ModelType = ["normal", "fusion", "quant"] | |||
| def _special_process_par(par, new_par): | |||
| """ | |||
| @@ -101,20 +104,22 @@ def _update_param(param, new_param): | |||
| param.set_parameter_data(type(param.data)(new_param.data)) | |||
| def save_checkpoint(parameter_list, ckpoint_file_name): | |||
| def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): | |||
| """ | |||
| Saves checkpoint info to a specified file. | |||
| Args: | |||
| parameter_list (list): Parameters list, each element is a dict | |||
| like {"name":xx, "type":xx, "shape":xx, "data":xx}. | |||
| ckpoint_file_name (str): Checkpoint file name. | |||
| ckpt_file_name (str): Checkpoint file name. | |||
| model_type (str): The name of model type. Default: "normal". | |||
| Raises: | |||
| RuntimeError: Failed to save the Checkpoint file. | |||
| """ | |||
| logger.info("Execute save checkpoint process.") | |||
| checkpoint_list = Checkpoint() | |||
| checkpoint_list.model_type = model_type | |||
| try: | |||
| for param in parameter_list: | |||
| @@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name): | |||
| for dim in param['data'].shape: | |||
| param_tensor.dims.append(dim) | |||
| with open(ckpoint_file_name, "wb") as f: | |||
| with open(ckpt_file_name, "wb") as f: | |||
| f.write(checkpoint_list.SerializeToString()) | |||
| os.chmod(ckpoint_file_name, stat.S_IRUSR) | |||
| os.chmod(ckpt_file_name, stat.S_IRUSR) | |||
| except BaseException as e: | |||
| logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name) | |||
| logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) | |||
| raise RuntimeError(e.__str__()) | |||
| logger.info("Save checkpoint process finish.") | |||
| def load_checkpoint(ckpoint_file_name, net=None): | |||
| def load_checkpoint(ckpt_file_name, model_type="normal", net=None): | |||
| """ | |||
| Loads checkpoint info from a specified file. | |||
| Args: | |||
| ckpoint_file_name (str): Checkpoint file name. | |||
| ckpt_file_name (str): Checkpoint file name. | |||
| model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". | |||
| net (Cell): Cell network. Default: None | |||
| Returns: | |||
| @@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None): | |||
| Raises: | |||
| ValueError: Checkpoint file is incorrect. | |||
| """ | |||
| if not isinstance(ckpoint_file_name, str): | |||
| raise ValueError("The ckpoint_file_name must be String.") | |||
| if not isinstance(ckpt_file_name, str): | |||
| raise ValueError("The ckpt_file_name must be string.") | |||
| if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt": | |||
| if model_type not in ModelType: | |||
| raise ValueError(f"The model_type is not in {ModelType}.") | |||
| if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": | |||
| raise ValueError("Please input the correct checkpoint file name.") | |||
| if os.path.getsize(ckpoint_file_name) == 0: | |||
| if os.path.getsize(ckpt_file_name) == 0: | |||
| raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") | |||
| logger.info("Execute load checkpoint process.") | |||
| checkpoint_list = Checkpoint() | |||
| try: | |||
| with open(ckpoint_file_name, "rb") as f: | |||
| with open(ckpt_file_name, "rb") as f: | |||
| pb_content = f.read() | |||
| checkpoint_list.ParseFromString(pb_content) | |||
| except BaseException as e: | |||
| logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name) | |||
| logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) | |||
| raise ValueError(e.__str__()) | |||
| parameter_dict = {} | |||
| if model_type != checkpoint_list.model_type: | |||
| raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( | |||
| checkpoint_list.model_type, model_type)) | |||
| try: | |||
| for element in checkpoint_list.value: | |||
| data = element.tensor.tensor_content | |||
| @@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None): | |||
| logger.info("Load checkpoint process finish.") | |||
| except BaseException as e: | |||
| logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name) | |||
| logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) | |||
| raise RuntimeError(e.__str__()) | |||
| if net: | |||
| @@ -303,14 +314,15 @@ def _save_graph(network, file_name): | |||
| os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) | |||
| def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): | |||
| def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True): | |||
| """ | |||
| Saves checkpoint for 'ms' backend. | |||
| Args: | |||
| train_network (Network): The train network for training. | |||
| ckpoint_file_name (str): The name of checkpoint file. | |||
| integrated_save (bool): Whether to intergrated save in automatic model parallel scene. | |||
| ckpt_file_name (str): The name of checkpoint file. | |||
| model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". | |||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. | |||
| """ | |||
| param_dict = {} | |||
| @@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True | |||
| each_param["data"] = param_data | |||
| param_list.append(each_param) | |||
| save_checkpoint(param_list, ckpoint_file_name) | |||
| save_checkpoint(param_list, ckpt_file_name, model_type) | |||
| def _get_merged_param_data(net, param_name, param_data): | |||
| @@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt | |||
| import os | |||
| import argparse | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet import LeNet5 | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| from src.config import mnist_cfg as cfg | |||
| from src.lenet import LeNet5 | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description='MindSpore Lenet Example') | |||
| @@ -49,9 +47,6 @@ if __name__ == "__main__": | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| repeat_size = cfg.epoch_size | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Testing ==============") | |||
| @@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following: | |||
| ```bash | |||
| >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] | |||
| >>> ... | |||
| >>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | |||
| ``` | |||
| To save your time, just run this command. | |||
| @@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following: | |||
| ```bash | |||
| >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] | |||
| >>> ... | |||
| >>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] | |||
| >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] | |||
| ``` | |||
| ### Evaluate quantization aware model | |||
| @@ -214,8 +214,8 @@ network = LeNet5Fusion(cfg.num_classes) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| load_param_into_net(network, param_dict) | |||
| # convert funsion netwrok to aware quantizaiton network | |||
| network = quant.convert_quant_network(network | |||
| # convert funsion netwrok to quantization aware network | |||
| network = quant.convert_quant_network(network) | |||
| ``` | |||
| To save your time, just run this command. | |||
| @@ -23,7 +23,6 @@ import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from src.dataset import create_dataset | |||
| @@ -47,16 +46,18 @@ if __name__ == "__main__": | |||
| ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) | |||
| step_size = ds_eval.get_dataset_size() | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # define loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| repeat_size = cfg.epoch_size | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| # call back and monitor | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| # load check point into network | |||
| param_dict = load_checkpoint(args.ckpt_path, network.type) | |||
| load_param_into_net(network, param_dict) | |||
| print("============== Starting Testing ==============") | |||
| @@ -23,7 +23,6 @@ import argparse | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train import Model | |||
| from mindspore.nn.metrics import Accuracy | |||
| from mindspore.train.quant import quant | |||
| @@ -48,20 +47,21 @@ if __name__ == "__main__": | |||
| ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) | |||
| step_size = ds_eval.get_dataset_size() | |||
| # define funsion network | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # convert funsion netwrok to aware quantizaiton network | |||
| # convert fusion netwrok to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| # define loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| # call back and monitor | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| # load aware quantizaiton network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path, model_type="quant") | |||
| load_param_into_net(network, param_dict) | |||
| print("============== Starting Testing ==============") | |||
| @@ -34,8 +34,8 @@ class LeNet5(nn.Cell): | |||
| super(LeNet5, self).__init__() | |||
| self.num_class = num_class | |||
| self.conv1 = nn.Conv2d(channel, 6, 5) | |||
| self.conv2 = nn.Conv2d(6, 16, 5) | |||
| self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') | |||
| self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') | |||
| self.fc1 = nn.Dense(16 * 5 * 5, 120) | |||
| self.fc2 = nn.Dense(120, 84) | |||
| self.fc3 = nn.Dense(84, self.num_class) | |||
| @@ -32,11 +32,12 @@ class LeNet5(nn.Cell): | |||
| def __init__(self, num_class=10, channel=1): | |||
| super(LeNet5, self).__init__() | |||
| self.type = "fusion" | |||
| self.num_class = num_class | |||
| # change `nn.Conv2d` to `nn.Conv2dBnAct` | |||
| self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu') | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu') | |||
| self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') | |||
| self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') | |||
| # change `nn.Dense` to `nn.DenseBnAct` | |||
| self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') | |||
| self.fc2 = nn.DenseBnAct(120, 84, activation='relu') | |||
| @@ -46,16 +46,24 @@ if __name__ == "__main__": | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) | |||
| step_size = ds_train.get_dataset_size() | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # define network loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # call back and monitor | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max, | |||
| model_type=network.type) | |||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | |||
| # define model | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| @@ -48,23 +48,30 @@ if __name__ == "__main__": | |||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) | |||
| step_size = ds_train.get_dataset_size() | |||
| # define funsion network | |||
| # define fusion network | |||
| network = LeNet5Fusion(cfg.num_classes) | |||
| # load aware quantizaiton network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path) | |||
| # load quantization aware network checkpoint | |||
| param_dict = load_checkpoint(args.ckpt_path, network.type) | |||
| load_param_into_net(network, param_dict) | |||
| # convert funsion netwrok to aware quantizaiton network | |||
| # convert fusion network to quantization aware network | |||
| network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) | |||
| # define network loss | |||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") | |||
| # define network optimization | |||
| net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) | |||
| # call back and monitor | |||
| time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max) | |||
| ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) | |||
| config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, | |||
| keep_checkpoint_max=cfg.keep_checkpoint_max, | |||
| model_type="quant") | |||
| ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) | |||
| # define model | |||
| model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) | |||
| print("============== Starting Training ==============") | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], | |||
| model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], | |||
| dataset_sink_mode=args.dataset_sink_mode) | |||
| print("============== End Training ==============") | |||
| @@ -85,7 +85,7 @@ if __name__ == '__main__': | |||
| is_ckpt_exist = os.path.exists(ckpt_file_path) | |||
| if is_ckpt_exist: | |||
| param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path) | |||
| param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path) | |||
| load_param_into_net(net, param_dict) | |||
| export(net, input_data, file_name=model_path_name, file_format='LITE') | |||
| print("test lenet predict success.") | |||
| @@ -111,19 +111,19 @@ def test_save_checkpoint(): | |||
| os.chmod('./parameters.ckpt', stat.S_IWRITE) | |||
| os.remove('./parameters.ckpt') | |||
| ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||
| save_checkpoint(parameter_list, ckpoint_file_name) | |||
| ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||
| save_checkpoint(parameter_list, ckpt_file_name) | |||
| def test_load_checkpoint_error_filename(): | |||
| ckpoint_file_name = 1 | |||
| ckpt_file_name = 1 | |||
| with pytest.raises(ValueError): | |||
| load_checkpoint(ckpoint_file_name) | |||
| load_checkpoint(ckpt_file_name) | |||
| def test_load_checkpoint(): | |||
| ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||
| par_dict = load_checkpoint(ckpoint_file_name) | |||
| ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') | |||
| par_dict = load_checkpoint(ckpt_file_name) | |||
| assert len(par_dict) == 3 | |||
| assert par_dict['param_test'].name == 'param_test' | |||
| @@ -136,17 +136,17 @@ def test_checkpoint_manager(): | |||
| """ test_checkpoint_manager """ | |||
| ckp_mgr = _CheckpointManager() | |||
| ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt') | |||
| with open(ckpoint_file_name, 'w'): | |||
| os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR) | |||
| ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt') | |||
| with open(ckpt_file_name, 'w'): | |||
| os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR) | |||
| ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") | |||
| assert ckp_mgr.ckpoint_num == 1 | |||
| ckp_mgr.remove_ckpoint_file(ckpoint_file_name) | |||
| ckp_mgr.remove_ckpoint_file(ckpt_file_name) | |||
| ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") | |||
| assert ckp_mgr.ckpoint_num == 0 | |||
| assert not os.path.exists(ckpoint_file_name) | |||
| assert not os.path.exists(ckpt_file_name) | |||
| another_file_name = os.path.join(_cur_dir, './test2.ckpt') | |||
| another_file_name = os.path.realpath(another_file_name) | |||
| @@ -283,7 +283,7 @@ def test_exec_save_checkpoint(): | |||
| loss_net = WithLossCell(net, loss) | |||
| train_network = TrainOneStepCell(loss_net, opt) | |||
| _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt") | |||
| _exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") | |||
| load_checkpoint("new_ckpt.ckpt") | |||