Browse Source

ResourceVariable.sparse_read()

tags/v0.12
Oceania2018 6 years ago
parent
commit
5754c49247
6 changed files with 43 additions and 6 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +4
    -4
      src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  4. +3
    -2
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +15
    -0
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  6. +12
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs

+ 1
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -129,6 +129,7 @@ namespace Tensorflow
}
}

[DebuggerStepThrough]
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception
public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife
{


+ 4
- 4
src/TensorFlowNET.Core/Keras/Engine/Sequential.cs View File

@@ -56,9 +56,9 @@ namespace Tensorflow.Keras.Engine
{
// Instantiate an input layer.
var x = keras.layers.Input(
batch_shape: batch_shape,
dtype: dtype,
name: layer.name + "_input");
batch_shape: batch_shape,
dtype: dtype,
name: layer.name + "_input");

// This will build the current layer
// and create the node connecting the current layer
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine
if (set_inputs)
{
// If an input layer (placeholder) is available.
// outputs = layer._inbound_nodes;
// outputs = layer.inbound_nodes;
}

}


+ 8
- 0
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -106,6 +106,7 @@ namespace Tensorflow.Keras.Layers
VariableScope scope = null)
{
var input_list = inputs;
var input = inputs[0];
Tensor outputs = null;

// We will attempt to build a TF graph if & only if all inputs are symbolic.
@@ -139,6 +140,7 @@ namespace Tensorflow.Keras.Layers
_maybe_build(inputs[0]);

outputs = call(inputs[0], training: training);
(input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null);
});
@@ -147,6 +149,12 @@ namespace Tensorflow.Keras.Layers
return outputs;
}

private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
{
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
return (inputs, outputs);
}

private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
{
//if(_activity_regularizer != null)


+ 3
- 2
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -605,8 +605,9 @@ namespace Tensorflow
if (axis != 0)
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
if (@params is ResourceVariable variable)
return variable.sparse_read();
if (@params is ResourceVariable variable &&
indices is Tensor indices_tensor)
return variable.sparse_read(indices_tensor, name);
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
}


+ 15
- 0
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -73,5 +73,20 @@ namespace Tensorflow

return _op.output;
}

public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype,
int batch_dims = 0, bool validate_indices = true, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ResourceGather", name, new
{
resource,
indices,
dtype,
batch_dims,
validate_indices
});

return _op.output;
}
}
}

+ 12
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -216,6 +216,18 @@ namespace Tensorflow
_dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype"));
}

public Tensor sparse_read(Tensor indices, string name = "Gather")
{
return tf_with(ops.name_scope(name), scope =>
{
name = scope;
var value = gen_resource_variable_ops.resource_gather(
_handle, indices, dtype: _dtype, name: name);

return array_ops.identity(value);
});
}

public override string ToString()
{
return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}";


Loading…
Cancel
Save