|
|
@@ -0,0 +1,90 @@ |
|
|
|
""" |
|
|
|
######################## train lenet example ######################## |
|
|
|
train lenet and get network model files(.ckpt) |
|
|
|
""" |
|
|
|
#!/usr/bin/python |
|
|
|
#coding=utf-8 |
|
|
|
|
|
|
|
import os |
|
|
|
import argparse |
|
|
|
from config import mnist_cfg as cfg |
|
|
|
from dataset import create_dataset |
|
|
|
from lenet import LeNet5 |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import context |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor |
|
|
|
from mindspore.train import Model |
|
|
|
from mindspore.nn.metrics import Accuracy |
|
|
|
from mindspore.common import set_seed |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') |
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
'--device_target', |
|
|
|
type=str, |
|
|
|
default="Ascend", |
|
|
|
choices=['Ascend', 'CPU'], |
|
|
|
help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend') |
|
|
|
|
|
|
|
parser.add_argument('--epoch_size', |
|
|
|
type=int, |
|
|
|
default=5, |
|
|
|
help='Training epochs.') |
|
|
|
|
|
|
|
set_seed(1) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = parser.parse_args() |
|
|
|
print('args:') |
|
|
|
print(args) |
|
|
|
|
|
|
|
train_dir = '/cache/output' |
|
|
|
data_dir = '/cache/dataset' |
|
|
|
|
|
|
|
#注意:这里很重要,指定了训练所用的设备CPU还是Ascend NPU |
|
|
|
context.set_context(mode=context.GRAPH_MODE, |
|
|
|
device_target=args.device_target) |
|
|
|
#创建数据集 |
|
|
|
ds_train = create_dataset(os.path.join(data_dir, "train"), |
|
|
|
cfg.batch_size) |
|
|
|
if ds_train.get_dataset_size() == 0: |
|
|
|
raise ValueError( |
|
|
|
"Please check dataset size > 0 and batch_size <= dataset size") |
|
|
|
#创建网络 |
|
|
|
network = LeNet5(cfg.num_classes) |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) |
|
|
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) |
|
|
|
|
|
|
|
if args.device_target != "Ascend": |
|
|
|
model = Model(network, |
|
|
|
net_loss, |
|
|
|
net_opt, |
|
|
|
metrics={"accuracy": Accuracy()}) |
|
|
|
else: |
|
|
|
model = Model(network, |
|
|
|
net_loss, |
|
|
|
net_opt, |
|
|
|
metrics={"accuracy": Accuracy()}, |
|
|
|
amp_level="O2") |
|
|
|
|
|
|
|
config_ck = CheckpointConfig( |
|
|
|
save_checkpoint_steps=cfg.save_checkpoint_steps, |
|
|
|
keep_checkpoint_max=cfg.keep_checkpoint_max) |
|
|
|
#定义模型输出路径 |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", |
|
|
|
directory=train_dir, |
|
|
|
config=config_ck) |
|
|
|
#开始训练 |
|
|
|
print("============== Starting Training ==============") |
|
|
|
epoch_size = cfg['epoch_size'] |
|
|
|
if (args.epoch_size): |
|
|
|
epoch_size = args.epoch_size |
|
|
|
print('epoch_size is: ', epoch_size) |
|
|
|
|
|
|
|
model.train(epoch_size, |
|
|
|
ds_train, |
|
|
|
callbacks=[time_cb, ckpoint_cb, |
|
|
|
LossMonitor()]) |
|
|
|
|
|
|
|
print("============== Finish Training ==============") |