From 64276a3ce8bbe6e6402520faa5c8bca0574fd9f1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 10 Oct 2020 08:43:04 -0500 Subject: [PATCH] Fix AsTensor() for ref and copy. --- .../Tensors/Tensor.Operators.cs | 17 ++++++++++++--- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 7 +++++++ src/TensorFlowNET.Core/Tensors/Tensors.cs | 3 +++ .../Variables/BaseResourceVariable.cs | 21 +++++++++++++------ .../Variables/IVariableV1.cs | 2 +- .../Variables/RefVariable.cs | 2 +- .../Variables/ResourceVariable.Implicit.cs | 2 +- 7 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index fc97895d..ca022783 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -310,17 +310,25 @@ namespace Tensorflow private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; + bool switchToGraphModeTemp = !tf.executing_eagerly(); if (x is Tensor tl) + { dtype = tl.dtype.as_base_dtype(); + switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor; + } + if (y is Tensor tr) + { dtype = tr.dtype.as_base_dtype(); - - if (name == "div") - name = div_or_truediv(name, x, y); + switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor; + } return tf_with(ops.name_scope(null, name, new { x, y }), scope => { + if (switchToGraphModeTemp) + tf.Context.graph_mode(); + Tensor result; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); @@ -352,6 +360,9 @@ namespace Tensorflow throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); } + if (switchToGraphModeTemp) + tf.Context.restore_mode(); + return result; }); } diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 34c26bbb..889b800f 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -253,6 +253,13 @@ namespace Tensorflow return (int[]) dims.Clone(); } + public long[] as_list_long() + { + if (shape.IsEmpty) + throw new ValueError("as_list() is not defined on an unknown TensorShape."); + return dims.Select(x => Convert.ToInt64(x)).ToArray(); + } + public int num_elements() { if(is_fully_defined()) diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index af8796bd..50b1395c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -56,6 +56,9 @@ namespace Tensorflow public static implicit operator Tensors(Tensor[] tensors) => new Tensors(tensors); + public static implicit operator Tensors(List tensors) + => new Tensors(tensors.ToArray()); + public static implicit operator Tensor(Tensors tensors) => tensors.FirstOrDefault(); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index fca60f88..94790339 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -100,6 +100,18 @@ namespace Tensorflow variable_accessed(this); var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); // _maybe_set_handle_data(_dtype, _handle, result); + + // have to set shape when converting to substituent placeholder + if (result.TensorShape.ndim == -1) + { + c_api.TF_GraphSetTensorShape(result.graph, + result._as_tf_output(), + shape.as_list_long(), + shape.ndim, + tf.Status.Handle); + tf.Status.Check(true); + } + return result; } @@ -160,15 +172,12 @@ namespace Tensorflow { } - public Tensor AsTensor(bool as_ref = true) + public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) { - if (!as_ref && GraphElement != null) - return GraphElement; - if (as_ref) - return tf.executing_eagerly() ? read_value() : GraphElement; + return read_value().op.inputs[0]; else - return _read_variable_op(); + return value(); } } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index 4367cf09..36297a41 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -49,6 +49,6 @@ namespace Tensorflow public TensorShape shape { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true); - Tensor AsTensor(bool as_ref = true); + Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 68df1e66..f6de69ca 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -222,7 +222,7 @@ namespace Tensorflow public Tensor value() => _snapshot; - public Tensor AsTensor(bool as_ref = true) => _snapshot; + public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) => _snapshot; public Tensor _as_graph_element() => _variable; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs index d8a743dc..656e1653 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Implicit.cs @@ -37,7 +37,7 @@ namespace Tensorflow if (as_ref) return handle; else - return tf.executing_eagerly() ? AsTensor() : value(); + return AsTensor(); } } }