Browse Source

ResourceVariable.sparse_read()

tags/v0.12
Oceania2018 6 years ago
parent
commit
5754c49247
6 changed files with 43 additions and 6 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +4
    -4
      src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
  3. +8
    -0
      src/TensorFlowNET.Core/Keras/Layers/Layer.cs
  4. +3
    -2
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +15
    -0
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  6. +12
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs

+ 1
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -129,6 +129,7 @@ namespace Tensorflow
} }
} }


[DebuggerStepThrough]
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception
public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife
{ {


+ 4
- 4
src/TensorFlowNET.Core/Keras/Engine/Sequential.cs View File

@@ -56,9 +56,9 @@ namespace Tensorflow.Keras.Engine
{ {
// Instantiate an input layer. // Instantiate an input layer.
var x = keras.layers.Input( var x = keras.layers.Input(
batch_shape: batch_shape,
dtype: dtype,
name: layer.name + "_input");
batch_shape: batch_shape,
dtype: dtype,
name: layer.name + "_input");


// This will build the current layer // This will build the current layer
// and create the node connecting the current layer // and create the node connecting the current layer
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine
if (set_inputs) if (set_inputs)
{ {
// If an input layer (placeholder) is available. // If an input layer (placeholder) is available.
// outputs = layer._inbound_nodes;
// outputs = layer.inbound_nodes;
} }


} }


+ 8
- 0
src/TensorFlowNET.Core/Keras/Layers/Layer.cs View File

@@ -106,6 +106,7 @@ namespace Tensorflow.Keras.Layers
VariableScope scope = null) VariableScope scope = null)
{ {
var input_list = inputs; var input_list = inputs;
var input = inputs[0];
Tensor outputs = null; Tensor outputs = null;


// We will attempt to build a TF graph if & only if all inputs are symbolic. // We will attempt to build a TF graph if & only if all inputs are symbolic.
@@ -139,6 +140,7 @@ namespace Tensorflow.Keras.Layers
_maybe_build(inputs[0]); _maybe_build(inputs[0]);


outputs = call(inputs[0], training: training); outputs = call(inputs[0], training: training);
(input, outputs) = _set_connectivity_metadata_(input, outputs);
_handle_activity_regularization(inputs[0], outputs); _handle_activity_regularization(inputs[0], outputs);
_set_mask_metadata(inputs[0], outputs, null); _set_mask_metadata(inputs[0], outputs, null);
}); });
@@ -147,6 +149,12 @@ namespace Tensorflow.Keras.Layers
return outputs; return outputs;
} }


private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs)
{
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
return (inputs, outputs);
}

private void _handle_activity_regularization(Tensor inputs, Tensor outputs) private void _handle_activity_regularization(Tensor inputs, Tensor outputs)
{ {
//if(_activity_regularizer != null) //if(_activity_regularizer != null)


+ 3
- 2
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -605,8 +605,9 @@ namespace Tensorflow
if (axis != 0) if (axis != 0)
return gen_array_ops.gather_v2(@params, indices, axis, name: name); return gen_array_ops.gather_v2(@params, indices, axis, name: name);
if (@params is ResourceVariable variable)
return variable.sparse_read();
if (@params is ResourceVariable variable &&
indices is Tensor indices_tensor)
return variable.sparse_read(indices_tensor, name);
return gen_array_ops.gather_v2(@params, indices, axis, name: name); return gen_array_ops.gather_v2(@params, indices, axis, name: name);
} }


+ 15
- 0
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -73,5 +73,20 @@ namespace Tensorflow


return _op.output; return _op.output;
} }

public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype,
int batch_dims = 0, bool validate_indices = true, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ResourceGather", name, new
{
resource,
indices,
dtype,
batch_dims,
validate_indices
});

return _op.output;
}
} }
} }

+ 12
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -216,6 +216,18 @@ namespace Tensorflow
_dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); _dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype"));
} }


public Tensor sparse_read(Tensor indices, string name = "Gather")
{
return tf_with(ops.name_scope(name), scope =>
{
name = scope;
var value = gen_resource_variable_ops.resource_gather(
_handle, indices, dtype: _dtype, name: name);

return array_ops.identity(value);
});
}

public override string ToString() public override string ToString()
{ {
return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}"; return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}";


Loading…
Cancel
Save