diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 334f4f74..150fa89a 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -129,6 +129,7 @@ namespace Tensorflow } } + [DebuggerStepThrough] [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception public static TOut tf_with(TIn py, Func action) where TIn : IObjectLife { diff --git a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs index e9f85530..9aa6d619 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Sequential.cs @@ -56,9 +56,9 @@ namespace Tensorflow.Keras.Engine { // Instantiate an input layer. var x = keras.layers.Input( - batch_shape: batch_shape, - dtype: dtype, - name: layer.name + "_input"); + batch_shape: batch_shape, + dtype: dtype, + name: layer.name + "_input"); // This will build the current layer // and create the node connecting the current layer @@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine if (set_inputs) { // If an input layer (placeholder) is available. - // outputs = layer._inbound_nodes; + // outputs = layer.inbound_nodes; } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 16a5c67d..25161721 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -106,6 +106,7 @@ namespace Tensorflow.Keras.Layers VariableScope scope = null) { var input_list = inputs; + var input = inputs[0]; Tensor outputs = null; // We will attempt to build a TF graph if & only if all inputs are symbolic. @@ -139,6 +140,7 @@ namespace Tensorflow.Keras.Layers _maybe_build(inputs[0]); outputs = call(inputs[0], training: training); + (input, outputs) = _set_connectivity_metadata_(input, outputs); _handle_activity_regularization(inputs[0], outputs); _set_mask_metadata(inputs[0], outputs, null); }); @@ -147,6 +149,12 @@ namespace Tensorflow.Keras.Layers return outputs; } + private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) + { + //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); + return (inputs, outputs); + } + private void _handle_activity_regularization(Tensor inputs, Tensor outputs) { //if(_activity_regularizer != null) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 45a2946f..86ab150f 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -605,8 +605,9 @@ namespace Tensorflow if (axis != 0) return gen_array_ops.gather_v2(@params, indices, axis, name: name); - if (@params is ResourceVariable variable) - return variable.sparse_read(); + if (@params is ResourceVariable variable && + indices is Tensor indices_tensor) + return variable.sparse_read(indices_tensor, name); return gen_array_ops.gather_v2(@params, indices, axis, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index 1927014f..664572a5 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -73,5 +73,20 @@ namespace Tensorflow return _op.output; } + + public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype, + int batch_dims = 0, bool validate_indices = true, string name = null) + { + var _op = _op_def_lib._apply_op_helper("ResourceGather", name, new + { + resource, + indices, + dtype, + batch_dims, + validate_indices + }); + + return _op.output; + } } } diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 8398deac..7b887e22 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -216,6 +216,18 @@ namespace Tensorflow _dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); } + public Tensor sparse_read(Tensor indices, string name = "Gather") + { + return tf_with(ops.name_scope(name), scope => + { + name = scope; + var value = gen_resource_variable_ops.resource_gather( + _handle, indices, dtype: _dtype, name: name); + + return array_ops.identity(value); + }); + } + public override string ToString() { return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}";