You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_team_transfer_trainer.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. import unittest
  3. import json
  4. import requests
  5. import torch
  6. import torch.distributed as dist
  7. import torch.multiprocessing as mp
  8. from modelscope.hub.snapshot_download import snapshot_download
  9. from modelscope.metainfo import Trainers
  10. from modelscope.msdatasets import MsDataset
  11. from modelscope.trainers import build_trainer
  12. from modelscope.trainers.multi_modal.team.team_trainer_utils import (
  13. collate_fn, train_mapping, val_mapping)
  14. from modelscope.utils.config import Config
  15. from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile
  16. from modelscope.utils.logger import get_logger
  17. from modelscope.utils.test_utils import test_level
  18. logger = get_logger()
  19. def train_worker(device_id):
  20. model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
  21. ckpt_dir = './ckpt'
  22. os.makedirs(ckpt_dir, exist_ok=True)
  23. # Use epoch=1 for faster training here
  24. cfg = Config({
  25. 'framework': 'pytorch',
  26. 'task': 'multi-modal-similarity',
  27. 'pipeline': {
  28. 'type': 'multi-modal-similarity'
  29. },
  30. 'model': {
  31. 'type': 'team-multi-modal-similarity'
  32. },
  33. 'dataset': {
  34. 'name': 'Caltech101',
  35. 'class_num': 101
  36. },
  37. 'preprocessor': {},
  38. 'train': {
  39. 'epoch': 1,
  40. 'batch_size': 32,
  41. 'ckpt_dir': ckpt_dir
  42. },
  43. 'evaluation': {
  44. 'batch_size': 64
  45. }
  46. })
  47. cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION)
  48. cfg.dump(cfg_file)
  49. train_dataset = MsDataset.load(
  50. cfg.dataset.name,
  51. namespace='modelscope',
  52. split='train',
  53. download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
  54. train_dataset = train_dataset.with_transform(train_mapping)
  55. val_dataset = MsDataset.load(
  56. cfg.dataset.name,
  57. namespace='modelscope',
  58. split='validation',
  59. download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
  60. val_dataset = val_dataset.with_transform(val_mapping)
  61. default_args = dict(
  62. cfg_file=cfg_file,
  63. model=model_id,
  64. device_id=device_id,
  65. data_collator=collate_fn,
  66. train_dataset=train_dataset,
  67. val_dataset=val_dataset)
  68. trainer = build_trainer(
  69. name=Trainers.image_classification_team, default_args=default_args)
  70. trainer.train()
  71. trainer.evaluate()
  72. class TEAMTransferTrainerTest(unittest.TestCase):
  73. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  74. def test_trainer(self):
  75. if torch.cuda.device_count() > 0:
  76. train_worker(device_id=0)
  77. else:
  78. train_worker(device_id=-1)
  79. logger.info('Training done')
  80. if __name__ == '__main__':
  81. unittest.main()