- """
- ######################## train lenet example ########################
- train lenet and get network model files(.ckpt)
- """
- #!/usr/bin/python
- #coding=utf-8
- import os
- import argparse
- import moxing as mox
- from config import mnist_cfg as cfg
- from dataset import create_dataset
- from dataset_distributed import create_dataset_parallel
- from lenet import LeNet5
- import json
- import mindspore.nn as nn
- from mindspore import context
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
- from mindspore.train import Model
- from mindspore.nn.metrics import Accuracy
- from mindspore import load_checkpoint, load_param_into_net
- from mindspore.context import ParallelMode
- from mindspore.communication.management import init, get_rank
- import time
- ### Copy multiple datasets from obs to training image ###
- def MultiObsToEnv(multi_data_url, data_dir):
- #--multi_data_url is json data, need to do json parsing for multi_data_url
- multi_data_json = json.loads(multi_data_url)
- for i in range(len(multi_data_json)):
- path = data_dir + "/" + multi_data_json[i]["dataset_name"]
- file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0]
- if not os.path.exists(file_path):
- os.makedirs(file_path)
- try:
- mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
- print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
- #unzip dataset
- os.system("unzip -d %s %s" % (file_path, path))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- multi_data_json[i]["dataset_url"], path) + str(e))
- #Set a cache file to determine whether the data has been copied to obs.
- #If this file exists during multi-card training, there is no need to copy the dataset multiple times.
- f = open("/cache/download_input.txt", 'w')
- f.close()
- try:
- if os.path.exists("/cache/download_input.txt"):
- print("download_input succeed")
- except Exception as e:
- print("download_input failed")
- return
- def DownloadFromQizhi(multi_data_url, data_dir):
- device_num = int(os.getenv('RANK_SIZE'))
- if device_num == 1:
- MultiObsToEnv(multi_data_url,data_dir)
- context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
- if device_num > 1:
- # set device_id and init for multi-card training
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
- init()
- #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
- local_rank=int(os.getenv('RANK_ID'))
- if local_rank%8==0:
- MultiObsToEnv(multi_data_url,data_dir)
- #If the cache file does not exist, it means that the copy data has not been completed,
- #and Wait for 0th card to finish copying data
- while not os.path.exists("/cache/download_input.txt"):
- time.sleep(1)
- return
- parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
- ### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset,
- ### otherwise an error will be reported.
- ### There is no need to add these parameters to the running parameters of the Qizhi platform,
- ### because they are predefined in the background, you only need to define them in your code.
- parser.add_argument('--multi_data_url',
- help='dataset path in obs')
- parser.add_argument('--ckpt_url',
- help='pre_train_model path in obs')
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'CPU'],
- help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
- parser.add_argument('--epoch_size',
- type=int,
- default=5,
- help='Training epochs.')
- if __name__ == "__main__":
- args, unknown = parser.parse_known_args()
- data_dir = '/cache/dataset'
- train_dir = '/cache/output'
- if not os.path.exists(data_dir):
- os.makedirs(data_dir)
- if not os.path.exists(train_dir):
- os.makedirs(train_dir)
- ###Initialize and copy data to training image
- DownloadFromQizhi(args.multi_data_url, data_dir)
- print("--------start ls:")
- os.system("cd /cache/dataset; ls -al")
- print("--------end ls-----------")