diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs
index 11ae5fe4..fb153301 100644
--- a/src/TensorFlowNET.Core/APIs/tf.init.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.init.cs
@@ -14,7 +14,7 @@ namespace Tensorflow
public static variable_scope variable_scope(string name,
string default_name = null,
- object values = null,
+ Tensor[] values = null,
bool auxiliary_name_scope = true) => new variable_scope(name,
default_name,
values,
@@ -22,7 +22,7 @@ namespace Tensorflow
public static variable_scope variable_scope(VariableScope scope,
string default_name = null,
- object values = null,
+ Tensor[] values = null,
bool? reuse = null,
bool auxiliary_name_scope = true) => new variable_scope(scope,
default_name,
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
index 4bd9e088..9c89aadf 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs
@@ -4,6 +4,7 @@ using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
+using Tensorflow.Train;
using static Tensorflow.Python;
namespace Tensorflow.Keras.Layers
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers
/// as convolution, batch norm, etc. These operations require managing weights,
/// losses, updates, and inter-layer connectivity.
///
- public class Layer : CheckpointableBase
+ public class Layer : AutoTrackable
{
///
/// Indicates whether `build` needs to be called upon layer call, to create
@@ -84,32 +85,35 @@ namespace Tensorflow.Keras.Layers
// models using the functional API).
bool build_graph = tf_utils.are_all_symbolic_tensors(input_list);
- // Handle Keras mask propagation from previous layer to current layer.
- Python.with(ops.name_scope(_name_scope()), delegate
+ if (build_graph)
{
- /*if (!built)
- {
- _maybe_build(inputs);
- built = true;
- }*/
+ // Only create Keras history if at least one tensor originates from a
+ // `keras.Input`. Otherwise this Layer may be being used outside the Keras
+ // framework.
+ // base_layer_utils.create_keras_history(inputs)
+ }
- if (build_graph)
+ // with base_layer_utils.call_context(self):
+
+ // Handle Keras mask propagation from previous layer to current layer.
+ // with base_layer_utils.call_context(self):
+ // Check input assumptions set after layer building, e.g. input shape.
+ if (build_graph)
+ {
+ // Symbolic execution on symbolic tensors. We will attempt to build
+ // the corresponding TF subgraph inside `backend.get_graph()`
+ var graph = backend.get_graph().as_default();
+ with(ops.name_scope(_name_scope()), delegate
{
- // Symbolic execution on symbolic tensors. We will attempt to build
- // the corresponding TF subgraph inside `backend.get_graph()`
- var graph = backend.get_graph().as_default();
- with(ops.name_scope(_name_scope()), delegate
- {
- // Build layer if applicable (if the `build` method has been
- // overridden).
- _maybe_build(inputs[0]);
- });
-
- outputs = call(inputs[0], training: training);
- _handle_activity_regularization(inputs[0], outputs);
- _set_mask_metadata(inputs[0], outputs, null);
- }
- });
+ // Build layer if applicable (if the `build` method has been
+ // overridden).
+ _maybe_build(inputs[0]);
+ });
+
+ outputs = call(inputs[0], training: training);
+ _handle_activity_regularization(inputs[0], outputs);
+ _set_mask_metadata(inputs[0], outputs, null);
+ }
return outputs;
}
@@ -147,6 +151,8 @@ namespace Tensorflow.Keras.Layers
// Check input assumptions set before layer building, e.g. input rank.
if (built)
return;
+ if (_dtype == TF_DataType.DtInvalid)
+ _dtype = input.dtype;
build(input.GetShape());
built = true;
@@ -170,10 +176,21 @@ namespace Tensorflow.Keras.Layers
if (trainable == null)
trainable = true;
+ // Initialize variable when no initializer provided
+ if(initializer == null)
+ {
+ // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
+ if (dtype.is_floating())
+ initializer = tf.glorot_uniform_initializer;
+ else if (dtype.is_integer())
+ initializer = tf.zeros_initializer;
+ else
+ throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}");
+ }
var variable = _add_variable_with_custom_getter(name,
shape,
dtype: dtype,
- //getter: getter == null ? base_layer_utils.make_variable : getter,
+ getter: getter, // getter == null ? base_layer_utils.make_variable : getter,
overwrite: true,
initializer: initializer,
trainable: trainable.Value);
diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs
index 45d46ad4..4213957f 100644
--- a/src/TensorFlowNET.Core/Keras/backend.cs
+++ b/src/TensorFlowNET.Core/Keras/backend.cs
@@ -12,10 +12,11 @@ namespace Tensorflow.Keras
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
///
public static Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>();
-
+ public static Dictionary _GRAPH_VARIABLES = new Dictionary();
public static void track_variable(RefVariable v)
{
-
+ var graph = v.graph;
+ _GRAPH_VARIABLES[graph.graph_key] = v;
}
public static Tensor placeholder(int[] shape = null,
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index a4aafc94..e569d9c0 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -51,9 +51,14 @@ namespace Tensorflow.Layers
auxiliary_name_scope: false);
}
- with(scope_context_manager, scope2 => _current_scope = scope2);
- // Actually call layer
- var outputs = base.__call__(new Tensor[] { inputs }, training: training);
+ Tensor outputs = null;
+ with(scope_context_manager, scope2 =>
+ {
+ _current_scope = scope2;
+ // Actually call layer
+ outputs = base.__call__(new Tensor[] { inputs }, training: training);
+ });
+
// Update global default collections.
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });
@@ -80,6 +85,17 @@ namespace Tensorflow.Layers
}
}
+ ///
+ /// Adds a new variable to the layer, or gets an existing one; returns it.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
protected virtual RefVariable add_weight(string name,
int[] shape,
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -157,7 +173,10 @@ namespace Tensorflow.Layers
else
{
with(tf.variable_scope(scope, default_name: _base_name),
- captured_scope => _scope = captured_scope);
+ captured_scope =>
+ {
+ _scope = captured_scope;
+ });
}
}
diff --git a/src/TensorFlowNET.Core/Train/AutoTrackable.cs b/src/TensorFlowNET.Core/Train/AutoTrackable.cs
new file mode 100644
index 00000000..d3479f1b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Train/AutoTrackable.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Train
+{
+ public abstract class AutoTrackable : Trackable
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs
index c5592ccf..466091fa 100644
--- a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs
+++ b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs
@@ -5,7 +5,7 @@ using Tensorflow.Train;
namespace Tensorflow
{
- public abstract class CheckpointableBase : Trackable
+ public abstract class CheckpointableBase : AutoTrackable
{
}
diff --git a/src/TensorFlowNET.Core/Train/Trackable.cs b/src/TensorFlowNET.Core/Train/Trackable.cs
index e565b15b..c16304a9 100644
--- a/src/TensorFlowNET.Core/Train/Trackable.cs
+++ b/src/TensorFlowNET.Core/Train/Trackable.cs
@@ -18,7 +18,13 @@ namespace Tensorflow.Train
bool overwrite = false,
bool trainable = false)
{
+ var checkpoint_initializer = true;
var new_variable = getter(name, shape, dtype, initializer, trainable);
+
+ // If we set an initializer and the variable processed it, tracking will not
+ // assign again. It will add this variable to our dependencies, and if there
+ // is a non-trivial restoration queued, it will handle that. This also
+ // handles slot variables.
if (!overwrite || new_variable is RefVariable)
return _track_checkpointable(new_variable, name: name,
overwrite: overwrite);
diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs
index 6f97b19d..3476fc99 100644
--- a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs
+++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs
@@ -35,7 +35,7 @@ namespace Tensorflow
_old_name_scope = old_name_scope;
_var_store = variable_scope._get_default_variable_store();
_var_scope_store = variable_scope.get_variable_scope_store();
- _new_name = _scope._name;
+ _new_name = _scope.name;
string name_scope = _scope._name_scope;
variable_scope_object = new VariableScope(_reuse,
@@ -55,7 +55,7 @@ namespace Tensorflow
}
else
{
- _new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name;
+ _new_name = string.IsNullOrEmpty(_old.name) ? _name : _old.name + "/" + _name;
_reuse = _reuse || _old.resue;
string name_scope = _old_name_scope == null ? _name : _old_name_scope;
diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs
index 9e97e373..fe6f973e 100644
--- a/src/TensorFlowNET.Core/Variables/VariableScope.cs
+++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs
@@ -15,7 +15,8 @@ namespace Tensorflow
public bool resue;
private TF_DataType _dtype;
- public string _name { get; set; }
+ string _name;
+ public string name => _name;
public string _name_scope { get; set; }
public string original_name_scope => _name_scope;
diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
index ba508e4f..ca183d19 100644
--- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
+++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs
@@ -19,15 +19,17 @@ namespace Tensorflow
private string _name;
private VariableScope _scope;
private string _default_name;
- private object _values;
+ private Tensor[] _values;
private ops.NameScope _current_name_scope;
private bool _auxiliary_name_scope;
private PureVariableScope _cached_pure_variable_scope;
private bool? _reuse;
+ bool _in_graph_mode;
+ protected Graph _graph;
public variable_scope(string name,
- string default_name = "",
- object values = null,
+ string default_name = "",
+ Tensor[] values = null,
bool? reuse = null,
bool auxiliary_name_scope = true)
{
@@ -45,7 +47,7 @@ namespace Tensorflow
public variable_scope(VariableScope scope,
string default_name = "",
- object values = null,
+ Tensor[] values = null,
bool? reuse = null,
bool auxiliary_name_scope = true)
{
@@ -58,6 +60,11 @@ namespace Tensorflow
if (_default_name == null && _scope == null)
throw new TypeError("If default_name is None then scope is required");
+ if (_values == null)
+ _values = new Tensor[0];
+ _in_graph_mode = true;
+ if (_in_graph_mode)
+ _graph = ops._get_graph_from_inputs(_values);
_auxiliary_name_scope = auxiliary_name_scope;
}
@@ -87,7 +94,7 @@ namespace Tensorflow
if (_name != null || _scope != null)
{
- var name_scope = _name == null ? _scope._name.Split('/').Last() : _name;
+ var name_scope = _name == null ? _scope.name.Split('/').Last() : _name;
if (name_scope != null || current_name_scope != null)
current_name_scope = ops.name_scope(name_scope);
current_name_scope.__enter__();
@@ -124,7 +131,7 @@ namespace Tensorflow
{
var var_scope_store = get_variable_scope_store();
var current_scope = get_variable_scope();
- string name = !string.IsNullOrEmpty(current_scope._name) ? current_scope._name + "/" + prefix : prefix;
+ string name = !string.IsNullOrEmpty(current_scope.name) ? current_scope.name + "/" + prefix : prefix;
if (var_scope_store.variable_scope_count(name) == 0)
return prefix;
throw new NotImplementedException("_get_unique_variable_scope");
diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
index 289ed4d0..d6cd059f 100644
--- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
@@ -174,8 +174,9 @@ namespace TensorFlowNET.Examples
x_emb = tf.expand_dims(x_emb, -1);
});
- foreach(var filter_size in filter_sizes)
+ for(int len = 0; len < filter_sizes.Rank; len++)
{
+ int filter_size = filter_sizes.GetLength(len);
var conv = tf.layers.conv2d(
x_emb,
filters: num_filters,
@@ -183,6 +184,12 @@ namespace TensorFlowNET.Examples
strides: new int[] { 1, 1 },
padding: "VALID",
activation: tf.nn.relu());
+
+ var pool = tf.layers.max_pooling2d(
+ conv,
+ pool_size: new[] { document_max_len - filter_size + 1, 1 },
+ strides: new[] { 1, 1 },
+ padding: "VALID");
}
return graph;