|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import shutil
- import unittest
-
- import json
-
- from modelscope.metainfo import Metrics, Trainers
- from modelscope.msdatasets import MsDataset
- from modelscope.trainers import build_trainer
- from modelscope.utils.constant import ModelFile
- from modelscope.utils.test_utils import test_level
-
-
- class TestClipTrainer(unittest.TestCase):
-
- def setUp(self) -> None:
- self.finetune_cfg = \
- {'framework': 'pytorch',
- 'task': 'multi-modal-embedding',
- 'pipeline': {'type': 'multi-modal-embedding'},
- 'pretrained_model': {'model_name': 'damo/multi-modal_clip-vit-base-patch16_zh'},
- 'dataset': {'column_map': {'img': 'image', 'text': 'query'}},
- 'train': {'work_dir': './workspace/ckpts/clip',
- # 'launcher': 'pytorch',
- 'max_epochs': 1,
- 'use_fp16': True,
- 'dataloader': {'batch_size_per_gpu': 8,
- 'workers_per_gpu': 0,
- 'shuffle': True,
- 'drop_last': True},
- 'lr_scheduler': {'name': 'cosine',
- 'warmup_proportion': 0.01},
- 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
- 'optimizer': {'type': 'AdamW'},
- 'optimizer_hparams': {'lr': 5e-05, 'weight_decay': 0.01},
- 'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
- 'cumulative_iters': 1,
- 'loss_keys': 'loss'},
- 'loss_cfg': {'aggregate': True},
- 'hooks': [{'type': 'BestCkptSaverHook',
- 'metric_key': 'inbatch_t2i_recall_at_1',
- 'interval': 100},
- {'type': 'TextLoggerHook', 'interval': 1},
- {'type': 'IterTimerHook'},
- {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1},
- {'type': 'ClipClampLogitScaleHook'}]},
- 'evaluation': {'dataloader': {'batch_size_per_gpu': 8,
- 'workers_per_gpu': 0,
- 'shuffle': True,
- 'drop_last': True},
- 'metrics': [{'type': 'inbatch_recall'}]},
- 'preprocessor': []}
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_trainer_std(self):
- WORKSPACE = './workspace/ckpts/clip'
- os.makedirs(WORKSPACE, exist_ok=True)
- config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
- with open(config_file, 'w') as writer:
- json.dump(self.finetune_cfg, writer)
-
- pretrained_model = 'damo/multi-modal_clip-vit-base-patch16_zh'
- args = dict(
- model=pretrained_model,
- work_dir=WORKSPACE,
- train_dataset=MsDataset.load(
- 'muge', namespace='modelscope', split='train[:200]'),
- eval_dataset=MsDataset.load(
- 'muge', namespace='modelscope', split='validation[:100]'),
- metrics=[Metrics.inbatch_recall],
- cfg_file=config_file)
- trainer = build_trainer(
- name=Trainers.clip_multi_modal_embedding, default_args=args)
- trainer.train()
-
- self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
- os.listdir(os.path.join(WORKSPACE, 'output')))
- shutil.rmtree(WORKSPACE)
-
-
- if __name__ == '__main__':
- unittest.main()
|