| @@ -1,7 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import glob | |||||
| import os | import os | ||||
| import os.path as osp | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| @@ -54,7 +52,7 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| 'drop_worst_ratio': 0.0, | 'drop_worst_ratio': 0.0, | ||||
| 'ignore_eos': False, | 'ignore_eos': False, | ||||
| 'ignore_prefix_size': 0, | 'ignore_prefix_size': 0, | ||||
| 'label_smoothing': 0.0, | |||||
| 'label_smoothing': 0.1, | |||||
| 'reg_alpha': 1.0, | 'reg_alpha': 1.0, | ||||
| 'report_accuracy': False, | 'report_accuracy': False, | ||||
| 'sample_patch_num': 196, | 'sample_patch_num': 196, | ||||
| @@ -77,11 +75,11 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| def test_trainer_std(self): | def test_trainer_std(self): | ||||
| WORKSPACE = './workspace/ckpts/caption' | WORKSPACE = './workspace/ckpts/caption' | ||||
| os.makedirs(WORKSPACE, exist_ok=True) | os.makedirs(WORKSPACE, exist_ok=True) | ||||
| config_file = os.path.join(WORKSPACE, 'configuration.json') | |||||
| config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) | |||||
| with open(config_file, 'w') as writer: | with open(config_file, 'w') as writer: | ||||
| json.dump(self.finetune_cfg, writer) | json.dump(self.finetune_cfg, writer) | ||||
| pretrained_model = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_image-caption_coco_large_en' | |||||
| pretrained_model = 'damo/ofa_image-caption_coco_large_en' | |||||
| args = dict( | args = dict( | ||||
| model=pretrained_model, | model=pretrained_model, | ||||
| work_dir=WORKSPACE, | work_dir=WORKSPACE, | ||||