diff --git a/src/TensorFlowNET.Core/APIs/tf.reshape.cs b/src/TensorFlowNET.Core/APIs/tf.reshape.cs index 5da7b795..102a8132 100644 --- a/src/TensorFlowNET.Core/APIs/tf.reshape.cs +++ b/src/TensorFlowNET.Core/APIs/tf.reshape.cs @@ -31,6 +31,6 @@ namespace Tensorflow public Tensor reshape(Tensor tensor, object[] shape, string name = null) - => gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); + => array_ops.reshape(tensor, shape, name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.tile.cs b/src/TensorFlowNET.Core/APIs/tf.tile.cs index 65975ac8..1220230d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tile.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tile.cs @@ -23,7 +23,7 @@ namespace Tensorflow => gen_array_ops.tile(input, multiples, name); public Tensor tile(Tensor input, object[] multiples, string name = null) - => gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); + => array_ops.tile(input, multiples, name); public Tensor tile(Tensor input, Shape multiples, string name = null) { diff --git a/src/TensorFlowNET.Core/GlobalUsing.cs b/src/TensorFlowNET.Core/GlobalUsing.cs index 209bc291..7e02c908 100644 --- a/src/TensorFlowNET.Core/GlobalUsing.cs +++ b/src/TensorFlowNET.Core/GlobalUsing.cs @@ -5,4 +5,5 @@ global using System.Collections; global using System.Data; global using System.Linq; global using Tensorflow.Keras.Engine; -global using Tensorflow.Framework.Models; \ No newline at end of file +global using Tensorflow.Framework.Models; +global using static Tensorflow.Binding; \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs index 9287284f..5a264b63 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasTensor.cs @@ -30,21 +30,32 @@ public class KerasTensor public static KerasTensor from_tensor(Tensor tensor) { var type_spec = tensor.ToTensorSpec(); - var kt = new KerasTensor(type_spec, name: tensor.name); + Shape? inferred_value = default; + if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) + { + inferred_value = tf.ones(tensor).shape; + } + var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); kt.original_tensors = tensor; return kt; } + public KerasTensor this[int idx] + => _original_tensors.First()[idx]; + + public KerasTensor this[params Slice[] slices] + => _original_tensors.First()[slices]; + public override string ToString() => _original_tensors.Length switch { - > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]", - 1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}", + > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", + 1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", _ => _original_tensors.ToString(), }; private string GetInferredValueString() - => _inferred_value == null ? "" : ""; + => _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; public static implicit operator Tensors(KerasTensor kt) => kt._original_tensors; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 02bf0e86..9d4647fa 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -137,7 +137,7 @@ namespace Tensorflow if(shape.Length > 1) { shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); - if(shapeTensor.ndim > 1) + if (shapeTensor.ndim > 1) { shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); } @@ -304,6 +304,10 @@ namespace Tensorflow { elems_as_tensors.Add(tensor); } + else if (elem is KerasTensor kt) + { + elems_as_tensors.Add(kt); + } else { var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); @@ -404,7 +408,10 @@ namespace Tensorflow => gen_array_ops.reshape(tensor, shape, name: name); public static Tensor reshape(Tensor tensor, object[] shape, string name = null) - => gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name); + { + var dims = shape_utils.from_object_array(shape); + return gen_array_ops.reshape(tensor, dims, name: name); + } private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) { @@ -425,6 +432,10 @@ namespace Tensorflow return tf_with(ops.name_scope(name, "ones", new { shape }), scope => { name = scope; + if (shape._shape_tuple().Length == 0) + { + shape = reshape(shape, new Shape(-1)); + } var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); return output; }); @@ -647,6 +658,20 @@ namespace Tensorflow } }); + public static Tensor tile(Tensor input, object[] multiples, string name = null) + { + Shape dims = shape_utils.from_object_array(multiples); + + return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Tmultiples = op.get_attr("Tmultiples") + } + }); + } + public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) { return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => diff --git a/src/TensorFlowNET.Core/Tensors/shape_utils.cs b/src/TensorFlowNET.Core/Tensors/shape_utils.cs index 254cdad8..a77dd34c 100644 --- a/src/TensorFlowNET.Core/Tensors/shape_utils.cs +++ b/src/TensorFlowNET.Core/Tensors/shape_utils.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -13,5 +14,31 @@ namespace Tensorflow throw new NotImplementedException(""); } + + public static Shape from_object_array(object[] shape) + { + var dims = shape.Select(x => + { + if (x is KerasTensor kt && kt.inferred_value != null) + { + return kt.inferred_value.as_int_list()[0]; + } + else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32) + { + return et.ToArray()[0]; + } + else if (x is int i) + { + return i; + } + else if (x is long l) + { + return l; + } + throw new NotImplementedException(); + }).ToArray(); + + return new Shape(dims); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index 6a62d34a..ac26b3da 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -46,6 +46,9 @@ namespace Tensorflow public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.ones(shape, dtype, name); + public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => array_ops.ones(shape, dtype, name); + public Tensor size(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index c624c990..351fd18f 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -144,11 +144,18 @@ namespace Tensorflow } if (!graph.building_function) { - throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); - // return eager_tensor.AsPlaceholder(name: name); + // throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); + return eager_tensor.AsPlaceholder(name: name); } } } + else if (value is KerasTensor kt) + { + if (kt.inferred_value != null) + { + return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name); + } + } // graph mode Tensor ret = value switch diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index c7fa7711..eeb7c559 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -141,7 +141,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac - + diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj index 240960c9..7a6a7f92 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj @@ -41,8 +41,8 @@ - - + +