Browse Source

TensorShape: Fixed construction when passing int[] or long[]

tags/v0.12
Eli Belash 6 years ago
parent
commit
cc54fe19bf
1 changed files with 41 additions and 7 deletions
  1. +41
    -7
      src/TensorFlowNET.Core/Tensors/TensorShape.cs

+ 41
- 7
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -57,10 +57,36 @@ namespace Tensorflow


public TensorShape(params object[] dims) public TensorShape(params object[] dims)
{ {
var intdims = new int[dims.Length];
for (int i = 0; i < dims.Length; i++)
Array arr;

if (dims.Length == 1)
{
switch (dims[0])
{
case int[] intarr:
arr = intarr;
break;
case long[] longarr:
arr = longarr;
break;
case object[] objarr:
arr = objarr;
break;
case int _:
case long _:
arr = dims;
break;
default:
Binding.print(dims);
throw new ArgumentException(nameof(dims));
}
} else
arr = dims;

var intdims = new int[arr.Length];
for (int i = 0; i < arr.Length; i++)
{ {
var val = dims[i];
var val = arr.GetValue(i);
if (val == Binding.None) if (val == Binding.None)
intdims[i] = -1; intdims[i] = -1;
else else
@@ -69,10 +95,18 @@ namespace Tensorflow


switch (dims.Length) switch (dims.Length)
{ {
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int) intdims[0]); break;
case 2: shape = Shape.Matrix(intdims[0], intdims[1]); break;
default: shape = new Shape(intdims); break;
case 0:
shape = new Shape(new int[0]);
break;
case 1:
shape = Shape.Vector((int) intdims[0]);
break;
case 2:
shape = Shape.Matrix(intdims[0], intdims[1]);
break;
default:
shape = new Shape(intdims);
break;
} }
} }




Loading…
Cancel
Save