From b63a44e1fe72cfb03349223e0f54c3f485713c49 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 7 Oct 2020 18:16:28 -0500 Subject: [PATCH] Refactor convolutional layer. --- src/TensorFlowNET.Core/APIs/tf.layers.cs | 36 ++--- src/TensorFlowNET.Core/APIs/tf.ops.cs | 2 + .../Eager/EagerRunner.TFE_FastPathExecute.cs | 5 + .../Eager/EagerTensor.Creation.cs | 20 +++ src/TensorFlowNET.Core/Gradients/nn_grad.cs | 34 ++--- .../ArgsDefinition/BatchNormalizationArgs.cs | 24 ++++ .../Keras/ArgsDefinition/Conv2DArgs.cs | 2 +- .../{ConvArgs.cs => ConvolutionalArgs.cs} | 7 +- .../Keras/ArgsDefinition/LayerArgs.cs | 2 +- .../Keras/ArgsDefinition/ModelArgs.cs | 4 +- src/TensorFlowNET.Core/Keras/BackendImpl.cs | 22 ++- .../Keras/Engine/Functional.cs | 67 +++++++++ .../Keras/Engine/InputSpec.cs | 8 ++ .../Keras/Engine/KerasHistory.cs | 10 +- src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 110 ++++++++++++--- src/TensorFlowNET.Core/Keras/Engine/Model.cs | 9 +- src/TensorFlowNET.Core/Keras/KerasApi.cs | 9 +- .../Keras/Layers/BatchNormalization.cs | 128 +++++++----------- src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs | 7 +- .../Layers/{Conv.cs => Convolutional.cs} | 29 ++-- .../Keras/Layers/InputLayer.cs | 5 +- .../Keras/Layers/LayersApi.cs | 60 ++++++-- src/TensorFlowNET.Core/Keras/Regularizers.cs | 12 ++ .../Keras/Regularizers/IRegularizer.cs | 11 ++ .../Keras/Regularizers/L2.cs | 21 +++ .../Keras/Regularizers/RegularizerArgs.cs | 10 ++ .../Keras/Utils/base_layer_utils.cs | 23 ++-- src/TensorFlowNET.Core/Layers/Layer.cs | 4 +- .../Operations/Initializers/RandomNormal.cs | 4 +- .../Operations/NnOps/Convolution.cs | 84 ------------ .../Operations/NnOps/ConvolutionInternal.cs | 100 ++++++++++++++ .../Operations/NnOps/_NonAtrousConvolution.cs | 83 ------------ .../Operations/NnOps/_WithSpaceToBatch.cs | 76 ----------- .../Operations/NnOps/gen_nn_ops.cs | 96 +++++++------ .../Operations/gen_random_ops.cs | 10 +- src/TensorFlowNET.Core/Operations/nn_ops.cs | 20 +-- src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 2 +- .../Variables/BaseResourceVariable.cs | 17 ++- .../Variables/IVariableV1.cs | 2 +- .../Variables/RefVariable.cs | 4 +- .../Variables/ResourceVariable.cs | 4 +- src/TensorFlowNET.Core/ops.cs | 25 +--- src/TensorFlowNET.Core/ops.name_scope.cs | 1 + .../Keras/ModelSaveTest.cs | 3 +- .../ManagedAPI/BitwiseApiTest.cs | 5 + 46 files changed, 671 insertions(+), 548 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs rename src/TensorFlowNET.Core/Keras/ArgsDefinition/{ConvArgs.cs => ConvolutionalArgs.cs} (82%) create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Functional.cs rename src/TensorFlowNET.Core/Keras/Layers/{Conv.cs => Convolutional.cs} (85%) create mode 100644 src/TensorFlowNET.Core/Keras/Regularizers.cs create mode 100644 src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs create mode 100644 src/TensorFlowNET.Core/Keras/Regularizers/L2.cs create mode 100644 src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs delete mode 100644 src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs delete mode 100644 src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs delete mode 100644 src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs index 3485fbd5..7330e957 100644 --- a/src/TensorFlowNET.Core/APIs/tf.layers.cs +++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs @@ -92,7 +92,7 @@ namespace Tensorflow /// /// /// - public Tensor batch_normalization(Tensor inputs, + public Tensors batch_normalization(Tensor inputs, int axis = -1, float momentum = 0.99f, float epsilon = 0.001f, @@ -108,22 +108,24 @@ namespace Tensorflow bool renorm = false, float renorm_momentum = 0.99f) { - var layer = new BatchNormalization( - axis: axis, - momentum: momentum, - epsilon: epsilon, - center: center, - scale: scale, - beta_initializer: beta_initializer, - gamma_initializer: gamma_initializer, - moving_mean_initializer: moving_mean_initializer, - moving_variance_initializer: moving_variance_initializer, - renorm: renorm, - renorm_momentum: renorm_momentum, - trainable: trainable, - name: name); - - return layer.apply(inputs, training: training).Item1; + var layer = new BatchNormalization(new BatchNormalizationArgs + { + Axis = axis, + Momentum = momentum, + Epsilon = epsilon, + Center = center, + Scale = scale, + BetaInitializer = beta_initializer, + GammaInitializer = gamma_initializer, + MovingMeanInitializer = moving_mean_initializer, + MovingVarianceInitializer = moving_variance_initializer, + Renorm = renorm, + RenormMomentum = renorm_momentum, + Trainable = trainable, + Name = name + }); + + return layer.Apply(inputs); } /// diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs index c651bba9..d8109676 100644 --- a/src/TensorFlowNET.Core/APIs/tf.ops.cs +++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs @@ -41,6 +41,8 @@ namespace Tensorflow /// /// A context manager that lifts ops out of control-flow scopes and function-building graphs. + /// When eager execution is enabled, code inside an init_scope block runs with + /// eager execution enabled even when tracing a `tf.function`. /// public void init_scope() => ops.init_scope(); diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index 84e27cc6..d1c7eb13 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -227,6 +227,11 @@ namespace Tensorflow.Eager input_handle = input.EagerTensorHandle; flattened_inputs.Add(input); break; + case ResourceVariable variable: + var var_tensor = variable.AsTensor(); + input_handle = var_tensor.EagerTensorHandle; + flattened_inputs.Add(var_tensor); + break; default: var tensor = tf.convert_to_tensor(inputs); input_handle = tensor.EagerTensorHandle; diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 809c4cea..5733e08d 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -57,6 +57,26 @@ namespace Tensorflow.Eager return this; } + /// + /// _create_substitute_placeholder + /// + /// + public Tensor AsPlaceholder(string name = null) + { + Tensor placeholder = null; + tf_with(ops.control_dependencies(null), delegate + { + placeholder = tf.placeholder(dtype, shape: shape, name: name ?? this.name); + }); + // custom_gradient.copy_handle_data(value, placeholder) + return placeholder; + } + + void copy_handle_data() + { + + } + public override IntPtr ToPointer() => EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero; diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs index e2564ff5..b3e4039c 100644 --- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs @@ -138,30 +138,16 @@ namespace Tensorflow.Gradients return new Tensor[] { - gen_nn_ops.conv2d_backprop_input(new Conv2dParams - { - InputSizes = shape[0], - Filter = op.inputs[1], - OutBackProp = grads[0], - Dilations = dilations, - Strides = strides, - Padding = padding.ToString(), - ExplicitPaddings = explicit_paddings, - UseCudnnOnGpu = (bool)use_cudnn_on_gpu, - DataFormat = data_format.ToString(), - }), - gen_nn_ops.conv2d_backprop_filter(new Conv2dParams - { - Input = op.inputs[0], - FilterSizes = shape[1], - OutBackProp = grads[0], - Dilations = dilations, - Strides = strides, - Padding = padding.ToString(), - ExplicitPaddings = explicit_paddings, - UseCudnnOnGpu = (bool)use_cudnn_on_gpu, - DataFormat = data_format.ToString() - }) + gen_nn_ops.conv2d_backprop_input(shape[0], op.inputs[1], grads[0], + strides, padding, use_cudnn_on_gpu, explicit_paddings, + dilations: dilations, + data_format: data_format), + gen_nn_ops.conv2d_backprop_filter(op.inputs[0], shape[1], grads[0], + strides, padding, + dilations: dilations, + explicit_paddings: explicit_paddings, + use_cudnn_on_gpu: use_cudnn_on_gpu, + data_format: data_format) }; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs new file mode 100644 index 00000000..888082c7 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class BatchNormalizationArgs : LayerArgs + { + public TensorShape Axis { get; set; } = -1; + public float Momentum { get; set; } = 0.99f; + public float Epsilon { get; set; } = 1e-3f; + public bool Center { get; set; } = true; + public bool Scale { get; set; } = true; + public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; + public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; + public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; + public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; + public IRegularizer BetaRegularizer { get; set; } + public IRegularizer GammaRegularizer { get; set; } + public bool Renorm { get; set; } + public float RenormMomentum { get; set; } = 0.99f; + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs index be0ef74e..838954fc 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs @@ -4,7 +4,7 @@ using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class Conv2DArgs : ConvArgs + public class Conv2DArgs : ConvolutionalArgs { } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs similarity index 82% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs index b96a6ba7..00d1706b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs @@ -5,10 +5,11 @@ using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { - public class ConvArgs : LayerArgs + public class ConvolutionalArgs : LayerArgs { public int Rank { get; set; } = 2; public int Filters { get; set; } + public int NumSpatialDims { get; set; } = Unknown; public TensorShape KernelSize { get; set; } = 5; /// @@ -24,8 +25,8 @@ namespace Tensorflow.Keras.ArgsDefinition public bool UseBias { get; set; } public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; - public IInitializer KernelRegularizer { get; set; } - public IInitializer BiasRegularizer { get; set; } + public IRegularizer KernelRegularizer { get; set; } + public IRegularizer BiasRegularizer { get; set; } public Action KernelConstraint { get; set; } public Action BiasConstraint { get; set; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index aaf89a0c..182e616e 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.ArgsDefinition /// /// Regularizer function applied to the output of the layer(its "activation"). /// - public IInitializer ActivityRegularizer { get; set; } + public IRegularizer ActivityRegularizer { get; set; } public bool Autocast { get; set; } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs index 70238405..b1f3569c 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs @@ -6,7 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition { public class ModelArgs : LayerArgs { - public Tensor[] Inputs { get; set; } - public Tensor[] Outputs { get; set; } + public Tensors Inputs { get; set; } + public Tensors Outputs { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/BackendImpl.cs b/src/TensorFlowNET.Core/Keras/BackendImpl.cs index 84b244a2..ef9b3d97 100644 --- a/src/TensorFlowNET.Core/Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Core/Keras/BackendImpl.cs @@ -42,7 +42,7 @@ namespace Tensorflow.Keras /// for various layer names in each graph. /// Allows to give unique autogenerated names to layers, in a graph-specific way. /// - public Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); + public Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); public Dictionary _GRAPH_VARIABLES = new Dictionary(); public Dictionary _GRAPH_TF_OPTIMIZERS = new Dictionary(); @@ -80,25 +80,19 @@ namespace Tensorflow.Keras return ops.get_default_graph(); } - public int get_uid(string prefix, string @namespace = "") + public int get_uid(string prefix) { var graph = tf.get_default_graph(); if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) - PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); - PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1; + PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict()); + if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix)) + PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0; + PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1; - return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)]; + return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix]; } - public int get_uid((string, string) name) - { - var graph = tf.get_default_graph(); - if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) - PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); - PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1; - return PER_GRAPH_LAYER_NAME_UIDS[graph][name]; - } - public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); + public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>(); public void clear_session() { ops.reset_default_graph(); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs new file mode 100644 index 00000000..fe2f0728 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -0,0 +1,67 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine +{ + /// + /// A `Functional` model is a `Model` defined as a directed graph of layers. + /// + public class Functional : Model + { + TensorShape _build_input_shape; + bool _compute_output_and_mask_jointly; + bool _expects_training_arg; + bool _expects_mask_arg; + bool _autocast; + List _output_layers; + List _input_layers; + List _input_coordinates; + List _output_coordinates; + + public Functional(Tensors inputs, Tensors outputs) + : base(new ModelArgs + { + Inputs = inputs, + Outputs = outputs + }) + { + _input_layers = new List(); + _output_layers = new List(); + _input_coordinates = new List(); + _output_coordinates = new List(); + _init_graph_network(inputs, outputs); + } + + void _init_graph_network(Tensors inputs, Tensors outputs) + { + _is_graph_network = true; + this.inputs = inputs; + this.outputs = outputs; + built = true; + _build_input_shape = inputs.shape; + _compute_output_and_mask_jointly = true; + _expects_training_arg = true; + _expects_mask_arg = true; + // A graph network does not autocast inputs, as its layers will cast them instead. + _autocast = false; + + // Build self._output_layers: + foreach(var x in outputs) + { + var (layer, node_index, tensor_index) = x.KerasHistory; + _output_layers.append(layer); + _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); + } + + // Build self._input_layers: + foreach(var x in inputs) + { + var (layer, node_index, tensor_index) = x.KerasHistory; + _input_layers.append(layer); + _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 2041fe7d..cae054ce 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Linq; namespace Tensorflow.Keras.Engine { @@ -27,6 +28,7 @@ namespace Tensorflow.Keras.Engine public int? min_ndim; Dictionary axes; TensorShape shape; + public int[] AllAxisDim; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, @@ -42,6 +44,12 @@ namespace Tensorflow.Keras.Engine this.shape = shape; if (ndim == null && shape != null) this.ndim = shape.ndim; + + if(axes != null) + AllAxisDim = axes.Select(x => x.Value).ToArray(); } + + public override string ToString() + => $"min_ndim={min_ndim}, , axes={axes.Count}"; } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs index 832124e4..dd32f473 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -20,10 +20,14 @@ namespace Tensorflow.Keras.Engine this.tensor_index = tensor_index; } + public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) + { + layer = this.layer; + node_index = this.node_index; + tensor_index = this.tensor_index; + } + public static implicit operator Layer(KerasHistory history) => history.layer; - - public static implicit operator (Layer, int, int)(KerasHistory history) - => (history.layer, history.node_index, history.tensor_index); } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index 8c943235..b9df4ce7 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -72,10 +72,10 @@ namespace Tensorflow.Keras.Engine protected List nonTrainableWeights; public List non_trainable_variables => nonTrainableWeights; - string name; + protected string name; + protected string base_name; public string Name => name; - - protected string baseName; + protected bool computePreviousMask; protected List updates; public TensorShape BatchInputShape => args.BatchInputShape; @@ -98,9 +98,9 @@ namespace Tensorflow.Keras.Engine // Indicates whether `build` needs to be called upon layer call, to create // the layer's weights. built = false; - this.SupportsMasking = false; + SupportsMasking = false; - _init_set_name(name); + _init_set_name(args.Name); trainableWeights = new List(); nonTrainableWeights = new List(); computePreviousMask = false; @@ -124,23 +124,25 @@ namespace Tensorflow.Keras.Engine /// public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) { - Tensors outputs = null; - callContext = callContext ?? new ThreadLocal() { Value = new CallContext() }; + if (_in_functional_construction_mode(inputs)) + return _functional_construction_call(inputs); + + Tensors outputs = null; + var eager = tf.executing_eagerly(); using var ctxManager = CallContext.enter(); string nameScope = ""; if (eager) - nameScope = name; + nameScope = Name; else nameScope = _name_scope(); - // using var graph = tf.keras.backend.get_graph().as_default(); if (!inputs.IsEagerTensor) tf.Context.graph_mode(); @@ -162,6 +164,46 @@ namespace Tensorflow.Keras.Engine return outputs; } + bool _in_functional_construction_mode(Tensors inputs) + { + return inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); + } + + Tensors _functional_construction_call(Tensors inputs) + { + bool mask_arg_passed_by_framework = false; + bool training_arg_passed_by_framework = false; + Tensor training_value = null; + if(training_value == null) + { + training_arg_passed_by_framework = true; + } + + Tensors outputs = null; + using var ctxManager = CallContext.enter(); + + // using var graph = tf.keras.backend.get_graph().as_default(); + + if (!inputs.IsEagerTensor) + tf.Context.graph_mode(); + + tf_with(ops.name_scope(_name_scope()), scope => + { + MaybeBuild(inputs); + + outputs = call(inputs); + + outputs = _set_connectivity_metadata_(inputs, outputs); + _handle_activity_regularization(inputs, outputs); + _set_mask_metadata(inputs, outputs, null); + }); + + if (!inputs.IsEagerTensor) + tf.Context.restore_mode(); + + return outputs; + } + private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) { /*var returnOutputs = new List(); @@ -219,8 +261,12 @@ namespace Tensorflow.Keras.Engine if (DType == TF_DataType.DtInvalid) args.DType = inputs.dtype; - var input_shapes = inputs.shape; - build(input_shapes); + tf.init_scope(); + + //tf.Context.eager_mode(); + build(inputs.shape); + //tf.Context.restore_mode(); + built = true; } @@ -229,10 +275,16 @@ namespace Tensorflow.Keras.Engine built = true; } + protected virtual void add_loss(Func losses) + { + + } + protected virtual IVariableV1 add_weight(string name, TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, + IRegularizer regularizer = null, bool? trainable = null, Func getter = null) { @@ -251,7 +303,7 @@ namespace Tensorflow.Keras.Engine else if (dtype.is_integer()) initializer = tf.zeros_initializer; else - throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}"); + throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); } var args = new VariableArgs @@ -266,6 +318,12 @@ namespace Tensorflow.Keras.Engine }; var variable = _add_variable_with_custom_getter(args); + if(regularizer != null) + { + var name_in_scope = variable.Name.Split(':')[0]; + _handle_weight_regularization(name_in_scope, variable, regularizer); + } + //backend.track_variable(variable); if (trainable == true) trainableWeights.Add(variable); @@ -275,6 +333,20 @@ namespace Tensorflow.Keras.Engine return variable; } + /// + /// Create lambdas which compute regularization losses. + /// + /// + /// + /// + void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer) + { + add_loss(() => regularizer.Apply(new RegularizerArgs + { + + })); + } + protected virtual void add_update(Tensor[] updates, bool inputs = false) { var updates_op = updates.Select(x => x.op).ToArray(); @@ -284,17 +356,13 @@ namespace Tensorflow.Keras.Engine // Determine layer name (non-unique). protected virtual void _init_set_name(string name, bool zero_based = true) { - var base_name = name; + base_name = name; this.name = name; if (name == null) - (this.name, baseName) = _make_unique_name(); - } - - protected virtual (string, string) _make_unique_name() - { - string base_name = generic_utils.to_snake_case(this.GetType().Name); - string name = base_layer_utils.unique_layer_name(base_name); - return (name, base_name); + { + base_name = generic_utils.to_snake_case(this.GetType().Name); + this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based); + } } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index b5e2b0c8..c816e85d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -23,15 +23,14 @@ namespace Tensorflow.Keras.Engine string loss; IOptimizer optimizer; IVariableV1 _steps_per_execution; + protected bool _is_graph_network; + protected Tensors inputs; + protected Tensors outputs; public Model(ModelArgs args) : base(args) { - // Build _output_layers - /*foreach(var x in args.Outputs) - { - var layer = x.KerasHistory; - }*/ + } public void compile(string optimizerName, string lossName) diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs index 603dd2cf..5d08e8e8 100644 --- a/src/TensorFlowNET.Core/Keras/KerasApi.cs +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -16,6 +16,7 @@ namespace Tensorflow { public KerasDataset datasets { get; } = new KerasDataset(); public Initializers initializers { get; } = new Initializers(); + public Regularizers regularizers { get; } = new Regularizers(); public LayersApi layers { get; } = new LayersApi(); public LossesApi losses { get; } = new LossesApi(); public Activations activations { get; } = new Activations(); @@ -36,12 +37,8 @@ namespace Tensorflow /// /// /// - public Model Model(Tensor input, Tensor output) - => new Model(new ModelArgs - { - Inputs = new[] { input }, - Outputs = new[] { output } - }); + public Functional Model(Tensors inputs, Tensors outputs) + => new Functional(inputs, outputs); /// /// Instantiate a Keras tensor. diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index d7664493..3d6287cb 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -15,73 +15,41 @@ ******************************************************************************/ using System; +using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { - public class BatchNormalization : Tensorflow.Layers.Layer + public class BatchNormalization : Layer { -#pragma warning disable CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used - private bool _USE_V2_BEHAVIOR = true; -#pragma warning restore CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used - private float momentum; - private float epsilon; - private bool center; - private bool scale; - private bool renorm; - private bool fused; -#pragma warning disable CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used - private bool _bessels_correction_test_only; -#pragma warning restore CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used - private int[] axis; - private string _data_format; - private IInitializer beta_initializer; - private IInitializer gamma_initializer; - private IInitializer moving_mean_initializer; - private IInitializer moving_variance_initializer; - private IVariableV1 gamma; - private IVariableV1 beta; - private RefVariable moving_mean; - private RefVariable moving_variance; - - public BatchNormalization(int axis = -1, - float momentum = 0.99f, - float epsilon = 0.001f, - bool center = true, - bool scale = true, - IInitializer beta_initializer = null, - IInitializer gamma_initializer = null, - IInitializer moving_mean_initializer = null, - IInitializer moving_variance_initializer = null, - bool renorm = false, - float renorm_momentum = 0.99f, - bool trainable = true, - string name = null) : base(trainable: trainable, - name: name) + BatchNormalizationArgs args; + + float momentum => args.Momentum; + float epsilon => args.Epsilon; + bool center => args.Center; + bool scale => args.Scale; + bool renorm => args.Renorm; + bool fused; + int[] axis; + string _data_format; + IInitializer beta_initializer => args.BetaInitializer; + IInitializer gamma_initializer => args.GammaInitializer; + IInitializer moving_mean_initializer; + IInitializer moving_variance_initializer; + IRegularizer gamma_regularizer => args.GammaRegularizer; + IVariableV1 gamma; + IVariableV1 beta; + IVariableV1 moving_mean; + IVariableV1 moving_variance; + + public BatchNormalization(BatchNormalizationArgs args) : base(args) { - this.axis = new int[] { axis }; - this.momentum = momentum; - this.epsilon = epsilon; - this.center = center; - this.scale = scale; - if (beta_initializer == null) - beta_initializer = tf.zeros_initializer; - if (gamma_initializer == null) - gamma_initializer = tf.ones_initializer; - if (moving_mean_initializer == null) - moving_mean_initializer = tf.zeros_initializer; - if (moving_variance_initializer == null) - moving_variance_initializer = tf.ones_initializer; - this.beta_initializer = beta_initializer; - this.gamma_initializer = gamma_initializer; - this.moving_mean_initializer = moving_mean_initializer; - this.moving_variance_initializer = moving_variance_initializer; - this.renorm = renorm; - this.fused = true; - this.SupportsMasking = true; - this._bessels_correction_test_only = true; + this.args = args; + axis = args.Axis.dims; } protected override void build(TensorShape input_shape) @@ -91,12 +59,25 @@ namespace Tensorflow.Keras.Layers if (x < 0) axis[idx] = ndims + x; + fused = ndims == 4; + if (fused) - if (Enumerable.SequenceEqual(axis, new int[] { 3 })) + { + if (Enumerable.SequenceEqual(axis, new int[] { 1 })) + _data_format = "NCHW"; + else if (Enumerable.SequenceEqual(axis, new int[] { 3 })) _data_format = "NHWC"; + else + throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]"); + } + + var axis_to_dim = new Dictionary(); + foreach(var x in axis) + axis_to_dim[x] = input_shape[x]; + inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; - var param_shape = new int[] { input_shape.dims[axis[0]] }; + var param_shape = inputSpec.AllAxisDim; if (scale) gamma = add_weight("gamma", @@ -116,26 +97,17 @@ namespace Tensorflow.Keras.Layers else throw new NotImplementedException("add_weight beta"); - if(_scope != null) - { - - } - - moving_mean = (RefVariable)add_weight("moving_mean", + moving_mean = add_weight("moving_mean", param_shape, dtype: param_dtype, initializer: moving_mean_initializer, - synchronization: VariableSynchronization.OnRead, - trainable: false, - aggregation: VariableAggregation.Mean); + trainable: false); - moving_variance = (RefVariable)add_weight("moving_variance", + moving_variance = add_weight("moving_variance", shape: param_shape, dtype: param_dtype, initializer: moving_variance_initializer, - synchronization: VariableSynchronization.OnRead, - trainable: false, - aggregation: VariableAggregation.Mean); + trainable: false); if (renorm) throw new NotImplementedException("build when renorm is true"); @@ -178,8 +150,8 @@ namespace Tensorflow.Keras.Layers inputs, gamma, beta, - mean: moving_mean, - variance: moving_variance, + mean: moving_mean.AsTensor(), + variance: moving_variance.AsTensor(), epsilon: epsilon, is_training: false, data_format: _data_format); @@ -202,8 +174,8 @@ namespace Tensorflow.Keras.Layers if(training_value == null) { - var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor); - var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor); + var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor); + var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor); add_update(new Tensor[] { mean_update }, inputs: true); add_update(new Tensor[] { variance_update }, inputs: true); } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs index 9fe38ad2..371d6cfd 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs @@ -19,12 +19,11 @@ using Tensorflow.Operations.Activation; namespace Tensorflow.Keras.Layers { - public class Conv2D : Conv + public class Conv2D : Convolutional { - public Conv2D(Conv2DArgs args) - : base(args) + public Conv2D(Conv2DArgs args) : base(args) { - + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs similarity index 85% rename from src/TensorFlowNET.Core/Keras/Layers/Conv.cs rename to src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs index b26f5465..43739c7e 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs @@ -20,13 +20,13 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; using Tensorflow.Operations; -using Tensorflow.Operations.Activation; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { - public class Conv : Layer + public class Convolutional : Layer { - ConvArgs args; + ConvolutionalArgs args; protected int rank => args.Rank; protected int filters => args.Filters; protected TensorShape kernel_size => args.KernelSize; @@ -37,13 +37,14 @@ namespace Tensorflow.Keras.Layers protected Activation activation => args.Activation; protected bool use_bias => args.UseBias; protected IInitializer kernel_initializer => args.KernelInitializer; + protected IRegularizer kernel_regularizer => args.KernelRegularizer; protected IInitializer bias_initializer => args.BiasInitializer; protected IVariableV1 kernel; protected IVariableV1 bias; - protected Convolution _convolution_op; - string _tf_data_format; + ConvolutionInternal _convolution_op; + protected string _tf_data_format; - public Conv(ConvArgs args) : base(args) + public Convolutional(ConvolutionalArgs args) : base(args) { this.args = args; args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size"); @@ -65,6 +66,7 @@ namespace Tensorflow.Keras.Layers kernel = add_weight(name: "kernel", shape: kernel_shape, initializer: kernel_initializer, + regularizer: kernel_regularizer, trainable: true, dtype: DType); if (use_bias) @@ -76,7 +78,7 @@ namespace Tensorflow.Keras.Layers var axes = new Dictionary(); axes.Add(-1, input_channel); - inputSpec = new InputSpec(ndim: rank + 2, axes: axes); + inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); string tf_padding; if (padding == "causal") @@ -84,20 +86,21 @@ namespace Tensorflow.Keras.Layers else tf_padding = padding.ToUpper(); - - _convolution_op = nn_ops.Convolution(input_shape, - kernel.shape, - tf_padding, + string tf_op_name = GetType().Name; + + + _convolution_op = nn_ops.convolution_internal(tf_padding, strides, dilation_rate, - data_format: _tf_data_format); + data_format: _tf_data_format, + name: tf_op_name); built = true; } protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) { - var outputs = _convolution_op.__call__(inputs, kernel); + var outputs = _convolution_op.Apply(inputs, kernel); if (use_bias) { if (data_format == "channels_first") diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs index 7d31bd40..8cdaf101 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs @@ -47,10 +47,10 @@ namespace Tensorflow.Keras.Layers } // moved to base class - if (string.IsNullOrEmpty(Name)) + if (string.IsNullOrEmpty(args.Name)) { var prefix = "input"; - args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); + name = prefix + '_' + tf.keras.backend.get_uid(prefix); } if(args.DType == TF_DataType.DtInvalid) @@ -91,7 +91,6 @@ namespace Tensorflow.Keras.Layers // input_tensor._keras_mask = None new Node(this, new NodeArgs { - InputTensors = args.InputTensor, Outputs = args.InputTensor }); diff --git a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs index fc0b209f..51c1056a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs @@ -11,15 +11,35 @@ namespace Tensorflow.Keras.Layers { public Conv2D Conv2D(int filters, TensorShape kernel_size = null, + TensorShape strides = null, string padding = "valid", - string activation = "relu") - => new Conv2D(new Conv2DArgs - { - Filters = filters, - KernelSize = kernel_size, - Padding = padding, - Activation = GetActivationByName(activation) - }); + string data_format = null, + TensorShape dilation_rate = null, + int groups = 1, + string activation = null, + bool use_bias = true, + IInitializer kernel_initializer = null, + IInitializer bias_initializer = null, + IRegularizer kernel_regularizer = null, + IRegularizer bias_regularizer = null, + IRegularizer activity_regularizer = null) + => new Conv2D(new Conv2DArgs + { + Rank = 2, + Filters = filters, + KernelSize = kernel_size, + Strides = strides == null ? (1, 1) : strides, + Padding = padding, + DataFormat = data_format, + DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, + Groups = groups, + KernelRegularizer = kernel_regularizer, + KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, + BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, + BiasRegularizer = bias_regularizer, + ActivityRegularizer = activity_regularizer, + Activation = GetActivationByName(activation) + }); public Dense Dense(int units, @@ -65,6 +85,30 @@ namespace Tensorflow.Keras.Layers DataFormat = data_format }); + /// + /// `Input()` is used to instantiate a Keras tensor. + /// + /// A shape tuple not including the batch size. + /// + /// + /// + /// + public Tensors Input(TensorShape shape, + string name = null, + bool sparse = false, + bool ragged = false) + { + var input_layer = new InputLayer(new InputLayerArgs + { + InputShape = shape, + Name = name, + Sparse = sparse, + Ragged = ragged + }); + + return input_layer.InboundNodes[0].Outputs; + } + public MaxPooling2D MaxPooling2D(TensorShape pool_size = null, TensorShape strides = null, string padding = "valid") diff --git a/src/TensorFlowNET.Core/Keras/Regularizers.cs b/src/TensorFlowNET.Core/Keras/Regularizers.cs new file mode 100644 index 00000000..1102b62b --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Regularizers.cs @@ -0,0 +1,12 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public class Regularizers + { + public IRegularizer l2(float l2 = 0.01f) + => new L2(l2); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs new file mode 100644 index 00000000..a54a81c7 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public interface IRegularizer + { + Tensor Apply(RegularizerArgs args); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/L2.cs b/src/TensorFlowNET.Core/Keras/Regularizers/L2.cs new file mode 100644 index 00000000..c0fa7078 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Regularizers/L2.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public class L2 : IRegularizer + { + float l2; + + public L2(float l2 = 0.01f) + { + this.l2 = l2; + } + + public Tensor Apply(RegularizerArgs args) + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs new file mode 100644 index 00000000..18bf87a5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras +{ + public class RegularizerArgs + { + } +} diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs index de9f479b..c49618cf 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs @@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Utils /// /// /// - public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null, - string[] avoid_names = null, string @namespace = "", bool zero_based = false) + public static string unique_layer_name(string name, Dictionary name_uid_map = null, + string[] avoid_names = null, bool zero_based = false) { if (name_uid_map == null) name_uid_map = get_default_graph_uid_map(); @@ -66,41 +66,40 @@ namespace Tensorflow.Keras.Utils string proposed_name = null; while (proposed_name == null || avoid_names.Contains(proposed_name)) { - var name_key = (@namespace, name); - if (!name_uid_map.ContainsKey(name_key)) - name_uid_map[name_key] = 0; + if (!name_uid_map.ContainsKey(name)) + name_uid_map[name] = 0; if (zero_based) { - int number = name_uid_map[name_key]; + int number = name_uid_map[name]; if (number > 0) proposed_name = $"{name}_{number}"; else proposed_name = name; - name_uid_map[name_key] += 1; + name_uid_map[name] += 1; } else { - name_uid_map[name_key] += 1; - proposed_name = $"{name}_{name_uid_map[name_key]}"; + name_uid_map[name] += 1; + proposed_name = $"{name}_{name_uid_map[name]}"; } } return proposed_name; } - public static Dictionary<(string, string), int> get_default_graph_uid_map() + public static Dictionary get_default_graph_uid_map() { var graph = ops.get_default_graph(); - Dictionary<(string, string), int> name_uid_map = null; + Dictionary name_uid_map = null; if (tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) { name_uid_map = tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; } else { - name_uid_map = new Dictionary<(string, string), int>(); + name_uid_map = new Dictionary(); tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; } diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index e07677e5..688e8266 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -183,8 +183,6 @@ namespace Tensorflow.Layers }); } - - protected override string _name_scope() { return _current_scope.original_name_scope; @@ -202,7 +200,7 @@ namespace Tensorflow.Layers } else { - tf_with(tf.variable_scope(scope, default_name: baseName), captured_scope => + tf_with(tf.variable_scope(scope, default_name: base_name), captured_scope => { // convert variable_scope to VariableScope _scope = captured_scope; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 147dccde..13635860 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -41,8 +41,8 @@ namespace Tensorflow.Operations.Initializers public Tensor Apply(InitializerArgs args) { if (args.DType == TF_DataType.DtInvalid) - args.DType = this.dtype; - return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed); + args.DType = dtype; + return random_ops.random_normal(args.Shape, mean, stddev, args.DType, seed: seed); } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs deleted file mode 100644 index be4aca3c..00000000 --- a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs +++ /dev/null @@ -1,84 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System.Linq; - -namespace Tensorflow.Operations -{ - public class Convolution - { - public TensorShape input_shape; - public TensorShape filter_shape; - public string data_format; - public int[] strides; - public string name; - public _WithSpaceToBatch conv_op; - - public Convolution(TensorShape input_shape, - TensorShape filter_shape, - string padding, - int[] strides, - int[] dilation_rate, - string name = null, - string data_format = null) - { - var num_total_dims = filter_shape.ndim; - var num_spatial_dims = num_total_dims - 2; - int input_channels_dim; - int[] spatial_dims; - if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC")) - { - input_channels_dim = input_shape.dims[num_spatial_dims + 1]; - spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray(); - } - else - { - input_channels_dim = input_shape.dims[1]; - spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray(); - } - - this.input_shape = input_shape; - this.filter_shape = filter_shape; - this.data_format = data_format; - this.strides = strides; - this.name = name; - - conv_op = new _WithSpaceToBatch( - input_shape, - dilation_rate: dilation_rate, - padding: padding, - build_op: _build_op, - filter_shape: filter_shape, - spatial_dims: spatial_dims, - data_format: data_format); - } - - public _NonAtrousConvolution _build_op(int _, string padding) - { - return new _NonAtrousConvolution(input_shape, - filter_shape: filter_shape, - padding: padding, - data_format: data_format, - strides: strides, - name: name); - } - - public Tensor __call__(Tensor inp, IVariableV1 filter) - { - return conv_op.__call__(inp, filter); - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs new file mode 100644 index 00000000..75b44af3 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs @@ -0,0 +1,100 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Xml; +using Tensorflow.Keras.ArgsDefinition; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + internal class ConvolutionInternal + { + ConvolutionalArgs args; + + string data_format => args.DataFormat; + string name; + string padding => args.Padding; + + public ConvolutionInternal(ConvolutionalArgs args) + { + this.args = args; + name = args.Name; + } + + public Tensor Apply(Tensors input, IVariableV1 filters) + { + var filters_rank = filters.shape.rank; + var inputs_rank = input.shape.rank; + var num_spatial_dims = args.NumSpatialDims; + if (num_spatial_dims == Unknown) + num_spatial_dims = filters_rank - 2; + + // Channel dimension. + var num_batch_dims = inputs_rank - num_spatial_dims - 1; + if (!new[] { 1, 2, 3 }.Contains(num_spatial_dims)) + throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + + $"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}."); + + var channel_index = num_batch_dims + num_spatial_dims; + var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index); + var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); + + Tensor result = null; + tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope => + { + name = scope; + if (num_spatial_dims == 2) + result = gen_nn_ops.conv2d(new Conv2dParams + { + Input = input, + Filter = filters.AsTensor(), + Strides = strides, + Padding = padding, + DataFormat = data_format, + Dilations = dilations, + Name = name + }); + else + throw new NotImplementedException(""); + }); + + return result; + } + + int[] _get_sequence(int[] value, int n, int channel_index) + { + var seq = new List(); + + if (channel_index == 1) + { + seq.Add(1); + seq.Add(1); + seq.AddRange(value); + } + else + { + seq.Add(1); + seq.AddRange(value); + seq.Add(1); + } + + return seq.ToArray(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs deleted file mode 100644 index f947cdbc..00000000 --- a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs +++ /dev/null @@ -1,83 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Linq; - -namespace Tensorflow.Operations -{ - public class _NonAtrousConvolution - { - public string padding; - public string name; - public int[] strides; - public string data_format; - private Func conv_op; - - public _NonAtrousConvolution(TensorShape input_shape, - TensorShape filter_shape, - string padding, - string data_format, - int[] strides, - string name) - { - this.padding = padding; - this.name = name; - var conv_dims = input_shape.ndim - 2; - if (conv_dims == 1) - { - throw new NotImplementedException("_NonAtrousConvolution conv_dims 1"); - } - else if (conv_dims == 2) - { - var list = strides.ToList(); - - if (string.IsNullOrEmpty(data_format) || data_format == "NHWC") - { - data_format = "NHWC"; - list.Insert(0, 1); - list.Add(1); - } - else if (data_format == "NCHW") - list.InsertRange(0, new int[] { 1, 1 }); - else - throw new ValueError("data_format must be \"NHWC\" or \"NCHW\"."); - - strides = list.ToArray(); - this.strides = strides; - this.data_format = data_format; - conv_op = gen_nn_ops.conv2d; - } - else if (conv_dims == 3) - { - throw new NotImplementedException("_NonAtrousConvolution conv_dims 3"); - } - } - - public Tensor __call__(Tensor inp, IVariableV1 filter) - { - return conv_op(new Conv2dParams - { - Input = inp, - Filter = filter.AsTensor(), - Strides = strides, - Padding = padding, - DataFormat = data_format, - Name = name - }); - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs deleted file mode 100644 index 8ae4ee36..00000000 --- a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs +++ /dev/null @@ -1,76 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System; -using System.Linq; - -namespace Tensorflow.Operations -{ - public class _WithSpaceToBatch - { - private _NonAtrousConvolution call; - - public _WithSpaceToBatch(TensorShape input_shape, - int[] dilation_rate, - string padding, - Func build_op, - TensorShape filter_shape = null, - int[] spatial_dims = null, - string data_format = null) - { - var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); - var rate_shape = dilation_rate_tensor.TensorShape; - var num_spatial_dims = rate_shape.dims[0]; -#pragma warning disable CS0219 // Variable is assigned but its value is never used - int starting_spatial_dim = -1; -#pragma warning restore CS0219 // Variable is assigned but its value is never used - if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) - starting_spatial_dim = 2; - else - starting_spatial_dim = 1; - - if (spatial_dims == null) - throw new NotImplementedException("_WithSpaceToBatch spatial_dims"); - - var orig_spatial_dims = spatial_dims; - spatial_dims = spatial_dims.OrderBy(x => x).ToArray(); - if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1)) - throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers"); - - int expected_input_rank = -1; - if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) - expected_input_rank = spatial_dims.Last(); - else - expected_input_rank = spatial_dims.Last() + 1; - - var const_rate = tensor_util.constant_value(dilation_rate_tensor); - var rate_or_const_rate = dilation_rate; - if(!(const_rate is null)) - { - if (const_rate.Data().Count(x => x == 1) == const_rate.size) - { - call = build_op(num_spatial_dims, padding); - return; - } - } - } - - public Tensor __call__(Tensor inp, IVariableV1 filter) - { - return call.__call__(inp, filter); - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index b239cfd8..fb19ab4e 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -78,35 +78,45 @@ namespace Tensorflow.Operations /// /// /// - public static Tensor conv2d_backprop_filter(Conv2dParams parameters) + public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, + int[] strides, string padding, bool use_cudnn_on_gpu = true, + int[] explicit_paddings = null, + string data_format = "NHWC", + int[] dilations = null, + string name = null) { + if (explicit_paddings == null) + explicit_paddings = new int[0]; + if (dilations == null) + dilations = new int[] { 1, 1, 1, 1 }; + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2DBackpropFilter", parameters.Name, + "Conv2DBackpropFilter", name, null, - parameters.Input, parameters.FilterSizes, parameters.OutBackProp, - "strides", parameters.Strides, - "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, - "padding", parameters.Padding, - "explicit_paddings", parameters.ExplicitPaddings, - "data_format", parameters.DataFormat, - "dilations", parameters.Dilations); + input, filter_sizes, out_backprop, + "strides", strides, + "use_cudnn_on_gpu", use_cudnn_on_gpu, + "padding", padding, + "explicit_paddings", explicit_paddings, + "data_format", data_format, + "dilations", dilations); return results[0]; } - var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new + var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new { - input = parameters.Input, - filter_sizes = parameters.FilterSizes, - out_backprop = parameters.OutBackProp, - strides = parameters.Strides, - padding = parameters.Padding, - use_cudnn_on_gpu = parameters.UseCudnnOnGpu, - explicit_paddings = parameters.ExplicitPaddings, - data_format = parameters.DataFormat, - dilations = parameters.Dilations + input, + filter_sizes, + out_backprop, + strides, + padding, + use_cudnn_on_gpu, + explicit_paddings, + data_format, + dilations }); return _op.outputs[0]; @@ -117,35 +127,45 @@ namespace Tensorflow.Operations /// /// /// - public static Tensor conv2d_backprop_input(Conv2dParams parameters) + public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, + int[] strides, string padding, bool use_cudnn_on_gpu = true, + int[] explicit_paddings = null, + string data_format= "NHWC", + int[] dilations = null, + string name = null) { + if (explicit_paddings == null) + explicit_paddings = new int[0]; + if (dilations == null) + dilations = new int[] { 1, 1, 1, 1 }; + if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2DBackpropInput", parameters.Name, + "Conv2DBackpropInput", name, null, - parameters.InputSizes, parameters.Filter, parameters.OutBackProp, - "strides", parameters.Strides, - "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, - "padding", parameters.Padding, - "explicit_paddings", parameters.ExplicitPaddings, - "data_format", parameters.DataFormat, - "dilations", parameters.Dilations); + input_sizes, filter, out_backprop, + "strides", strides, + "use_cudnn_on_gpu", use_cudnn_on_gpu, + "padding", padding, + "explicit_paddings", explicit_paddings, + "data_format", data_format, + "dilations", dilations); return results[0]; } - var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new + var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new { - input_sizes = parameters.InputSizes, - filter = parameters.Filter, - out_backprop = parameters.OutBackProp, - strides = parameters.Strides, - padding = parameters.Padding, - use_cudnn_on_gpu = parameters.UseCudnnOnGpu, - explicit_paddings = parameters.ExplicitPaddings, - data_format = parameters.DataFormat, - dilations = parameters.Dilations + input_sizes, + filter, + out_backprop, + strides, + padding, + use_cudnn_on_gpu, + explicit_paddings, + data_format, + dilations }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index f3442be8..a56a4cdc 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -33,11 +33,6 @@ namespace Tensorflow /// public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null) { - if (!seed.HasValue) - seed = 0; - if (!seed2.HasValue) - seed2 = 0; - if (tf.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, @@ -51,6 +46,11 @@ namespace Tensorflow return results[0]; } + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal", name: name, args: new { shape, dtype, seed, seed2 }); diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 4c30c34e..8ded44f1 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -16,6 +16,7 @@ using System; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -23,19 +24,18 @@ namespace Tensorflow { public class nn_ops { - public static Convolution Convolution(TensorShape input_shape, - TensorShape filter_shape, - string padding, + internal static ConvolutionInternal convolution_internal(string padding, int[] strides, int[] dilation_rate, string name = null, - string data_format = null) => new Convolution(input_shape, - filter_shape, - padding, - strides, - dilation_rate, - name: name, - data_format: data_format); + string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs + { + Padding = padding, + Strides = strides, + DilationRate = dilation_rate, + DataFormat = data_format, + Name = name + }); /// /// Adds `bias` to `value`. diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 0a3ea47a..7d4f57d9 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -64,7 +64,7 @@ namespace Tensorflow /// The string name of this tensor.
/// Tensor.name is meaningless when eager execution is enabled. ///
- public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}"; + public virtual string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}"; /// /// The index of this tensor in the outputs of its Operation. diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 2f130002..34c26bbb 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -132,7 +132,7 @@ namespace Tensorflow } } - public int this[int index] => dims[index]; + public int this[int index] => index < 0 ? dims[ndim + index] : dims[index]; /// /// Returns True iff `self` is fully defined in every dimension. diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 5fe0043e..fca60f88 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -8,7 +8,7 @@ using static Tensorflow.Binding; namespace Tensorflow { - public class BaseResourceVariable : DisposableObject, IVariableV1 + public class BaseResourceVariable : DisposableObject { protected string _name; public virtual string Name => _handle_name; @@ -92,7 +92,8 @@ namespace Tensorflow return assign_op; } - public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement; + public Tensor value() + => GraphElement ?? _read_variable_op(); protected Tensor _read_variable_op() { @@ -159,7 +160,15 @@ namespace Tensorflow { } - public Tensor AsTensor() - => tf.executing_eagerly() ? read_value() : GraphElement; + public Tensor AsTensor(bool as_ref = true) + { + if (!as_ref && GraphElement != null) + return GraphElement; + + if (as_ref) + return tf.executing_eagerly() ? read_value() : GraphElement; + else + return _read_variable_op(); + } } } diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs index 52549ecc..4367cf09 100644 --- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs +++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs @@ -49,6 +49,6 @@ namespace Tensorflow public TensorShape shape { get; } Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true); Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true); - Tensor AsTensor(); + Tensor AsTensor(bool as_ref = true); } } diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index cf9fe2f1..68df1e66 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -152,7 +152,7 @@ namespace Tensorflow if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); - tf_with(ops.init_scope2(), delegate + tf_with(ops.init_scope(), init_scope => { var values = init_from_fn ? new object[0] : new object[] { initial_value }; tf_with(ops.name_scope(name, "Variable", values), scope => @@ -222,7 +222,7 @@ namespace Tensorflow public Tensor value() => _snapshot; - public Tensor AsTensor() => _snapshot; + public Tensor AsTensor(bool as_ref = true) => _snapshot; public Tensor _as_graph_element() => _variable; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 3655a6db..40fb07bc 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -26,7 +26,7 @@ namespace Tensorflow /// /// Variable based on resource handles. /// - public partial class ResourceVariable : BaseResourceVariable + public partial class ResourceVariable : BaseResourceVariable, IVariableV1 { Tensor _cached_value; public string Device => handle.Device; @@ -90,7 +90,7 @@ namespace Tensorflow collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); _in_graph_mode = !tf.Context.executing_eagerly(); - tf_with(ops.init_scope2(), delegate + tf_with(ops.init_scope(), init_scope => { var values = init_from_fn ? new object[0] : new object[] { initial_value }; tf_with(ops.name_scope(name, "Variable", values), scope => diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index cf935ab3..fb74cb89 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -239,11 +239,8 @@ namespace Tensorflow /// A context manager that lifts ops out of control-flow scopes and function-building graphs. /// /// - public static void init_scope() + public static NameScope init_scope() { - if (tf.Context.executing_eagerly()) - return; - // Retrieve the active name scope: entering an `init_scope` preserves // the name scope of the current context. var default_graph = get_default_graph(); @@ -257,25 +254,11 @@ namespace Tensorflow tf_with(ops.control_dependencies(null), delegate { - var outer_graph = get_default_graph(); + // var outer_graph = get_default_graph(); // outer_device_stack = None }); - } - - public static ITensorFlowObject init_scope2() - { - // Retrieve the active name scope: entering an `init_scope` preserves - // the name scope of the current context. - var default_graph = get_default_graph(); - var scope = default_graph.get_name_scope(); - if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) - // Names that end with trailing slashes are treated by `name_scope` as - // absolute. - scope += "/"; - // inner_device_stack = default_graph._device_function_stack - // var outer_context = default_graph.as_default; - return ops.control_dependencies(null); + return ops.name_scope(scope); } private static int uid_number = -1; @@ -460,6 +443,8 @@ namespace Tensorflow { case NDArray nd: return constant_op.constant(nd, dtype: dtype, name: name); + case EagerTensor tensor: + return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name); case Tensor tensor: return tensor; case Tensor[] tensors: diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 97a24525..1fb2b43c 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -90,6 +90,7 @@ namespace Tensorflow return (scope_name, old_name); } + [DebuggerHidden] public void Dispose() { if (tf.Context.executing_eagerly()) diff --git a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs index 050151af..40d0e22b 100644 --- a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs @@ -28,7 +28,8 @@ namespace TensorFlowNET.UnitTest.Keras // Create a simple model. var inputs = keras.Input(shape: 32); - var outputs = keras.layers.Dense(1).Apply(inputs); + var dense_layer = keras.layers.Dense(1); + var outputs = dense_layer.Apply(inputs); var model = keras.Model(inputs, outputs); model.compile("adam", "mean_squared_error"); return model; diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs index 8c5420da..4ec4eb25 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs @@ -8,6 +8,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI [TestClass] public class BitwiseApiTest : TFNetApiTest { + [TestInitialize] + public void Init() + { + tf.enable_eager_execution(); + } [TestMethod] public void BitwiseAnd()