diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 11ae5fe4..fb153301 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -14,7 +14,7 @@ namespace Tensorflow public static variable_scope variable_scope(string name, string default_name = null, - object values = null, + Tensor[] values = null, bool auxiliary_name_scope = true) => new variable_scope(name, default_name, values, @@ -22,7 +22,7 @@ namespace Tensorflow public static variable_scope variable_scope(VariableScope scope, string default_name = null, - object values = null, + Tensor[] values = null, bool? reuse = null, bool auxiliary_name_scope = true) => new variable_scope(scope, default_name, diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 4bd9e088..9c89aadf 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; +using Tensorflow.Train; using static Tensorflow.Python; namespace Tensorflow.Keras.Layers @@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers /// as convolution, batch norm, etc. These operations require managing weights, /// losses, updates, and inter-layer connectivity. /// - public class Layer : CheckpointableBase + public class Layer : AutoTrackable { /// /// Indicates whether `build` needs to be called upon layer call, to create @@ -84,32 +85,35 @@ namespace Tensorflow.Keras.Layers // models using the functional API). bool build_graph = tf_utils.are_all_symbolic_tensors(input_list); - // Handle Keras mask propagation from previous layer to current layer. - Python.with(ops.name_scope(_name_scope()), delegate + if (build_graph) { - /*if (!built) - { - _maybe_build(inputs); - built = true; - }*/ + // Only create Keras history if at least one tensor originates from a + // `keras.Input`. Otherwise this Layer may be being used outside the Keras + // framework. + // base_layer_utils.create_keras_history(inputs) + } - if (build_graph) + // with base_layer_utils.call_context(self): + + // Handle Keras mask propagation from previous layer to current layer. + // with base_layer_utils.call_context(self): + // Check input assumptions set after layer building, e.g. input shape. + if (build_graph) + { + // Symbolic execution on symbolic tensors. We will attempt to build + // the corresponding TF subgraph inside `backend.get_graph()` + var graph = backend.get_graph().as_default(); + with(ops.name_scope(_name_scope()), delegate { - // Symbolic execution on symbolic tensors. We will attempt to build - // the corresponding TF subgraph inside `backend.get_graph()` - var graph = backend.get_graph().as_default(); - with(ops.name_scope(_name_scope()), delegate - { - // Build layer if applicable (if the `build` method has been - // overridden). - _maybe_build(inputs[0]); - }); - - outputs = call(inputs[0], training: training); - _handle_activity_regularization(inputs[0], outputs); - _set_mask_metadata(inputs[0], outputs, null); - } - }); + // Build layer if applicable (if the `build` method has been + // overridden). + _maybe_build(inputs[0]); + }); + + outputs = call(inputs[0], training: training); + _handle_activity_regularization(inputs[0], outputs); + _set_mask_metadata(inputs[0], outputs, null); + } return outputs; } @@ -147,6 +151,8 @@ namespace Tensorflow.Keras.Layers // Check input assumptions set before layer building, e.g. input rank. if (built) return; + if (_dtype == TF_DataType.DtInvalid) + _dtype = input.dtype; build(input.GetShape()); built = true; @@ -170,10 +176,21 @@ namespace Tensorflow.Keras.Layers if (trainable == null) trainable = true; + // Initialize variable when no initializer provided + if(initializer == null) + { + // If dtype is DT_FLOAT, provide a uniform unit scaling initializer + if (dtype.is_floating()) + initializer = tf.glorot_uniform_initializer; + else if (dtype.is_integer()) + initializer = tf.zeros_initializer; + else + throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}"); + } var variable = _add_variable_with_custom_getter(name, shape, dtype: dtype, - //getter: getter == null ? base_layer_utils.make_variable : getter, + getter: getter, // getter == null ? base_layer_utils.make_variable : getter, overwrite: true, initializer: initializer, trainable: trainable.Value); diff --git a/src/TensorFlowNET.Core/Keras/backend.cs b/src/TensorFlowNET.Core/Keras/backend.cs index 45d46ad4..4213957f 100644 --- a/src/TensorFlowNET.Core/Keras/backend.cs +++ b/src/TensorFlowNET.Core/Keras/backend.cs @@ -12,10 +12,11 @@ namespace Tensorflow.Keras /// Allows to give unique autogenerated names to layers, in a graph-specific way. /// public static Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); - + public static Dictionary _GRAPH_VARIABLES = new Dictionary(); public static void track_variable(RefVariable v) { - + var graph = v.graph; + _GRAPH_VARIABLES[graph.graph_key] = v; } public static Tensor placeholder(int[] shape = null, diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index a4aafc94..e569d9c0 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -51,9 +51,14 @@ namespace Tensorflow.Layers auxiliary_name_scope: false); } - with(scope_context_manager, scope2 => _current_scope = scope2); - // Actually call layer - var outputs = base.__call__(new Tensor[] { inputs }, training: training); + Tensor outputs = null; + with(scope_context_manager, scope2 => + { + _current_scope = scope2; + // Actually call layer + outputs = base.__call__(new Tensor[] { inputs }, training: training); + }); + // Update global default collections. _add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS }); @@ -80,6 +85,17 @@ namespace Tensorflow.Layers } } + /// + /// Adds a new variable to the layer, or gets an existing one; returns it. + /// + /// + /// + /// + /// + /// + /// + /// + /// protected virtual RefVariable add_weight(string name, int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, @@ -157,7 +173,10 @@ namespace Tensorflow.Layers else { with(tf.variable_scope(scope, default_name: _base_name), - captured_scope => _scope = captured_scope); + captured_scope => + { + _scope = captured_scope; + }); } } diff --git a/src/TensorFlowNET.Core/Train/AutoTrackable.cs b/src/TensorFlowNET.Core/Train/AutoTrackable.cs new file mode 100644 index 00000000..d3479f1b --- /dev/null +++ b/src/TensorFlowNET.Core/Train/AutoTrackable.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Train +{ + public abstract class AutoTrackable : Trackable + { + } +} diff --git a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs index c5592ccf..466091fa 100644 --- a/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs +++ b/src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs @@ -5,7 +5,7 @@ using Tensorflow.Train; namespace Tensorflow { - public abstract class CheckpointableBase : Trackable + public abstract class CheckpointableBase : AutoTrackable { } diff --git a/src/TensorFlowNET.Core/Train/Trackable.cs b/src/TensorFlowNET.Core/Train/Trackable.cs index e565b15b..c16304a9 100644 --- a/src/TensorFlowNET.Core/Train/Trackable.cs +++ b/src/TensorFlowNET.Core/Train/Trackable.cs @@ -18,7 +18,13 @@ namespace Tensorflow.Train bool overwrite = false, bool trainable = false) { + var checkpoint_initializer = true; var new_variable = getter(name, shape, dtype, initializer, trainable); + + // If we set an initializer and the variable processed it, tracking will not + // assign again. It will add this variable to our dependencies, and if there + // is a non-trivial restoration queued, it will handle that. This also + // handles slot variables. if (!overwrite || new_variable is RefVariable) return _track_checkpointable(new_variable, name: name, overwrite: overwrite); diff --git a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs index 6f97b19d..3476fc99 100644 --- a/src/TensorFlowNET.Core/Variables/PureVariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/PureVariableScope.cs @@ -35,7 +35,7 @@ namespace Tensorflow _old_name_scope = old_name_scope; _var_store = variable_scope._get_default_variable_store(); _var_scope_store = variable_scope.get_variable_scope_store(); - _new_name = _scope._name; + _new_name = _scope.name; string name_scope = _scope._name_scope; variable_scope_object = new VariableScope(_reuse, @@ -55,7 +55,7 @@ namespace Tensorflow } else { - _new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name; + _new_name = string.IsNullOrEmpty(_old.name) ? _name : _old.name + "/" + _name; _reuse = _reuse || _old.resue; string name_scope = _old_name_scope == null ? _name : _old_name_scope; diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 9e97e373..fe6f973e 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -15,7 +15,8 @@ namespace Tensorflow public bool resue; private TF_DataType _dtype; - public string _name { get; set; } + string _name; + public string name => _name; public string _name_scope { get; set; } public string original_name_scope => _name_scope; diff --git a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs index ba508e4f..ca183d19 100644 --- a/src/TensorFlowNET.Core/Variables/variable_scope.py.cs +++ b/src/TensorFlowNET.Core/Variables/variable_scope.py.cs @@ -19,15 +19,17 @@ namespace Tensorflow private string _name; private VariableScope _scope; private string _default_name; - private object _values; + private Tensor[] _values; private ops.NameScope _current_name_scope; private bool _auxiliary_name_scope; private PureVariableScope _cached_pure_variable_scope; private bool? _reuse; + bool _in_graph_mode; + protected Graph _graph; public variable_scope(string name, - string default_name = "", - object values = null, + string default_name = "", + Tensor[] values = null, bool? reuse = null, bool auxiliary_name_scope = true) { @@ -45,7 +47,7 @@ namespace Tensorflow public variable_scope(VariableScope scope, string default_name = "", - object values = null, + Tensor[] values = null, bool? reuse = null, bool auxiliary_name_scope = true) { @@ -58,6 +60,11 @@ namespace Tensorflow if (_default_name == null && _scope == null) throw new TypeError("If default_name is None then scope is required"); + if (_values == null) + _values = new Tensor[0]; + _in_graph_mode = true; + if (_in_graph_mode) + _graph = ops._get_graph_from_inputs(_values); _auxiliary_name_scope = auxiliary_name_scope; } @@ -87,7 +94,7 @@ namespace Tensorflow if (_name != null || _scope != null) { - var name_scope = _name == null ? _scope._name.Split('/').Last() : _name; + var name_scope = _name == null ? _scope.name.Split('/').Last() : _name; if (name_scope != null || current_name_scope != null) current_name_scope = ops.name_scope(name_scope); current_name_scope.__enter__(); @@ -124,7 +131,7 @@ namespace Tensorflow { var var_scope_store = get_variable_scope_store(); var current_scope = get_variable_scope(); - string name = !string.IsNullOrEmpty(current_scope._name) ? current_scope._name + "/" + prefix : prefix; + string name = !string.IsNullOrEmpty(current_scope.name) ? current_scope.name + "/" + prefix : prefix; if (var_scope_store.variable_scope_count(name) == 0) return prefix; throw new NotImplementedException("_get_unique_variable_scope"); diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 289ed4d0..d6cd059f 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -174,8 +174,9 @@ namespace TensorFlowNET.Examples x_emb = tf.expand_dims(x_emb, -1); }); - foreach(var filter_size in filter_sizes) + for(int len = 0; len < filter_sizes.Rank; len++) { + int filter_size = filter_sizes.GetLength(len); var conv = tf.layers.conv2d( x_emb, filters: num_filters, @@ -183,6 +184,12 @@ namespace TensorFlowNET.Examples strides: new int[] { 1, 1 }, padding: "VALID", activation: tf.nn.relu()); + + var pool = tf.layers.max_pooling2d( + conv, + pool_size: new[] { document_max_len - filter_size + 1, 1 }, + strides: new[] { 1, 1 }, + padding: "VALID"); } return graph;