|
|
@@ -71,13 +71,20 @@ class TestOfaTrainer(unittest.TestCase): |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_trainer_std(self): |
|
|
|
# WORKSPACE = './workspace/ckpts/recognition' |
|
|
|
# os.makedirs(WORKSPACE, exist_ok=True) |
|
|
|
# |
|
|
|
# pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' |
|
|
|
# cfg = read_config(pretrained_model) |
|
|
|
# config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) |
|
|
|
# cfg.dump(config_file) |
|
|
|
WORKSPACE = './workspace/ckpts/recognition' |
|
|
|
os.makedirs(WORKSPACE, exist_ok=True) |
|
|
|
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) |
|
|
|
with open(config_file, 'w') as writer: |
|
|
|
json.dump(self.finetune_cfg, writer) |
|
|
|
|
|
|
|
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' |
|
|
|
cfg = read_config(pretrained_model) |
|
|
|
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) |
|
|
|
cfg.dump(config_file) |
|
|
|
|
|
|
|
args = dict( |
|
|
|
model=pretrained_model, |
|
|
|