@@ -91,6 +91,9 @@ namespace Tensorflow.Eager | |||||
Tensor[] op_outputs) | Tensor[] op_outputs) | ||||
=> (output_grads, unneeded_gradients) => | => (output_grads, unneeded_gradients) => | ||||
{ | { | ||||
if (ops.gradientFunctions[op_name] == null) | |||||
return new Tensor[op_inputs.Length]; | |||||
var gradients = ops.gradientFunctions[op_name](new EagerOperation | var gradients = ops.gradientFunctions[op_name](new EagerOperation | ||||
{ | { | ||||
Name = op_name, | Name = op_name, | ||||
@@ -15,5 +15,7 @@ namespace Tensorflow.Gradients | |||||
public TapeTensor[] output_tensor_info { get; set; } | public TapeTensor[] output_tensor_info { get; set; } | ||||
public long[] input_tensor_id { get; set; } | public long[] input_tensor_id { get; set; } | ||||
public BackwardFunction backward_function { get; set; } | public BackwardFunction backward_function { get; set; } | ||||
public override string ToString() | |||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||||
} | } | ||||
} | } |
@@ -29,12 +29,13 @@ namespace Tensorflow.Gradients | |||||
tensor_tape_, | tensor_tape_, | ||||
state.op_tape); | state.op_tape); | ||||
while (op_stack.Count > 0) | |||||
while (!op_stack.empty()) | |||||
{ | { | ||||
var op = op_stack.Dequeue(); | var op = op_stack.Dequeue(); | ||||
if (!state.op_tape.find(op, out var trace)) | if (!state.op_tape.find(op, out var trace)) | ||||
continue; | continue; | ||||
Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||||
state.op_tape.erase(op); | state.op_tape.erase(op); | ||||
var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | ||||
@@ -103,7 +104,7 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||||
} | } | ||||
for (int i = 0; i < in_gradients.Length; ++i) | for (int i = 0; i < in_gradients.Length; ++i) | ||||
@@ -113,17 +114,18 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
var unaggregated_grads = gradients[id]; | var unaggregated_grads = gradients[id]; | ||||
unaggregated_grads.Add(in_gradients[i]); | unaggregated_grads.Add(in_gradients[i]); | ||||
if(unaggregated_grads.Count > kMinAggregateCount) | |||||
if (unaggregated_grads.Count > kMinAggregateCount) | |||||
{ | { | ||||
if(!gradients_size.ContainsKey(id)) | |||||
if (!gradients_size.find(id, out var size)) | |||||
{ | { | ||||
size = (long)unaggregated_grads[0].size; | |||||
gradients_size.emplace(id, size); | |||||
} | } | ||||
else | |||||
{ | |||||
if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | } | ||||
throw new NotImplementedException(""); | |||||
} | } | ||||
} | } | ||||
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class RMSpropArgs | |||||
{ | |||||
public float LearningRate { get; set; } = 0.001f; | |||||
public float RHO { get; set; } = 0.9f; | |||||
public float Momentum { get; set; } = 0.0f; | |||||
public float Epsilon { get; set; } = 1e-7f; | |||||
public bool Centered { get; set; } = false; | |||||
public string Name { get; set; } = "RMSprop"; | |||||
} | |||||
} |
@@ -23,8 +23,6 @@ namespace Tensorflow.Keras.Engine | |||||
List<KerasHistory> _input_coordinates; | List<KerasHistory> _input_coordinates; | ||||
List<KerasHistory> _output_coordinates; | List<KerasHistory> _output_coordinates; | ||||
public string[] NetworkNodes { get; set; } | public string[] NetworkNodes { get; set; } | ||||
public Dictionary<int, List<Node>> NodesByDepth { get; set; } | |||||
public List<Layer> Layers => _layers; | |||||
Dictionary<int, int> tensor_usage_count; | Dictionary<int, int> tensor_usage_count; | ||||
public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | ||||
@@ -43,9 +41,10 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
} | } | ||||
public Functional(Tensors inputs, Tensors outputs) | |||||
public Functional(Tensors inputs, Tensors outputs, string name = null) | |||||
: base(new ModelArgs | : base(new ModelArgs | ||||
{ | { | ||||
Name = name, | |||||
Inputs = inputs, | Inputs = inputs, | ||||
Outputs = outputs | Outputs = outputs | ||||
}) | }) | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <summary> | /// <summary> | ||||
/// `Model` groups layers into an object with training and inference features. | /// `Model` groups layers into an object with training and inference features. | ||||
/// </summary> | /// </summary> | ||||
public class Model : Layer | |||||
public partial class Model : Layer | |||||
{ | { | ||||
#pragma warning disable CS0169 // The field 'Model._cloning' is never used | #pragma warning disable CS0169 // The field 'Model._cloning' is never used | ||||
bool _cloning; | bool _cloning; | ||||
@@ -33,12 +33,20 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||||
{ | |||||
} | |||||
public void compile(string optimizerName, string lossName) | public void compile(string optimizerName, string lossName) | ||||
{ | { | ||||
switch (optimizerName) | switch (optimizerName) | ||||
{ | { | ||||
case "rmsprop": | case "rmsprop": | ||||
optimizer = new RMSprop(); | |||||
optimizer = new RMSprop(new RMSpropArgs | |||||
{ | |||||
}); | |||||
break; | break; | ||||
} | } | ||||
@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// Each time the output of a layer is used by another layer, | /// Each time the output of a layer is used by another layer, | ||||
/// a node is added to `layer._outbound_nodes`. | /// a node is added to `layer._outbound_nodes`. | ||||
/// </summary> | /// </summary> | ||||
public class Node | |||||
public partial class Node | |||||
{ | { | ||||
NodeArgs args; | NodeArgs args; | ||||
@@ -39,8 +39,8 @@ namespace Tensorflow | |||||
/// <param name="input"></param> | /// <param name="input"></param> | ||||
/// <param name="output"></param> | /// <param name="output"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Functional Model(Tensors inputs, Tensors outputs) | |||||
=> new Functional(inputs, outputs); | |||||
public Functional Model(Tensors inputs, Tensors outputs, string name = null) | |||||
=> new Functional(inputs, outputs, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Instantiate a Keras tensor. | /// Instantiate a Keras tensor. | ||||
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Optimizers | namespace Tensorflow.Keras.Optimizers | ||||
{ | { | ||||
@@ -29,5 +30,31 @@ namespace Tensorflow.Keras.Optimizers | |||||
epsilon: epsilon, | epsilon: epsilon, | ||||
amsgrad: amsgrad, | amsgrad: amsgrad, | ||||
name: name); | name: name); | ||||
/// <summary> | |||||
/// Construct a new RMSprop optimizer. | |||||
/// </summary> | |||||
/// <param name="learning_rate"></param> | |||||
/// <param name="rho"></param> | |||||
/// <param name="momentum"></param> | |||||
/// <param name="epsilon"></param> | |||||
/// <param name="centered"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public OptimizerV2 RMSprop(float learning_rate = 0.001f, | |||||
float rho = 0.9f, | |||||
float momentum = 0.0f, | |||||
float epsilon = 1e-7f, | |||||
bool centered = false, | |||||
string name = "RMSprop") | |||||
=> new RMSprop(new RMSpropArgs | |||||
{ | |||||
LearningRate = learning_rate, | |||||
RHO = rho, | |||||
Momentum = momentum, | |||||
Epsilon = epsilon, | |||||
Centered = centered, | |||||
Name = name | |||||
}); | |||||
} | } | ||||
} | } |
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Optimizers | namespace Tensorflow.Keras.Optimizers | ||||
{ | { | ||||
@@ -9,6 +10,11 @@ namespace Tensorflow.Keras.Optimizers | |||||
/// </summary> | /// </summary> | ||||
public class RMSprop : OptimizerV2 | public class RMSprop : OptimizerV2 | ||||
{ | { | ||||
RMSpropArgs args; | |||||
public RMSprop(RMSpropArgs args) | |||||
{ | |||||
this.args = args; | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,193 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Keras.Engine; | |||||
using NumSharp; | |||||
using System.Security.Cryptography; | |||||
namespace Tensorflow.Keras.Utils | |||||
{ | |||||
internal class layer_utils | |||||
{ | |||||
public static void print_summary(Model model, int line_length = -1, float[] positions = null) | |||||
{ | |||||
bool sequential_like = model is Sequential; | |||||
// || model.IsGraphNetwork; | |||||
if (!sequential_like) | |||||
{ | |||||
sequential_like = true; | |||||
var nodes = new List<Node>(); | |||||
foreach (var v in model.NodesByDepth) | |||||
{ | |||||
// if the model has multiple nodes | |||||
// or if the nodes have multiple inbound_layers | |||||
// the model is no longer sequential | |||||
if (v.Value.Count > 1 || (v.Value.Count == 1 && v.Value[0].KerasInputs.Count > 1)) | |||||
{ | |||||
sequential_like = false; | |||||
break; | |||||
} | |||||
nodes.AddRange(v.Value); | |||||
} | |||||
if (sequential_like) | |||||
{ | |||||
// search for shared layers | |||||
foreach(var layer in model.Layers) | |||||
{ | |||||
var flag = false; | |||||
foreach(var node in layer.InboundNodes) | |||||
{ | |||||
if(nodes.Contains(node)) | |||||
{ | |||||
if (flag) | |||||
{ | |||||
sequential_like = false; | |||||
break; | |||||
} | |||||
else | |||||
flag = true; | |||||
} | |||||
} | |||||
if (!sequential_like) | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
string[] to_display; | |||||
var relevant_nodes = new List<Node>(); | |||||
if (sequential_like) | |||||
{ | |||||
if (line_length < 0) | |||||
line_length = 65; | |||||
if (positions == null) | |||||
positions = new[] { 0.45f, 0.85f, 1.0f }; | |||||
if (positions[^1] <= 1) | |||||
positions = positions.Select(p => line_length * p).ToArray(); | |||||
to_display = new[] { "Layer (type)", "Output Shape", "Param #" }; | |||||
} | |||||
else | |||||
{ | |||||
if (line_length < 0) | |||||
line_length = 98; | |||||
if (positions == null) | |||||
positions = new[] { 0.33f, 0.55f, 0.67f, 1.0f }; | |||||
if (positions[^1] <= 1) | |||||
positions = positions.Select(p => line_length * p).ToArray(); | |||||
to_display = new[] { "Layer (type)", "Output Shape", "Param #", "Connected to" }; | |||||
foreach (var v in model.NodesByDepth) | |||||
relevant_nodes.AddRange(v.Value); | |||||
} | |||||
int[] positions_int = positions.Select(x => Convert.ToInt32(x)).ToArray(); | |||||
print($"Model: {model.Name}"); | |||||
print(string.Join("", range(line_length).Select(x => "_"))); | |||||
print_row(to_display, positions_int); | |||||
print(string.Join("", range(line_length).Select(x => "="))); | |||||
foreach(var (i, layer) in enumerate(model.Layers)) | |||||
{ | |||||
if (sequential_like) | |||||
print_layer_summary(layer, positions_int); | |||||
else | |||||
print_layer_summary_with_connections(layer, positions_int, relevant_nodes); | |||||
if(i == model.Layers.Count - 1) | |||||
print(string.Join("", range(line_length).Select(x => "="))); | |||||
else | |||||
print(string.Join("", range(line_length).Select(x => "_"))); | |||||
} | |||||
var trainable_count = count_params(model, model.trainable_variables); | |||||
var non_trainable_count = count_params(model, model.non_trainable_variables); | |||||
print($"Total params: {trainable_count + non_trainable_count}"); | |||||
print($"Trainable params: {trainable_count}"); | |||||
print($"Non-trainable params: {non_trainable_count}"); | |||||
print(string.Join("", range(line_length).Select(x => "_"))); | |||||
} | |||||
static void print_row(string[] fields, int[] positions) | |||||
{ | |||||
var line = ""; | |||||
foreach(var i in range(fields.Length)) | |||||
{ | |||||
if (i > 0) | |||||
line = line[0..^1] + " "; | |||||
line += fields[i]; | |||||
line = string.Join("", line.Take(positions[i])); | |||||
line += string.Join("", range(positions[i] - len(line)).Select(x => " ")); | |||||
} | |||||
print(line); | |||||
} | |||||
/// <summary> | |||||
/// Prints a summary for a single layer. | |||||
/// </summary> | |||||
/// <param name="layer"></param> | |||||
static void print_layer_summary(Layer layer, int[] positions) | |||||
{ | |||||
var name = layer.Name; | |||||
var fields = new string[] | |||||
{ | |||||
$"{name} ({layer.GetType().Name})", | |||||
$"{layer.output_shape}", | |||||
$"{layer.count_params()}" | |||||
}; | |||||
print_row(fields, positions); | |||||
} | |||||
static void print_layer_summary_with_connections(Layer layer, int[] positions, List<Node> relevant_nodes) | |||||
{ | |||||
var connections = new List<string>(); | |||||
foreach (var node in layer.InboundNodes) | |||||
{ | |||||
if (!relevant_nodes.Contains(node)) | |||||
continue; | |||||
foreach (var (inbound_layer, node_index, tensor_index, _) in node.iterate_inbound()) | |||||
connections.append($"{inbound_layer.Name}[{node_index}][{tensor_index}]"); | |||||
} | |||||
var name = layer.Name; | |||||
string first_connection = ""; | |||||
if (connections.Count > 0) | |||||
first_connection = connections[0]; | |||||
var fields = new string[] | |||||
{ | |||||
$"{name}({layer.GetType().Name})", | |||||
$"{layer.output_shape}", | |||||
$"{layer.count_params()}", | |||||
first_connection | |||||
}; | |||||
print_row(fields, positions); | |||||
if(connections.Count > 1) | |||||
{ | |||||
foreach(var i in range(1, connections.Count)) | |||||
{ | |||||
fields = new string[] { "", "", "", connections[i] }; | |||||
print_row(fields, positions); | |||||
} | |||||
} | |||||
} | |||||
public static int count_params(Layer layer, List<IVariableV1> weights) | |||||
{ | |||||
var weight_shapes = weights.Select(x => x.shape).ToArray(); | |||||
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); | |||||
return total; | |||||
} | |||||
} | |||||
} |