Browse Source

add TensorFlowOpLayer.

tags/v0.30
Oceania2018 5 years ago
parent
commit
a0ec655372
12 changed files with 301 additions and 204 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs
  3. +0
    -47
      src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
  6. +65
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.AddWeights.cs
  7. +62
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs
  8. +58
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs
  9. +4
    -149
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  10. +7
    -5
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  11. +31
    -0
      src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs
  12. +58
    -0
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.ArgsDefinition
public Layer[] InboundLayers { get; set; }
public int[] NodeIndices { get; set; }
public int[] TensorIndices { get; set; }
public Tensor InputTensors { get; set; }
public Tensors InputTensors { get; set; }
public Tensors Outputs { get; set; }
}
}

+ 11
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class TensorFlowOpLayerArgs : LayerArgs
{
public NodeDef NodeDef { get; set; }
}
}

+ 0
- 47
src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs View File

@@ -1,47 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public class BaseLayerUtils
{
public static Layer[] CreateKerasHistoryHelper(Tensors tensors)
{
var processed_ops = new List<Operation>();
var created_layers = new List<Layer>();

foreach (var tensor in tensors)
{
if (tensor.KerasHistory != null)
continue;

var op = tensor.op;
if (!processed_ops.Contains(op))
{
var layer_inputs = new List<Tensor>();

foreach (var (i, op_input) in enumerate(op.inputs._inputs))
{
if (uses_keras_history(op_input))
layer_inputs.Add(op_input);
else
{

}
}
}
}

return created_layers.ToArray();
}

static bool uses_keras_history(Tensor op_input)
{
return Layer.KerasHistories.Any(x => x.tensor == op_input);
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -4,6 +4,7 @@ using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine
{
@@ -50,7 +51,7 @@ namespace Tensorflow.Keras.Engine
_autocast = false;

if (outputs.Any(x => x.KerasHistory == null))
BaseLayerUtils.CreateKerasHistoryHelper(outputs);
base_layer_utils.create_keras_history(outputs);

// Build self._output_layers:
foreach (var x in outputs)


+ 2
- 1
src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
public class KerasHistory
{
Layer layer;
public Layer layer;
int node_index;
int tensor_index;
public Tensor tensor;
@@ -20,6 +20,7 @@ namespace Tensorflow.Keras.Engine
this.node_index = node_index;
this.tensor_index = tensor_index;
this.tensor = tensor;
Layer.KerasHistories.Add(this);
Console.WriteLine(tensor.name);
}



+ 65
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.AddWeights.cs View File

@@ -0,0 +1,65 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Layer
{
protected virtual IVariableV1 add_weight(string name,
TensorShape shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
IRegularizer regularizer = null,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None,
bool trainable = true,
Func<VariableArgs, IVariableV1> getter = null)
{
// Initialize variable when no initializer provided
if (initializer == null)
{
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if (dtype.is_floating())
initializer = tf.glorot_uniform_initializer;
else if (dtype.is_integer())
initializer = tf.zeros_initializer;
else
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
}

if (synchronization == VariableSynchronization.OnRead)
trainable = false;

var args = new VariableArgs
{
Name = name,
Shape = shape,
DType = dtype,
Getter = getter ?? base_layer_utils.make_variable,
Overwrite = true,
Initializer = initializer,
Synchronization = synchronization,
Aggregation = aggregation,
Trainable = trainable
};
var variable = _add_variable_with_custom_getter(args);

if (regularizer != null)
{
var name_in_scope = variable.Name.Split(':')[0];
_handle_weight_regularization(name_in_scope, variable, regularizer);
}

//backend.track_variable(variable);
if (trainable == true)
trainableWeights.Add(variable);
else
nonTrainableWeights.Add(variable);

return variable;
}
}
}

+ 62
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs View File

@@ -0,0 +1,62 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Layer
{
/// <summary>
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
{
callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

if (_in_functional_construction_mode(inputs))
return FunctionalConstructionCall(inputs);

Tensors outputs = null;

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = Name;
else
nameScope = _name_scope();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs);

outputs = call(inputs, state: state, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);
});

if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs;
}
}
}

+ 58
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -0,0 +1,58 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Layer
{
Tensors FunctionalConstructionCall(Tensors inputs)
{
bool mask_arg_passed_by_framework = false;
bool training_arg_passed_by_framework = false;
Tensor training_value = null;
if (training_value == null)
{
training_arg_passed_by_framework = true;
}

if (base_layer_utils.needs_keras_history(inputs))
base_layer_utils.create_keras_history(inputs);

Tensors outputs = null;
using var ctxManager = CallContext.enter();

// using var graph = tf.keras.backend.get_graph().as_default();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

tf_with(ops.name_scope(_name_scope()), scope =>
{
MaybeBuild(inputs);

// Wrapping `call` function in autograph to allow for dynamic control
// flow and control dependencies in call. We are limiting this to
// subclassed layers as autograph is strictly needed only for
// subclassed layers and models.
// tf_convert will respect the value of autograph setting in the
// enclosing tf.function, if any.
if (!dynamic)
throw new NotImplementedException("");

outputs = call(inputs);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);
});

if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs;
}
}
}

+ 4
- 149
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -109,116 +109,24 @@ namespace Tensorflow.Keras.Engine
updates = new List<Operation>();

inboundNodes = new List<Node>();
outboundNodes = new List<Node>();

// Manage input shape information if passed.
if(args.BatchInputShape == null && args.InputShape != null)
if (args.BatchInputShape == null && args.InputShape != null)
{
args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray();
}
}

/// <summary>
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
{
callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

var history = inputs.Where(x => x.KerasHistory != null
&& !KerasHistories.Contains(x.KerasHistory))
.Select(x => x.KerasHistory);
KerasHistories.AddRange(history);

if (_in_functional_construction_mode(inputs))
return _functional_construction_call(inputs);

Tensors outputs = null;

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = Name;
else
nameScope = _name_scope();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

tf_with(ops.name_scope(nameScope), scope =>
{
if (!built)
MaybeBuild(inputs);

outputs = call(inputs, state: state, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);
});

if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs;
}

bool _in_functional_construction_mode(Tensors inputs)
{
return tf.Context.executing_eagerly()
&& inputs.Count(x => !x.IsEagerTensor) == inputs.Count();
}

Tensors _functional_construction_call(Tensors inputs)
public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)
{
bool mask_arg_passed_by_framework = false;
bool training_arg_passed_by_framework = false;
Tensor training_value = null;
if(training_value == null)
{
training_arg_passed_by_framework = true;
}

Tensors outputs = null;
using var ctxManager = CallContext.enter();

// using var graph = tf.keras.backend.get_graph().as_default();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

tf_with(ops.name_scope(_name_scope()), scope =>
{
MaybeBuild(inputs);

// Wrapping `call` function in autograph to allow for dynamic control
// flow and control dependencies in call. We are limiting this to
// subclassed layers as autograph is strictly needed only for
// subclassed layers and models.
// tf_convert will respect the value of autograph setting in the
// enclosing tf.function, if any.
if (!dynamic)
throw new NotImplementedException("");

outputs = call(inputs);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);
});

if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs;
}

private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
@@ -235,6 +143,7 @@ namespace Tensorflow.Keras.Engine

new Node(this, new NodeArgs
{
InputTensors = inputs,
Outputs = outputs
});

@@ -304,60 +213,6 @@ namespace Tensorflow.Keras.Engine
}

protected virtual IVariableV1 add_weight(string name,
TensorShape shape,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
IRegularizer regularizer = null,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None,
bool trainable = true,
Func<VariableArgs, IVariableV1> getter = null)
{
// Initialize variable when no initializer provided
if (initializer == null)
{
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if (dtype.is_floating())
initializer = tf.glorot_uniform_initializer;
else if (dtype.is_integer())
initializer = tf.zeros_initializer;
else
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
}

if (synchronization == VariableSynchronization.OnRead)
trainable = false;

var args = new VariableArgs
{
Name = name,
Shape = shape,
DType = dtype,
Getter = getter ?? base_layer_utils.make_variable,
Overwrite = true,
Initializer = initializer,
Synchronization = synchronization,
Aggregation = aggregation,
Trainable = trainable
};
var variable = _add_variable_with_custom_getter(args);

if(regularizer != null)
{
var name_in_scope = variable.Name.Split(':')[0];
_handle_weight_regularization(name_in_scope, variable, regularizer);
}

//backend.track_variable(variable);
if (trainable == true)
trainableWeights.Add(variable);
else
nonTrainableWeights.Add(variable);

return variable;
}

/// <summary>
/// Create lambdas which compute regularization losses.
/// </summary>


+ 7
- 5
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -39,20 +39,22 @@ namespace Tensorflow.Keras.Engine
public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes;
public TensorShape[] output_shapes;
List<Layer> kerasInputs;
List<Tensor> kerasInputs = new List<Tensor>();

public Node(Layer layer, NodeArgs args)
{
this.args = args;

kerasInputs = new List<Layer>();
if (args.InputTensors != null)
kerasInputs.AddRange(args.InputTensors);

// Wire up Node to Layers.
layer.InboundNodes.Add(this);
foreach (var input in kerasInputs)
foreach (var kt in kerasInputs)
{
if (input != null)
input.OutboundNodes.Add(this);
var inbound_layer = kt.KerasHistory.layer;
if (inbound_layer != null)
inbound_layer.OutboundNodes.Add(this);
}

// Set metadata on outputs.


+ 31
- 0
src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine
{
public class TensorFlowOpLayer : Layer
{
TensorFlowOpLayerArgs args;
string _TF_OP_LAYER_NAME_PREFIX = "";

public TensorFlowOpLayer(TensorFlowOpLayerArgs args)
: base(new LayerArgs
{
Name = "tf_op_layer_" + args.Name,
Trainable = args.Trainable,
DType = args.DType,
Autocast = false
})
{
this.args = args;
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, state, is_training);
}
}
}

+ 58
- 0
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -17,6 +17,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Utils
@@ -105,5 +107,61 @@ namespace Tensorflow.Keras.Utils

return name_uid_map;
}

public static bool needs_keras_history(Tensors inputs)
{
if (inputs.Any(x => x.KerasHistory == null))
return true;

return false;
}

public static Layer[] create_keras_history(Tensors inputs)
{
var processed_ops = new List<Operation>();
var created_layers = new List<Layer>();
CreateKerasHistoryHelper(inputs, processed_ops, created_layers);
return created_layers.ToArray();
}

public static void CreateKerasHistoryHelper(Tensors tensors, List<Operation> processed_ops, List<Layer> created_layers)
{
foreach (var tensor in tensors)
{
if (tensor.KerasHistory != null)
continue;

var op = tensor.op;
if (!processed_ops.Contains(op))
{
var layer_inputs = new List<Tensor>();

foreach (var (i, op_input) in enumerate(op.inputs._inputs))
{
if (uses_keras_history(op_input))
layer_inputs.Add(op_input);
else
{

}

// recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Name = op.name
});
created_layers.Add(op_layer);
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
}
}
}
}

static bool uses_keras_history(Tensor op_input)
{
return Layer.KerasHistories.Any(x => x.tensor.name == op_input.name);
}
}
}

Loading…
Cancel
Save