You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_inference.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import torch
  7. from torch import nn
  8. from torch.utils.data import DataLoader
  9. from modelscope.metrics.builder import MetricKeys
  10. from modelscope.metrics.sequence_classification_metric import \
  11. SequenceClassificationMetric
  12. from modelscope.trainers.utils.inference import multi_gpu_test, single_gpu_test
  13. from modelscope.utils.test_utils import (DistributedTestCase,
  14. create_dummy_test_dataset, test_level)
  15. from modelscope.utils.torch_utils import get_dist_info, init_dist
  16. dummy_dataset = create_dummy_test_dataset(
  17. torch.rand((5, )), torch.randint(0, 4, (1, )), 20)
  18. class DummyModel(nn.Module):
  19. def __init__(self):
  20. super().__init__()
  21. self.linear = nn.Linear(5, 4)
  22. self.bn = nn.BatchNorm1d(4)
  23. def forward(self, feat, labels):
  24. x = self.linear(feat)
  25. x = self.bn(x)
  26. loss = torch.sum(x)
  27. return dict(logits=x, loss=loss)
  28. def test_func(dist=False):
  29. dummy_model = DummyModel()
  30. dataset = dummy_dataset.to_torch_dataset()
  31. dummy_loader = DataLoader(
  32. dataset,
  33. batch_size=2,
  34. )
  35. metric_class = SequenceClassificationMetric()
  36. if dist:
  37. init_dist(launcher='pytorch')
  38. rank, world_size = get_dist_info()
  39. device = torch.device(f'cuda:{rank}')
  40. dummy_model.cuda()
  41. if world_size > 1:
  42. from torch.nn.parallel.distributed import DistributedDataParallel
  43. dummy_model = DistributedDataParallel(
  44. dummy_model, device_ids=[torch.cuda.current_device()])
  45. test_func = multi_gpu_test
  46. else:
  47. test_func = single_gpu_test
  48. metric_results = test_func(
  49. dummy_model,
  50. dummy_loader,
  51. device=device,
  52. metric_classes=[metric_class])
  53. return metric_results
  54. @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
  55. class SingleGpuTestTest(unittest.TestCase):
  56. def setUp(self):
  57. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  58. self.tmp_dir = tempfile.TemporaryDirectory().name
  59. if not os.path.exists(self.tmp_dir):
  60. os.makedirs(self.tmp_dir)
  61. def tearDown(self):
  62. super().tearDown()
  63. shutil.rmtree(self.tmp_dir)
  64. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  65. def test_single_gpu_test(self):
  66. metric_results = test_func()
  67. self.assertIn(MetricKeys.ACCURACY, metric_results)
  68. @unittest.skipIf(not torch.cuda.is_available()
  69. or torch.cuda.device_count() <= 1, 'distributed unittest')
  70. class MultiGpuTestTest(DistributedTestCase):
  71. def setUp(self):
  72. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  73. self.tmp_dir = tempfile.TemporaryDirectory().name
  74. if not os.path.exists(self.tmp_dir):
  75. os.makedirs(self.tmp_dir)
  76. def tearDown(self):
  77. super().tearDown()
  78. shutil.rmtree(self.tmp_dir)
  79. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  80. def test_multi_gpu_test(self):
  81. self.start(
  82. test_func,
  83. num_gpus=2,
  84. assert_callback=lambda x: self.assertIn(MetricKeys.ACCURACY, x),
  85. dist=True)
  86. if __name__ == '__main__':
  87. unittest.main()