diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index 83d94ee2..a6e16cc8 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -91,6 +91,9 @@ namespace Tensorflow.Eager Tensor[] op_outputs) => (output_grads, unneeded_gradients) => { + if (ops.gradientFunctions[op_name] == null) + return new Tensor[op_inputs.Length]; + var gradients = ops.gradientFunctions[op_name](new EagerOperation { Name = op_name, diff --git a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs index 7f478e48..ff04866d 100644 --- a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs +++ b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs @@ -15,5 +15,7 @@ namespace Tensorflow.Gradients public TapeTensor[] output_tensor_info { get; set; } public long[] input_tensor_id { get; set; } public BackwardFunction backward_function { get; set; } + public override string ToString() + => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; } } diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs index 770b75ca..5ef4214c 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs @@ -29,12 +29,13 @@ namespace Tensorflow.Gradients tensor_tape_, state.op_tape); - while (op_stack.Count > 0) + while (!op_stack.empty()) { var op = op_stack.Dequeue(); if (!state.op_tape.find(op, out var trace)) continue; + Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); state.op_tape.erase(op); var out_gradients = new List(trace.output_tensor_info.Length); @@ -103,7 +104,7 @@ namespace Tensorflow.Gradients } else { - throw new NotImplementedException(""); + in_gradients = new Tensor[trace.input_tensor_id.Length]; } for (int i = 0; i < in_gradients.Length; ++i) @@ -113,17 +114,18 @@ namespace Tensorflow.Gradients { var unaggregated_grads = gradients[id]; unaggregated_grads.Add(in_gradients[i]); - if(unaggregated_grads.Count > kMinAggregateCount) + if (unaggregated_grads.Count > kMinAggregateCount) { - if(!gradients_size.ContainsKey(id)) + if (!gradients_size.find(id, out var size)) { + size = (long)unaggregated_grads[0].size; + gradients_size.emplace(id, size); } - else - { + if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) + { + throw new NotImplementedException(""); } - - throw new NotImplementedException(""); } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs new file mode 100644 index 00000000..42a5bcb1 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class RMSpropArgs + { + public float LearningRate { get; set; } = 0.001f; + public float RHO { get; set; } = 0.9f; + public float Momentum { get; set; } = 0.0f; + public float Epsilon { get; set; } = 1e-7f; + public bool Centered { get; set; } = false; + public string Name { get; set; } = "RMSprop"; + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index e4447840..c88fea71 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -23,8 +23,6 @@ namespace Tensorflow.Keras.Engine List _input_coordinates; List _output_coordinates; public string[] NetworkNodes { get; set; } - public Dictionary> NodesByDepth { get; set; } - public List Layers => _layers; Dictionary tensor_usage_count; public Dictionary TensorUsageCount => tensor_usage_count; @@ -43,9 +41,10 @@ namespace Tensorflow.Keras.Engine } } - public Functional(Tensors inputs, Tensors outputs) + public Functional(Tensors inputs, Tensors outputs, string name = null) : base(new ModelArgs { + Name = name, Inputs = inputs, Outputs = outputs }) diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index c816e85d..2b78aa2f 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine /// /// `Model` groups layers into an object with training and inference features. /// - public class Model : Layer + public partial class Model : Layer { #pragma warning disable CS0169 // The field 'Model._cloning' is never used bool _cloning; @@ -33,12 +33,20 @@ namespace Tensorflow.Keras.Engine } + public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) + { + + } + public void compile(string optimizerName, string lossName) { switch (optimizerName) { case "rmsprop": - optimizer = new RMSprop(); + optimizer = new RMSprop(new RMSpropArgs + { + + }); break; } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index 74f138d3..b99d1790 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine /// Each time the output of a layer is used by another layer, /// a node is added to `layer._outbound_nodes`. /// - public class Node + public partial class Node { NodeArgs args; diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index 43b46ebf..9d2731c8 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -39,8 +39,8 @@ namespace Tensorflow /// /// /// - public Functional Model(Tensors inputs, Tensors outputs) - => new Functional(inputs, outputs); + public Functional Model(Tensors inputs, Tensors outputs, string name = null) + => new Functional(inputs, outputs, name: name); /// /// Instantiate a Keras tensor. diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs index e521a827..7ce23204 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.ArgsDefinition; namespace Tensorflow.Keras.Optimizers { @@ -29,5 +30,31 @@ namespace Tensorflow.Keras.Optimizers epsilon: epsilon, amsgrad: amsgrad, name: name); + + /// + /// Construct a new RMSprop optimizer. + /// + /// + /// + /// + /// + /// + /// + /// + public OptimizerV2 RMSprop(float learning_rate = 0.001f, + float rho = 0.9f, + float momentum = 0.0f, + float epsilon = 1e-7f, + bool centered = false, + string name = "RMSprop") + => new RMSprop(new RMSpropArgs + { + LearningRate = learning_rate, + RHO = rho, + Momentum = momentum, + Epsilon = epsilon, + Centered = centered, + Name = name + }); } } diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs b/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs index 51b65b57..8a08282f 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.ArgsDefinition; namespace Tensorflow.Keras.Optimizers { @@ -9,6 +10,11 @@ namespace Tensorflow.Keras.Optimizers /// public class RMSprop : OptimizerV2 { + RMSpropArgs args; + public RMSprop(RMSpropArgs args) + { + this.args = args; + } } } diff --git a/src/TensorFlowNET.Core/Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/layer_utils.cs new file mode 100644 index 00000000..cbba92bf --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Utils/layer_utils.cs @@ -0,0 +1,193 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; +using Tensorflow.Keras.Engine; +using NumSharp; +using System.Security.Cryptography; + +namespace Tensorflow.Keras.Utils +{ + internal class layer_utils + { + public static void print_summary(Model model, int line_length = -1, float[] positions = null) + { + bool sequential_like = model is Sequential; + // || model.IsGraphNetwork; + + if (!sequential_like) + { + sequential_like = true; + var nodes = new List(); + + foreach (var v in model.NodesByDepth) + { + // if the model has multiple nodes + // or if the nodes have multiple inbound_layers + // the model is no longer sequential + if (v.Value.Count > 1 || (v.Value.Count == 1 && v.Value[0].KerasInputs.Count > 1)) + { + sequential_like = false; + break; + } + + nodes.AddRange(v.Value); + } + + if (sequential_like) + { + // search for shared layers + foreach(var layer in model.Layers) + { + var flag = false; + foreach(var node in layer.InboundNodes) + { + if(nodes.Contains(node)) + { + if (flag) + { + sequential_like = false; + break; + } + else + flag = true; + } + } + if (!sequential_like) + break; + } + } + } + + string[] to_display; + var relevant_nodes = new List(); + + if (sequential_like) + { + if (line_length < 0) + line_length = 65; + if (positions == null) + positions = new[] { 0.45f, 0.85f, 1.0f }; + if (positions[^1] <= 1) + positions = positions.Select(p => line_length * p).ToArray(); + to_display = new[] { "Layer (type)", "Output Shape", "Param #" }; + } + else + { + if (line_length < 0) + line_length = 98; + if (positions == null) + positions = new[] { 0.33f, 0.55f, 0.67f, 1.0f }; + if (positions[^1] <= 1) + positions = positions.Select(p => line_length * p).ToArray(); + to_display = new[] { "Layer (type)", "Output Shape", "Param #", "Connected to" }; + + foreach (var v in model.NodesByDepth) + relevant_nodes.AddRange(v.Value); + } + + int[] positions_int = positions.Select(x => Convert.ToInt32(x)).ToArray(); + print($"Model: {model.Name}"); + print(string.Join("", range(line_length).Select(x => "_"))); + print_row(to_display, positions_int); + print(string.Join("", range(line_length).Select(x => "="))); + + foreach(var (i, layer) in enumerate(model.Layers)) + { + if (sequential_like) + print_layer_summary(layer, positions_int); + else + print_layer_summary_with_connections(layer, positions_int, relevant_nodes); + if(i == model.Layers.Count - 1) + print(string.Join("", range(line_length).Select(x => "="))); + else + print(string.Join("", range(line_length).Select(x => "_"))); + } + + var trainable_count = count_params(model, model.trainable_variables); + var non_trainable_count = count_params(model, model.non_trainable_variables); + + print($"Total params: {trainable_count + non_trainable_count}"); + print($"Trainable params: {trainable_count}"); + print($"Non-trainable params: {non_trainable_count}"); + print(string.Join("", range(line_length).Select(x => "_"))); + } + + static void print_row(string[] fields, int[] positions) + { + var line = ""; + foreach(var i in range(fields.Length)) + { + if (i > 0) + line = line[0..^1] + " "; + line += fields[i]; + line = string.Join("", line.Take(positions[i])); + line += string.Join("", range(positions[i] - len(line)).Select(x => " ")); + } + print(line); + } + + /// + /// Prints a summary for a single layer. + /// + /// + static void print_layer_summary(Layer layer, int[] positions) + { + var name = layer.Name; + + var fields = new string[] + { + $"{name} ({layer.GetType().Name})", + $"{layer.output_shape}", + $"{layer.count_params()}" + }; + + print_row(fields, positions); + } + + static void print_layer_summary_with_connections(Layer layer, int[] positions, List relevant_nodes) + { + var connections = new List(); + foreach (var node in layer.InboundNodes) + { + if (!relevant_nodes.Contains(node)) + continue; + + foreach (var (inbound_layer, node_index, tensor_index, _) in node.iterate_inbound()) + connections.append($"{inbound_layer.Name}[{node_index}][{tensor_index}]"); + } + + var name = layer.Name; + string first_connection = ""; + if (connections.Count > 0) + first_connection = connections[0]; + + var fields = new string[] + { + $"{name}({layer.GetType().Name})", + $"{layer.output_shape}", + $"{layer.count_params()}", + first_connection + }; + + print_row(fields, positions); + + if(connections.Count > 1) + { + foreach(var i in range(1, connections.Count)) + { + fields = new string[] { "", "", "", connections[i] }; + print_row(fields, positions); + } + } + } + + public static int count_params(Layer layer, List weights) + { + var weight_shapes = weights.Select(x => x.shape).ToArray(); + var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); + return total; + } + } +}