You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_clip_multi_modal_embedding_trainer.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import os
  2. import tempfile
  3. import unittest
  4. import requests
  5. import torch
  6. import torch.distributed as dist
  7. import torch.multiprocessing as mp
  8. from modelscope.hub.snapshot_download import snapshot_download
  9. from modelscope.metainfo import Trainers
  10. from modelscope.trainers import build_trainer
  11. from modelscope.utils.constant import ModelFile
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.test_utils import test_level
  14. logger = get_logger()
  15. def clip_train_worker(local_rank, ngpus, node_size, node_rank):
  16. global_rank = local_rank + node_rank * ngpus
  17. dist_world_size = node_size * ngpus
  18. dist.init_process_group(
  19. backend='nccl', world_size=dist_world_size, rank=global_rank)
  20. model_id = 'damo/multi-modal_clip-vit-large-patch14_zh'
  21. local_model_dir = snapshot_download(model_id)
  22. default_args = dict(
  23. cfg_file='{}/{}'.format(local_model_dir, ModelFile.CONFIGURATION),
  24. model=model_id,
  25. device_id=local_rank)
  26. trainer = build_trainer(
  27. name=Trainers.clip_multi_modal_embedding, default_args=default_args)
  28. trainer.train()
  29. trainer.evaluate()
  30. class CLIPMultiModalEmbeddingTrainerTest(unittest.TestCase):
  31. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  32. def test_trainer(self):
  33. os.environ['MASTER_ADDR'] = '127.0.0.1'
  34. os.environ['MASTER_PORT'] = '2001'
  35. NODE_SIZE, NODE_RANK = 1, 0
  36. logger.info('Train clip with {} machines'.format(NODE_SIZE))
  37. ngpus = torch.cuda.device_count()
  38. logger.info('Machine: {} has {} GPUs'.format(NODE_RANK, ngpus))
  39. mp.spawn(
  40. clip_train_worker,
  41. nprocs=ngpus,
  42. args=(ngpus, NODE_SIZE, NODE_RANK))
  43. logger.info('Training done')
  44. if __name__ == '__main__':
  45. unittest.main()
  46. ...