@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
public Layer[] InboundLayers { get; set; } | public Layer[] InboundLayers { get; set; } | ||||
public int[] NodeIndices { get; set; } | public int[] NodeIndices { get; set; } | ||||
public int[] TensorIndices { get; set; } | public int[] TensorIndices { get; set; } | ||||
public Tensor InputTensors { get; set; } | |||||
public Tensors InputTensors { get; set; } | |||||
public Tensors Outputs { get; set; } | public Tensors Outputs { get; set; } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class TensorFlowOpLayerArgs : LayerArgs | |||||
{ | |||||
public NodeDef NodeDef { get; set; } | |||||
} | |||||
} |
@@ -1,47 +0,0 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Security.Cryptography.X509Certificates; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public class BaseLayerUtils | |||||
{ | |||||
public static Layer[] CreateKerasHistoryHelper(Tensors tensors) | |||||
{ | |||||
var processed_ops = new List<Operation>(); | |||||
var created_layers = new List<Layer>(); | |||||
foreach (var tensor in tensors) | |||||
{ | |||||
if (tensor.KerasHistory != null) | |||||
continue; | |||||
var op = tensor.op; | |||||
if (!processed_ops.Contains(op)) | |||||
{ | |||||
var layer_inputs = new List<Tensor>(); | |||||
foreach (var (i, op_input) in enumerate(op.inputs._inputs)) | |||||
{ | |||||
if (uses_keras_history(op_input)) | |||||
layer_inputs.Add(op_input); | |||||
else | |||||
{ | |||||
} | |||||
} | |||||
} | |||||
} | |||||
return created_layers.ToArray(); | |||||
} | |||||
static bool uses_keras_history(Tensor op_input) | |||||
{ | |||||
return Layer.KerasHistories.Any(x => x.tensor == op_input); | |||||
} | |||||
} | |||||
} |
@@ -4,6 +4,7 @@ using System.Linq; | |||||
using System.Security.Cryptography.X509Certificates; | using System.Security.Cryptography.X509Certificates; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -50,7 +51,7 @@ namespace Tensorflow.Keras.Engine | |||||
_autocast = false; | _autocast = false; | ||||
if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
BaseLayerUtils.CreateKerasHistoryHelper(outputs); | |||||
base_layer_utils.create_keras_history(outputs); | |||||
// Build self._output_layers: | // Build self._output_layers: | ||||
foreach (var x in outputs) | foreach (var x in outputs) | ||||
@@ -9,7 +9,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
public class KerasHistory | public class KerasHistory | ||||
{ | { | ||||
Layer layer; | |||||
public Layer layer; | |||||
int node_index; | int node_index; | ||||
int tensor_index; | int tensor_index; | ||||
public Tensor tensor; | public Tensor tensor; | ||||
@@ -20,6 +20,7 @@ namespace Tensorflow.Keras.Engine | |||||
this.node_index = node_index; | this.node_index = node_index; | ||||
this.tensor_index = tensor_index; | this.tensor_index = tensor_index; | ||||
this.tensor = tensor; | this.tensor = tensor; | ||||
Layer.KerasHistories.Add(this); | |||||
Console.WriteLine(tensor.name); | Console.WriteLine(tensor.name); | ||||
} | } | ||||
@@ -0,0 +1,65 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public partial class Layer | |||||
{ | |||||
protected virtual IVariableV1 add_weight(string name, | |||||
TensorShape shape, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
IRegularizer regularizer = null, | |||||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||||
VariableAggregation aggregation = VariableAggregation.None, | |||||
bool trainable = true, | |||||
Func<VariableArgs, IVariableV1> getter = null) | |||||
{ | |||||
// Initialize variable when no initializer provided | |||||
if (initializer == null) | |||||
{ | |||||
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer | |||||
if (dtype.is_floating()) | |||||
initializer = tf.glorot_uniform_initializer; | |||||
else if (dtype.is_integer()) | |||||
initializer = tf.zeros_initializer; | |||||
else | |||||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||||
} | |||||
if (synchronization == VariableSynchronization.OnRead) | |||||
trainable = false; | |||||
var args = new VariableArgs | |||||
{ | |||||
Name = name, | |||||
Shape = shape, | |||||
DType = dtype, | |||||
Getter = getter ?? base_layer_utils.make_variable, | |||||
Overwrite = true, | |||||
Initializer = initializer, | |||||
Synchronization = synchronization, | |||||
Aggregation = aggregation, | |||||
Trainable = trainable | |||||
}; | |||||
var variable = _add_variable_with_custom_getter(args); | |||||
if (regularizer != null) | |||||
{ | |||||
var name_in_scope = variable.Name.Split(':')[0]; | |||||
_handle_weight_regularization(name_in_scope, variable, regularizer); | |||||
} | |||||
//backend.track_variable(variable); | |||||
if (trainable == true) | |||||
trainableWeights.Add(variable); | |||||
else | |||||
nonTrainableWeights.Add(variable); | |||||
return variable; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,62 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading; | |||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public partial class Layer | |||||
{ | |||||
/// <summary> | |||||
/// Wraps `call`, applying pre- and post-processing steps. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="state"></param> | |||||
/// <param name="is_training"></param> | |||||
/// <returns></returns> | |||||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | |||||
callContext = callContext ?? new ThreadLocal<CallContext>() | |||||
{ | |||||
Value = new CallContext() | |||||
}; | |||||
if (_in_functional_construction_mode(inputs)) | |||||
return FunctionalConstructionCall(inputs); | |||||
Tensors outputs = null; | |||||
var eager = tf.executing_eagerly(); | |||||
using var ctxManager = CallContext.enter(); | |||||
string nameScope = ""; | |||||
if (eager) | |||||
nameScope = Name; | |||||
else | |||||
nameScope = _name_scope(); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(); | |||||
tf_with(ops.name_scope(nameScope), scope => | |||||
{ | |||||
if (!built) | |||||
MaybeBuild(inputs); | |||||
outputs = call(inputs, state: state, is_training: is_training); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||||
_handle_activity_regularization(inputs, outputs); | |||||
_set_mask_metadata(inputs, outputs, null); | |||||
}); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
return outputs; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,58 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public partial class Layer | |||||
{ | |||||
Tensors FunctionalConstructionCall(Tensors inputs) | |||||
{ | |||||
bool mask_arg_passed_by_framework = false; | |||||
bool training_arg_passed_by_framework = false; | |||||
Tensor training_value = null; | |||||
if (training_value == null) | |||||
{ | |||||
training_arg_passed_by_framework = true; | |||||
} | |||||
if (base_layer_utils.needs_keras_history(inputs)) | |||||
base_layer_utils.create_keras_history(inputs); | |||||
Tensors outputs = null; | |||||
using var ctxManager = CallContext.enter(); | |||||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(); | |||||
tf_with(ops.name_scope(_name_scope()), scope => | |||||
{ | |||||
MaybeBuild(inputs); | |||||
// Wrapping `call` function in autograph to allow for dynamic control | |||||
// flow and control dependencies in call. We are limiting this to | |||||
// subclassed layers as autograph is strictly needed only for | |||||
// subclassed layers and models. | |||||
// tf_convert will respect the value of autograph setting in the | |||||
// enclosing tf.function, if any. | |||||
if (!dynamic) | |||||
throw new NotImplementedException(""); | |||||
outputs = call(inputs); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||||
_handle_activity_regularization(inputs, outputs); | |||||
_set_mask_metadata(inputs, outputs, null); | |||||
}); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
return outputs; | |||||
} | |||||
} | |||||
} |
@@ -109,116 +109,24 @@ namespace Tensorflow.Keras.Engine | |||||
updates = new List<Operation>(); | updates = new List<Operation>(); | ||||
inboundNodes = new List<Node>(); | inboundNodes = new List<Node>(); | ||||
outboundNodes = new List<Node>(); | |||||
// Manage input shape information if passed. | // Manage input shape information if passed. | ||||
if(args.BatchInputShape == null && args.InputShape != null) | |||||
if (args.BatchInputShape == null && args.InputShape != null) | |||||
{ | { | ||||
args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | args.BatchInputShape = new int[] { args.BatchSize }.Concat(args.InputShape.dims).ToArray(); | ||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// Wraps `call`, applying pre- and post-processing steps. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="state"></param> | |||||
/// <param name="is_training"></param> | |||||
/// <returns></returns> | |||||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | |||||
callContext = callContext ?? new ThreadLocal<CallContext>() | |||||
{ | |||||
Value = new CallContext() | |||||
}; | |||||
var history = inputs.Where(x => x.KerasHistory != null | |||||
&& !KerasHistories.Contains(x.KerasHistory)) | |||||
.Select(x => x.KerasHistory); | |||||
KerasHistories.AddRange(history); | |||||
if (_in_functional_construction_mode(inputs)) | |||||
return _functional_construction_call(inputs); | |||||
Tensors outputs = null; | |||||
var eager = tf.executing_eagerly(); | |||||
using var ctxManager = CallContext.enter(); | |||||
string nameScope = ""; | |||||
if (eager) | |||||
nameScope = Name; | |||||
else | |||||
nameScope = _name_scope(); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(); | |||||
tf_with(ops.name_scope(nameScope), scope => | |||||
{ | |||||
if (!built) | |||||
MaybeBuild(inputs); | |||||
outputs = call(inputs, state: state, is_training: is_training); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||||
_handle_activity_regularization(inputs, outputs); | |||||
_set_mask_metadata(inputs, outputs, null); | |||||
}); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
return outputs; | |||||
} | |||||
bool _in_functional_construction_mode(Tensors inputs) | bool _in_functional_construction_mode(Tensors inputs) | ||||
{ | { | ||||
return tf.Context.executing_eagerly() | return tf.Context.executing_eagerly() | ||||
&& inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | ||||
} | } | ||||
Tensors _functional_construction_call(Tensors inputs) | |||||
public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | |||||
{ | { | ||||
bool mask_arg_passed_by_framework = false; | |||||
bool training_arg_passed_by_framework = false; | |||||
Tensor training_value = null; | |||||
if(training_value == null) | |||||
{ | |||||
training_arg_passed_by_framework = true; | |||||
} | |||||
Tensors outputs = null; | |||||
using var ctxManager = CallContext.enter(); | |||||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(); | |||||
tf_with(ops.name_scope(_name_scope()), scope => | |||||
{ | |||||
MaybeBuild(inputs); | |||||
// Wrapping `call` function in autograph to allow for dynamic control | |||||
// flow and control dependencies in call. We are limiting this to | |||||
// subclassed layers as autograph is strictly needed only for | |||||
// subclassed layers and models. | |||||
// tf_convert will respect the value of autograph setting in the | |||||
// enclosing tf.function, if any. | |||||
if (!dynamic) | |||||
throw new NotImplementedException(""); | |||||
outputs = call(inputs); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||||
_handle_activity_regularization(inputs, outputs); | |||||
_set_mask_metadata(inputs, outputs, null); | |||||
}); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
return outputs; | |||||
} | } | ||||
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | ||||
@@ -235,6 +143,7 @@ namespace Tensorflow.Keras.Engine | |||||
new Node(this, new NodeArgs | new Node(this, new NodeArgs | ||||
{ | { | ||||
InputTensors = inputs, | |||||
Outputs = outputs | Outputs = outputs | ||||
}); | }); | ||||
@@ -304,60 +213,6 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
protected virtual IVariableV1 add_weight(string name, | |||||
TensorShape shape, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
IRegularizer regularizer = null, | |||||
VariableSynchronization synchronization = VariableSynchronization.Auto, | |||||
VariableAggregation aggregation = VariableAggregation.None, | |||||
bool trainable = true, | |||||
Func<VariableArgs, IVariableV1> getter = null) | |||||
{ | |||||
// Initialize variable when no initializer provided | |||||
if (initializer == null) | |||||
{ | |||||
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer | |||||
if (dtype.is_floating()) | |||||
initializer = tf.glorot_uniform_initializer; | |||||
else if (dtype.is_integer()) | |||||
initializer = tf.zeros_initializer; | |||||
else | |||||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||||
} | |||||
if (synchronization == VariableSynchronization.OnRead) | |||||
trainable = false; | |||||
var args = new VariableArgs | |||||
{ | |||||
Name = name, | |||||
Shape = shape, | |||||
DType = dtype, | |||||
Getter = getter ?? base_layer_utils.make_variable, | |||||
Overwrite = true, | |||||
Initializer = initializer, | |||||
Synchronization = synchronization, | |||||
Aggregation = aggregation, | |||||
Trainable = trainable | |||||
}; | |||||
var variable = _add_variable_with_custom_getter(args); | |||||
if(regularizer != null) | |||||
{ | |||||
var name_in_scope = variable.Name.Split(':')[0]; | |||||
_handle_weight_regularization(name_in_scope, variable, regularizer); | |||||
} | |||||
//backend.track_variable(variable); | |||||
if (trainable == true) | |||||
trainableWeights.Add(variable); | |||||
else | |||||
nonTrainableWeights.Add(variable); | |||||
return variable; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Create lambdas which compute regularization losses. | /// Create lambdas which compute regularization losses. | ||||
/// </summary> | /// </summary> | ||||
@@ -39,20 +39,22 @@ namespace Tensorflow.Keras.Engine | |||||
public Tensors Outputs => args.Outputs; | public Tensors Outputs => args.Outputs; | ||||
public TensorShape[] input_shapes; | public TensorShape[] input_shapes; | ||||
public TensorShape[] output_shapes; | public TensorShape[] output_shapes; | ||||
List<Layer> kerasInputs; | |||||
List<Tensor> kerasInputs = new List<Tensor>(); | |||||
public Node(Layer layer, NodeArgs args) | public Node(Layer layer, NodeArgs args) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
kerasInputs = new List<Layer>(); | |||||
if (args.InputTensors != null) | |||||
kerasInputs.AddRange(args.InputTensors); | |||||
// Wire up Node to Layers. | // Wire up Node to Layers. | ||||
layer.InboundNodes.Add(this); | layer.InboundNodes.Add(this); | ||||
foreach (var input in kerasInputs) | |||||
foreach (var kt in kerasInputs) | |||||
{ | { | ||||
if (input != null) | |||||
input.OutboundNodes.Add(this); | |||||
var inbound_layer = kt.KerasHistory.layer; | |||||
if (inbound_layer != null) | |||||
inbound_layer.OutboundNodes.Add(this); | |||||
} | } | ||||
// Set metadata on outputs. | // Set metadata on outputs. | ||||
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public class TensorFlowOpLayer : Layer | |||||
{ | |||||
TensorFlowOpLayerArgs args; | |||||
string _TF_OP_LAYER_NAME_PREFIX = ""; | |||||
public TensorFlowOpLayer(TensorFlowOpLayerArgs args) | |||||
: base(new LayerArgs | |||||
{ | |||||
Name = "tf_op_layer_" + args.Name, | |||||
Trainable = args.Trainable, | |||||
DType = args.DType, | |||||
Autocast = false | |||||
}) | |||||
{ | |||||
this.args = args; | |||||
built = true; | |||||
} | |||||
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
{ | |||||
return base.call(inputs, state, is_training); | |||||
} | |||||
} | |||||
} |
@@ -17,6 +17,8 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
@@ -105,5 +107,61 @@ namespace Tensorflow.Keras.Utils | |||||
return name_uid_map; | return name_uid_map; | ||||
} | } | ||||
public static bool needs_keras_history(Tensors inputs) | |||||
{ | |||||
if (inputs.Any(x => x.KerasHistory == null)) | |||||
return true; | |||||
return false; | |||||
} | |||||
public static Layer[] create_keras_history(Tensors inputs) | |||||
{ | |||||
var processed_ops = new List<Operation>(); | |||||
var created_layers = new List<Layer>(); | |||||
CreateKerasHistoryHelper(inputs, processed_ops, created_layers); | |||||
return created_layers.ToArray(); | |||||
} | |||||
public static void CreateKerasHistoryHelper(Tensors tensors, List<Operation> processed_ops, List<Layer> created_layers) | |||||
{ | |||||
foreach (var tensor in tensors) | |||||
{ | |||||
if (tensor.KerasHistory != null) | |||||
continue; | |||||
var op = tensor.op; | |||||
if (!processed_ops.Contains(op)) | |||||
{ | |||||
var layer_inputs = new List<Tensor>(); | |||||
foreach (var (i, op_input) in enumerate(op.inputs._inputs)) | |||||
{ | |||||
if (uses_keras_history(op_input)) | |||||
layer_inputs.Add(op_input); | |||||
else | |||||
{ | |||||
} | |||||
// recursively | |||||
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers); | |||||
var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs | |||||
{ | |||||
NodeDef = op.node_def, | |||||
Name = op.name | |||||
}); | |||||
created_layers.Add(op_layer); | |||||
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
static bool uses_keras_history(Tensor op_input) | |||||
{ | |||||
return Layer.KerasHistories.Any(x => x.tensor.name == op_input.name); | |||||
} | |||||
} | } | ||||
} | } |