You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Eager.TensorHandleDevices.cs 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.Eager;
  3. namespace Tensorflow.Native.UnitTest.Eager
  4. {
  5. public partial class CApiEagerTest
  6. {
  7. /// <summary>
  8. /// TEST(CAPI, TensorHandleDevices)
  9. /// </summary>
  10. [TestMethod]
  11. public unsafe void TensorHandleDevices()
  12. {
  13. var status = c_api.TF_NewStatus();
  14. static SafeContextHandle NewContext(SafeStatusHandle status)
  15. {
  16. using var opts = c_api.TFE_NewContextOptions();
  17. return c_api.TFE_NewContext(opts, status);
  18. }
  19. using var ctx = NewContext(status);
  20. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  21. using (var hcpu = TestMatrixTensorHandle())
  22. {
  23. var device_name = TFE_TensorHandleDeviceName(hcpu, status);
  24. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  25. ASSERT_TRUE(device_name.Contains("CPU:0"));
  26. var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status);
  27. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  28. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  29. // Disable the test if no GPU is present.
  30. string gpu_device_name = "";
  31. if (GetDeviceName(ctx, ref gpu_device_name, "GPU"))
  32. {
  33. using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
  34. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  35. var retvals = new SafeEagerTensorHandle[1];
  36. using (var shape_op = ShapeOp(ctx, hgpu))
  37. {
  38. TFE_OpSetDevice(shape_op, gpu_device_name, status);
  39. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  40. int num_retvals;
  41. c_api.TFE_Execute(shape_op, retvals, out num_retvals, status);
  42. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  43. ASSERT_EQ(1, num_retvals);
  44. try
  45. {
  46. // .device of shape is GPU since the op is executed on GPU
  47. device_name = TFE_TensorHandleDeviceName(retvals[0], status);
  48. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  49. ASSERT_TRUE(device_name.Contains("GPU:0"));
  50. // .backing_device of shape is CPU since the tensor is backed by CPU
  51. backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
  52. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  53. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  54. }
  55. finally
  56. {
  57. retvals[0].Dispose();
  58. }
  59. }
  60. }
  61. }
  62. // not export api
  63. using var executor = TFE_ContextGetExecutorForThread(ctx);
  64. TFE_ExecutorWaitForAllPendingNodes(executor, status);
  65. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  66. }
  67. }
  68. }