You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.TensorHandleDevices.cs 3.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. namespace TensorFlowNET.UnitTest.NativeAPI
  6. {
  7. public partial class CApiEagerTest
  8. {
  9. /// <summary>
  10. /// TEST(CAPI, TensorHandleDevices)
  11. /// </summary>
  12. [TestMethod]
  13. public unsafe void TensorHandleDevices()
  14. {
  15. var status = c_api.TF_NewStatus();
  16. static SafeContextHandle NewContext(SafeStatusHandle status)
  17. {
  18. using var opts = c_api.TFE_NewContextOptions();
  19. return c_api.TFE_NewContext(opts, status);
  20. }
  21. using var ctx = NewContext(status);
  22. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  23. using (var hcpu = TestMatrixTensorHandle())
  24. {
  25. var device_name = TFE_TensorHandleDeviceName(hcpu, status);
  26. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  27. ASSERT_TRUE(device_name.Contains("CPU:0"));
  28. var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status);
  29. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  30. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  31. // Disable the test if no GPU is present.
  32. string gpu_device_name = "";
  33. if (GetDeviceName(ctx, ref gpu_device_name, "GPU"))
  34. {
  35. using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
  36. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  37. var retvals = new SafeTensorHandleHandle[1];
  38. using (var shape_op = ShapeOp(ctx, hgpu))
  39. {
  40. TFE_OpSetDevice(shape_op, gpu_device_name, status);
  41. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  42. int num_retvals;
  43. c_api.TFE_Execute(shape_op, retvals, out num_retvals, status);
  44. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  45. ASSERT_EQ(1, num_retvals);
  46. try
  47. {
  48. // .device of shape is GPU since the op is executed on GPU
  49. device_name = TFE_TensorHandleDeviceName(retvals[0], status);
  50. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  51. ASSERT_TRUE(device_name.Contains("GPU:0"));
  52. // .backing_device of shape is CPU since the tensor is backed by CPU
  53. backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
  54. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  55. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  56. }
  57. finally
  58. {
  59. retvals[0].Dispose();
  60. }
  61. }
  62. }
  63. }
  64. // not export api
  65. using var executor = TFE_ContextGetExecutorForThread(ctx);
  66. TFE_ExecutorWaitForAllPendingNodes(executor, status);
  67. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  68. }
  69. }
  70. }