Browse Source

添加 'c2net_listdata.py'

test_v20221116
wjtest1201 1 year ago
parent
commit
9710340c0d
1 changed files with 114 additions and 0 deletions
  1. +114
    -0
      c2net_listdata.py

+ 114
- 0
c2net_listdata.py View File

@@ -0,0 +1,114 @@
"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt)
"""
#!/usr/bin/python
#coding=utf-8


import os
import argparse

import moxing as mox
from config import mnist_cfg as cfg
from dataset import create_dataset
from dataset_distributed import create_dataset_parallel
from lenet import LeNet5
import json
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore import load_checkpoint, load_param_into_net
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
import time

### Copy multiple datasets from obs to training image ###
def MultiObsToEnv(multi_data_url, data_dir):
#--multi_data_url is json data, need to do json parsing for multi_data_url
multi_data_json = json.loads(multi_data_url)
for i in range(len(multi_data_json)):
path = data_dir + "/" + multi_data_json[i]["dataset_name"]
file_path = data_dir + "/" + os.path.splitext(multi_data_json[i]["dataset_name"])[0]
if not os.path.exists(file_path):
os.makedirs(file_path)
try:
mox.file.copy_parallel(multi_data_json[i]["dataset_url"], path)
print("Successfully Download {} to {}".format(multi_data_json[i]["dataset_url"],path))
#unzip dataset
os.system("unzip -d %s %s" % (file_path, path))
except Exception as e:
print('moxing download {} to {} failed: '.format(
multi_data_json[i]["dataset_url"], path) + str(e))
#Set a cache file to determine whether the data has been copied to obs.
#If this file exists during multi-card training, there is no need to copy the dataset multiple times.
f = open("/cache/download_input.txt", 'w')
f.close()
try:
if os.path.exists("/cache/download_input.txt"):
print("download_input succeed")
except Exception as e:
print("download_input failed")
return
def DownloadFromQizhi(multi_data_url, data_dir):
device_num = int(os.getenv('RANK_SIZE'))
if device_num == 1:
MultiObsToEnv(multi_data_url,data_dir)
context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
if device_num > 1:
# set device_id and init for multi-card training
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
init()
#Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
local_rank=int(os.getenv('RANK_ID'))
if local_rank%8==0:
MultiObsToEnv(multi_data_url,data_dir)
#If the cache file does not exist, it means that the copy data has not been completed,
#and Wait for 0th card to finish copying data
while not os.path.exists("/cache/download_input.txt"):
time.sleep(1)
return

parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
### --multi_data_url,--ckpt_url,--device_target,These 4 parameters must be defined first in a multi-dataset,
### otherwise an error will be reported.
### There is no need to add these parameters to the running parameters of the Qizhi platform,
### because they are predefined in the background, you only need to define them in your code.

parser.add_argument('--multi_data_url',
help='dataset path in obs')

parser.add_argument('--ckpt_url',
help='pre_train_model path in obs')

parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')

parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')

if __name__ == "__main__":
args, unknown = parser.parse_known_args()
data_dir = '/cache/dataset'
train_dir = '/cache/output'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists(train_dir):
os.makedirs(train_dir)
###Initialize and copy data to training image
DownloadFromQizhi(args.multi_data_url, data_dir)
print("--------start ls:")
os.system("cd /cache/dataset; ls -al")
print("--------end ls-----------")

Loading…
Cancel
Save