# Copyright (c) Alibaba, Inc. and its affiliates. import os import shutil import tempfile import time import unittest import torch from modelscope.utils.constant import Frameworks from modelscope.utils.device import (create_device, device_placement, verify_device) # import tensorflow must be imported after torch is imported when using tf1.15 import tensorflow as tf # isort:skip class DeviceTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def tearDown(self): super().tearDown() def test_verify(self): device_name, device_id = verify_device('cpu') self.assertEqual(device_name, 'cpu') self.assertTrue(device_id is None) device_name, device_id = verify_device('CPU') self.assertEqual(device_name, 'cpu') device_name, device_id = verify_device('gpu') self.assertEqual(device_name, 'gpu') self.assertTrue(device_id == 0) device_name, device_id = verify_device('cuda') self.assertEqual(device_name, 'gpu') self.assertTrue(device_id == 0) device_name, device_id = verify_device('cuda:0') self.assertEqual(device_name, 'gpu') self.assertTrue(device_id == 0) device_name, device_id = verify_device('gpu:1') self.assertEqual(device_name, 'gpu') self.assertTrue(device_id == 1) with self.assertRaises(AssertionError): verify_device('xgu') with self.assertRaises(AssertionError): verify_device('') with self.assertRaises(AssertionError): verify_device(None) def test_create_device_torch(self): if torch.cuda.is_available(): target_device_type = 'cuda' target_device_index = 0 else: target_device_type = 'cpu' target_device_index = None device = create_device('gpu') self.assertTrue(isinstance(device, torch.device)) self.assertTrue(device.type == target_device_type) self.assertTrue(device.index == target_device_index) device = create_device('gpu:0') self.assertTrue(isinstance(device, torch.device)) self.assertTrue(device.type == target_device_type) self.assertTrue(device.index == target_device_index) device = create_device('cuda') self.assertTrue(device.type == target_device_type) self.assertTrue(isinstance(device, torch.device)) self.assertTrue(device.index == target_device_index) device = create_device('cuda:0') self.assertTrue(isinstance(device, torch.device)) self.assertTrue(device.type == target_device_type) self.assertTrue(device.index == target_device_index) def test_device_placement_cpu(self): with device_placement(Frameworks.torch, 'cpu'): pass @unittest.skip('skip this test to avoid debug logging.') def test_device_placement_tf_gpu(self): tf.debugging.set_log_device_placement(True) with device_placement(Frameworks.tf, 'gpu:0'): a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) c = tf.matmul(a, b) s = tf.Session() s.run(c) tf.debugging.set_log_device_placement(False) def test_device_placement_torch_gpu(self): with device_placement(Frameworks.torch, 'gpu:0'): if torch.cuda.is_available(): self.assertEqual(torch.cuda.current_device(), 0) if __name__ == '__main__': unittest.main()