Browse Source

更新 'npu/train_for_c2net_testcopy2.py'

test_v20221116
wjtest1215 2 years ago
parent
commit
d52171723e
1 changed files with 92 additions and 0 deletions
  1. +92
    -0
      npu/train_for_c2net_testcopy2.py

+ 92
- 0
npu/train_for_c2net_testcopy2.py View File

@@ -0,0 +1,92 @@
"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt)
"""
#!/usr/bin/python
#coding=utf-8

import os
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed

parser = argparse.ArgumentParser(description='MindSpore Lenet Example')

parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')

parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')

set_seed(1)

if __name__ == "__main__":
args = parser.parse_args()
print('args:')
print(args)

train_dir = '/cache/output'
data_dir = '/cache/dataset'
#注意:这里很重要,指定了训练所用的设备CPU还是Ascend NPU
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target)
#创建数据集
ds_train = create_dataset(os.path.join(data_dir, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError(
"Please check dataset size > 0 and batch_size <= dataset size")
#创建网络
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())

if args.device_target != "Ascend":
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()})
else:
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()},
amp_level="O2")

config_ck = CheckpointConfig(
save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
#定义模型输出路径
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=train_dir,
config=config_ck)
#开始训练
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
# 测试代码。结果回传
os.system("cd /cache/script_for_grampus/ &&./uploader_for_npu " + "/cache/code")

model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])

print("============== Finish Training ==============")

Loading…
Cancel
Save