You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ofa_trainer.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import unittest
  5. import json
  6. from modelscope.msdatasets import MsDataset
  7. from modelscope.trainers import build_trainer
  8. from modelscope.utils.constant import DownloadMode, ModelFile
  9. from modelscope.utils.hub import read_config
  10. from modelscope.utils.test_utils import test_level
  11. class TestOfaTrainer(unittest.TestCase):
  12. def setUp(self) -> None:
  13. self.finetune_cfg = \
  14. {'framework': 'pytorch',
  15. 'task': 'ocr-recognition',
  16. 'model': {'type': 'ofa',
  17. 'beam_search': {'beam_size': 5,
  18. 'max_len_b': 64,
  19. 'min_len': 1,
  20. 'no_repeat_ngram_size': 0},
  21. 'seed': 7,
  22. 'max_src_length': 128,
  23. 'language': 'zh',
  24. 'gen_type': 'generation',
  25. 'patch_image_size': 480,
  26. 'is_document': False,
  27. 'max_image_size': 480,
  28. 'imagenet_default_mean_and_std': False},
  29. 'pipeline': {'type': 'ofa-ocr-recognition'},
  30. 'dataset': {'column_map': {'text': 'label'}},
  31. 'train': {'work_dir': 'work/ckpts/recognition',
  32. # 'launcher': 'pytorch',
  33. 'max_epochs': 1,
  34. 'use_fp16': True,
  35. 'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
  36. 'lr_scheduler': {'name': 'polynomial_decay',
  37. 'warmup_proportion': 0.01,
  38. 'lr_end': 1e-07},
  39. 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
  40. 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
  41. 'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
  42. 'cumulative_iters': 1,
  43. 'grad_clip': {'max_norm': 1.0, 'norm_type': 2},
  44. 'loss_keys': 'loss'},
  45. 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion',
  46. 'constraint_range': None,
  47. 'drop_worst_after': 0,
  48. 'drop_worst_ratio': 0.0,
  49. 'ignore_eos': False,
  50. 'ignore_prefix_size': 0,
  51. 'label_smoothing': 0.1,
  52. 'reg_alpha': 1.0,
  53. 'report_accuracy': False,
  54. 'sample_patch_num': 196,
  55. 'sentence_avg': False,
  56. 'use_rdrop': True},
  57. 'hooks': [{'type': 'BestCkptSaverHook',
  58. 'metric_key': 'accuracy',
  59. 'interval': 100},
  60. {'type': 'TextLoggerHook', 'interval': 1},
  61. {'type': 'IterTimerHook'},
  62. {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
  63. 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
  64. 'metrics': [{'type': 'accuracy'}]},
  65. 'preprocessor': []}
  66. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  67. def test_trainer_std(self):
  68. # WORKSPACE = './workspace/ckpts/recognition'
  69. # os.makedirs(WORKSPACE, exist_ok=True)
  70. #
  71. # pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
  72. # cfg = read_config(pretrained_model)
  73. # config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
  74. # cfg.dump(config_file)
  75. WORKSPACE = './workspace/ckpts/recognition'
  76. os.makedirs(WORKSPACE, exist_ok=True)
  77. config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
  78. with open(config_file, 'w') as writer:
  79. json.dump(self.finetune_cfg, writer)
  80. pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
  81. args = dict(
  82. model=pretrained_model,
  83. work_dir=WORKSPACE,
  84. train_dataset=MsDataset.load(
  85. 'ocr_fudanvi_zh',
  86. subset_name='scene',
  87. namespace='modelscope',
  88. split='train[800:900]',
  89. download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
  90. eval_dataset=MsDataset.load(
  91. 'ocr_fudanvi_zh',
  92. subset_name='scene',
  93. namespace='modelscope',
  94. split='test[:20]',
  95. download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
  96. cfg_file=config_file)
  97. trainer = build_trainer(name='ofa', default_args=args)
  98. trainer.train()
  99. self.assertIn(
  100. ModelFile.TORCH_MODEL_BIN_FILE,
  101. os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
  102. shutil.rmtree(WORKSPACE)
  103. if __name__ == '__main__':
  104. unittest.main()