You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_device.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import time
  6. import unittest
  7. import torch
  8. from modelscope.utils.constant import Frameworks
  9. from modelscope.utils.device import (create_device, device_placement,
  10. verify_device)
  11. # import tensorflow must be imported after torch is imported when using tf1.15
  12. import tensorflow as tf # isort:skip
  13. class DeviceTest(unittest.TestCase):
  14. def setUp(self):
  15. print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
  16. def tearDown(self):
  17. super().tearDown()
  18. def test_verify(self):
  19. device_name, device_id = verify_device('cpu')
  20. self.assertEqual(device_name, 'cpu')
  21. self.assertTrue(device_id is None)
  22. device_name, device_id = verify_device('CPU')
  23. self.assertEqual(device_name, 'cpu')
  24. device_name, device_id = verify_device('gpu')
  25. self.assertEqual(device_name, 'gpu')
  26. self.assertTrue(device_id == 0)
  27. device_name, device_id = verify_device('cuda')
  28. self.assertEqual(device_name, 'gpu')
  29. self.assertTrue(device_id == 0)
  30. device_name, device_id = verify_device('cuda:0')
  31. self.assertEqual(device_name, 'gpu')
  32. self.assertTrue(device_id == 0)
  33. device_name, device_id = verify_device('gpu:1')
  34. self.assertEqual(device_name, 'gpu')
  35. self.assertTrue(device_id == 1)
  36. with self.assertRaises(AssertionError):
  37. verify_device('xgu')
  38. with self.assertRaises(AssertionError):
  39. verify_device('')
  40. with self.assertRaises(AssertionError):
  41. verify_device(None)
  42. def test_create_device_torch(self):
  43. if torch.cuda.is_available():
  44. target_device_type = 'cuda'
  45. target_device_index = 0
  46. else:
  47. target_device_type = 'cpu'
  48. target_device_index = None
  49. device = create_device('gpu')
  50. self.assertTrue(isinstance(device, torch.device))
  51. self.assertTrue(device.type == target_device_type)
  52. self.assertTrue(device.index == target_device_index)
  53. device = create_device('gpu:0')
  54. self.assertTrue(isinstance(device, torch.device))
  55. self.assertTrue(device.type == target_device_type)
  56. self.assertTrue(device.index == target_device_index)
  57. device = create_device('cuda')
  58. self.assertTrue(device.type == target_device_type)
  59. self.assertTrue(isinstance(device, torch.device))
  60. self.assertTrue(device.index == target_device_index)
  61. device = create_device('cuda:0')
  62. self.assertTrue(isinstance(device, torch.device))
  63. self.assertTrue(device.type == target_device_type)
  64. self.assertTrue(device.index == target_device_index)
  65. def test_device_placement_cpu(self):
  66. with device_placement(Frameworks.torch, 'cpu'):
  67. pass
  68. @unittest.skip('skip this test to avoid debug logging.')
  69. def test_device_placement_tf_gpu(self):
  70. tf.debugging.set_log_device_placement(True)
  71. with device_placement(Frameworks.tf, 'gpu:0'):
  72. a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
  73. b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
  74. c = tf.matmul(a, b)
  75. s = tf.Session()
  76. s.run(c)
  77. tf.debugging.set_log_device_placement(False)
  78. def test_device_placement_torch_gpu(self):
  79. with device_placement(Frameworks.torch, 'gpu:0'):
  80. if torch.cuda.is_available():
  81. self.assertEqual(torch.cuda.current_device(), 0)
  82. if __name__ == '__main__':
  83. unittest.main()