You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NDArrayConverter.cs 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. namespace Tensorflow.NumPy
  6. {
  7. public class NDArrayConverter
  8. {
  9. public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
  10. => nd.dtype switch
  11. {
  12. TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
  13. TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data),
  14. _ => throw new NotImplementedException("")
  15. };
  16. static T Scalar<T>(float input)
  17. => Type.GetTypeCode(typeof(T)) switch
  18. {
  19. TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
  20. TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
  21. _ => throw new NotImplementedException("")
  22. };
  23. static T Scalar<T>(long input)
  24. => Type.GetTypeCode(typeof(T)) switch
  25. {
  26. TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
  27. TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
  28. _ => throw new NotImplementedException("")
  29. };
  30. public static unsafe Array ToMultiDimArray<T>(NDArray nd) where T : unmanaged
  31. {
  32. var ret = Array.CreateInstance(typeof(T), nd.shape.as_int_list());
  33. var addr = ret switch
  34. {
  35. T[] array => Addr(array),
  36. T[,] array => Addr(array),
  37. T[,,] array => Addr(array),
  38. T[,,,] array => Addr(array),
  39. T[,,,,] array => Addr(array),
  40. T[,,,,,] array => Addr(array),
  41. _ => throw new NotImplementedException("")
  42. };
  43. System.Buffer.MemoryCopy(nd.data.ToPointer(), addr, nd.bytesize, nd.bytesize);
  44. return ret;
  45. }
  46. #region multiple array
  47. static unsafe T* Addr<T>(T[] array) where T : unmanaged
  48. {
  49. fixed (T* a = &array[0])
  50. return a;
  51. }
  52. static unsafe T* Addr<T>(T[,] array) where T : unmanaged
  53. {
  54. fixed (T* a = &array[0, 0])
  55. return a;
  56. }
  57. static unsafe T* Addr<T>(T[,,] array) where T : unmanaged
  58. {
  59. fixed (T* a = &array[0, 0, 0])
  60. return a;
  61. }
  62. static unsafe T* Addr<T>(T[,,,] array) where T : unmanaged
  63. {
  64. fixed (T* a = &array[0, 0, 0, 0])
  65. return a;
  66. }
  67. static unsafe T* Addr<T>(T[,,,,] array) where T : unmanaged
  68. {
  69. fixed (T* a = &array[0, 0, 0, 0, 0])
  70. return a;
  71. }
  72. static unsafe T* Addr<T>(T[,,,,,] array) where T : unmanaged
  73. {
  74. fixed (T* a = &array[0, 0, 0, 0, 0, 0])
  75. return a;
  76. }
  77. #endregion
  78. }
  79. }