|
|
@@ -79,7 +79,7 @@ class TestOfaTrainer(unittest.TestCase): |
|
|
|
with open(config_file, 'w') as writer: |
|
|
|
json.dump(self.finetune_cfg, writer) |
|
|
|
|
|
|
|
pretrained_model = 'damo/ofa_image-caption_coco_large_en' |
|
|
|
pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' |
|
|
|
args = dict( |
|
|
|
model=pretrained_model, |
|
|
|
work_dir=WORKSPACE, |
|
|
@@ -97,8 +97,8 @@ class TestOfaTrainer(unittest.TestCase): |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, |
|
|
|
os.path.join(WORKSPACE, 'output')) |
|
|
|
shutil.rmtree(WORKSPACE) |
|
|
|
os.listdir(os.path.join(WORKSPACE, 'output'))) |
|
|
|
# shutil.rmtree(WORKSPACE) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|