You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

moxing_adapter.py 6.7 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Moxing adapter for ModelArts"""
  16. import os
  17. import functools
  18. import mindspore as ms
  19. from .config import config
  20. _global_sync_count = 0
  21. def get_device_id():
  22. device_id = os.getenv('DEVICE_ID', '0')
  23. return int(device_id)
  24. def get_device_num():
  25. device_num = os.getenv('RANK_SIZE', '1')
  26. return int(device_num)
  27. def get_rank_id():
  28. global_rank_id = os.getenv('RANK_ID', '0')
  29. return int(global_rank_id)
  30. def get_job_id():
  31. job_id = os.getenv('JOB_ID')
  32. job_id = job_id if job_id != "" else "default"
  33. return job_id
  34. def sync_data(from_path, to_path):
  35. """
  36. Download data from remote obs to local directory if the first url is remote url and the second one is local path
  37. Upload data from local directory to remote obs in contrast.
  38. """
  39. import moxing as mox
  40. import time
  41. global _global_sync_count
  42. sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
  43. _global_sync_count += 1
  44. # Each server contains 8 devices as most.
  45. if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
  46. print("from path: ", from_path)
  47. print("to path: ", to_path)
  48. mox.file.copy_parallel(from_path, to_path)
  49. print("===finish data synchronization===")
  50. try:
  51. os.mknod(sync_lock)
  52. except IOError:
  53. pass
  54. print("===save flag===")
  55. while True:
  56. if os.path.exists(sync_lock):
  57. break
  58. time.sleep(1)
  59. print("Finish sync data from {} to {}.".format(from_path, to_path))
  60. def modelarts_pre_process(args):
  61. '''modelarts pre process function.'''
  62. def unzip(zip_file, save_dir):
  63. import zipfile
  64. s_time = time.time()
  65. if not os.path.exists(os.path.join(save_dir, args.modelarts_dataset_unzip_name)):
  66. zip_isexist = zipfile.is_zipfile(zip_file)
  67. if zip_isexist:
  68. fz = zipfile.ZipFile(zip_file, 'r')
  69. data_num = len(fz.namelist())
  70. print("Extract Start...")
  71. print("unzip file num: {}".format(data_num))
  72. data_print = int(data_num / 100) if data_num > 100 else 1
  73. i = 0
  74. for file in fz.namelist():
  75. if i % data_print == 0:
  76. print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
  77. i += 1
  78. fz.extract(file, save_dir)
  79. print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),
  80. int(int(time.time() - s_time) % 60)))
  81. print("Extract Done.")
  82. else:
  83. print("This is not zip.")
  84. else:
  85. print("Zip has been extracted.")
  86. if args.need_modelarts_dataset_unzip:
  87. zip_file_1 = os.path.join(args.data_path, args.modelarts_dataset_unzip_name + ".zip")
  88. save_dir_1 = os.path.join(args.data_path)
  89. sync_lock = "/tmp/unzip_sync.lock"
  90. # Each server contains 8 devices as most.
  91. if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
  92. print("Zip file path: ", zip_file_1)
  93. print("Unzip file save dir: ", save_dir_1)
  94. unzip(zip_file_1, save_dir_1)
  95. print("===Finish extract data synchronization===")
  96. try:
  97. os.mknod(sync_lock)
  98. except IOError:
  99. pass
  100. while True:
  101. if os.path.exists(sync_lock):
  102. break
  103. time.sleep(1)
  104. print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
  105. args.output_dir = os.path.join(args.output_path, args.output_dir)
  106. args.ckpt_path = os.path.join(args.output_path, args.ckpt_path)
  107. def modelarts_post_process():
  108. sync_data(from_path='/cache/output', to_path='obs://hit-cyf/yolov5_npu/outputs/')
  109. def modelarts_export_preprocess(args):
  110. args.file_name = os.path.join(args.output_path, args.file_name)
  111. def moxing_wrapper(pre_process=None, post_process=None, **kwargs):
  112. """
  113. Moxing wrapper to download dataset and upload outputs.
  114. """
  115. def wrapper(run_func):
  116. @functools.wraps(run_func)
  117. def wrapped_func(*args, **kwargs):
  118. # Download data from data_url
  119. if config.enable_modelarts:
  120. if config.data_url:
  121. sync_data(config.data_url, config.data_path)
  122. print("Dataset downloaded: ", os.listdir(config.data_path))
  123. if config.checkpoint_url:
  124. sync_data(config.checkpoint_url, config.load_path)
  125. print("Preload downloaded: ", os.listdir(config.load_path))
  126. if config.train_url:
  127. sync_data(config.train_url, config.output_path)
  128. print("Workspace downloaded: ", os.listdir(config.output_path))
  129. ms.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
  130. config.device_num = get_device_num()
  131. config.device_id = get_device_id()
  132. if not os.path.exists(config.output_path):
  133. os.makedirs(config.output_path)
  134. if pre_process:
  135. if "pre_args" in kwargs.keys():
  136. pre_process(*kwargs["pre_args"])
  137. else:
  138. pre_process()
  139. # Run the main function
  140. run_func(*args, **kwargs)
  141. # Upload data to train_url
  142. if config.enable_modelarts:
  143. if post_process:
  144. if "post_args" in kwargs.keys():
  145. post_process(*kwargs["post_args"])
  146. else:
  147. post_process()
  148. if config.train_url:
  149. print("Start to copy output directory")
  150. sync_data(config.output_path, config.train_url)
  151. return wrapped_func
  152. return wrapper

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。