using System; using System.Collections.Generic; using System.Linq; using System.Text; namespace Tensorflow { public class Shape { public int ndim => _dims == null ? -1 : _dims.Length; long[] _dims; public long[] dims => _dims; private Shape() { } public Shape(TensorShapeProto proto) { _dims = proto.Dim.Select(x => x.Size).ToArray(); } public void Deconstruct(out long h, out long w) { h = dims[0]; w = dims[1]; } public Shape(params int[] dims) => _dims = dims?.Select(x => Convert.ToInt64(x))?.ToArray(); public Shape(params long[] dims) => _dims = dims; public static implicit operator Shape(int dims) => new Shape(dims); public static implicit operator Shape(long[] dims) => dims == null ? null : new Shape(dims); public static implicit operator Shape(int[] dims) => dims == null ? null : new Shape(dims); public static implicit operator Shape((int, int) dims) => new Shape(dims.Item1, dims.Item2); public static implicit operator Shape((long, long) dims) => new Shape(dims.Item1, dims.Item2); public static implicit operator Shape((int, int, int) dims) => new Shape(dims.Item1, dims.Item2, dims.Item3); public static implicit operator Shape((long, long, long) dims) => new Shape(dims.Item1, dims.Item2, dims.Item3); public static implicit operator Shape((int, int, int, int) dims) => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); public static implicit operator Shape((long, long, long, long) dims) => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); public static implicit operator int[](Shape shape) => shape.dims.Select(x => (int)x).ToArray(); public static implicit operator long[](Shape shape) => shape.dims; public bool IsEmpty => size == 0; public bool IsScalar => ndim == 0; public bool IsNull => _dims == null; public bool IsFullyDefined => ndim > -1 && dims.Count(x => x < 1) == 0; public static Shape Scalar => new Shape(new long[0]); public static Shape Null => new Shape(); public long this[int n] { get => dims[n]; set => dims[n] = value; } public Shape this[Slice slice] { get { if (!slice.Stop.HasValue) slice.Stop = dims.Length - slice.Start + 1; if (slice.Start.HasValue == false || slice.Length.HasValue == false) throw new ArgumentException("Slice must has Start and Length."); return new Shape(dims.Skip(slice.Start.Value) .Take(slice.Length.Value) .ToArray()); } } /// /// Returns the size this shape represents. /// public long size { get { // scalar if (ndim == 0) return 1; var computed = 1L; for (int i = 0; i < _dims.Length; i++) { var val = _dims[i]; if (val == 0) return 0; else if (val < 0) continue; computed *= val; } return computed; } } public bool is_compatible_with(Shape shape2) { if (dims != null && shape2.dims != null) { if (dims.Contains(-1) || shape2.dims.Contains(-1)) return true; if (size != shape2.size) return false; } return true; } public Shape with_rank_at_least(int rank) { if (ndim < rank) throw new ValueError($"Shape {this} must have rank at least {rank}"); else return this; } public Shape with_rank(int rank) { return merge_with(unknown_shape(rank: rank)); } /// /// Returns an unknown Shape, optionally with a known rank. /// /// /// public Shape unknown_shape(int rank = -1) { if (rank == -1) return Shape.Null; else return new Shape(Enumerable.Repeat(-1L, rank).ToArray()); } public Shape concatenate(long[] other) { return concatenate(new Shape(other)); } /// /// Returns the concatenation of the dimension in `self` and `other`. /// /// /// public Shape concatenate(Shape other) { var otherShape = other; if (ndim < 0 || otherShape.ndim < 0) return Shape.Null; else { var concatenate_dims = new long[ndim + otherShape.ndim]; for (int i = 0; i < ndim; i++) concatenate_dims[i] = dims[i]; for (int i = 0; i < otherShape.ndim; i++) concatenate_dims[ndim + i] = otherShape.dims[i]; return new Shape(concatenate_dims); } } /// /// Returns a `Shape` combining the information in `self` and `other`. /// /// /// public Shape merge_with(Shape other) { if (dims == null) return other; var new_dims = new List(); foreach (var i in Enumerable.Range(0, ndim)) { var dim = new Dimension(dims[i]); var merged = dim.merge_with(new Dimension(other.dims[i])); new_dims.Add(merged.value); } return new Shape(new_dims.ToArray()); } public int[] as_int_list() { return _dims.Select(x => (int)x).ToArray(); } public void assert_has_rank(int rank) { if (rank != ndim) throw new ValueError(String.Format("Shape {0} must have rank {1}", ndim, rank)); } public override bool Equals(object obj) { switch (obj) { case Shape shape1: if (ndim == -1 && shape1.ndim == -1) return false; else if (ndim != shape1.ndim) return false; return Enumerable.SequenceEqual(shape1.dims, dims); case long[] shape2: if (ndim != shape2.Length) return false; return Enumerable.SequenceEqual(dims, shape2); default: return false; } } public override string ToString() => ndim switch { -1 => "", 0 => "()", 1 => $"({dims[0]},)", _ => $"({string.Join(", ", _dims).Replace("-1", "None")})" }; } }