From c17ba07e7020602c4f17579af1bb36a1a9d34405 Mon Sep 17 00:00:00 2001 From: wjtest001 Date: Wed, 8 Mar 2023 15:47:42 +0800 Subject: [PATCH] add continue scripts --- gpu/train_continue.py | 121 +++++++++++++++++++ gpu/train_continue_c2net.py | 122 ++++++++++++++++++++ npu/lewis/c2net_npu_continue.py | 196 +++++++++++++++++++++++++++++++ npu/train_continue.py | 199 ++++++++++++++++++++++++++++++++ 4 files changed, 638 insertions(+) create mode 100755 gpu/train_continue.py create mode 100755 gpu/train_continue_c2net.py create mode 100755 npu/lewis/c2net_npu_continue.py create mode 100755 npu/train_continue.py diff --git a/gpu/train_continue.py b/gpu/train_continue.py new file mode 100755 index 0000000..91ee495 --- /dev/null +++ b/gpu/train_continue.py @@ -0,0 +1,121 @@ +##################################################################################################### +# 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果 +# +# 示例用法 +# - 增加两个训练参数 +# 'ckpt_save_name' 此次任务的输出文件名称 +# 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称,默认为空,则不读取任何文件 +# - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务 +##################################################################################################### + +from model import Model +import numpy as np +import torch +from torchvision.datasets import mnist +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor +import argparse +import os + +# Training settings +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +#The dataset location is placed under /dataset +parser.add_argument('--traindata', default="/dataset/train" ,help='path to train dataset') +parser.add_argument('--testdata', default="/dataset/test" ,help='path to test dataset') +parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train') +parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') +#获取预训练模型文件名称 +parser.add_argument('--ckpt_url', default="", help='pretrain model path') +#继续训练模型文件名称 +parser.add_argument('--ckpt_save_name', default="", help='save model name') +parser.add_argument('--ckpt_load_name', default="", help='load model name') + +# 参数声明 +WORKERS = 0 # dataloder线程数 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = Model().to(device) +optimizer = SGD(model.parameters(), lr=1e-1) +cost = CrossEntropyLoss() + +# 模型训练 +def train(model, train_loader, epoch): + model.train() + train_loss = 0 + for i, data in enumerate(train_loader, 0): + x, y = data + x = x.to(device) + y = y.to(device) + optimizer.zero_grad() + y_hat = model(x) + loss = cost(y_hat, y) + loss.backward() + optimizer.step() + train_loss += loss + loss_mean = train_loss / (i+1) + print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item())) + +# 模型测试 +def test(model, test_loader, test_data): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for i, data in enumerate(test_loader, 0): + x, y = data + x = x.to(device) + y = y.to(device) + optimizer.zero_grad() + y_hat = model(x) + test_loss += cost(y_hat, y).item() + pred = y_hat.max(1, keepdim=True)[1] + correct += pred.eq(y.view_as(pred)).sum().item() + test_loss /= (i+1) + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_data), 100. * correct / len(test_data))) + +def main(): + base_path = "/model" # 若使用智算集群则修改为 "/tmp/output" + + # 预训练模型加载,限制只在第一次任务生效,则 args.ckpt_load_name为空时 + if os.path.exists(args.ckpt_url) and not args.ckpt_load_name: + checkpoint = torch.load(args.ckpt_url) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + print('加载预训练模型 epoch {} 权重成功!'.format(start_epoch)) + # 继续训练模型加载,需要先行任务有输出文件 + elif args.ckpt_load_name: + load_path = "{}/{}.pkl".format(base_path, args.ckpt_load_name) + checkpoint = torch.load(load_path) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + print('加载继续训练 epoch {} 权重成功!'.format(start_epoch)) + else: + print('无保存模型,将从头开始训练!') + + for epoch in range(epochs): + train(model, train_loader, epoch) + test(model, test_loader, test_dataset) + # 保存模型 + state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} + save_path = "{}/{}.pkl".format(base_path, args.ckpt_save_name) + torch.save(state, save_path) + + +if __name__ == '__main__': + args, unknown = parser.parse_known_args() + #log output + print('cuda is available:{}'.format(torch.cuda.is_available())) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + batch_size = args.batch_size + epochs = args.epoch_size + train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False) + test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False) + train_loader = DataLoader(train_dataset, batch_size=batch_size) + test_loader = DataLoader(test_dataset, batch_size=batch_size) + main() + + diff --git a/gpu/train_continue_c2net.py b/gpu/train_continue_c2net.py new file mode 100755 index 0000000..196981e --- /dev/null +++ b/gpu/train_continue_c2net.py @@ -0,0 +1,122 @@ +##################################################################################################### +# 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果 +# +# 示例用法 +# - 增加两个训练参数 +# 'ckpt_save_name' 此次任务的输出文件名称 +# 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称,默认为空,则不读取任何文件 +# - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务 +##################################################################################################### + +from model import Model +import numpy as np +import torch +from torchvision.datasets import mnist +from torch.nn import CrossEntropyLoss +from torch.optim import SGD +from torch.utils.data import DataLoader +from torchvision.transforms import ToTensor +import argparse +import os + +# Training settings +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +#The dataset location is placed under /dataset +parser.add_argument('--traindata', default="/tmp/dataset/train" ,help='path to train dataset') +parser.add_argument('--testdata', default="/tmp/dataset/test" ,help='path to test dataset') +parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train') +parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') +#获取预训练模型文件名称 +parser.add_argument('--ckpt_url', default="", help='pretrain model path') +#继续训练模型文件名称 +parser.add_argument('--ckpt_save_name', default="", help='save model name') +parser.add_argument('--ckpt_load_name', default="", help='load model name') + +# 参数声明 +WORKERS = 0 # dataloder线程数 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = Model().to(device) +optimizer = SGD(model.parameters(), lr=1e-1) +cost = CrossEntropyLoss() + +# 模型训练 +def train(model, train_loader, epoch): + model.train() + train_loss = 0 + for i, data in enumerate(train_loader, 0): + x, y = data + x = x.to(device) + y = y.to(device) + optimizer.zero_grad() + y_hat = model(x) + loss = cost(y_hat, y) + loss.backward() + optimizer.step() + train_loss += loss + loss_mean = train_loss / (i+1) + print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item())) + +# 模型测试 +def test(model, test_loader, test_data): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for i, data in enumerate(test_loader, 0): + x, y = data + x = x.to(device) + y = y.to(device) + optimizer.zero_grad() + y_hat = model(x) + test_loss += cost(y_hat, y).item() + pred = y_hat.max(1, keepdim=True)[1] + correct += pred.eq(y.view_as(pred)).sum().item() + test_loss /= (i+1) + print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_data), 100. * correct / len(test_data))) + +def main(): + base_path = "/tmp/output" # 若使用智算集群则修改为 "/tmp/output" + + # 预训练模型加载,限制只在第一次任务生效,则 args.ckpt_load_name为空时 + if os.path.exists(args.ckpt_url) and not args.ckpt_load_name: + checkpoint = torch.load(args.ckpt_url) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + print('加载预训练模型 epoch {} 权重成功!'.format(start_epoch)) + # 继续训练模型加载,需要先行任务有输出文件 + elif args.ckpt_load_name: + load_path = "{}/{}.pkl".format(base_path, args.ckpt_load_name) + checkpoint = torch.load(load_path) + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + start_epoch = checkpoint['epoch'] + print('加载继续训练 epoch {} 权重成功!'.format(start_epoch)) + else: + print('无保存模型,将从头开始训练!') + + for epoch in range(epochs): + train(model, train_loader, epoch) + test(model, test_loader, test_dataset) + # 保存模型 + state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} + save_path = "{}/{}.pkl".format(base_path, args.ckpt_save_name) + torch.save(state, save_path) + os.system("cd /tmp/script_for_grampus/ &&./uploader_for_gpu " + "/tmp/output/") + + +if __name__ == '__main__': + args, unknown = parser.parse_known_args() + #log output + print('cuda is available:{}'.format(torch.cuda.is_available())) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + batch_size = args.batch_size + epochs = args.epoch_size + train_dataset = mnist.MNIST(root=args.traindata, train=True, transform=ToTensor(),download=False) + test_dataset = mnist.MNIST(root=args.testdata, train=False, transform=ToTensor(),download=False) + train_loader = DataLoader(train_dataset, batch_size=batch_size) + test_loader = DataLoader(test_dataset, batch_size=batch_size) + main() + + diff --git a/npu/lewis/c2net_npu_continue.py b/npu/lewis/c2net_npu_continue.py new file mode 100755 index 0000000..d17c47d --- /dev/null +++ b/npu/lewis/c2net_npu_continue.py @@ -0,0 +1,196 @@ +""" +######################## train lenet example ######################## +train lenet and get network model files(.ckpt) +""" +#!/usr/bin/python +#coding=utf-8 + + +import os +import argparse + +import moxing as mox +from config import mnist_cfg as cfg +from dataset import create_dataset +from dataset_distributed import create_dataset_parallel +from lenet import LeNet5 +import json +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore import load_checkpoint, load_param_into_net +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +import time + +### Copy multiple datasets from obs to training image and unzip### +def C2netMultiObsToEnv(multi_data_url, data_dir): + #--multi_data_url is json data, need to do json parsing for multi_data_url + multi_data_json = json.loads(multi_data_url) + for i in range(len(multi_data_json)): + zipfile_path = data_dir + "/" + multi_data_json[i]["dataset_name"] + try: + mox.file.copy(multi_data_json[i]["dataset_url"], zipfile_path) + print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],zipfile_path)) + #get filename and unzip the dataset + filename = os.path.splitext(multi_data_json[i]["dataset_name"])[0] + filePath = data_dir + "/" + filename + if not os.path.exists(filePath): + os.makedirs(filePath) + os.system("unzip {} -d {}".format(zipfile_path, filePath)) + + except Exception as e: + print('moxing download {} to {} failed: '.format( + multi_data_json[i]["dataset_url"], zipfile_path) + str(e)) + #Set a cache file to determine whether the data has been copied to obs. + #If this file exists during multi-card training, there is no need to copy the dataset multiple times. + f = open("/cache/download_input.txt", 'w') + f.close() + try: + if os.path.exists("/cache/download_input.txt"): + print("download_input succeed") + except Exception as e: + print("download_input failed") + return + +### Copy the output model to obs ### +def EnvToObs(train_dir, obs_train_url): + try: + mox.file.copy_parallel(train_dir, obs_train_url) + print("Successfully Upload {} to {}".format(train_dir, + obs_train_url)) + except Exception as e: + print('moxing upload {} to {} failed: '.format(train_dir, + obs_train_url) + str(e)) + return + +def DownloadFromQizhi(multi_data_url, data_dir): + device_num = int(os.getenv('RANK_SIZE')) + if device_num == 1: + C2netMultiObsToEnv(multi_data_url,data_dir) + context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) + if device_num > 1: + # set device_id and init for multi-card training + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) + init() + #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data + local_rank=int(os.getenv('RANK_ID')) + if local_rank%8==0: + C2netMultiObsToEnv(multi_data_url,data_dir) + #If the cache file does not exist, it means that the copy data has not been completed, + #and Wait for 0th card to finish copying data + while not os.path.exists("/cache/download_input.txt"): + time.sleep(1) + return + +def UploadToQizhi(train_dir, obs_train_url): + device_num = int(os.getenv('RANK_SIZE')) + local_rank=int(os.getenv('RANK_ID')) + if device_num == 1: + EnvToObs(train_dir, obs_train_url) + if device_num > 1: + if local_rank%8==0: + EnvToObs(train_dir, obs_train_url) + return + + + +parser = argparse.ArgumentParser(description='MindSpore Lenet Example') +### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset, +### otherwise an error will be reported. +### There is no need to add these parameters to the running parameters of the Qizhi platform, +### because they are predefined in the background, you only need to define them in your code. + +parser.add_argument('--multi_data_url', + help='path to multi dataset', + default= '/cache/data/') + +parser.add_argument('--ckpt_url', + help='pre_train_model path in obs') + +parser.add_argument('--train_url', + help='model folder to save/load', + default= '/cache/output/') + +parser.add_argument( + '--device_target', + type=str, + default="Ascend", + choices=['Ascend', 'CPU'], + help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') + +parser.add_argument('--epoch_size', + type=int, + default=5, + help='Training epochs.') + +### continue task parameters +parser.add_argument('--ckpt_load_name', + help='model name to load', + default= '') + +parser.add_argument('--ckpt_save_name', + help='model name to save', + default= 'checkpoint') + + + +if __name__ == "__main__": + args, unknown = parser.parse_known_args() + data_dir = '/cache/dataset' + train_dir = '/cache/output' + if not os.path.exists(data_dir): + os.makedirs(data_dir) + if not os.path.exists(train_dir): + os.makedirs(train_dir) + ###Initialize and copy data to training image + DownloadFromQizhi(args.multi_data_url, data_dir) + ###The dataset path is used here:data_dir + "/MNIST_Data" +"/train" + device_num = int(os.getenv('RANK_SIZE')) + if device_num == 1: + ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) + if device_num > 1: + ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size) + if ds_train.get_dataset_size() == 0: + raise ValueError( + "Please check dataset size > 0 and batch_size <= dataset size") + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + + ### 继续训练模型加载 + if args.ckpt_load_name: + C2netMultiObsToEnv(args.train_url, train_dir) + load_path = "{}/{}.ckpt".format(train_dir, args.ckpt_load_name) + load_param_into_net(network, load_checkpoint(load_path)) + + if args.device_target != "Ascend": + model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()}) + else: + model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2") + config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=cfg.keep_checkpoint_max) + #Note that this method saves the model file on each card. You need to specify the save path on each card. + # In this example, get_rank() is added to distinguish different paths. + if device_num == 1: + outputDirectory = train_dir + "/" + if device_num > 1: + outputDirectory = train_dir + "/" + str(get_rank()) + "/" + ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_save_name, + directory=outputDirectory, + config=config_ck) + print("============== Starting Training ==============") + epoch_size = cfg['epoch_size'] + if (args.epoch_size): + epoch_size = args.epoch_size + print('epoch_size is: ', epoch_size) + model.train(epoch_size, + ds_train, + callbacks=[time_cb, ckpoint_cb, + LossMonitor()]) + diff --git a/npu/train_continue.py b/npu/train_continue.py new file mode 100755 index 0000000..81f2d4a --- /dev/null +++ b/npu/train_continue.py @@ -0,0 +1,199 @@ +##################################################################################################### +# 继续训练功能:修改训练任务时,若勾选复用上次结果,则可在新训练任务的输出路径中读取到上次结果 +# +# 示例用法 +# - 增加两个训练参数 +# 'ckpt_save_name' 此次任务的输出文件名称 +# 'ckpt_load_name' 上一次任务的输出文件名,用于加载上一次输出的模型文件名称,默认为空,则不读取任何文件 +# - 训练代码中判断 'ckpt_load_name' 是否为空,若不为空,则为继续训练任务 +##################################################################################################### + + +import os +import argparse +import moxing as mox +from config import mnist_cfg as cfg +from dataset import create_dataset +from dataset_distributed import create_dataset_parallel +from lenet import LeNet5 +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore import load_checkpoint, load_param_into_net +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +import mindspore.ops as ops +import time +from upload import UploadOutput + +### Copy single file from obs to training image### +def ObsToEnv(obs_data_url, data_dir): + try: + mox.file.copy_parallel(obs_data_url, data_dir) + print("Successfully Download {} to {}".format(obs_data_url, data_dir)) + except Exception as e: + print('moxing download {} to {} failed: '.format(obs_data_url, data_dir) + str(e)) + #Set a cache file to determine whether the data has been copied to obs. + #If this file exists during multi-card training, there is no need to copy the dataset multiple times. + f = open("/cache/download_input.txt", 'w') + f.close() + try: + if os.path.exists("/cache/download_input.txt"): + print("download_input succeed") + except Exception as e: + print("download_input failed") + return + +### Copy the output to obs### +def EnvToObs(train_dir, obs_train_url): + try: + mox.file.copy_parallel(train_dir, obs_train_url) + print("Successfully Upload {} to {}".format(train_dir,obs_train_url)) + except Exception as e: + print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e)) + return + +def DownloadFromQizhi(obs_data_url, data_dir): + device_num = int(os.getenv('RANK_SIZE')) + if device_num == 1: + ObsToEnv(obs_data_url,data_dir) + context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) + if device_num > 1: + # set device_id and init for multi-card training + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) + init() + #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data + local_rank=int(os.getenv('RANK_ID')) + if local_rank%8==0: + ObsToEnv(obs_data_url,data_dir) + #If the cache file does not exist, it means that the copy data has not been completed, + #and Wait for 0th card to finish copying data + while not os.path.exists("/cache/download_input.txt"): + time.sleep(1) + return + +def UploadToQizhi(train_dir, obs_train_url): + device_num = int(os.getenv('RANK_SIZE')) + local_rank=int(os.getenv('RANK_ID')) + if device_num == 1: + EnvToObs(train_dir, obs_train_url) + if device_num > 1: + if local_rank%8==0: + EnvToObs(train_dir, obs_train_url) + return + +### --data_url,--train_url,--device_target,These 3 parameters must be defined first in a single dataset, +### otherwise an error will be reported. +###There is no need to add these parameters to the running parameters of the Qizhi platform, +###because they are predefined in the background, you only need to define them in your code. +parser = argparse.ArgumentParser(description='MindSpore Lenet Example') +parser.add_argument('--data_url', + help='path to training/inference dataset folder', + default= '/cache/data/') + +parser.add_argument('--train_url', + help='output folder to save/load', + default= '/cache/output/') + +parser.add_argument( + '--device_target', + type=str, + default="Ascend", + choices=['Ascend', 'CPU'], + help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') + +parser.add_argument('--epoch_size', + type=int, + default=5, + help='Training epochs.') + +### continue task parameters +parser.add_argument('--ckpt_load_name', + help='model name to load', + default= '') + +parser.add_argument('--ckpt_save_name', + help='model name to save', + default= 'checkpoint') + + +if __name__ == "__main__": + args, unknown = parser.parse_known_args() + data_dir = '/cache/data' + base_path = '/cache/output' + + try: + if not os.path.exists(data_dir): + os.makedirs(data_dir) + if not os.path.exists(base_path): + os.makedirs(base_path) + except Exception as e: + print("path already exists") + + ###Initialize and copy data to training image + ###Copy data from obs to training image + DownloadFromQizhi(args.data_url, data_dir) + ###The dataset path is used here:data_dir +"/train" + device_num = int(os.getenv('RANK_SIZE')) + if device_num == 1: + ds_train = create_dataset(os.path.join(data_dir, "train"), cfg.batch_size) + if device_num > 1: + ds_train = create_dataset_parallel(os.path.join(data_dir, "train"), cfg.batch_size) + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + + ### 继续训练模型加载 + if args.ckpt_load_name: + ObsToEnv(args.train_url, base_path) + load_path = "{}/{}.ckpt".format(base_path,args.ckpt_load_name) + load_param_into_net(network, load_checkpoint(load_path)) + + if args.device_target != "Ascend": + model = Model(network, + net_loss, + net_opt, + metrics={"accuracy": Accuracy()}) + else: + model = Model(network, + net_loss, + net_opt, + metrics={"accuracy": Accuracy()}, + amp_level="O2") + + config_ck = CheckpointConfig( + save_checkpoint_steps=cfg.save_checkpoint_steps, + keep_checkpoint_max=1) + #Note that this method saves the model file on each card. You need to specify the save path on each card. + # In this example, get_rank() is added to distinguish different paths. + if device_num == 1: + save_path = base_path + "/" + if device_num > 1: + save_path = base_path + "/" + str(get_rank()) + "/" + ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_save_name, + directory=save_path, + config=config_ck) + print("============== Starting Training ==============") + epoch_size = cfg['epoch_size'] + if (args.epoch_size): + epoch_size = args.epoch_size + print('epoch_size is: ', epoch_size) + #Custom callback, upload output after each epoch + uploadOutput = UploadOutput(base_path,args.train_url) + model.train(epoch_size, + ds_train, + callbacks=[time_cb, ckpoint_cb, + LossMonitor(), uploadOutput]) + + ###Copy the trained output data from the local running environment back to obs, + ###and download it in the training task corresponding to the Qizhi platform + #This step is not required if UploadOutput is called + UploadToQizhi(base_path,args.train_url) \ No newline at end of file