Browse Source

skip gradient when no grad_func found.

tags/v0.30
Oceania2018 5 years ago
parent
commit
1dd95bdb35
11 changed files with 272 additions and 16 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
  3. +10
    -8
      src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
  4. +16
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs
  5. +2
    -3
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  6. +10
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  9. +27
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs
  10. +6
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs
  11. +193
    -0
      src/TensorFlowNET.Core/Keras/Utils/layer_utils.cs

+ 3
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -91,6 +91,9 @@ namespace Tensorflow.Eager
Tensor[] op_outputs)
=> (output_grads, unneeded_gradients) =>
{
if (ops.gradientFunctions[op_name] == null)
return new Tensor[op_inputs.Length];

var gradients = ops.gradientFunctions[op_name](new EagerOperation
{
Name = op_name,


+ 2
- 0
src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs View File

@@ -15,5 +15,7 @@ namespace Tensorflow.Gradients
public TapeTensor[] output_tensor_info { get; set; }
public long[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; }
public override string ToString()
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
}
}

+ 10
- 8
src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs View File

@@ -29,12 +29,13 @@ namespace Tensorflow.Gradients
tensor_tape_,
state.op_tape);

while (op_stack.Count > 0)
while (!op_stack.empty())
{
var op = op_stack.Dequeue();
if (!state.op_tape.find(op, out var trace))
continue;

Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}");
state.op_tape.erase(op);

var out_gradients = new List<Tensor>(trace.output_tensor_info.Length);
@@ -103,7 +104,7 @@ namespace Tensorflow.Gradients
}
else
{
throw new NotImplementedException("");
in_gradients = new Tensor[trace.input_tensor_id.Length];
}

for (int i = 0; i < in_gradients.Length; ++i)
@@ -113,17 +114,18 @@ namespace Tensorflow.Gradients
{
var unaggregated_grads = gradients[id];
unaggregated_grads.Add(in_gradients[i]);
if(unaggregated_grads.Count > kMinAggregateCount)
if (unaggregated_grads.Count > kMinAggregateCount)
{
if(!gradients_size.ContainsKey(id))
if (!gradients_size.find(id, out var size))
{
size = (long)unaggregated_grads[0].size;
gradients_size.emplace(id, size);
}
else
{

if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
{
throw new NotImplementedException("");
}

throw new NotImplementedException("");
}
}



+ 16
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class RMSpropArgs
{
public float LearningRate { get; set; } = 0.001f;
public float RHO { get; set; } = 0.9f;
public float Momentum { get; set; } = 0.0f;
public float Epsilon { get; set; } = 1e-7f;
public bool Centered { get; set; } = false;
public string Name { get; set; } = "RMSprop";
}
}

+ 2
- 3
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -23,8 +23,6 @@ namespace Tensorflow.Keras.Engine
List<KerasHistory> _input_coordinates;
List<KerasHistory> _output_coordinates;
public string[] NetworkNodes { get; set; }
public Dictionary<int, List<Node>> NodesByDepth { get; set; }
public List<Layer> Layers => _layers;

Dictionary<int, int> tensor_usage_count;
public Dictionary<int, int> TensorUsageCount => tensor_usage_count;
@@ -43,9 +41,10 @@ namespace Tensorflow.Keras.Engine
}
}

public Functional(Tensors inputs, Tensors outputs)
public Functional(Tensors inputs, Tensors outputs, string name = null)
: base(new ModelArgs
{
Name = name,
Inputs = inputs,
Outputs = outputs
})


+ 10
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine
/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
public class Model : Layer
public partial class Model : Layer
{
#pragma warning disable CS0169 // The field 'Model._cloning' is never used
bool _cloning;
@@ -33,12 +33,20 @@ namespace Tensorflow.Keras.Engine
}

public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
{

}

public void compile(string optimizerName, string lossName)
{
switch (optimizerName)
{
case "rmsprop":
optimizer = new RMSprop();
optimizer = new RMSprop(new RMSpropArgs
{

});
break;
}



+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine
/// Each time the output of a layer is used by another layer,
/// a node is added to `layer._outbound_nodes`.
/// </summary>
public class Node
public partial class Node
{
NodeArgs args;



+ 2
- 2
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -39,8 +39,8 @@ namespace Tensorflow
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public Functional Model(Tensors inputs, Tensors outputs)
=> new Functional(inputs, outputs);
public Functional Model(Tensors inputs, Tensors outputs, string name = null)
=> new Functional(inputs, outputs, name: name);

/// <summary>
/// Instantiate a Keras tensor.


+ 27
- 0
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Optimizers
{
@@ -29,5 +30,31 @@ namespace Tensorflow.Keras.Optimizers
epsilon: epsilon,
amsgrad: amsgrad,
name: name);

/// <summary>
/// Construct a new RMSprop optimizer.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="rho"></param>
/// <param name="momentum"></param>
/// <param name="epsilon"></param>
/// <param name="centered"></param>
/// <param name="name"></param>
/// <returns></returns>
public OptimizerV2 RMSprop(float learning_rate = 0.001f,
float rho = 0.9f,
float momentum = 0.0f,
float epsilon = 1e-7f,
bool centered = false,
string name = "RMSprop")
=> new RMSprop(new RMSpropArgs
{
LearningRate = learning_rate,
RHO = rho,
Momentum = momentum,
Epsilon = epsilon,
Centered = centered,
Name = name
});
}
}

+ 6
- 0
src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Optimizers
{
@@ -9,6 +10,11 @@ namespace Tensorflow.Keras.Optimizers
/// </summary>
public class RMSprop : OptimizerV2
{
RMSpropArgs args;

public RMSprop(RMSpropArgs args)
{
this.args = args;
}
}
}

+ 193
- 0
src/TensorFlowNET.Core/Keras/Utils/layer_utils.cs View File

@@ -0,0 +1,193 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Keras.Engine;
using NumSharp;
using System.Security.Cryptography;

namespace Tensorflow.Keras.Utils
{
internal class layer_utils
{
public static void print_summary(Model model, int line_length = -1, float[] positions = null)
{
bool sequential_like = model is Sequential;
// || model.IsGraphNetwork;

if (!sequential_like)
{
sequential_like = true;
var nodes = new List<Node>();

foreach (var v in model.NodesByDepth)
{
// if the model has multiple nodes
// or if the nodes have multiple inbound_layers
// the model is no longer sequential
if (v.Value.Count > 1 || (v.Value.Count == 1 && v.Value[0].KerasInputs.Count > 1))
{
sequential_like = false;
break;
}

nodes.AddRange(v.Value);
}

if (sequential_like)
{
// search for shared layers
foreach(var layer in model.Layers)
{
var flag = false;
foreach(var node in layer.InboundNodes)
{
if(nodes.Contains(node))
{
if (flag)
{
sequential_like = false;
break;
}
else
flag = true;
}
}
if (!sequential_like)
break;
}
}
}

string[] to_display;
var relevant_nodes = new List<Node>();

if (sequential_like)
{
if (line_length < 0)
line_length = 65;
if (positions == null)
positions = new[] { 0.45f, 0.85f, 1.0f };
if (positions[^1] <= 1)
positions = positions.Select(p => line_length * p).ToArray();
to_display = new[] { "Layer (type)", "Output Shape", "Param #" };
}
else
{
if (line_length < 0)
line_length = 98;
if (positions == null)
positions = new[] { 0.33f, 0.55f, 0.67f, 1.0f };
if (positions[^1] <= 1)
positions = positions.Select(p => line_length * p).ToArray();
to_display = new[] { "Layer (type)", "Output Shape", "Param #", "Connected to" };

foreach (var v in model.NodesByDepth)
relevant_nodes.AddRange(v.Value);
}
int[] positions_int = positions.Select(x => Convert.ToInt32(x)).ToArray();
print($"Model: {model.Name}");
print(string.Join("", range(line_length).Select(x => "_")));
print_row(to_display, positions_int);
print(string.Join("", range(line_length).Select(x => "=")));

foreach(var (i, layer) in enumerate(model.Layers))
{
if (sequential_like)
print_layer_summary(layer, positions_int);
else
print_layer_summary_with_connections(layer, positions_int, relevant_nodes);
if(i == model.Layers.Count - 1)
print(string.Join("", range(line_length).Select(x => "=")));
else
print(string.Join("", range(line_length).Select(x => "_")));
}

var trainable_count = count_params(model, model.trainable_variables);
var non_trainable_count = count_params(model, model.non_trainable_variables);

print($"Total params: {trainable_count + non_trainable_count}");
print($"Trainable params: {trainable_count}");
print($"Non-trainable params: {non_trainable_count}");
print(string.Join("", range(line_length).Select(x => "_")));
}

static void print_row(string[] fields, int[] positions)
{
var line = "";
foreach(var i in range(fields.Length))
{
if (i > 0)
line = line[0..^1] + " ";
line += fields[i];
line = string.Join("", line.Take(positions[i]));
line += string.Join("", range(positions[i] - len(line)).Select(x => " "));
}
print(line);
}

/// <summary>
/// Prints a summary for a single layer.
/// </summary>
/// <param name="layer"></param>
static void print_layer_summary(Layer layer, int[] positions)
{
var name = layer.Name;

var fields = new string[]
{
$"{name} ({layer.GetType().Name})",
$"{layer.output_shape}",
$"{layer.count_params()}"
};

print_row(fields, positions);
}

static void print_layer_summary_with_connections(Layer layer, int[] positions, List<Node> relevant_nodes)
{
var connections = new List<string>();
foreach (var node in layer.InboundNodes)
{
if (!relevant_nodes.Contains(node))
continue;

foreach (var (inbound_layer, node_index, tensor_index, _) in node.iterate_inbound())
connections.append($"{inbound_layer.Name}[{node_index}][{tensor_index}]");
}

var name = layer.Name;
string first_connection = "";
if (connections.Count > 0)
first_connection = connections[0];

var fields = new string[]
{
$"{name}({layer.GetType().Name})",
$"{layer.output_shape}",
$"{layer.count_params()}",
first_connection
};

print_row(fields, positions);

if(connections.Count > 1)
{
foreach(var i in range(1, connections.Count))
{
fields = new string[] { "", "", "", connections[i] };
print_row(fields, positions);
}
}
}

public static int count_params(Layer layer, List<IVariableV1> weights)
{
var weight_shapes = weights.Select(x => x.shape).ToArray();
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum();
return total;
}
}
}

Loading…
Cancel
Save