Browse Source

change tensor shape to Shape.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
8784c31cb3
33 changed files with 118 additions and 72 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Framework/tensor_shape.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Gradients/image_grad.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  5. +7
    -0
      src/TensorFlowNET.Core/NumPy/Axis.cs
  6. +3
    -2
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  7. +27
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
  10. +9
    -9
      src/TensorFlowNET.Core/Operations/array_ops.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  12. +11
    -11
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
  15. +8
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  16. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  17. +9
    -9
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  18. +2
    -1
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  19. +3
    -1
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs
  22. +1
    -1
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  23. +3
    -3
      src/TensorFlowNET.Keras/BackendImpl.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  25. +2
    -2
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  26. +1
    -1
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs
  27. +3
    -3
      src/TensorFlowNET.Keras/tf.layers.cs
  28. +2
    -2
      test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs
  29. +1
    -1
      test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
  30. +1
    -1
      test/TensorFlowNET.UnitTest/Basics/VariableTest.cs
  31. +1
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs
  32. +1
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs
  33. +5
    -5
      test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

+ 1
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -155,7 +155,7 @@ namespace Tensorflow
switch (a) switch (a)
{ {
case Tensor tensor: case Tensor tensor:
return tensor.shape[0];
return (int)tensor.shape[0];
case Tensors arr: case Tensors arr:
return arr.Length; return arr.Length;
case Array arr: case Array arr:


+ 2
- 2
src/TensorFlowNET.Core/Framework/tensor_shape.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Framework
{ {
public static void assert_is_compatible_with(this Tensor self, Tensor other) public static void assert_is_compatible_with(this Tensor self, Tensor other)
{ {
if (!self.is_compatible_with(other))
/*if (!self.is_compatible_with(other))
{ {
var selfDim = self.shape var selfDim = self.shape
.Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString()) .Aggregate(new StringBuilder("{"), (sb, i) => sb.Append(i).Append(", "), sb => sb.ToString())
@@ -21,7 +21,7 @@ namespace Tensorflow.Framework
.Replace(", }", "}"); .Replace(", }", "}");


throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible"); throw new ArgumentException($"Dimensions {selfDim} and {otherDim} are not compatible");
}
}*/
} }


public static bool is_compatible_with(this Tensor self, Tensor other) public static bool is_compatible_with(this Tensor self, Tensor other)


+ 2
- 2
src/TensorFlowNET.Core/Gradients/image_grad.cs View File

@@ -27,10 +27,10 @@ namespace Tensorflow.Gradients
{ {
var grad = grads[0]; var grad = grads[0];
var image = op.inputs[0]; var image = op.inputs[0];
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
var shape = new TensorShape(image.shape.dims.Skip(1).Take(2).ToArray());
Tensor image_shape = null; Tensor image_shape = null;
if (shape.is_fully_defined()) if (shape.is_fully_defined())
image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray());
image_shape = constant_op.constant(image.shape.dims.Skip(1).Take(2).ToArray());
else else
image_shape = array_ops.shape(image)["1:3"]; image_shape = array_ops.shape(image)["1:3"];




+ 2
- 2
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -195,7 +195,7 @@ namespace Tensorflow.Gradients


if (op is EagerOperation op_eager && if (op is EagerOperation op_eager &&
op_eager.SkipInputIndices.Contains(1) && op_eager.SkipInputIndices.Contains(1) &&
y.NDims == 0)
y.ndim == 0)
{ {
return new Tensor[] return new Tensor[]
{ {
@@ -759,7 +759,7 @@ namespace Tensorflow.Gradients


if (op is EagerOperation op_eager && if (op is EagerOperation op_eager &&
op_eager.SkipInputIndices.Contains(1) && op_eager.SkipInputIndices.Contains(1) &&
y.NDims == 0)
y.ndim == 0)
{ {
x = math_ops.conj(x); x = math_ops.conj(x);
y = math_ops.conj(y); y = math_ops.conj(y);


+ 7
- 0
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;


namespace Tensorflow namespace Tensorflow
@@ -22,6 +23,12 @@ namespace Tensorflow


public static implicit operator Axis(int[] axis) public static implicit operator Axis(int[] axis)
=> new Axis(axis); => new Axis(axis);

public static implicit operator Axis(long[] shape)
=> new Axis(shape.Select(x => (int)x).ToArray());

public static implicit operator Axis(Shape shape)
=> new Axis(shape.dims.Select(x => (int)x).ToArray());
} }
} }




+ 3
- 2
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -11,8 +11,9 @@ namespace Tensorflow.NumPy
Tensor _tensor; Tensor _tensor;
public TF_DataType dtype => _tensor.dtype; public TF_DataType dtype => _tensor.dtype;
public ulong size => _tensor.size; public ulong size => _tensor.size;
public ulong dtypesize => _tensor.itemsize;
public int ndim => _tensor.NDims;
public ulong dtypesize => _tensor.dtypesize;
public ulong bytesize => _tensor.bytesize;
public int ndim => _tensor.ndim;
public long[] dims => _tensor.dims.Select(x => Convert.ToInt64(x)).ToArray(); public long[] dims => _tensor.dims.Select(x => Convert.ToInt64(x)).ToArray();
public Shape shape => _tensor.shape; public Shape shape => _tensor.shape;
public IntPtr data => _tensor.TensorDataPointer; public IntPtr data => _tensor.TensorDataPointer;


+ 27
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -48,6 +48,12 @@ namespace Tensorflow
public static implicit operator Shape((long, long, long, long) dims) public static implicit operator Shape((long, long, long, long) dims)
=> new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); => new Shape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);


public static implicit operator int[](Shape shape)
=> shape.dims.Select(x => (int)x).ToArray();

public static implicit operator long[](Shape shape)
=> shape.dims;

public bool IsEmpty => size == 0; public bool IsEmpty => size == 0;


public bool IsScalar => ndim == 0; public bool IsScalar => ndim == 0;
@@ -55,6 +61,8 @@ namespace Tensorflow
public static Shape Scalar public static Shape Scalar
=> new Shape(new long[0]); => new Shape(new long[0]);


public long this[int n] => dims[n];

/// <summary> /// <summary>
/// Returns the size this shape represents. /// Returns the size this shape represents.
/// </summary> /// </summary>
@@ -81,6 +89,25 @@ namespace Tensorflow
} }
} }


public bool is_fully_defined()
{
return ndim > -1 && dims != null && dims.Count(x => x < 1) == 0;
}

public bool is_compatible_with(TensorShape shape2)
{
if (dims != null && shape2.dims != null)
{
if (dims.Contains(-1) || shape2.dims.Contains(-1))
return true;

if (size != (ulong)shape2.size)
return false;
}

return true;
}

public override bool Equals(object obj) public override bool Equals(object obj)
{ {
if(obj is Shape shape) if(obj is Shape shape)


+ 1
- 1
src/TensorFlowNET.Core/Operations/Distributions/normal.py.cs View File

@@ -92,7 +92,7 @@ namespace Tensorflow


public Tensor _batch_shape() public Tensor _batch_shape()
{ {
return array_ops.broadcast_static_shape(new Tensor(_loc.shape), new Tensor(_scale.shape));
return array_ops.broadcast_static_shape(new Tensor(_loc.shape.dims), new Tensor(_scale.shape.dims));
} }


protected override Tensor _log_prob(Tensor x) protected override Tensor _log_prob(Tensor x)


+ 2
- 2
src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs View File

@@ -27,9 +27,9 @@ namespace Tensorflow.Operations
{ {
var p = prefix; var p = prefix;
var p_static = tensor_util.constant_value(prefix); var p_static = tensor_util.constant_value(prefix);
if (p.NDims == 0)
if (p.ndim == 0)
p = array_ops.expand_dims(p, 0); p = array_ops.expand_dims(p, 0);
else if (p.NDims != 1)
else if (p.ndim != 1)
throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}"); throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}");


var s_tensor_shape = new TensorShape(suffix); var s_tensor_shape = new TensorShape(suffix);


+ 9
- 9
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -186,7 +186,7 @@ namespace Tensorflow


private static Tensor _constant_if_small(int value, Tensor shape) private static Tensor _constant_if_small(int value, Tensor shape)
{ {
return shape < 1000L;
return shape < 1000UL;
} }


private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name)
@@ -330,7 +330,7 @@ namespace Tensorflow
{ {
name = scope; name = scope;
var input_tensor = ops.convert_to_tensor(inputs); var input_tensor = ops.convert_to_tensor(inputs);
return constant_op.constant(input_tensor.NDims, dtype: tf.int32, name: name);
return constant_op.constant(input_tensor.ndim, dtype: tf.int32, name: name);
}); });
} }


@@ -340,7 +340,7 @@ namespace Tensorflow
{ {
name = scope; name = scope;
var input_tensor = ops.convert_to_tensor(input); var input_tensor = ops.convert_to_tensor(input);
var input_shape = tensor_util.to_shape(input_tensor.shape);
var input_shape = input_tensor.shape;
if (optimize && input_shape.ndim > 0) if (optimize && input_shape.ndim > 0)
return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name); return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name);
else else
@@ -364,7 +364,7 @@ namespace Tensorflow
tensor = ops.convert_to_tensor(tensor, name: "tensor"); tensor = ops.convert_to_tensor(tensor, name: "tensor");


// is_fully_defined return unexpected value. // is_fully_defined return unexpected value.
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{ {


} }
@@ -589,9 +589,9 @@ namespace Tensorflow
if (!tf.Context.executing_eagerly()) if (!tf.Context.executing_eagerly())
{ {
var input_shape = input.TensorShape; var input_shape = input.TensorShape;
if (optimize && input.NDims > -1 && input_shape.is_fully_defined())
if (optimize && input.ndim > -1 && input_shape.is_fully_defined())
{ {
var nd = np.array(input.shape).astype(out_type.as_system_dtype());
var nd = np.array(input.shape.dims).astype(out_type.as_system_dtype());
return constant_op.constant(nd, name: name); return constant_op.constant(nd, name: name);
} }
} }
@@ -607,7 +607,7 @@ namespace Tensorflow
name = scope; name = scope;


var input_tensor = ops.convert_to_tensor(input); var input_tensor = ops.convert_to_tensor(input);
var input_shape = tensor_util.to_shape(input_tensor.shape);
var input_shape = input_tensor.shape;
if (optimize) if (optimize)
{ {
if (input_shape.is_fully_defined()) if (input_shape.is_fully_defined())
@@ -633,7 +633,7 @@ namespace Tensorflow
tensor = ops.convert_to_tensor(tensor, name: "tensor"); tensor = ops.convert_to_tensor(tensor, name: "tensor");


// is_fully_defined return unexpected value. // is_fully_defined return unexpected value.
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
if (optimize && tensor.shape.is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{ {


} }
@@ -933,7 +933,7 @@ namespace Tensorflow
string name = "split") string name = "split")
{ {
if (num == -1) if (num == -1)
num = size_splits.shape[0];
num = (int)size_splits.shape[0];


return gen_array_ops.split_v(value, size_splits, axis, num, name: name); return gen_array_ops.split_v(value, size_splits, axis, num, name: name);
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -91,7 +91,7 @@ namespace Tensorflow
elem.dtype, elem.dtype,
size: tf.constant(n), size: tf.constant(n),
dynamic_size: false, dynamic_size: false,
element_shape: elem.shape.Skip(1).ToArray(),
element_shape: elem.shape.dims.Skip(1).ToArray(),
infer_shape: true)).ToList(); infer_shape: true)).ToList();


for (int index = 0; index < elems_ta.Count; index++) for (int index = 0; index < elems_ta.Count; index++)


+ 11
- 11
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -341,14 +341,14 @@ or rank = 4. Had rank = {0}", rank));
{ {
h = _get_dim(image, 0); // img_h == h[0], dynamic_h == h[1] h = _get_dim(image, 0); // img_h == h[0], dynamic_h == h[1]
w = _get_dim(image, 1); w = _get_dim(image, 1);
d = image.shape[3];
d = (int)image.shape[3];
} }
else else
{ {
bs = image.shape[0];
bs = (int)image.shape[0];
h = _get_dim(image, 1); h = _get_dim(image, 1);
w = _get_dim(image, 2); w = _get_dim(image, 2);
d = image.shape[3];
d = (int)image.shape[3];
} }


object hd, bbox_h_start; object hd, bbox_h_start;
@@ -1115,7 +1115,7 @@ new_height, new_width");
array_ops.expand_dims(tf.constant(3), 0)); array_ops.expand_dims(tf.constant(3), 0));
var multiples = array_ops.concat(new Tensor[] { shape_list }, 0); var multiples = array_ops.concat(new Tensor[] { shape_list }, 0);
var rgb = array_ops.tile(images, multiples, name: name); var rgb = array_ops.tile(images, multiples, name: name);
int[] rgb_temp = images.shape.Take(images.shape.Length - 1).ToArray();
int[] rgb_temp = images.shape.dims.Take(images.shape.ndim - 1).Select(x => (int)x).ToArray();
rgb.set_shape(array_ops.concat(new Tensor[] { ops.convert_to_tensor(rgb_temp) }, 3)); rgb.set_shape(array_ops.concat(new Tensor[] { ops.convert_to_tensor(rgb_temp) }, 3));
return rgb; return rgb;
}); });
@@ -1459,7 +1459,7 @@ new_height, new_width");


// shape takes an int, python code passes size, a Tensor. NDims is the only int type // shape takes an int, python code passes size, a Tensor. NDims is the only int type
// i could think of a Tensor having. it might be incorrect tho, so keep that in mind. // i could think of a Tensor having. it might be incorrect tho, so keep that in mind.
return array_ops.reshape(g, shape: new int[] { size.NDims, size.NDims, 1, 1 });
return array_ops.reshape(g, shape: new int[] { size.ndim, size.ndim, 1, 1 });
} }


internal static (Tensor, Tensor) _ssim_per_channel(Tensor img1, Tensor img2, float max_val = 1f, internal static (Tensor, Tensor) _ssim_per_channel(Tensor img1, Tensor img2, float max_val = 1f,
@@ -1487,7 +1487,7 @@ new_height, new_width");
img1 = array_ops.identity(img1); img1 = array_ops.identity(img1);


var kernel = _fspecial_gauss(filter_size_tensor, filter_sigma_tensor); var kernel = _fspecial_gauss(filter_size_tensor, filter_sigma_tensor);
kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 }));
kernel = array_ops.tile(kernel, multiples: new Tensor(new int[] { 1, 1, (int)shape1_tensor.dims[shape1_tensor.dims.Length - 2], 1 }));


float compensation = 1.0f; float compensation = 1.0f;


@@ -1503,8 +1503,8 @@ new_height, new_width");
(Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2); (Tensor luminance, Tensor cs) = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2);


var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32); var axes = constant_op.constant(new[] { -3, -2 }, dtype: dtypes.int32);
var ssim_val = math_ops.reduce_mean(luminance * cs, new(axes.dims));
cs = math_ops.reduce_mean(cs, new(axes.dims));
var ssim_val = math_ops.reduce_mean(luminance * cs, axes.dims);
cs = math_ops.reduce_mean(cs, axes.dims);
return (ssim_val, cs); return (ssim_val, cs);
} }


@@ -1685,7 +1685,7 @@ new_height, new_width");
var kernels_tf = constant_op.constant(kernels, dtype: image.dtype); var kernels_tf = constant_op.constant(kernels, dtype: image.dtype);


kernels_tf = array_ops.tile( kernels_tf = array_ops.tile(
kernels_tf, new Tensor(new int[] { 1, 1, image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters");
kernels_tf, new Tensor(new int[] { 1, 1, (int)image_shape.dims[image_shape.dims.Length - 2], 1 }), name: "sobel_filters");


var pad_sizes = new int[,] { { 0, 0 }, { 1, 1 }, { 1, 1 }, { 0, 0 } }; var pad_sizes = new int[,] { { 0, 0 }, { 1, 1 }, { 1, 1 }, { 0, 0 } };
var padded = array_ops.pad(image, new Tensor(pad_sizes), mode: "reflect"); var padded = array_ops.pad(image, new Tensor(pad_sizes), mode: "reflect");
@@ -1966,8 +1966,8 @@ new_height, new_width");
Tensor index_offsets, indices, sorted_scores, sorted_boxes, sorted_scores_indices; Tensor index_offsets, indices, sorted_scores, sorted_boxes, sorted_scores_indices;
using (ops.name_scope("sort_scores_and_boxes")) using (ops.name_scope("sort_scores_and_boxes"))
{ {
batch_size = array_ops.shape(boxes).dims[0];
num_boxes = array_ops.shape(boxes).dims[1];
batch_size = (int)array_ops.shape(boxes).dims[0];
num_boxes = (int)array_ops.shape(boxes).dims[1];
sorted_scores_indices = null; /*sort_ops.argsort( sorted_scores_indices = null; /*sort_ops.argsort(
scores, axis: 1, direction: "DESCENDING); */ scores, axis: 1, direction: "DESCENDING); */
index_offsets = math_ops.range(batch_size) * num_boxes; index_offsets = math_ops.range(batch_size) * num_boxes;


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -178,7 +178,7 @@ namespace Tensorflow
logits = ops.convert_to_tensor(logits); logits = ops.convert_to_tensor(logits);


var shape = logits.shape; var shape = logits.shape;
bool is_last_dim = dim == -1 || dim == shape.Length - 1;
bool is_last_dim = dim == -1 || dim == shape.ndim - 1;
if (is_last_dim) if (is_last_dim)
return compute_op(logits, name); return compute_op(logits, name);




+ 1
- 1
src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
{ {
get get
{ {
return _row_splits.shape[0] - 1;
return (int)_row_splits.shape[0] - 1;
} }
} }




+ 8
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -145,6 +145,10 @@ namespace Tensorflow
byte[,] val => InitTensor(val, shape, dtype), byte[,] val => InitTensor(val, shape, dtype),
byte[,,] val => InitTensor(val, shape, dtype), byte[,,] val => InitTensor(val, shape, dtype),
byte[,,,] val => InitTensor(val, shape, dtype), byte[,,,] val => InitTensor(val, shape, dtype),
short[] val => InitTensor(val, shape, dtype),
short[,] val => InitTensor(val, shape, dtype),
short[,,] val => InitTensor(val, shape, dtype),
short[,,,] val => InitTensor(val, shape, dtype),
int[] val => InitTensor(val, shape, dtype), int[] val => InitTensor(val, shape, dtype),
int[,] val => InitTensor(val, shape, dtype), int[,] val => InitTensor(val, shape, dtype),
int[,,] val => InitTensor(val, shape, dtype), int[,,] val => InitTensor(val, shape, dtype),
@@ -153,6 +157,10 @@ namespace Tensorflow
long[,] val => InitTensor(val, shape, dtype), long[,] val => InitTensor(val, shape, dtype),
long[,,] val => InitTensor(val, shape, dtype), long[,,] val => InitTensor(val, shape, dtype),
long[,,,] val => InitTensor(val, shape, dtype), long[,,,] val => InitTensor(val, shape, dtype),
ulong[] val => InitTensor(val, shape, dtype),
ulong[,] val => InitTensor(val, shape, dtype),
ulong[,,] val => InitTensor(val, shape, dtype),
ulong[,,,] val => InitTensor(val, shape, dtype),
float[] val => InitTensor(val, shape, dtype), float[] val => InitTensor(val, shape, dtype),
float[,] val => InitTensor(val, shape, dtype), float[,] val => InitTensor(val, shape, dtype),
float[,,] val => InitTensor(val, shape, dtype), float[,,] val => InitTensor(val, shape, dtype),


+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow
if (typeof(T).as_tf_dtype() != dtype) if (typeof(T).as_tf_dtype() != dtype)
throw new ArrayTypeMismatchException($"dtype {dtype} mismatch."); throw new ArrayTypeMismatchException($"dtype {dtype} mismatch.");


if (NDims == 0 && size == 1) //is it a scalar?
if (ndim == 0 && size == 1) //is it a scalar?
{ {
unsafe unsafe
{ {
@@ -28,7 +28,7 @@ namespace Tensorflow


//types match, no need to perform cast //types match, no need to perform cast
var ret = new T[size]; var ret = new T[size];
var len = (long)(size * itemsize);
var len = (long)(size * dtypesize);
var src = (T*)buffer; var src = (T*)buffer;


fixed (T* dst = ret) fixed (T* dst = ret)


+ 9
- 9
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -72,17 +72,17 @@ namespace Tensorflow
/// </summary> /// </summary>
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
public ulong dtypesize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dtypesize;
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
public int NDims => rank;
public int ndim => rank;


/// <summary> /// <summary>
/// The name of the device on which this tensor will be produced, or null. /// The name of the device on which this tensor will be produced, or null.
/// </summary> /// </summary>
public virtual string Device => op.Device; public virtual string Device => op.Device;
public int[] dims => shape;
public long[] dims => shape.dims;


/// <summary> /// <summary>
/// Used for keep other pointer when do implicit operating /// Used for keep other pointer when do implicit operating
@@ -107,7 +107,7 @@ namespace Tensorflow
/// Returns the shape of a tensor. /// Returns the shape of a tensor.
/// </summary> /// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> /// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks>
public int[] shape
public Shape shape
{ {
get get
{ {
@@ -123,7 +123,7 @@ namespace Tensorflow
dims[i] = c_api.TF_Dim(_handle, i); dims[i] = c_api.TF_Dim(_handle, i);
} }


return dims.Select(x => ((IConvertible)x).ToInt32(CultureInfo.InvariantCulture)).ToArray();
return dims;
} }


set set
@@ -131,7 +131,7 @@ namespace Tensorflow
if (value == null) if (value == null)
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle); c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, tf.Status.Handle);
else else
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, tf.Status.Handle);
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.dims, value.ndim, tf.Status.Handle);


tf.Status.Check(true); tf.Status.Check(true);
} }
@@ -139,10 +139,10 @@ namespace Tensorflow


public int[] _shape_tuple() public int[] _shape_tuple()
{ {
return rank < 0 ? null : shape;
return rank < 0 ? null : shape.dims.Select(x => (int)x).ToArray();
} }


public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape);
public TensorShape TensorShape => rank < 0 ? new TensorShape() : shape;


/// <summary> /// <summary>
/// Keras History: (Layer, (node_index, tensor_index)) /// Keras History: (Layer, (node_index, tensor_index))


+ 2
- 1
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -109,7 +109,8 @@ namespace Tensorflow
var length = shape.size * (ulong)dtype.get_datatype_size(); var length = shape.size * (ulong)dtype.get_datatype_size();
var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length); var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length);
var tensor = TF_TensorData(handle); var tensor = TF_TensorData(handle);
System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length);
if (tensor != IntPtr.Zero)
System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length);
return handle; return handle;
} }




+ 3
- 1
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -124,6 +124,8 @@ namespace Tensorflow
return new EagerTensor(new[] { val }, Shape.Scalar); return new EagerTensor(new[] { val }, Shape.Scalar);
case long val: case long val:
return new EagerTensor(new[] { val }, Shape.Scalar); return new EagerTensor(new[] { val }, Shape.Scalar);
case ulong val:
return new EagerTensor(new[] { val }, Shape.Scalar);
case float val: case float val:
return new EagerTensor(new[] { val }, Shape.Scalar); return new EagerTensor(new[] { val }, Shape.Scalar);
case double val: case double val:
@@ -146,7 +148,7 @@ namespace Tensorflow
if (shape == null) if (shape == null)
return t; return t;


if (t.shape.Select(x => Convert.ToInt64(x)).SequenceEqual(shape.dims))
if (t.shape.dims.SequenceEqual(shape.dims))
return t; return t;


if (verify_shape) if (verify_shape)


+ 1
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -127,7 +127,7 @@ namespace Tensorflow
} }
else if (values is Tensor tensor && tensor.IsReferencedByNDArray) else if (values is Tensor tensor && tensor.IsReferencedByNDArray)
{ {
var len = tensor.itemsize * tensor.size;
var len = tensor.dtypesize * tensor.size;
byte[] bytes = tensor.BufferToArray(); byte[] bytes = tensor.BufferToArray();
tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes); tensor_proto.TensorContent = Google.Protobuf.ByteString.CopyFrom(bytes);
} }


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/SaveableObject.cs View File

@@ -45,7 +45,7 @@ namespace Tensorflow
var restored_tensor = restored_tensors[0]; var restored_tensor = restored_tensors[0];
return gen_state_ops.assign(op, return gen_state_ops.assign(op,
restored_tensor, restored_tensor,
validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined());
validate_shape: restored_shapes == null && op.shape.is_fully_defined());
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -50,7 +50,7 @@ namespace Tensorflow
public Operation Op => _variable.op; public Operation Op => _variable.op;


public TF_DataType dtype => _variable.dtype; public TF_DataType dtype => _variable.dtype;
public TensorShape shape => tensor_util.to_shape(_variable.shape);
public TensorShape shape => _variable.shape;
public string Device => ""; public string Device => "";


public string Name => _variable.name; public string Name => _variable.name;


+ 3
- 3
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -297,8 +297,8 @@ namespace Tensorflow.Keras
// x = permute_dimensions(x, [0, 3, 1, 2]); // x = permute_dimensions(x, [0, 3, 1, 2]);
throw new NotImplementedException(""); throw new NotImplementedException("");


int new_height = original_shape[rows] < 0 ? -1 : original_shape[rows] * height_factor;
int new_width = original_shape[cols] < 0 ? -1 : original_shape[cols] * width_factor;
int new_height = original_shape[rows] < 0 ? -1 : (int)original_shape[rows] * height_factor;
int new_width = original_shape[cols] < 0 ? -1 : (int)original_shape[cols] * width_factor;


TensorShape output_shape = data_format == "channels_first" ? TensorShape output_shape = data_format == "channels_first" ?
(-1, -1, new_height, new_width) : (-1, new_height, new_width, -1); (-1, -1, new_height, new_width) : (-1, new_height, new_width, -1);
@@ -316,7 +316,7 @@ namespace Tensorflow.Keras
{ {
if(axis < 0) if(axis < 0)
{ {
var rank = tensors[0].NDims;
var rank = tensors[0].ndim;
if (rank > -1) if (rank > -1)
axis += rank; axis += rank;
else else


+ 1
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
{ {
this.args = args; this.args = args;
_process_tensorlike(); _process_tensorlike();
num_samples = args.X.shape[0];
num_samples = (int)args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size; _batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));


+ 2
- 2
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -63,8 +63,8 @@ namespace Tensorflow.Keras.Engine
{ {
var y_t_rank = y_t.rank; var y_t_rank = y_t.rank;
var y_p_rank = y_p.rank; var y_p_rank = y_p.rank;
var y_t_last_dim = y_t.shape[y_t.shape.Length - 1];
var y_p_last_dim = y_p.shape[y_p.shape.Length - 1];
var y_t_last_dim = y_t.shape[y_t.shape.ndim - 1];
var y_p_last_dim = y_p.shape[y_p.shape.ndim - 1];


bool is_binary = y_p_last_dim == 1; bool is_binary = y_p_last_dim == 1;
bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1; bool is_sparse_categorical = (y_t_rank < y_p_rank || y_t_last_dim == 1) && y_p_last_dim > 1;


+ 1
- 1
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Metrics
var y_true_rank = y_true.TensorShape.ndim; var y_true_rank = y_true.TensorShape.ndim;
// If the shape of y_true is (num_samples, 1), squeeze to (num_samples,) // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
if (y_true_rank != -1 && y_pred_rank != -1 if (y_true_rank != -1 && y_pred_rank != -1
&& y_true.shape.Length == y_pred.shape.Length)
&& y_true.shape.ndim == y_pred.shape.ndim)
y_true = array_ops.squeeze(y_true, axis: new[] { -1 }); y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
y_pred = math_ops.argmax(y_pred, -1); y_pred = math_ops.argmax(y_pred, -1);




+ 3
- 3
src/TensorFlowNET.Keras/tf.layers.cs View File

@@ -212,13 +212,13 @@ namespace Tensorflow.Keras
string data_format = "channels_last") string data_format = "channels_last")
{ {
var input_shape = inputs.shape; var input_shape = inputs.shape;
if (inputs.shape.Length == 0)
if (inputs.shape.ndim == 0)
throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()"); throw new ValueError($"Input 0 of layer flatten is incompatible with the layer: : expected min_ndim={1}, found ndim={0}. Full shape received: ()");


var premutation = new List<int>() { 0 }; var premutation = new List<int>() { 0 };
if (data_format == "channels_first" && inputs.NDims > 1)
if (data_format == "channels_first" && inputs.ndim > 1)
{ {
premutation.AddRange(Binding.range(2, inputs.NDims));
premutation.AddRange(Binding.range(2, inputs.ndim));
premutation.Add(1); premutation.Add(1);
inputs = array_ops.transpose(inputs, premutation.ToArray()); inputs = array_ops.transpose(inputs, premutation.ToArray());
} }


+ 2
- 2
test/TensorFlowNET.Native.UnitTest/Sessions/SessionTest.cs View File

@@ -40,7 +40,7 @@ namespace Tensorflow.Native.UnitTest.Sessions
csession.Run(s); csession.Run(s);
Tensor outTensor = csession.output_tensor(0); Tensor outTensor = csession.output_tensor(0);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
EXPECT_EQ(0, outTensor.ndim);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.ToArray<int>(); var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]); EXPECT_EQ(3 + 2, output_contents[0]);
@@ -61,7 +61,7 @@ namespace Tensorflow.Native.UnitTest.Sessions
outTensor = csession.output_tensor(0); outTensor = csession.output_tensor(0);
ASSERT_TRUE(outTensor != IntPtr.Zero); ASSERT_TRUE(outTensor != IntPtr.Zero);
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
EXPECT_EQ(0, outTensor.ndim); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.ToArray<int>(); output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]); EXPECT_EQ(-(7 + 2), output_contents[0]);


+ 1
- 1
test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs View File

@@ -66,7 +66,7 @@ namespace Tensorflow.Native.UnitTest.Tensors
long[] dims = { 2, 3 }; long[] dims = { 2, 3 };
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
EXPECT_EQ(2, t.NDims);
EXPECT_EQ(2, t.ndim);
EXPECT_EQ((int)dims[0], t.shape[0]); EXPECT_EQ((int)dims[0], t.shape[0]);
EXPECT_EQ(num_bytes, t.bytesize); EXPECT_EQ(num_bytes, t.bytesize);
t.Dispose(); t.Dispose();


+ 1
- 1
test/TensorFlowNET.UnitTest/Basics/VariableTest.cs View File

@@ -126,7 +126,7 @@ namespace TensorFlowNET.UnitTest.Basics
{ {
var x = tf.constant(new[,] { { 1, 2 } }); var x = tf.constant(new[,] { { 1, 2 } });
var neg_x = tf.negative(x); var neg_x = tf.negative(x);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 1, 2 }, neg_x.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 1, 2 }, neg_x.shape.dims));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>())); Assert.IsTrue(Enumerable.SequenceEqual(new[] { -1, -2 }, neg_x.numpy().ToArray<int>()));
} }




+ 1
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/ConstantTest.cs View File

@@ -145,7 +145,7 @@ namespace TensorFlowNET.UnitTest.Basics
var tensor = tf.constant(nd); var tensor = tf.constant(nd);
var data = tensor.numpy().ToArray<int>(); var data = tensor.numpy().ToArray<int>();


Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3 }, tensor.shape.dims));
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
} }




+ 1
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs View File

@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var b = tf.Variable(-0.73f, name: "bias"); var b = tf.Variable(-0.73f, name: "bias");
using var g = tf.GradientTape(); using var g = tf.GradientTape();
var pred = W * X + b; var pred = W * X + b;
var test = tf.slice(pred, new[] { 0 }, pred.shape);
var test = tf.slice(pred, new[] { 0 }, (int[])pred.shape);
var gradients = g.gradient(test, (W, b)); var gradients = g.gradient(test, (W, b));
Assert.AreEqual((float)gradients.Item1, 0f); Assert.AreEqual((float)gradients.Item1, 0f);
Assert.AreEqual((float)gradients.Item2, 10f); Assert.AreEqual((float)gradients.Item2, 10f);


+ 5
- 5
test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs View File

@@ -85,14 +85,14 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
{ { 1 }, { 2 }, { 3 } }, { { 1 }, { 2 }, { 3 } },
{ { 4 }, { 5 }, { 6 } } { { 4 }, { 5 }, { 6 } }
})); }));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, a.shape.dims));


var b = tf.constant(new[, ,] var b = tf.constant(new[, ,]
{ {
{ { 1 }, { 2 }, { 3 } }, { { 1 }, { 2 }, { 3 } },
{ { 4 }, { 5 }, { 6 } } { { 4 }, { 5 }, { 6 } }
}); });
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, b.shape.dims));
} }


[TestMethod] [TestMethod]
@@ -103,7 +103,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });


var concatValue = tf.concat(new[] { a, b, c }, axis: 0); var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims));
} }


[TestMethod] [TestMethod]
@@ -114,7 +114,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } }); var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } });


var concatValue = tf.concat(new[] { a, b, c }, axis: 0); var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims));
} }


[TestMethod] [TestMethod]
@@ -128,7 +128,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI


var splitValue = tf.split(value, 3, axis: 0); var splitValue = tf.split(value, 3, axis: 0);
Assert.AreEqual(3, splitValue.Length); Assert.AreEqual(3, splitValue.Length);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape));
Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 2 }, splitValue[0].shape.dims));
} }


#region ones/zeros like #region ones/zeros like


Loading…
Cancel
Save