|
- import os
- import unittest
-
- import json
- import requests
- import torch
- import torch.distributed as dist
- import torch.multiprocessing as mp
-
- from modelscope.hub.snapshot_download import snapshot_download
- from modelscope.metainfo import Trainers
- from modelscope.msdatasets import MsDataset
- from modelscope.trainers import build_trainer
- from modelscope.trainers.multi_modal.team.team_trainer_utils import (
- collate_fn, train_mapping, val_mapping)
- from modelscope.utils.config import Config
- from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile
- from modelscope.utils.logger import get_logger
- from modelscope.utils.test_utils import test_level
-
- logger = get_logger()
-
-
- def train_worker(device_id):
- model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
- ckpt_dir = './ckpt'
- os.makedirs(ckpt_dir, exist_ok=True)
- # Use epoch=1 for faster training here
- cfg = Config({
- 'framework': 'pytorch',
- 'task': 'multi-modal-similarity',
- 'pipeline': {
- 'type': 'multi-modal-similarity'
- },
- 'model': {
- 'type': 'team-multi-modal-similarity'
- },
- 'dataset': {
- 'name': 'Caltech101',
- 'class_num': 101
- },
- 'preprocessor': {},
- 'train': {
- 'epoch': 1,
- 'batch_size': 32,
- 'ckpt_dir': ckpt_dir
- },
- 'evaluation': {
- 'batch_size': 64
- }
- })
- cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION)
- cfg.dump(cfg_file)
-
- train_dataset = MsDataset.load(
- cfg.dataset.name,
- namespace='modelscope',
- split='train',
- download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
- train_dataset = train_dataset.with_transform(train_mapping)
- val_dataset = MsDataset.load(
- cfg.dataset.name,
- namespace='modelscope',
- split='validation',
- download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
- val_dataset = val_dataset.with_transform(val_mapping)
-
- default_args = dict(
- cfg_file=cfg_file,
- model=model_id,
- device_id=device_id,
- data_collator=collate_fn,
- train_dataset=train_dataset,
- val_dataset=val_dataset)
-
- trainer = build_trainer(
- name=Trainers.image_classification_team, default_args=default_args)
- trainer.train()
- trainer.evaluate()
-
-
- class TEAMTransferTrainerTest(unittest.TestCase):
-
- @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
- def test_trainer(self):
- if torch.cuda.device_count() > 0:
- train_worker(device_id=0)
- else:
- train_worker(device_id=-1)
- logger.info('Training done')
-
-
- if __name__ == '__main__':
- unittest.main()
|