You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 5.6 kB

3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YoloV5 train."""
  16. import os
  17. import time
  18. import mindspore as ms
  19. import mindspore.nn as nn
  20. import mindspore.communication as comm
  21. from src.yolo import YOLOV5, YoloWithLossCell
  22. from src.logger import get_logger
  23. from src.util import AverageMeter, get_param_groups, cpu_affinity
  24. from src.lr_scheduler import get_lr
  25. from src.yolo_dataset import create_yolo_dataset
  26. from src.initializer import default_recurisive_init, load_yolov5_params
  27. from model_utils.config import config
  28. from model_utils.device_adapter import get_device_id
  29. # only useful for huawei cloud modelarts.
  30. from model_utils.moxing_adapter import moxing_wrapper, modelarts_pre_process, modelarts_post_process
  31. ms.set_seed(1)
  32. def init_distribute():
  33. comm.init()
  34. config.rank = comm.get_rank()
  35. config.group_size = comm.get_group_size()
  36. ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL, gradients_mean=True,
  37. device_num=config.group_size)
  38. def train_preprocess():
  39. if config.lr_scheduler == 'cosine_annealing' and config.max_epoch > config.T_max:
  40. config.T_max = config.max_epoch
  41. config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
  42. config.data_root = os.path.join(config.data_dir, 'dataset/mask/images/train')
  43. config.annFile = os.path.join(config.data_dir, 'coco/annotations.json')
  44. device_id = get_device_id()
  45. ms.set_context(mode=ms.GRAPH_MODE, device_target=config.device_target, device_id=device_id)
  46. if config.is_distributed:
  47. # init distributed
  48. init_distribute()
  49. # for promoting performance in GPU device
  50. if config.device_target == "GPU" and config.bind_cpu:
  51. cpu_affinity(config.rank, min(config.group_size, config.device_num))
  52. # logger module is managed by config, it is used in other function. e.x. config.logger.info("xxx")
  53. config.logger = get_logger(config.output_dir, config.rank)
  54. config.logger.save_args(config)
  55. @moxing_wrapper(pre_process=modelarts_pre_process, post_process=modelarts_post_process, pre_args=[config])
  56. def run_train():
  57. train_preprocess()
  58. loss_meter = AverageMeter('loss')
  59. dict_version = {'yolov5s': 0, 'yolov5m': 1, 'yolov5l': 2, 'yolov5x': 3}
  60. network = YOLOV5(is_training=True, version=dict_version[config.yolov5_version])
  61. # default is kaiming-normal
  62. default_recurisive_init(network)
  63. load_yolov5_params(config, network)
  64. network = YoloWithLossCell(network)
  65. ds = create_yolo_dataset(image_dir=config.data_root, anno_path=config.annFile, is_training=True,
  66. batch_size=config.per_batch_size, device_num=config.group_size,
  67. rank=config.rank, config=config)
  68. config.logger.info('Finish loading dataset')
  69. steps_per_epoch = ds.get_dataset_size()
  70. lr = get_lr(config, steps_per_epoch)
  71. opt = nn.Momentum(params=get_param_groups(network), momentum=config.momentum, learning_rate=ms.Tensor(lr),
  72. weight_decay=config.weight_decay, loss_scale=config.loss_scale)
  73. network = nn.TrainOneStepCell(network, opt, config.loss_scale // 2)
  74. network.set_train()
  75. data_loader = ds.create_tuple_iterator(do_copy=False)
  76. first_step = True
  77. t_end = time.time()
  78. for epoch_idx in range(config.max_epoch):
  79. for step_idx, data in enumerate(data_loader):
  80. images = data[0]
  81. input_shape = images.shape[2:4]
  82. input_shape = ms.Tensor(tuple(input_shape[::-1]), ms.float32)
  83. loss = network(images, data[2], data[3], data[4], data[5], data[6],
  84. data[7], input_shape)
  85. loss_meter.update(loss.asnumpy())
  86. # it is used for loss, performance output per config.log_interval steps.
  87. if (epoch_idx * steps_per_epoch + step_idx) % config.log_interval == 0:
  88. time_used = time.time() - t_end
  89. if first_step:
  90. fps = config.per_batch_size * config.group_size / time_used
  91. per_step_time = time_used * 1000
  92. first_step = False
  93. else:
  94. fps = config.per_batch_size * config.log_interval * config.group_size / time_used
  95. per_step_time = time_used / config.log_interval * 1000
  96. config.logger.info('epoch[{}], iter[{}], {}, fps:{:.2f} imgs/sec, '
  97. 'lr:{}, per step time: {}ms'.format(epoch_idx + 1, step_idx + 1,
  98. loss_meter, fps, lr[step_idx], per_step_time))
  99. t_end = time.time()
  100. loss_meter.reset()
  101. if config.rank == 0:
  102. ckpt_name = os.path.join(config.output_dir, "yolov5_{}_{}.ckpt".format(epoch_idx + 1, steps_per_epoch))
  103. ms.save_checkpoint(network, ckpt_name)
  104. config.logger.info('==========end training===============')
  105. if __name__ == "__main__":
  106. run_train()

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。