You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.TensorHandleDevices.cs 3.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using Tensorflow.Eager;
  4. namespace TensorFlowNET.UnitTest.NativeAPI
  5. {
  6. public partial class CApiEagerTest
  7. {
  8. /// <summary>
  9. /// TEST(CAPI, TensorHandleDevices)
  10. /// </summary>
  11. [TestMethod]
  12. public unsafe void TensorHandleDevices()
  13. {
  14. var status = c_api.TF_NewStatus();
  15. static SafeContextHandle NewContext(SafeStatusHandle status)
  16. {
  17. using var opts = c_api.TFE_NewContextOptions();
  18. return c_api.TFE_NewContext(opts, status);
  19. }
  20. using var ctx = NewContext(status);
  21. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  22. using (var hcpu = TestMatrixTensorHandle())
  23. {
  24. var device_name = TFE_TensorHandleDeviceName(hcpu, status);
  25. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  26. ASSERT_TRUE(device_name.Contains("CPU:0"));
  27. var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status);
  28. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  29. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  30. // Disable the test if no GPU is present.
  31. string gpu_device_name = "";
  32. if (GetDeviceName(ctx, ref gpu_device_name, "GPU"))
  33. {
  34. using var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
  35. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  36. var retvals = new SafeTensorHandleHandle[1];
  37. using (var shape_op = ShapeOp(ctx, hgpu))
  38. {
  39. TFE_OpSetDevice(shape_op, gpu_device_name, status);
  40. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  41. int num_retvals;
  42. c_api.TFE_Execute(shape_op, retvals, out num_retvals, status);
  43. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  44. ASSERT_EQ(1, num_retvals);
  45. try
  46. {
  47. // .device of shape is GPU since the op is executed on GPU
  48. device_name = TFE_TensorHandleDeviceName(retvals[0], status);
  49. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  50. ASSERT_TRUE(device_name.Contains("GPU:0"));
  51. // .backing_device of shape is CPU since the tensor is backed by CPU
  52. backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
  53. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  54. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  55. }
  56. finally
  57. {
  58. retvals[0].Dispose();
  59. }
  60. }
  61. }
  62. }
  63. // not export api
  64. using var executor = TFE_ContextGetExecutorForThread(ctx);
  65. TFE_ExecutorWaitForAllPendingNodes(executor, status);
  66. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  67. }
  68. }
  69. }