@@ -129,6 +129,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
[DebuggerStepThrough] | |||||
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | ||||
public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife | public static TOut tf_with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IObjectLife | ||||
{ | { | ||||
@@ -56,9 +56,9 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
// Instantiate an input layer. | // Instantiate an input layer. | ||||
var x = keras.layers.Input( | var x = keras.layers.Input( | ||||
batch_shape: batch_shape, | |||||
dtype: dtype, | |||||
name: layer.name + "_input"); | |||||
batch_shape: batch_shape, | |||||
dtype: dtype, | |||||
name: layer.name + "_input"); | |||||
// This will build the current layer | // This will build the current layer | ||||
// and create the node connecting the current layer | // and create the node connecting the current layer | ||||
@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine | |||||
if (set_inputs) | if (set_inputs) | ||||
{ | { | ||||
// If an input layer (placeholder) is available. | // If an input layer (placeholder) is available. | ||||
// outputs = layer._inbound_nodes; | |||||
// outputs = layer.inbound_nodes; | |||||
} | } | ||||
} | } | ||||
@@ -106,6 +106,7 @@ namespace Tensorflow.Keras.Layers | |||||
VariableScope scope = null) | VariableScope scope = null) | ||||
{ | { | ||||
var input_list = inputs; | var input_list = inputs; | ||||
var input = inputs[0]; | |||||
Tensor outputs = null; | Tensor outputs = null; | ||||
// We will attempt to build a TF graph if & only if all inputs are symbolic. | // We will attempt to build a TF graph if & only if all inputs are symbolic. | ||||
@@ -139,6 +140,7 @@ namespace Tensorflow.Keras.Layers | |||||
_maybe_build(inputs[0]); | _maybe_build(inputs[0]); | ||||
outputs = call(inputs[0], training: training); | outputs = call(inputs[0], training: training); | ||||
(input, outputs) = _set_connectivity_metadata_(input, outputs); | |||||
_handle_activity_regularization(inputs[0], outputs); | _handle_activity_regularization(inputs[0], outputs); | ||||
_set_mask_metadata(inputs[0], outputs, null); | _set_mask_metadata(inputs[0], outputs, null); | ||||
}); | }); | ||||
@@ -147,6 +149,12 @@ namespace Tensorflow.Keras.Layers | |||||
return outputs; | return outputs; | ||||
} | } | ||||
private (Tensor, Tensor) _set_connectivity_metadata_(Tensor inputs, Tensor outputs) | |||||
{ | |||||
//_add_inbound_node(input_tensors: inputs, output_tensors: outputs); | |||||
return (inputs, outputs); | |||||
} | |||||
private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | private void _handle_activity_regularization(Tensor inputs, Tensor outputs) | ||||
{ | { | ||||
//if(_activity_regularizer != null) | //if(_activity_regularizer != null) | ||||
@@ -605,8 +605,9 @@ namespace Tensorflow | |||||
if (axis != 0) | if (axis != 0) | ||||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | return gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
if (@params is ResourceVariable variable) | |||||
return variable.sparse_read(); | |||||
if (@params is ResourceVariable variable && | |||||
indices is Tensor indices_tensor) | |||||
return variable.sparse_read(indices_tensor, name); | |||||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | return gen_array_ops.gather_v2(@params, indices, axis, name: name); | ||||
} | } | ||||
@@ -73,5 +73,20 @@ namespace Tensorflow | |||||
return _op.output; | return _op.output; | ||||
} | } | ||||
public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype, | |||||
int batch_dims = 0, bool validate_indices = true, string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("ResourceGather", name, new | |||||
{ | |||||
resource, | |||||
indices, | |||||
dtype, | |||||
batch_dims, | |||||
validate_indices | |||||
}); | |||||
return _op.output; | |||||
} | |||||
} | } | ||||
} | } |
@@ -216,6 +216,18 @@ namespace Tensorflow | |||||
_dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); | _dtype = dtypes.as_tf_dtype((DataType)_handle.op.get_attr("dtype")); | ||||
} | } | ||||
public Tensor sparse_read(Tensor indices, string name = "Gather") | |||||
{ | |||||
return tf_with(ops.name_scope(name), scope => | |||||
{ | |||||
name = scope; | |||||
var value = gen_resource_variable_ops.resource_gather( | |||||
_handle, indices, dtype: _dtype, name: name); | |||||
return array_ops.identity(value); | |||||
}); | |||||
} | |||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}"; | return $"tf.ResourceVariable '{name}' shape={shape} dtype={dtype}"; | ||||