You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Shape.cs 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. namespace Tensorflow
  6. {
  7. public class Shape
  8. {
  9. public int ndim => _dims == null ? -1 : _dims.Length;
  10. long[] _dims;
  11. public long[] dims => _dims;
  12. private Shape()
  13. {
  14. }
  15. public Shape(TensorShapeProto proto)
  16. {
  17. _dims = proto.Dim.Select(x => x.Size).ToArray();
  18. }
  19. public void Deconstruct(out long h, out long w)
  20. {
  21. h = dims[0];
  22. w = dims[1];
  23. }
  24. public Shape(params int[] dims)
  25. => _dims = dims?.Select(x => Convert.ToInt64(x))?.ToArray();
  26. public Shape(params long[] dims)
  27. => _dims = dims;
  28. public static implicit operator Shape(int dims)
  29. => new Shape(dims);
  30. public static implicit operator Shape(long[] dims)
  31. => dims == null ? null : new Shape(dims);
  32. public static implicit operator Shape(int[] dims)
  33. => dims == null ? null : new Shape(dims);
  34. public static implicit operator Shape((int, int) dims)
  35. => new Shape(dims.Item1, dims.Item2);
  36. public static implicit operator Shape((long, long) dims)
  37. => new Shape(dims.Item1, dims.Item2);
  38. public static implicit operator Shape((int, int, int) dims)
  39. => new Shape(dims.Item1, dims.Item2, dims.Item3);
  40. public static implicit operator Shape((long, long, long) dims)
  41. => new Shape(dims.Item1, dims.Item2, dims.Item3);
  42. public static implicit operator Shape((int, int, int, int) dims)
  43. => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
  44. public static implicit operator Shape((long, long, long, long) dims)
  45. => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
  46. public static implicit operator int[](Shape shape)
  47. => shape.dims.Select(x => (int)x).ToArray();
  48. public static implicit operator long[](Shape shape)
  49. => shape.dims;
  50. public bool IsEmpty => size == 0;
  51. public bool IsScalar => ndim == 0;
  52. public bool IsNull => _dims == null;
  53. public bool IsFullyDefined => ndim > -1 && dims.Count(x => x < 1) == 0;
  54. public static Shape Scalar => new Shape(new long[0]);
  55. public static Shape Null => new Shape();
  56. public long this[int n]
  57. {
  58. get => dims[n];
  59. set => dims[n] = value;
  60. }
  61. public Shape this[Slice slice]
  62. {
  63. get
  64. {
  65. if (!slice.Stop.HasValue)
  66. slice.Stop = dims.Length - slice.Start + 1;
  67. if (slice.Start.HasValue == false || slice.Length.HasValue == false)
  68. throw new ArgumentException("Slice must has Start and Length.");
  69. return new Shape(dims.Skip(slice.Start.Value)
  70. .Take(slice.Length.Value)
  71. .ToArray());
  72. }
  73. }
  74. /// <summary>
  75. /// Returns the size this shape represents.
  76. /// </summary>
  77. public long size
  78. {
  79. get
  80. {
  81. // scalar
  82. if (ndim == 0)
  83. return 1;
  84. var computed = 1L;
  85. for (int i = 0; i < _dims.Length; i++)
  86. {
  87. var val = _dims[i];
  88. if (val == 0)
  89. return 0;
  90. else if (val < 0)
  91. continue;
  92. computed *= val;
  93. }
  94. return computed;
  95. }
  96. }
  97. public bool is_compatible_with(Shape shape2)
  98. {
  99. if (dims != null && shape2.dims != null)
  100. {
  101. if (dims.Contains(-1) || shape2.dims.Contains(-1))
  102. return true;
  103. if (size != shape2.size)
  104. return false;
  105. }
  106. return true;
  107. }
  108. public Shape with_rank_at_least(int rank)
  109. {
  110. if (ndim < rank)
  111. throw new ValueError($"Shape {this} must have rank at least {rank}");
  112. else
  113. return this;
  114. }
  115. public Shape with_rank(int rank)
  116. {
  117. return merge_with(unknown_shape(rank: rank));
  118. }
  119. /// <summary>
  120. /// Returns an unknown Shape, optionally with a known rank.
  121. /// </summary>
  122. /// <param name="rank"></param>
  123. /// <returns></returns>
  124. public Shape unknown_shape(int rank = -1)
  125. {
  126. if (rank == -1)
  127. return Shape.Null;
  128. else
  129. return new Shape(Enumerable.Repeat(-1L, rank).ToArray());
  130. }
  131. public Shape concatenate(long[] other)
  132. {
  133. return concatenate(new Shape(other));
  134. }
  135. /// <summary>
  136. /// Returns the concatenation of the dimension in `self` and `other`.
  137. /// </summary>
  138. /// <param name="other"></param>
  139. /// <returns></returns>
  140. public Shape concatenate(Shape other)
  141. {
  142. var otherShape = other;
  143. if (ndim < 0 || otherShape.ndim < 0)
  144. return Shape.Null;
  145. else
  146. {
  147. var concatenate_dims = new long[ndim + otherShape.ndim];
  148. for (int i = 0; i < ndim; i++)
  149. concatenate_dims[i] = dims[i];
  150. for (int i = 0; i < otherShape.ndim; i++)
  151. concatenate_dims[ndim + i] = otherShape.dims[i];
  152. return new Shape(concatenate_dims);
  153. }
  154. }
  155. /// <summary>
  156. /// Returns a `Shape` combining the information in `self` and `other`.
  157. /// </summary>
  158. /// <param name="other"></param>
  159. /// <returns></returns>
  160. public Shape merge_with(Shape other)
  161. {
  162. if (dims == null)
  163. return other;
  164. var new_dims = new List<long>();
  165. foreach (var i in Enumerable.Range(0, ndim))
  166. {
  167. var dim = new Dimension(dims[i]);
  168. var merged = dim.merge_with(new Dimension(other.dims[i]));
  169. new_dims.Add(merged.value);
  170. }
  171. return new Shape(new_dims.ToArray());
  172. }
  173. public int[] as_int_list()
  174. {
  175. return _dims.Select(x => (int)x).ToArray();
  176. }
  177. public void assert_has_rank(int rank)
  178. {
  179. if (rank != ndim)
  180. throw new ValueError(String.Format("Shape {0} must have rank {1}", ndim, rank));
  181. }
  182. public override bool Equals(object obj)
  183. {
  184. switch (obj)
  185. {
  186. case Shape shape1:
  187. if (ndim == -1 && shape1.ndim == -1)
  188. return false;
  189. else if (ndim != shape1.ndim)
  190. return false;
  191. return Enumerable.SequenceEqual(shape1.dims, dims);
  192. case long[] shape2:
  193. if (ndim != shape2.Length)
  194. return false;
  195. return Enumerable.SequenceEqual(dims, shape2);
  196. default:
  197. return false;
  198. }
  199. }
  200. public override string ToString()
  201. => ndim switch
  202. {
  203. -1 => "<unknown>",
  204. 0 => "()",
  205. 1 => $"({dims[0]},)",
  206. _ => $"({string.Join(", ", _dims).Replace("-1", "None")})"
  207. };
  208. }
  209. }