You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_inference.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import torch
  7. from torch import nn
  8. from torch.utils.data import DataLoader
  9. from modelscope.metrics.builder import MetricKeys
  10. from modelscope.metrics.sequence_classification_metric import \
  11. SequenceClassificationMetric
  12. from modelscope.models.base import Model
  13. from modelscope.trainers import EpochBasedTrainer
  14. from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test
  15. from modelscope.utils.test_utils import (DistributedTestCase,
  16. create_dummy_test_dataset, test_level)
  17. from modelscope.utils.torch_utils import get_dist_info, init_dist
  18. dummy_dataset = create_dummy_test_dataset(
  19. torch.rand((5, )), torch.randint(0, 4, (1, )), 20)
  20. class DummyModel(nn.Module, Model):
  21. def __init__(self):
  22. super().__init__()
  23. self.linear = nn.Linear(5, 4)
  24. self.bn = nn.BatchNorm1d(4)
  25. def forward(self, feat, labels):
  26. x = self.linear(feat)
  27. x = self.bn(x)
  28. loss = torch.sum(x)
  29. return dict(logits=x, loss=loss)
  30. class DummyTrainer(EpochBasedTrainer):
  31. def __init__(self, model):
  32. self.model = model
  33. def test_func(dist=False):
  34. dummy_model = DummyModel()
  35. dataset = dummy_dataset.to_torch_dataset()
  36. dummy_loader = DataLoader(
  37. dataset,
  38. batch_size=2,
  39. )
  40. metric_class = SequenceClassificationMetric()
  41. if dist:
  42. init_dist(launcher='pytorch')
  43. rank, world_size = get_dist_info()
  44. device = torch.device(f'cuda:{rank}')
  45. dummy_model.cuda()
  46. if world_size > 1:
  47. from torch.nn.parallel.distributed import DistributedDataParallel
  48. dummy_model = DistributedDataParallel(
  49. dummy_model, device_ids=[torch.cuda.current_device()])
  50. test_func = multi_gpu_test
  51. else:
  52. test_func = single_gpu_test
  53. dummy_trainer = DummyTrainer(dummy_model)
  54. metric_results = test_func(
  55. dummy_trainer,
  56. dummy_loader,
  57. device=device,
  58. metric_classes=[metric_class])
  59. return metric_results
  60. @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
  61. class SingleGpuTestTest(unittest.TestCase):
  62. def setUp(self):
  63. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  64. self.tmp_dir = tempfile.TemporaryDirectory().name
  65. if not os.path.exists(self.tmp_dir):
  66. os.makedirs(self.tmp_dir)
  67. def tearDown(self):
  68. super().tearDown()
  69. shutil.rmtree(self.tmp_dir)
  70. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  71. def test_single_gpu_test(self):
  72. metric_results = test_func()
  73. self.assertIn(MetricKeys.ACCURACY, metric_results)
  74. @unittest.skipIf(not torch.cuda.is_available()
  75. or torch.cuda.device_count() <= 1, 'distributed unittest')
  76. class MultiGpuTestTest(DistributedTestCase):
  77. def setUp(self):
  78. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  79. self.tmp_dir = tempfile.TemporaryDirectory().name
  80. if not os.path.exists(self.tmp_dir):
  81. os.makedirs(self.tmp_dir)
  82. def tearDown(self):
  83. super().tearDown()
  84. shutil.rmtree(self.tmp_dir)
  85. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  86. def test_multi_gpu_test(self):
  87. self.start(
  88. test_func,
  89. num_gpus=2,
  90. assert_callback=lambda x: self.assertIn(MetricKeys.ACCURACY, x),
  91. dist=True)
  92. if __name__ == '__main__':
  93. unittest.main()