From 9710340c0dec49df79155e3f3a5af3735290cf64 Mon Sep 17 00:00:00 2001 From: wjtest1201 Date: Thu, 10 Nov 2022 11:52:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20'c2net=5Flistdata.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- c2net_listdata.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 c2net_listdata.py diff --git a/c2net_listdata.py b/c2net_listdata.py new file mode 100644 index 0000000..df94803 --- /dev/null +++ b/c2net_listdata.py @@ -0,0 +1,114 @@ +""" +######################## train lenet example ######################## +train lenet and get network model files(.ckpt) +""" +#!/usr/bin/python +#coding=utf-8 + + +import os +import argparse + +import moxing as mox +from config import mnist_cfg as cfg +from dataset import create_dataset +from dataset_distributed import create_dataset_parallel +from lenet import LeNet5 +import json +import mindspore.nn as nn +from mindspore import context +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train import Model +from mindspore.nn.metrics import Accuracy +from mindspore import load_checkpoint, load_param_into_net +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +import time + +### Copy multiple datasets from obs to training image ### +def MultiObsToEnv(multi_data_url, data_dir): + #--multi_data_url is json data, need to do json parsing for multi_data_url + multi_data_json = json.loads(multi_data_url) + for i in range(len(multi_data_json)): + path = data_dir + "/" + multi_data_json[i]["dataset_name"] + file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0] + if not os.path.exists(file_path): + os.makedirs(file_path) + try: + mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path) + print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path)) + #unzip dataset + os.system("unzip -d %s %s" % (file_path, path)) + except Exception as e: + print('moxing download {} to {} failed: '.format( + multi_data_json[i]["dataset_url"], path) + str(e)) + #Set a cache file to determine whether the data has been copied to obs. + #If this file exists during multi-card training, there is no need to copy the dataset multiple times. + f = open("/cache/download_input.txt", 'w') + f.close() + try: + if os.path.exists("/cache/download_input.txt"): + print("download_input succeed") + except Exception as e: + print("download_input failed") + return + +def DownloadFromQizhi(multi_data_url, data_dir): + device_num = int(os.getenv('RANK_SIZE')) + if device_num == 1: + MultiObsToEnv(multi_data_url,data_dir) + context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target) + if device_num > 1: + # set device_id and init for multi-card training + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID'))) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True) + init() + #Copying obs data does not need to be executed multiple times, just let the 0th card copy the data + local_rank=int(os.getenv('RANK_ID')) + if local_rank%8==0: + MultiObsToEnv(multi_data_url,data_dir) + #If the cache file does not exist, it means that the copy data has not been completed, + #and Wait for 0th card to finish copying data + while not os.path.exists("/cache/download_input.txt"): + time.sleep(1) + return + +parser = argparse.ArgumentParser(description='MindSpore Lenet Example') +### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset, +### otherwise an error will be reported. +### There is no need to add these parameters to the running parameters of the Qizhi platform, +### because they are predefined in the background, you only need to define them in your code. + +parser.add_argument('--multi_data_url', + help='dataset path in obs') + +parser.add_argument('--ckpt_url', + help='pre_train_model path in obs') + +parser.add_argument( + '--device_target', + type=str, + default="Ascend", + choices=['Ascend', 'CPU'], + help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU') + +parser.add_argument('--epoch_size', + type=int, + default=5, + help='Training epochs.') + +if __name__ == "__main__": + args, unknown = parser.parse_known_args() + data_dir = '/cache/dataset' + train_dir = '/cache/output' + if not os.path.exists(data_dir): + os.makedirs(data_dir) + if not os.path.exists(train_dir): + os.makedirs(train_dir) + ###Initialize and copy data to training image + DownloadFromQizhi(args.multi_data_url, data_dir) + print("--------start ls:") + os.system("cd /cache/dataset; ls -al") + print("--------end ls-----------") +