You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ShapeHelper.cs 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. namespace Tensorflow.NumPy
  6. {
  7. internal class ShapeHelper
  8. {
  9. public static long GetSize(Shape shape)
  10. {
  11. // scalar
  12. if (shape.ndim == 0)
  13. return 1;
  14. var computed = 1L;
  15. for (int i = 0; i < shape.ndim; i++)
  16. {
  17. var val = shape.dims[i];
  18. if (val == 0)
  19. return 0;
  20. else if (val < 0)
  21. continue;
  22. computed *= val;
  23. }
  24. return computed;
  25. }
  26. public static long[] GetStrides(Shape shape)
  27. {
  28. var strides = new long[shape.ndim];
  29. if (shape.ndim == 0)
  30. return strides;
  31. strides[strides.Length - 1] = 1;
  32. for (int idx = strides.Length - 1; idx >= 1; idx--)
  33. strides[idx - 1] = strides[idx] * shape.dims[idx];
  34. return strides;
  35. }
  36. public static bool Equals(Shape shape, object target)
  37. {
  38. switch (target)
  39. {
  40. case Shape shape1:
  41. if (shape.ndim == -1 && shape1.ndim == -1)
  42. return false;
  43. else if (shape.ndim != shape1.ndim)
  44. return false;
  45. return Enumerable.SequenceEqual(shape1.dims, shape.dims);
  46. case long[] shape2:
  47. if (shape.ndim != shape2.Length)
  48. return false;
  49. return Enumerable.SequenceEqual(shape.dims, shape2);
  50. default:
  51. return false;
  52. }
  53. }
  54. public static string ToString(Shape shape)
  55. {
  56. return shape.ndim switch
  57. {
  58. -1 => "<unknown>",
  59. 0 => "()",
  60. 1 => $"({shape.dims[0]},)",
  61. _ => $"({string.Join(", ", shape.dims).Replace("-1", "None")})"
  62. };
  63. }
  64. public static long GetOffset(Shape shape, params int[] indices)
  65. {
  66. if (shape.ndim == 0 && indices.Length == 1)
  67. return indices[0];
  68. long offset = 0;
  69. var strides = shape.strides;
  70. for (int i = 0; i < indices.Length; i++)
  71. offset += strides[i] * indices[i];
  72. return offset;
  73. }
  74. }
  75. }