Browse Source

!2555 checkpoint add model_type

Merge pull request !2555 from chenzhongming/r0.5
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fe1d4ca3bd
15 changed files with 136 additions and 90 deletions
  1. +11
    -0
      mindspore/_checkparam.py
  2. +1
    -0
      mindspore/ccsrc/utils/checkpoint.proto
  3. +19
    -9
      mindspore/train/callback/_checkpoint.py
  4. +3
    -3
      mindspore/train/callback/_loss_monitor.py
  5. +31
    -19
      mindspore/train/serialization.py
  6. +3
    -8
      model_zoo/lenet/eval.py
  7. +8
    -8
      model_zoo/lenet_quant/README.md
  8. +7
    -6
      model_zoo/lenet_quant/eval.py
  9. +8
    -8
      model_zoo/lenet_quant/eval_quant.py
  10. +2
    -2
      model_zoo/lenet_quant/src/lenet.py
  11. +3
    -2
      model_zoo/lenet_quant/src/lenet_fusion.py
  12. +12
    -4
      model_zoo/lenet_quant/train.py
  13. +15
    -8
      model_zoo/lenet_quant/train_quant.py
  14. +1
    -1
      tests/ut/python/predict/test_predict_save_model.py
  15. +12
    -12
      tests/ut/python/utils/test_serialize.py

+ 11
- 0
mindspore/_checkparam.py View File

@@ -593,6 +593,17 @@ def check_bool(input_param):
raise TypeError("Input type must be bool!")


def check_string(input_param, valid_values):
"""String type judgment."""
if isinstance(input_param, str) and input_param in valid_values:
return input_param
if len(valid_values) == 1:
raise ValueError(f'Input should be str and must be {valid_values[0]},'
f' but got {input_param}.')
raise ValueError(f'Input should be str and must be one of {valid_values},'
f' but got {input_param}.')


def check_input_format(input_param):
"""Judge input format."""
if input_param == "NCHW":


+ 1
- 0
mindspore/ccsrc/utils/checkpoint.proto View File

@@ -22,6 +22,7 @@ message Checkpoint {
required TensorProto tensor = 2;
}
repeated Value value = 1;
required string model_type = 2;
}




+ 19
- 9
mindspore/train/callback/_checkpoint.py View File

@@ -21,17 +21,16 @@ import time

import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph

from ._callback import Callback, set_cur_net


_cur_dir = os.getcwd()
_save_dir = _cur_dir



def _check_file_name_prefix(file_name_prefix):
"""
Check file name valid or not.
@@ -87,6 +86,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".

Raises:
ValueError: If the input_param is None or 0.
@@ -101,7 +101,8 @@ class CheckpointConfig:
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0,
integrated_save=True):
integrated_save=True,
model_type="normal"):

if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
@@ -115,6 +116,8 @@ class CheckpointConfig:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
if model_type:
model_type = check_string(model_type, ["normal", "fusion", "quant"])

self._save_checkpoint_steps = save_checkpoint_steps
self._save_checkpoint_seconds = save_checkpoint_seconds
@@ -129,6 +132,7 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1

self._model_type = model_type
self._integrated_save = check_bool(integrated_save)

@property
@@ -156,12 +160,18 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
return self._integrated_save

@property
def model_type(self):
"""Get the value of model_type."""
return self._model_type

def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
'save_checkpoint_seconds': self._save_checkpoint_seconds,
'keep_checkpoint_max': self._keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
'model_type': self._model_type}

return checkpoint_policy

@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
self._save_ckpt(cb_params)
self._save_ckpt(cb_params, self._config.model_type)

def end(self, run_context):
"""
@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
"""
cb_params = run_context.original_args()
_to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt)
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)

from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):

return False

def _save_ckpt(self, cb_params, force_to_save=False):
def _save_ckpt(self, cb_params, model_type, force_to_save=False):
"""Save checkpoint files."""
if cb_params.cur_step_num == self._last_triggered_step:
return
@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()

_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)

if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)


+ 3
- 3
mindspore/train/callback/_loss_monitor.py View File

@@ -76,7 +76,7 @@ class LossMonitor(Callback):
step_loss = np.mean(step_loss.asnumpy())

self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1

if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
@@ -87,7 +87,7 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
cur_step_in_epoch, cb_params.batch_num,
cb_params.cur_epoch_num, cb_params.epoch_num,
cur_step_in_epoch, int(cb_params.batch_num),
step_loss, np.mean(self.losses),
step_mseconds), flush=True)

+ 31
- 19
mindspore/train/serialization.py View File

@@ -29,6 +29,7 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data


__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]

tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}

ModelType = ["normal", "fusion", "quant"]


def _special_process_par(par, new_par):
"""
@@ -101,20 +104,22 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data))


def save_checkpoint(parameter_list, ckpoint_file_name):
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
"""
Saves checkpoint info to a specified file.

Args:
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".

Raises:
RuntimeError: Failed to save the Checkpoint file.
"""
logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint()
checkpoint_list.model_type = model_type

try:
for param in parameter_list:
@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
for dim in param['data'].shape:
param_tensor.dims.append(dim)

with open(ckpoint_file_name, "wb") as f:
with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpoint_file_name, stat.S_IRUSR)
os.chmod(ckpt_file_name, stat.S_IRUSR)

except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
logger.info("Save checkpoint process finish.")


def load_checkpoint(ckpoint_file_name, net=None):
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
"""
Loads checkpoint info from a specified file.

Args:
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None

Returns:
@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
Raises:
ValueError: Checkpoint file is incorrect.
"""
if not isinstance(ckpoint_file_name, str):
raise ValueError("The ckpoint_file_name must be String.")
if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.")

if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
if model_type not in ModelType:
raise ValueError(f"The model_type is not in {ModelType}.")

if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")

if os.path.getsize(ckpoint_file_name) == 0:
if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")

logger.info("Execute load checkpoint process.")
checkpoint_list = Checkpoint()

try:
with open(ckpoint_file_name, "rb") as f:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())

parameter_dict = {}

if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
try:
for element in checkpoint_list.value:
data = element.tensor.tensor_content
@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
logger.info("Load checkpoint process finish.")

except BaseException as e:
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())

if net:
@@ -303,14 +314,15 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)


def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
"""
Saves checkpoint for 'ms' backend.

Args:
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
ckpt_file_name (str): The name of checkpoint file.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
"""

param_dict = {}
@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param["data"] = param_data
param_list.append(each_param)

save_checkpoint(param_list, ckpoint_file_name)
save_checkpoint(param_list, ckpt_file_name, model_type)


def _get_merged_param_data(net, param_name, param_data):


+ 3
- 8
model_zoo/lenet/eval.py View File

@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt

import os
import argparse
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy

from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
@@ -49,9 +47,6 @@ if __name__ == "__main__":
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

print("============== Starting Testing ==============")


+ 8
- 8
model_zoo/lenet_quant/README.md View File

@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ...
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```

To save your time, just run this command.
@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ...
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```

### Evaluate quantization aware model
@@ -214,8 +214,8 @@ network = LeNet5Fusion(cfg.num_classes)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)

# convert funsion netwrok to aware quantizaiton network
network = quant.convert_quant_network(network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network)
```

To save your time, just run this command.


+ 7
- 6
model_zoo/lenet_quant/eval.py View File

@@ -23,7 +23,6 @@ import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
@@ -47,16 +46,18 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size()

# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)

# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

param_dict = load_checkpoint(args.ckpt_path)
# load check point into network
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict)

print("============== Starting Testing ==============")


+ 8
- 8
model_zoo/lenet_quant/eval_quant.py View File

@@ -23,7 +23,6 @@ import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant
@@ -48,20 +47,21 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size()

# define funsion network
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert funsion netwrok to aware quantizaiton network
# convert fusion netwrok to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)

# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)

# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

# load aware quantizaiton network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
load_param_into_net(network, param_dict)

print("============== Starting Testing ==============")


+ 2
- 2
model_zoo/lenet_quant/src/lenet.py View File

@@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
super(LeNet5, self).__init__()
self.num_class = num_class

self.conv1 = nn.Conv2d(channel, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)


+ 3
- 2
model_zoo/lenet_quant/src/lenet_fusion.py View File

@@ -32,11 +32,12 @@ class LeNet5(nn.Cell):

def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.type = "fusion"
self.num_class = num_class

# change `nn.Conv2d` to `nn.Conv2dBnAct`
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu')
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
# change `nn.Dense` to `nn.DenseBnAct`
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')


+ 12
- 4
model_zoo/lenet_quant/train.py View File

@@ -46,16 +46,24 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size()

# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)

# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type=network.type)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)

# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============")

+ 15
- 8
model_zoo/lenet_quant/train_quant.py View File

@@ -48,23 +48,30 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size()

# define funsion network
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)

# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)

# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type="quant")
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)

# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============")

+ 1
- 1
tests/ut/python/predict/test_predict_save_model.py View File

@@ -85,7 +85,7 @@ if __name__ == '__main__':

is_ckpt_exist = os.path.exists(ckpt_file_path)
if is_ckpt_exist:
param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path)
param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path)
load_param_into_net(net, param_dict)
export(net, input_data, file_name=model_path_name, file_format='LITE')
print("test lenet predict success.")


+ 12
- 12
tests/ut/python/utils/test_serialize.py View File

@@ -111,19 +111,19 @@ def test_save_checkpoint():
os.chmod('./parameters.ckpt', stat.S_IWRITE)
os.remove('./parameters.ckpt')

ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
save_checkpoint(parameter_list, ckpoint_file_name)
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
save_checkpoint(parameter_list, ckpt_file_name)


def test_load_checkpoint_error_filename():
ckpoint_file_name = 1
ckpt_file_name = 1
with pytest.raises(ValueError):
load_checkpoint(ckpoint_file_name)
load_checkpoint(ckpt_file_name)


def test_load_checkpoint():
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
par_dict = load_checkpoint(ckpoint_file_name)
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
par_dict = load_checkpoint(ckpt_file_name)

assert len(par_dict) == 3
assert par_dict['param_test'].name == 'param_test'
@@ -136,17 +136,17 @@ def test_checkpoint_manager():
""" test_checkpoint_manager """
ckp_mgr = _CheckpointManager()

ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
with open(ckpoint_file_name, 'w'):
os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
with open(ckpt_file_name, 'w'):
os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)

ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
assert ckp_mgr.ckpoint_num == 1

ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
ckp_mgr.remove_ckpoint_file(ckpt_file_name)
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
assert ckp_mgr.ckpoint_num == 0
assert not os.path.exists(ckpoint_file_name)
assert not os.path.exists(ckpt_file_name)

another_file_name = os.path.join(_cur_dir, './test2.ckpt')
another_file_name = os.path.realpath(another_file_name)
@@ -283,7 +283,7 @@ def test_exec_save_checkpoint():

loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt)
_exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")

load_checkpoint("new_ckpt.ckpt")



Loading…
Cancel
Save