diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 303e832e..a1448b5f 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.ArgsDefinition public Layer[] InboundLayers { get; set; } public int[] NodeIndices { get; set; } public int[] TensorIndices { get; set; } - public Tensor InputTensors { get; set; } + public Tensors InputTensors { get; set; } public Tensors Outputs { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs new file mode 100644 index 00000000..743a47d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class TensorFlowOpLayerArgs : LayerArgs + { + public NodeDef NodeDef { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs b/src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs deleted file mode 100644 index dbaa1247..00000000 --- a/src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs +++ /dev/null @@ -1,47 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Security.Cryptography.X509Certificates; -using System.Text; -using static Tensorflow.Binding; - -namespace Tensorflow.Keras.Engine -{ - public class BaseLayerUtils - { - public static Layer[] CreateKerasHistoryHelper(Tensors tensors) - { - var processed_ops = new List(); - var created_layers = new List(); - - foreach (var tensor in tensors) - { - if (tensor.KerasHistory != null) - continue; - - var op = tensor.op; - if (!processed_ops.Contains(op)) - { - var layer_inputs = new List(); - - foreach (var (i, op_input) in enumerate(op.inputs._inputs)) - { - if (uses_keras_history(op_input)) - layer_inputs.Add(op_input); - else - { - - } - } - } - } - - return created_layers.ToArray(); - } - - static bool uses_keras_history(Tensor op_input) - { - return Layer.KerasHistories.Any(x => x.tensor == op_input); - } - } -} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index adbf3073..ab83dc88 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -4,6 +4,7 @@ using System.Linq; using System.Security.Cryptography.X509Certificates; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Engine { @@ -50,7 +51,7 @@ namespace Tensorflow.Keras.Engine _autocast = false; if (outputs.Any(x => x.KerasHistory == null)) - BaseLayerUtils.CreateKerasHistoryHelper(outputs); + base_layer_utils.create_keras_history(outputs); // Build self._output_layers: foreach (var x in outputs) diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs index 2d627768..71b3aeda 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Engine /// public class KerasHistory { - Layer layer; + public Layer layer; int node_index; int tensor_index; public Tensor tensor; @@ -20,6 +20,7 @@ namespace Tensorflow.Keras.Engine this.node_index = node_index; this.tensor_index = tensor_index; this.tensor = tensor; + Layer.KerasHistories.Add(this); Console.WriteLine(tensor.name); } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.AddWeights.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.AddWeights.cs new file mode 100644 index 00000000..894503db --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.AddWeights.cs @@ -0,0 +1,65 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + protected virtual IVariableV1 add_weight(string name, + TensorShape shape, + TF_DataType dtype = TF_DataType.TF_FLOAT, + IInitializer initializer = null, + IRegularizer regularizer = null, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None, + bool trainable = true, + Func getter = null) + { + // Initialize variable when no initializer provided + if (initializer == null) + { + // If dtype is DT_FLOAT, provide a uniform unit scaling initializer + if (dtype.is_floating()) + initializer = tf.glorot_uniform_initializer; + else if (dtype.is_integer()) + initializer = tf.zeros_initializer; + else + throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); + } + + if (synchronization == VariableSynchronization.OnRead) + trainable = false; + + var args = new VariableArgs + { + Name = name, + Shape = shape, + DType = dtype, + Getter = getter ?? base_layer_utils.make_variable, + Overwrite = true, + Initializer = initializer, + Synchronization = synchronization, + Aggregation = aggregation, + Trainable = trainable + }; + var variable = _add_variable_with_custom_getter(args); + + if (regularizer != null) + { + var name_in_scope = variable.Name.Split(':')[0]; + _handle_weight_regularization(name_in_scope, variable, regularizer); + } + + //backend.track_variable(variable); + if (trainable == true) + trainableWeights.Add(variable); + else + nonTrainableWeights.Add(variable); + + return variable; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs new file mode 100644 index 00000000..b115e201 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + /// + /// Wraps `call`, applying pre- and post-processing steps. + /// + /// + /// + /// + /// + public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) + { + callContext = callContext ?? new ThreadLocal() + { + Value = new CallContext() + }; + + if (_in_functional_construction_mode(inputs)) + return FunctionalConstructionCall(inputs); + + Tensors outputs = null; + + var eager = tf.executing_eagerly(); + using var ctxManager = CallContext.enter(); + + string nameScope = ""; + if (eager) + nameScope = Name; + else + nameScope = _name_scope(); + + if (!inputs.IsEagerTensor) + tf.Context.graph_mode(); + + tf_with(ops.name_scope(nameScope), scope => + { + if (!built) + MaybeBuild(inputs); + + outputs = call(inputs, state: state, is_training: is_training); + + outputs = _set_connectivity_metadata_(inputs, outputs); + _handle_activity_regularization(inputs, outputs); + _set_mask_metadata(inputs, outputs, null); + }); + + if (!inputs.IsEagerTensor) + tf.Context.restore_mode(); + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs new file mode 100644 index 00000000..af24b6d4 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -0,0 +1,58 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + Tensors FunctionalConstructionCall(Tensors inputs) + { + bool mask_arg_passed_by_framework = false; + bool training_arg_passed_by_framework = false; + Tensor training_value = null; + if (training_value == null) + { + training_arg_passed_by_framework = true; + } + + if (base_layer_utils.needs_keras_history(inputs)) + base_layer_utils.create_keras_history(inputs); + + Tensors outputs = null; + using var ctxManager = CallContext.enter(); + + // using var graph = tf.keras.backend.get_graph().as_default(); + + if (!inputs.IsEagerTensor) + tf.Context.graph_mode(); + + tf_with(ops.name_scope(_name_scope()), scope => + { + MaybeBuild(inputs); + + // Wrapping `call` function in autograph to allow for dynamic control + // flow and control dependencies in call. We are limiting this to + // subclassed layers as autograph is strictly needed only for + // subclassed layers and models. + // tf_convert will respect the value of autograph setting in the + // enclosing tf.function, if any. + if (!dynamic) + throw new NotImplementedException(""); + + outputs = call(inputs); + + outputs = _set_connectivity_metadata_(inputs, outputs); + _handle_activity_regularization(inputs, outputs); + _set_mask_metadata(inputs, outputs, null); + }); + + if (!inputs.IsEagerTensor) + tf.Context.restore_mode(); + + return outputs; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index c3dfb665..70d806f3 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -109,116 +109,24 @@ namespace Tensorflow.Keras.Engine updates = new List(); inboundNodes = new List(); + outboundNodes = new List(); // Manage input shape information if passed. - if(args.BatchInputShape == null && args.InputShape != null) + if (args.BatchInputShape == null && args.InputShape != null) { args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); } } - /// - /// Wraps `call`, applying pre- and post-processing steps. - /// - /// - /// - /// - /// - public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) - { - callContext = callContext ?? new ThreadLocal() - { - Value = new CallContext() - }; - - var history = inputs.Where(x => x.KerasHistory != null - && !KerasHistories.Contains(x.KerasHistory)) - .Select(x => x.KerasHistory); - KerasHistories.AddRange(history); - - if (_in_functional_construction_mode(inputs)) - return _functional_construction_call(inputs); - - Tensors outputs = null; - - var eager = tf.executing_eagerly(); - using var ctxManager = CallContext.enter(); - - string nameScope = ""; - if (eager) - nameScope = Name; - else - nameScope = _name_scope(); - - if (!inputs.IsEagerTensor) - tf.Context.graph_mode(); - - tf_with(ops.name_scope(nameScope), scope => - { - if (!built) - MaybeBuild(inputs); - - outputs = call(inputs, state: state, is_training: is_training); - - outputs = _set_connectivity_metadata_(inputs, outputs); - _handle_activity_regularization(inputs, outputs); - _set_mask_metadata(inputs, outputs, null); - }); - - if (!inputs.IsEagerTensor) - tf.Context.restore_mode(); - - return outputs; - } - bool _in_functional_construction_mode(Tensors inputs) { return tf.Context.executing_eagerly() && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); } - Tensors _functional_construction_call(Tensors inputs) + public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) { - bool mask_arg_passed_by_framework = false; - bool training_arg_passed_by_framework = false; - Tensor training_value = null; - if(training_value == null) - { - training_arg_passed_by_framework = true; - } - - Tensors outputs = null; - using var ctxManager = CallContext.enter(); - // using var graph = tf.keras.backend.get_graph().as_default(); - - if (!inputs.IsEagerTensor) - tf.Context.graph_mode(); - - tf_with(ops.name_scope(_name_scope()), scope => - { - MaybeBuild(inputs); - - // Wrapping `call` function in autograph to allow for dynamic control - // flow and control dependencies in call. We are limiting this to - // subclassed layers as autograph is strictly needed only for - // subclassed layers and models. - // tf_convert will respect the value of autograph setting in the - // enclosing tf.function, if any. - if (!dynamic) - throw new NotImplementedException(""); - - outputs = call(inputs); - - outputs = _set_connectivity_metadata_(inputs, outputs); - _handle_activity_regularization(inputs, outputs); - _set_mask_metadata(inputs, outputs, null); - }); - - if (!inputs.IsEagerTensor) - tf.Context.restore_mode(); - - return outputs; } private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) @@ -235,6 +143,7 @@ namespace Tensorflow.Keras.Engine new Node(this, new NodeArgs { + InputTensors = inputs, Outputs = outputs }); @@ -304,60 +213,6 @@ namespace Tensorflow.Keras.Engine } - protected virtual IVariableV1 add_weight(string name, - TensorShape shape, - TF_DataType dtype = TF_DataType.TF_FLOAT, - IInitializer initializer = null, - IRegularizer regularizer = null, - VariableSynchronization synchronization = VariableSynchronization.Auto, - VariableAggregation aggregation = VariableAggregation.None, - bool trainable = true, - Func getter = null) - { - // Initialize variable when no initializer provided - if (initializer == null) - { - // If dtype is DT_FLOAT, provide a uniform unit scaling initializer - if (dtype.is_floating()) - initializer = tf.glorot_uniform_initializer; - else if (dtype.is_integer()) - initializer = tf.zeros_initializer; - else - throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); - } - - if (synchronization == VariableSynchronization.OnRead) - trainable = false; - - var args = new VariableArgs - { - Name = name, - Shape = shape, - DType = dtype, - Getter = getter ?? base_layer_utils.make_variable, - Overwrite = true, - Initializer = initializer, - Synchronization = synchronization, - Aggregation = aggregation, - Trainable = trainable - }; - var variable = _add_variable_with_custom_getter(args); - - if(regularizer != null) - { - var name_in_scope = variable.Name.Split(':')[0]; - _handle_weight_regularization(name_in_scope, variable, regularizer); - } - - //backend.track_variable(variable); - if (trainable == true) - trainableWeights.Add(variable); - else - nonTrainableWeights.Add(variable); - - return variable; - } - /// /// Create lambdas which compute regularization losses. /// diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index 923db038..0ae84ac8 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -39,20 +39,22 @@ namespace Tensorflow.Keras.Engine public Tensors Outputs => args.Outputs; public TensorShape[] input_shapes; public TensorShape[] output_shapes; - List kerasInputs; + List kerasInputs = new List(); public Node(Layer layer, NodeArgs args) { this.args = args; - kerasInputs = new List(); + if (args.InputTensors != null) + kerasInputs.AddRange(args.InputTensors); // Wire up Node to Layers. layer.InboundNodes.Add(this); - foreach (var input in kerasInputs) + foreach (var kt in kerasInputs) { - if (input != null) - input.OutboundNodes.Add(this); + var inbound_layer = kt.KerasHistory.layer; + if (inbound_layer != null) + inbound_layer.OutboundNodes.Add(this); } // Set metadata on outputs. diff --git a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs new file mode 100644 index 00000000..b39421ff --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine +{ + public class TensorFlowOpLayer : Layer + { + TensorFlowOpLayerArgs args; + string _TF_OP_LAYER_NAME_PREFIX = ""; + + public TensorFlowOpLayer(TensorFlowOpLayerArgs args) + : base(new LayerArgs + { + Name = "tf_op_layer_" + args.Name, + Trainable = args.Trainable, + DType = args.DType, + Autocast = false + }) + { + this.args = args; + built = true; + } + + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) + { + return base.call(inputs, state, is_training); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index c49618cf..b3ed60e6 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -17,6 +17,8 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; using static Tensorflow.Binding; namespace Tensorflow.Keras.Utils @@ -105,5 +107,61 @@ namespace Tensorflow.Keras.Utils return name_uid_map; } + + public static bool needs_keras_history(Tensors inputs) + { + if (inputs.Any(x => x.KerasHistory == null)) + return true; + + return false; + } + + public static Layer[] create_keras_history(Tensors inputs) + { + var processed_ops = new List(); + var created_layers = new List(); + CreateKerasHistoryHelper(inputs, processed_ops, created_layers); + return created_layers.ToArray(); + } + + public static void CreateKerasHistoryHelper(Tensors tensors, List processed_ops, List created_layers) + { + foreach (var tensor in tensors) + { + if (tensor.KerasHistory != null) + continue; + + var op = tensor.op; + if (!processed_ops.Contains(op)) + { + var layer_inputs = new List(); + + foreach (var (i, op_input) in enumerate(op.inputs._inputs)) + { + if (uses_keras_history(op_input)) + layer_inputs.Add(op_input); + else + { + + } + + // recursively + CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); + var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs + { + NodeDef = op.node_def, + Name = op.name + }); + created_layers.Add(op_layer); + op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); + } + } + } + } + + static bool uses_keras_history(Tensor op_input) + { + return Layer.KerasHistories.Any(x => x.tensor.name == op_input.name); + } } }