You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Shape.cs 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. /*****************************************************************************
  2. Copyright 2021 Haiping Chen. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using System.Linq;
  16. using System.Text;
  17. using Tensorflow.NumPy;
  18. namespace Tensorflow
  19. {
  20. public class Shape
  21. {
  22. public int ndim => _dims == null ? -1 : _dims.Length;
  23. long[] _dims;
  24. public long[] dims => _dims;
  25. public int rank => ndim;
  26. long[] _strides;
  27. public long[] strides
  28. {
  29. get
  30. {
  31. _strides = _strides ?? ShapeHelper.GetStrides(this);
  32. return _strides;
  33. }
  34. }
  35. #region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
  36. public int Length => ndim;
  37. public long[] Slice(int start, int length)
  38. {
  39. var slice = new long[length];
  40. Array.Copy(_dims, start, slice, 0, length);
  41. return slice;
  42. }
  43. #endregion
  44. private Shape()
  45. {
  46. }
  47. public Shape(TensorShapeProto proto)
  48. {
  49. _dims = proto.Dim.Select(x => x.Size).ToArray();
  50. }
  51. public void Deconstruct(out long h, out long w)
  52. {
  53. h = dims[0];
  54. w = dims[1];
  55. }
  56. public Shape(params int[] dims)
  57. => _dims = dims?.Select(x => Convert.ToInt64(x))?.ToArray();
  58. public Shape(params long[] dims)
  59. => _dims = dims;
  60. public static implicit operator Shape(int dims)
  61. => new Shape(dims);
  62. public static implicit operator Shape(long[] dims)
  63. => dims == null ? null : new Shape(dims);
  64. public static implicit operator Shape(int[] dims)
  65. => dims == null ? null : new Shape(dims);
  66. public static implicit operator Shape((int, int) dims)
  67. => new Shape(dims.Item1, dims.Item2);
  68. public static implicit operator Shape((long, long) dims)
  69. => new Shape(dims.Item1, dims.Item2);
  70. public static implicit operator Shape((int, int, int) dims)
  71. => new Shape(dims.Item1, dims.Item2, dims.Item3);
  72. public static implicit operator Shape((long, long, long) dims)
  73. => new Shape(dims.Item1, dims.Item2, dims.Item3);
  74. public static implicit operator Shape((int, int, int, int) dims)
  75. => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
  76. public static implicit operator Shape((long, long, long, long) dims)
  77. => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
  78. public static implicit operator Shape((int, int, int, int, int) dims)
  79. => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
  80. public static implicit operator Shape((long, long, long, long, long) dims)
  81. => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
  82. public static implicit operator int[](Shape shape)
  83. => shape.dims.Select(x => (int)x).ToArray();
  84. public static implicit operator long[](Shape shape)
  85. => shape.dims;
  86. public static implicit operator Tensor(Shape shape)
  87. => constant_op.constant(shape);
  88. public bool IsEmpty => size == 0;
  89. public bool IsScalar => ndim == 0;
  90. public bool IsNull => _dims == null;
  91. public bool IsFullyDefined => ndim > -1 && dims.Count(x => x < 1) == 0;
  92. public static Shape Scalar => new Shape(new long[0]);
  93. public static Shape Null => new Shape();
  94. public long this[int n]
  95. {
  96. get => n < 0 ? dims[ndim + n] : dims[n];
  97. set => dims[n] = value;
  98. }
  99. public Shape this[Slice slice]
  100. {
  101. get
  102. {
  103. if (!slice.Stop.HasValue)
  104. slice.Stop = dims.Length - slice.Start + 1;
  105. if (slice.Start.HasValue == false || slice.Length.HasValue == false)
  106. throw new ArgumentException("Slice must has Start and Length.");
  107. return new Shape(dims.Skip(slice.Start.Value)
  108. .Take(slice.Length.Value)
  109. .ToArray());
  110. }
  111. }
  112. /// <summary>
  113. /// Returns the size this shape represents.
  114. /// </summary>
  115. public long size => ShapeHelper.GetSize(this);
  116. public bool is_compatible_with(Shape shape2)
  117. {
  118. if (dims != null && shape2.dims != null)
  119. {
  120. if (dims.Contains(-1) || shape2.dims.Contains(-1))
  121. return true;
  122. if (size != shape2.size)
  123. return false;
  124. }
  125. return true;
  126. }
  127. public Shape with_rank_at_least(int rank)
  128. {
  129. if (ndim < rank)
  130. throw new ValueError($"Shape {this} must have rank at least {rank}");
  131. else
  132. return this;
  133. }
  134. public Shape with_rank(int rank)
  135. {
  136. return merge_with(unknown_shape(rank: rank));
  137. }
  138. /// <summary>
  139. /// Returns an unknown Shape, optionally with a known rank.
  140. /// </summary>
  141. /// <param name="rank"></param>
  142. /// <returns></returns>
  143. public Shape unknown_shape(int rank = -1)
  144. {
  145. if (rank == -1)
  146. return Shape.Null;
  147. else
  148. return new Shape(Enumerable.Repeat(-1L, rank).ToArray());
  149. }
  150. public Shape concatenate(long[] other)
  151. {
  152. return concatenate(new Shape(other));
  153. }
  154. /// <summary>
  155. /// Returns the concatenation of the dimension in `self` and `other`.
  156. /// </summary>
  157. /// <param name="other"></param>
  158. /// <returns></returns>
  159. public Shape concatenate(Shape other)
  160. {
  161. var otherShape = other;
  162. if (ndim < 0 || otherShape.ndim < 0)
  163. return Shape.Null;
  164. else
  165. {
  166. var concatenate_dims = new long[ndim + otherShape.ndim];
  167. for (int i = 0; i < ndim; i++)
  168. concatenate_dims[i] = dims[i];
  169. for (int i = 0; i < otherShape.ndim; i++)
  170. concatenate_dims[ndim + i] = otherShape.dims[i];
  171. return new Shape(concatenate_dims);
  172. }
  173. }
  174. /// <summary>
  175. /// Returns a `Shape` combining the information in `self` and `other`.
  176. /// </summary>
  177. /// <param name="other"></param>
  178. /// <returns></returns>
  179. public Shape merge_with(Shape other)
  180. {
  181. if (dims == null)
  182. return other;
  183. var new_dims = new List<long>();
  184. foreach (var i in Enumerable.Range(0, ndim))
  185. {
  186. var dim = new Dimension(dims[i]);
  187. var merged = dim.merge_with(new Dimension(other.dims[i]));
  188. new_dims.Add(merged.value);
  189. }
  190. return new Shape(new_dims.ToArray());
  191. }
  192. public int[] as_int_list()
  193. {
  194. return _dims.Select(x => (int)x).ToArray();
  195. }
  196. public void assert_has_rank(int rank)
  197. {
  198. if (rank != ndim)
  199. throw new ValueError(String.Format("Shape {0} must have rank {1}", ndim, rank));
  200. }
  201. public override bool Equals(object obj) => ShapeHelper.Equals(this, obj);
  202. public override string ToString() => ShapeHelper.ToString(this);
  203. public static bool operator ==(Shape a, Shape b)
  204. => ShapeHelper.Equals(a, b);
  205. public static bool operator !=(Shape a, Shape b)
  206. => !ShapeHelper.Equals(a, b);
  207. }
  208. }