diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 9e97154c..6f6657af 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -32,8 +32,19 @@ namespace Tensorflow { public static T2 get(this Dictionary dict, T1 key) => key == null ? - default(T2) : - (dict.ContainsKey(key) ? dict[key] : default(T2)); + default : + (dict.ContainsKey(key) ? dict[key] : default); + + public static void Update(this IList list, T element) + { + var index = list.IndexOf(element); + if (index < 0) + list.Add(element); + else + { + list[index] = element; + } + } public static void add(this IList list, T element) => list.Add(element); diff --git a/src/TensorFlowNET.Core/Keras/BackendImpl.cs b/src/TensorFlowNET.Core/Keras/BackendImpl.cs index 00e1587c..b765a48e 100644 --- a/src/TensorFlowNET.Core/Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Core/Keras/BackendImpl.cs @@ -155,6 +155,16 @@ namespace Tensorflow.Keras return array_ops.pad(x, pattern); } + /// + /// Method to evaluate a tensor in eager or in a tf.function. + /// + /// + /// + public Tensor eval_in_eager_or_function(Tensor outputs) + { + throw new NotImplementedException(""); + } + public class _DummyEagerGraph { } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index ab83dc88..2b977dbc 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Security.Cryptography.X509Certificates; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Utils; @@ -69,6 +68,11 @@ namespace Tensorflow.Keras.Engine _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); } } - + + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) + { + return base.call(inputs, state, is_training); + } + } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs index 71b3aeda..05eb6fa7 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -20,8 +20,6 @@ namespace Tensorflow.Keras.Engine this.node_index = node_index; this.tensor_index = tensor_index; this.tensor = tensor; - Layer.KerasHistories.Add(this); - Console.WriteLine(tensor.name); } public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 70d806f3..84e9e750 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -89,7 +89,6 @@ namespace Tensorflow.Keras.Engine ThreadLocal callContext; public CallContext CallContext => callContext.Value; - public static List KerasHistories = new List(); public Layer(LayerArgs args) { @@ -125,29 +124,16 @@ namespace Tensorflow.Keras.Engine } public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) - { - - } + => _set_connectivity_metadata_(inputs, outputs); private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) { - /*var returnOutputs = new List(); - foreach(var x in outputs) - { - if (inputs.Contains(x)) - { - - } - returnOutputs.Add(x); - }*/ - new Node(this, new NodeArgs { InputTensors = inputs, Outputs = outputs }); - //_add_inbound_node(input_tensors: inputs, output_tensors: outputs); return outputs; } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index 0ae84ac8..6c29850b 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -52,6 +52,8 @@ namespace Tensorflow.Keras.Engine layer.InboundNodes.Add(this); foreach (var kt in kerasInputs) { + if (kt.KerasHistory == null) + continue; var inbound_layer = kt.KerasHistory.layer; if (inbound_layer != null) inbound_layer.OutboundNodes.Add(this); diff --git a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs index b39421ff..8e9725fa 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs @@ -8,12 +8,12 @@ namespace Tensorflow.Keras.Engine public class TensorFlowOpLayer : Layer { TensorFlowOpLayerArgs args; - string _TF_OP_LAYER_NAME_PREFIX = ""; + static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_"; public TensorFlowOpLayer(TensorFlowOpLayerArgs args) : base(new LayerArgs { - Name = "tf_op_layer_" + args.Name, + Name = TF_OP_LAYER_NAME_PREFIX + args.Name, Trainable = args.Trainable, DType = args.DType, Autocast = false diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index 5d08e8e8..43b46ebf 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -8,6 +8,7 @@ using Tensorflow.Keras.Datasets; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Optimizers; using static Tensorflow.Binding; namespace Tensorflow @@ -22,6 +23,7 @@ namespace Tensorflow public Activations activations { get; } = new Activations(); public Preprocessing preprocessing { get; } = new Preprocessing(); public BackendImpl backend { get; } = new BackendImpl(); + public OptimizerApi optimizers { get; } = new OptimizerApi(); public Sequential Sequential(List layers = null, string name = null) diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs new file mode 100644 index 00000000..e521a827 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Optimizers +{ + public class OptimizerApi + { + /// + /// Adam optimization is a stochastic gradient descent method that is based on + /// adaptive estimation of first-order and second-order moments. + /// + /// + /// + /// + /// + /// + /// + /// + public OptimizerV2 Adam(float learning_rate = 0.001f, + float beta_1 = 0.9f, + float beta_2 = 0.999f, + float epsilon = 1e-7f, + bool amsgrad = false, + string name = "Adam") + => new Adam(learning_rate: learning_rate, + beta_1: beta_1, + beta_2: beta_2, + epsilon: epsilon, + amsgrad: amsgrad, + name: name); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index b3ed60e6..ab246464 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -142,26 +142,39 @@ namespace Tensorflow.Keras.Utils layer_inputs.Add(op_input); else { + tf_with(ops.init_scope(), delegate + { - } - // recursively - CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); - var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs - { - NodeDef = op.node_def, - Name = op.name - }); - created_layers.Add(op_layer); - op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); + }); + } } + + // recursively + CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); + var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs + { + NodeDef = op.node_def, + Name = op.name + }); + created_layers.Add(op_layer); + op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); + processed_ops.Add(op); } } } + // recusive static bool uses_keras_history(Tensor op_input) { - return Layer.KerasHistories.Any(x => x.tensor.name == op_input.name); + if (op_input.KerasHistory != null) + return true; + + foreach (var input in op_input.op.inputs._inputs) + if (uses_keras_history(input)) + return true; + + return false; } } } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 015edfb2..29a33837 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -29,7 +29,7 @@ https://tensorflownet.readthedocs.io 0.21.0.0 LICENSE true - false + true Open.snk AnyCPU;x64