You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApi.Eager.TensorHandleDevices.cs 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Eager;
  5. using Buffer = System.Buffer;
  6. namespace TensorFlowNET.UnitTest.Eager
  7. {
  8. public partial class CApiEagerTest
  9. {
  10. /// <summary>
  11. /// TEST(CAPI, TensorHandleDevices)
  12. /// </summary>
  13. [TestMethod]
  14. public unsafe void TensorHandleDevices()
  15. {
  16. var status = c_api.TF_NewStatus();
  17. var opts = TFE_NewContextOptions();
  18. var ctx = TFE_NewContext(opts, status);
  19. TFE_DeleteContextOptions(opts);
  20. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  21. var hcpu = TestMatrixTensorHandle();
  22. var device_name = TFE_TensorHandleDeviceName(hcpu, status);
  23. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  24. ASSERT_TRUE(device_name.Contains("CPU:0"));
  25. var backing_device_name = TFE_TensorHandleBackingDeviceName(hcpu, status);
  26. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  27. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  28. // Disable the test if no GPU is present.
  29. string gpu_device_name = "";
  30. if(GetDeviceName(ctx, ref gpu_device_name, "GPU"))
  31. {
  32. var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
  33. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  34. var shape_op = ShapeOp(ctx, hgpu);
  35. TFE_OpSetDevice(shape_op, gpu_device_name, status);
  36. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  37. var retvals = new IntPtr[1];
  38. int num_retvals = 1;
  39. c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status);
  40. ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
  41. // .device of shape is GPU since the op is executed on GPU
  42. device_name = TFE_TensorHandleDeviceName(retvals[0], status);
  43. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  44. ASSERT_TRUE(device_name.Contains("GPU:0"));
  45. // .backing_device of shape is CPU since the tensor is backed by CPU
  46. backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
  47. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  48. ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
  49. TFE_DeleteOp(shape_op);
  50. TFE_DeleteTensorHandle(retvals[0]);
  51. TFE_DeleteTensorHandle(hgpu);
  52. }
  53. TFE_DeleteTensorHandle(hcpu);
  54. // not export api
  55. var executor = TFE_ContextGetExecutorForThread(ctx);
  56. TFE_ExecutorWaitForAllPendingNodes(executor, status);
  57. ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
  58. TFE_DeleteExecutor(executor);
  59. TFE_DeleteContext(ctx);
  60. }
  61. }
  62. }