From 267d775ab1addd1d16144cd4a92e73f3581593f7 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 12 Oct 2019 20:25:43 -0500 Subject: [PATCH] _read_variable_op --- src/TensorFlowNET.Core/Keras/Layers/Layer.cs | 2 +- .../Operations/array_ops.py.cs | 10 ++++++++- .../Operations/gen_resource_variable_ops.cs | 22 +++++++++++++++++-- .../Variables/ResourceVariable.cs | 12 ++++++++-- 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 22cef8e1..16a5c67d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Layers overwrite: true, initializer: initializer, trainable: trainable.Value); - backend.track_variable(variable); + //backend.track_variable(variable); _trainable_weights.Add(variable); return variable; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index b7ef6440..45a2946f 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -601,7 +601,15 @@ namespace Tensorflow } public static Tensor gather(T1 @params, T2 indices, string name = null, int axis = 0) - => gen_array_ops.gather_v2(@params, indices, axis, name: name); + { + if (axis != 0) + return gen_array_ops.gather_v2(@params, indices, axis, name: name); + + if (@params is ResourceVariable variable) + return variable.sparse_read(); + + return gen_array_ops.gather_v2(@params, indices, axis, name: name); + } public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) { diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index df46ad55..1927014f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -31,7 +31,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); - return _op; + return _op.output; } /// @@ -53,7 +53,25 @@ namespace Tensorflow shared_name }); - return _op; + return _op.output; + } + + /// + /// Reads the value of a variable. + /// + /// + /// + /// + /// + public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new + { + resource, + dtype + }); + + return _op.output; } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index a05a549c..8398deac 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -149,12 +149,20 @@ namespace Tensorflow // messages. tf_with(ops.name_scope("Read"), delegate { - + var value = _read_variable_op(); + _graph_element = value; }); } + + ops.add_to_collections(collections, this); }); + } - throw new NotImplementedException(""); + private Tensor _read_variable_op() + { + var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype); + // _maybe_set_handle_data(_dtype, _handle, result); + return result; } private void _init_from_proto(VariableDef variable_def, string import_scope = null)