|
|
@@ -71,61 +71,32 @@ namespace Tensorflow |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public TensorShape(params object[] dims) |
|
|
|
public TensorShape(params int[] dims) |
|
|
|
{ |
|
|
|
Array arr; |
|
|
|
switch (dims.Length) |
|
|
|
{ |
|
|
|
case 0: shape = new Shape(new int[0]); break; |
|
|
|
case 1: shape = Shape.Vector((int)dims[0]); break; |
|
|
|
case 2: shape = Shape.Matrix(dims[0], dims[1]); break; |
|
|
|
default: shape = new Shape(dims); break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (dims.Length == 1) |
|
|
|
public TensorShape(int[][] dims) |
|
|
|
{ |
|
|
|
if(dims.Length == 1) |
|
|
|
{ |
|
|
|
switch (dims[0]) |
|
|
|
switch (dims[0].Length) |
|
|
|
{ |
|
|
|
case int[] intarr: |
|
|
|
arr = intarr; |
|
|
|
break; |
|
|
|
case long[] longarr: |
|
|
|
arr = longarr; |
|
|
|
break; |
|
|
|
case object[] objarr: |
|
|
|
arr = objarr; |
|
|
|
break; |
|
|
|
case int _: |
|
|
|
case long _: |
|
|
|
arr = dims; |
|
|
|
break; |
|
|
|
case null: //==Binding.None |
|
|
|
arr = dims; |
|
|
|
break; |
|
|
|
default: |
|
|
|
Binding.print(dims); |
|
|
|
throw new ArgumentException(nameof(dims)); |
|
|
|
case 0: shape = new Shape(new int[0]); break; |
|
|
|
case 1: shape = Shape.Vector((int)dims[0][0]); break; |
|
|
|
case 2: shape = Shape.Matrix(dims[0][0], dims[1][2]); break; |
|
|
|
default: shape = new Shape(dims[0]); break; |
|
|
|
} |
|
|
|
} else |
|
|
|
arr = dims; |
|
|
|
|
|
|
|
var intdims = new int[arr.Length]; |
|
|
|
for (int i = 0; i < arr.Length; i++) |
|
|
|
{ |
|
|
|
var val = arr.GetValue(i); |
|
|
|
if (val == Binding.None) |
|
|
|
intdims[i] = -1; |
|
|
|
else |
|
|
|
intdims[i] = Converts.ToInt32(val); |
|
|
|
} |
|
|
|
|
|
|
|
switch (intdims.Length) |
|
|
|
else |
|
|
|
{ |
|
|
|
case 0: |
|
|
|
shape = new Shape(new int[0]); |
|
|
|
break; |
|
|
|
case 1: |
|
|
|
shape = Shape.Vector((int) intdims[0]); |
|
|
|
break; |
|
|
|
case 2: |
|
|
|
shape = Shape.Matrix(intdims[0], intdims[1]); |
|
|
|
break; |
|
|
|
default: |
|
|
|
shape = new Shape(intdims); |
|
|
|
break; |
|
|
|
throw new NotImplementedException("TensorShape int[][] dims"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -232,8 +203,6 @@ namespace Tensorflow |
|
|
|
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); |
|
|
|
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); |
|
|
|
|
|
|
|
public static implicit operator TensorShape(object[] dims) => new TensorShape(dims); |
|
|
|
|
|
|
|
public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes |
|
|
|
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); |
|
|
|
|
|
|
@@ -260,16 +229,5 @@ namespace Tensorflow |
|
|
|
|
|
|
|
public static explicit operator (int, int, int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 8 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5], shape.dims[6], shape.dims[7]) : (0, 0, 0, 0, 0, 0, 0, 0); |
|
|
|
public static implicit operator TensorShape((int, int, int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); |
|
|
|
|
|
|
|
public static implicit operator TensorShape(int?[] dims) => new TensorShape(dims); |
|
|
|
public static implicit operator TensorShape(int? dim) => new TensorShape(dim); |
|
|
|
public static implicit operator TensorShape((object, object) dims) => new TensorShape(dims.Item1, dims.Item2); |
|
|
|
public static implicit operator TensorShape((object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); |
|
|
|
public static implicit operator TensorShape((object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); |
|
|
|
public static implicit operator TensorShape((object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); |
|
|
|
public static implicit operator TensorShape((object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); |
|
|
|
public static implicit operator TensorShape((object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7); |
|
|
|
public static implicit operator TensorShape((object, object, object, object, object, object, object, object) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6, dims.Item7, dims.Item8); |
|
|
|
|
|
|
|
} |
|
|
|
} |