|
|
|
@@ -1,7 +1,5 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import glob |
|
|
|
import os |
|
|
|
import os.path as osp |
|
|
|
import shutil |
|
|
|
import unittest |
|
|
|
|
|
|
|
@@ -54,7 +52,7 @@ class TestOfaTrainer(unittest.TestCase): |
|
|
|
'drop_worst_ratio': 0.0, |
|
|
|
'ignore_eos': False, |
|
|
|
'ignore_prefix_size': 0, |
|
|
|
'label_smoothing': 0.0, |
|
|
|
'label_smoothing': 0.1, |
|
|
|
'reg_alpha': 1.0, |
|
|
|
'report_accuracy': False, |
|
|
|
'sample_patch_num': 196, |
|
|
|
@@ -77,11 +75,11 @@ class TestOfaTrainer(unittest.TestCase): |
|
|
|
def test_trainer_std(self): |
|
|
|
WORKSPACE = './workspace/ckpts/caption' |
|
|
|
os.makedirs(WORKSPACE, exist_ok=True) |
|
|
|
config_file = os.path.join(WORKSPACE, 'configuration.json') |
|
|
|
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) |
|
|
|
with open(config_file, 'w') as writer: |
|
|
|
json.dump(self.finetune_cfg, writer) |
|
|
|
|
|
|
|
pretrained_model = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_image-caption_coco_large_en' |
|
|
|
pretrained_model = 'damo/ofa_image-caption_coco_large_en' |
|
|
|
args = dict( |
|
|
|
model=pretrained_model, |
|
|
|
work_dir=WORKSPACE, |
|
|
|
|