From b79d6bcafb86f87cf0de5db43a108adf113c5955 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 10 Oct 2020 08:47:08 -0500 Subject: [PATCH] Add tf.Context.RunInAutoMode to switch mode automatically. --- src/TensorFlowNET.Core/Contexts/Context.cs | 24 ++++ .../Keras/ArgsDefinition/ZeroPadding2DArgs.cs | 12 ++ src/TensorFlowNET.Core/Keras/BackendImpl.cs | 33 ++++++ .../Keras/Engine/BaseLayerUtils.cs | 47 ++++++++ .../Keras/Engine/Functional.cs | 12 +- .../Keras/Engine/KerasHistory.cs | 8 +- .../Keras/Engine/Layer.LoadWeights.cs | 18 +++ src/TensorFlowNET.Core/Keras/Engine/Layer.cs | 49 ++++++-- src/TensorFlowNET.Core/Keras/Engine/Node.cs | 2 +- .../Keras/Layers/BatchNormalization.cs | 31 +++-- .../Keras/Layers/LayersApi.cs | 15 ++- .../Keras/Layers/ZeroPadding2D.cs | 39 +++++++ src/TensorFlowNET.Core/Layers/Layer.cs | 4 +- .../Operations/NnOps/ConvolutionInternal.cs | 8 +- .../Operations/OpDefLibrary.cs | 2 +- .../Operations/gen_array_ops.cs | 110 ++++++++---------- .../Operations/gen_math_ops.cs | 69 ++++------- src/TensorFlowNET.Core/ops.cs | 15 ++- src/TensorFlowNET.Core/ops.name_scope.cs | 7 +- 19 files changed, 355 insertions(+), 150 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/ZeroPadding2DArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs create mode 100644 src/TensorFlowNET.Core/Keras/Engine/Layer.LoadWeights.cs create mode 100644 src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 16bb6e5b..d605b8a8 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Linq; using Tensorflow.Eager; namespace Tensorflow.Contexts @@ -87,6 +88,29 @@ namespace Tensorflow.Contexts context_switches.Pop(); } + public Tensor RunInAutoMode(Func graphAction, Func eagerAction, params Tensor[] tensors) + { + var shouldRunInEager = executing_eagerly() + && tensors.Count(x => x.IsEagerTensor) == tensors.Length; + + if (shouldRunInEager) + return eagerAction(); + else + { + if (executing_eagerly()) + { + graph_mode(); + var result = graphAction(); + restore_mode(); + return result; + } + else + { + return graphAction(); + } + } + } + public void Dispose() => Handle.Dispose(); } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ZeroPadding2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ZeroPadding2DArgs.cs new file mode 100644 index 00000000..5103839c --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ZeroPadding2DArgs.cs @@ -0,0 +1,12 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class ZeroPadding2DArgs : LayerArgs + { + public NDArray Padding { get; set; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/BackendImpl.cs b/src/TensorFlowNET.Core/Keras/BackendImpl.cs index ef9b3d97..00e1587c 100644 --- a/src/TensorFlowNET.Core/Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Core/Keras/BackendImpl.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using NumSharp; using System; using System.Collections.Generic; using static Tensorflow.Binding; @@ -121,6 +122,38 @@ namespace Tensorflow.Keras _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); } + /// + /// Pads the 2nd and 3rd dimensions of a 4D tensor. + /// + /// + /// + /// + /// + public Tensor spatial_2d_padding(Tensor x, NDArray padding = null, string data_format = null) + { + if (padding == null) + padding = new[,] { { 1, 1 }, { 1, 1 } }; + + NDArray pattern; + + if (data_format == "channels_first") + pattern = new int[,] + { + { 0, 0 }, + { 0, 0 }, + { padding[0][0], padding[0][1] }, + { padding[1][0], padding[1][1] } + }; + else + pattern = new int[,] + { + { 0, 0 }, + { padding[0][0], padding[0][1] }, + { padding[1][0], padding[1][1] }, + { 0, 0 } + }; + return array_ops.pad(x, pattern); + } public class _DummyEagerGraph { } diff --git a/src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs b/src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs new file mode 100644 index 00000000..dbaa1247 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/BaseLayerUtils.cs @@ -0,0 +1,47 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography.X509Certificates; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine +{ + public class BaseLayerUtils + { + public static Layer[] CreateKerasHistoryHelper(Tensors tensors) + { + var processed_ops = new List(); + var created_layers = new List(); + + foreach (var tensor in tensors) + { + if (tensor.KerasHistory != null) + continue; + + var op = tensor.op; + if (!processed_ops.Contains(op)) + { + var layer_inputs = new List(); + + foreach (var (i, op_input) in enumerate(op.inputs._inputs)) + { + if (uses_keras_history(op_input)) + layer_inputs.Add(op_input); + else + { + + } + } + } + } + + return created_layers.ToArray(); + } + + static bool uses_keras_history(Tensor op_input) + { + return Layer.KerasHistories.Any(x => x.tensor == op_input); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs index fe2f0728..adbf3073 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs @@ -1,5 +1,7 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Security.Cryptography.X509Certificates; using System.Text; using Tensorflow.Keras.ArgsDefinition; @@ -47,12 +49,15 @@ namespace Tensorflow.Keras.Engine // A graph network does not autocast inputs, as its layers will cast them instead. _autocast = false; + if (outputs.Any(x => x.KerasHistory == null)) + BaseLayerUtils.CreateKerasHistoryHelper(outputs); + // Build self._output_layers: - foreach(var x in outputs) + foreach (var x in outputs) { var (layer, node_index, tensor_index) = x.KerasHistory; _output_layers.append(layer); - _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); + _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); } // Build self._input_layers: @@ -60,8 +65,9 @@ namespace Tensorflow.Keras.Engine { var (layer, node_index, tensor_index) = x.KerasHistory; _input_layers.append(layer); - _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); + _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); } } + } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs index dd32f473..2d627768 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs @@ -12,12 +12,15 @@ namespace Tensorflow.Keras.Engine Layer layer; int node_index; int tensor_index; + public Tensor tensor; - public KerasHistory(Layer layer, int node_index, int tensor_index) + public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) { this.layer = layer; this.node_index = node_index; this.tensor_index = tensor_index; + this.tensor = tensor; + Console.WriteLine(tensor.name); } public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) @@ -27,6 +30,9 @@ namespace Tensorflow.Keras.Engine tensor_index = this.tensor_index; } + public override string ToString() + => $"{layer.GetType().Name} {layer.Name} {tensor.name}"; + public static implicit operator Layer(KerasHistory history) => history.layer; } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.LoadWeights.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.LoadWeights.cs new file mode 100644 index 00000000..99ced6a3 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.LoadWeights.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public partial class Layer + { + /// + /// Loads all layer weights, either from a TensorFlow or an HDF5 weight file. + /// + /// + public void load_weights(string filepath) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index b9df4ce7..c3dfb665 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -56,6 +56,7 @@ namespace Tensorflow.Keras.Engine /// Provides information about which inputs are compatible with the layer. /// protected InputSpec inputSpec; + bool dynamic = true; public bool SupportsMasking { get; set; } protected List trainableWeights; public List trainable_variables @@ -88,6 +89,7 @@ namespace Tensorflow.Keras.Engine ThreadLocal callContext; public CallContext CallContext => callContext.Value; + public static List KerasHistories = new List(); public Layer(LayerArgs args) { @@ -129,6 +131,11 @@ namespace Tensorflow.Keras.Engine Value = new CallContext() }; + var history = inputs.Where(x => x.KerasHistory != null + && !KerasHistories.Contains(x.KerasHistory)) + .Select(x => x.KerasHistory); + KerasHistories.AddRange(history); + if (_in_functional_construction_mode(inputs)) return _functional_construction_call(inputs); @@ -166,7 +173,8 @@ namespace Tensorflow.Keras.Engine bool _in_functional_construction_mode(Tensors inputs) { - return inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); + return tf.Context.executing_eagerly() + && inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); } Tensors _functional_construction_call(Tensors inputs) @@ -191,6 +199,15 @@ namespace Tensorflow.Keras.Engine { MaybeBuild(inputs); + // Wrapping `call` function in autograph to allow for dynamic control + // flow and control dependencies in call. We are limiting this to + // subclassed layers as autograph is strictly needed only for + // subclassed layers and models. + // tf_convert will respect the value of autograph setting in the + // enclosing tf.function, if any. + if (!dynamic) + throw new NotImplementedException(""); + outputs = call(inputs); outputs = _set_connectivity_metadata_(inputs, outputs); @@ -243,6 +260,13 @@ namespace Tensorflow.Keras.Engine return null; } + /// + /// Subclass has to override this method. + /// + /// + /// + /// + /// protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) { throw new NotImplementedException(""); @@ -263,9 +287,9 @@ namespace Tensorflow.Keras.Engine tf.init_scope(); - //tf.Context.eager_mode(); + tf.Context.eager_mode(); build(inputs.shape); - //tf.Context.restore_mode(); + tf.Context.restore_mode(); built = true; } @@ -282,18 +306,14 @@ namespace Tensorflow.Keras.Engine protected virtual IVariableV1 add_weight(string name, TensorShape shape, - TF_DataType dtype = TF_DataType.DtInvalid, + TF_DataType dtype = TF_DataType.TF_FLOAT, IInitializer initializer = null, IRegularizer regularizer = null, - bool? trainable = null, + VariableSynchronization synchronization = VariableSynchronization.Auto, + VariableAggregation aggregation = VariableAggregation.None, + bool trainable = true, Func getter = null) { - if (dtype == TF_DataType.DtInvalid) - dtype = TF_DataType.TF_FLOAT; - - if (trainable == null) - trainable = true; - // Initialize variable when no initializer provided if (initializer == null) { @@ -306,6 +326,9 @@ namespace Tensorflow.Keras.Engine throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); } + if (synchronization == VariableSynchronization.OnRead) + trainable = false; + var args = new VariableArgs { Name = name, @@ -314,7 +337,9 @@ namespace Tensorflow.Keras.Engine Getter = getter ?? base_layer_utils.make_variable, Overwrite = true, Initializer = initializer, - Trainable = trainable.Value + Synchronization = synchronization, + Aggregation = aggregation, + Trainable = trainable }; var variable = _add_variable_with_custom_getter(args); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Node.cs b/src/TensorFlowNET.Core/Keras/Engine/Node.cs index 5eef1195..923db038 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Node.cs @@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine // Set metadata on outputs. var node_index = layer.InboundNodes.Count - 1; foreach (var (i, tensor) in enumerate(Outputs)) - tensor.KerasHistory = new KerasHistory(layer, node_index, i); + tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 3d6287cb..c452d485 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -38,8 +38,8 @@ namespace Tensorflow.Keras.Layers string _data_format; IInitializer beta_initializer => args.BetaInitializer; IInitializer gamma_initializer => args.GammaInitializer; - IInitializer moving_mean_initializer; - IInitializer moving_variance_initializer; + IInitializer moving_mean_initializer => args.MovingMeanInitializer; + IInitializer moving_variance_initializer => args.MovingVarianceInitializer; IRegularizer gamma_regularizer => args.GammaRegularizer; IVariableV1 gamma; IVariableV1 beta; @@ -101,13 +101,17 @@ namespace Tensorflow.Keras.Layers param_shape, dtype: param_dtype, initializer: moving_mean_initializer, + synchronization: VariableSynchronization.OnRead, + aggregation: VariableAggregation.Mean, trainable: false); moving_variance = add_weight("moving_variance", - shape: param_shape, - dtype: param_dtype, - initializer: moving_variance_initializer, - trainable: false); + shape: param_shape, + dtype: param_dtype, + initializer: moving_variance_initializer, + synchronization: VariableSynchronization.OnRead, + aggregation: VariableAggregation.Mean, + trainable: false); if (renorm) throw new NotImplementedException("build when renorm is true"); @@ -131,6 +135,12 @@ namespace Tensorflow.Keras.Layers private Tensor _fused_batch_norm(Tensor inputs, Tensor training) { + TensorShape input_batch_size = null; + var use_fused_avg_updates = true; + float exponential_avg_factor = 0; + if (use_fused_avg_updates) + exponential_avg_factor = 1.0f - momentum; + var beta = this.beta; var gamma = this.gamma; @@ -146,17 +156,22 @@ namespace Tensorflow.Keras.Layers Func _fused_batch_norm_inference = () => { + var moving_mean_tensor = moving_mean.AsTensor(); + var moving_variance_tensor = moving_variance.AsTensor(); return tf.nn.fused_batch_norm( inputs, gamma, beta, - mean: moving_mean.AsTensor(), - variance: moving_variance.AsTensor(), + mean: moving_mean_tensor, + variance: moving_variance_tensor, epsilon: epsilon, is_training: false, data_format: _data_format); }; + if (use_fused_avg_updates && input_batch_size != null) + throw new NotImplementedException(""); + var results = tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); var (output, mean, variance) = (results[0], results[1], results[2]); var training_value = tf_utils.constant_value(training); diff --git a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs index 51c1056a..98c45e15 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; @@ -33,6 +34,7 @@ namespace Tensorflow.Keras.Layers DataFormat = data_format, DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, Groups = groups, + UseBias = use_bias, KernelRegularizer = kernel_regularizer, KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, @@ -129,6 +131,17 @@ namespace Tensorflow.Keras.Layers InputShape = input_shape }); + /// + /// Zero-padding layer for 2D input (e.g. picture). + /// + /// + /// + public ZeroPadding2D ZeroPadding2D(NDArray padding) + => new ZeroPadding2D(new ZeroPadding2DArgs + { + Padding = padding + }); + Activation GetActivationByName(string name) => name switch { diff --git a/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs b/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs new file mode 100644 index 00000000..7e6d06a8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs @@ -0,0 +1,39 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Zero-padding layer for 2D input (e.g. picture). + /// + /// This layer can add rows and columns of zeros + /// at the top, bottom, left and right side of an image tensor. + /// + public class ZeroPadding2D : Layer + { + string data_format; + NDArray padding; + InputSpec input_spec; + + public ZeroPadding2D(ZeroPadding2DArgs args, string data_format = null) + : base(args) + { + this.data_format = conv_utils.normalize_data_format(data_format); + this.padding = args.Padding; + this.input_spec = new InputSpec(ndim: 4); + } + + protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false) + { + return tf.keras.backend.spatial_2d_padding(inputs, + padding: padding, + data_format: data_format); + } + } +} diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index 688e8266..b7ec2ea1 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -127,7 +127,7 @@ namespace Tensorflow.Layers int[] shape, TF_DataType dtype = TF_DataType.DtInvalid, IInitializer initializer = null, - bool? trainable = null, + bool trainable = true, VariableSynchronization synchronization = VariableSynchronization.Auto, VariableAggregation aggregation = VariableAggregation.None) { @@ -137,8 +137,6 @@ namespace Tensorflow.Layers if (synchronization == VariableSynchronization.OnRead) trainable = false; - else if (!trainable.HasValue) - trainable = true; if (default_graph.building_function) { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs index 75b44af3..c25f7da3 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs @@ -56,20 +56,24 @@ namespace Tensorflow.Operations var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); Tensor result = null; - tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope => + tf_with(ops.name_scope(name, default_name: null), scope => { name = scope; if (num_spatial_dims == 2) + { + var filters_tensor = filters.AsTensor(); + result = gen_nn_ops.conv2d(new Conv2dParams { Input = input, - Filter = filters.AsTensor(), + Filter = filters_tensor, Strides = strides, Padding = padding, DataFormat = data_format, Dilations = dilations, Name = name }); + } else throw new NotImplementedException(""); }); diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index e3ee5ef8..20cb81a2 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -263,7 +263,7 @@ namespace Tensorflow List types, List base_types, List input_types, - dynamic values) + object values) { var input_name = input_arg.Name; diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 6bce44c5..1452271b 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -73,6 +73,16 @@ namespace Tensorflow return _op.output; } + public static Tensor concat_v2(Tensor[] values, int axis, string name = null) + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("ConcatV2", name: name, + args: new { values, axis }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "ConcatV2", name, + null, + values, axis).FirstOrDefault(), + values); + private static Tensor concat_v2_eager_fallback(T1[] values, T2 axis, string name, Context ctx) { var _attr_N = len(values); @@ -293,20 +303,13 @@ namespace Tensorflow } public static Tensor reshape(Tensor tensor, T shape, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Reshape", name, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Reshape", name, null, - tensor, shape); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }); - return _op.output; - } + tensor, shape).FirstOrDefault(), + tensor); public static Tensor reshape(Tensor tensor, int[] shape, string name = null) { @@ -399,21 +402,15 @@ namespace Tensorflow } public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Shape", name, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Shape", name, + new { input, out_type }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Shape", name, null, input, - "out_type", out_type); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Shape", name, new { input, out_type }); - return _op.outputs[0]; - } + "out_type", out_type).FirstOrDefault(), + input); /// /// Returns shape of tensors. @@ -460,20 +457,13 @@ namespace Tensorflow } public static Tensor tile(Tensor input, T multiples, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tile", name, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Tile", name, null, - input, multiples); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }); - return _op.outputs[0]; - } + input, multiples).FirstOrDefault(), + input); public static Tensor transpose(T1 x, T2 perm, string name = null) { @@ -510,37 +500,29 @@ namespace Tensorflow int new_axis_mask = 0, int shrink_axis_mask = 0, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "StridedSlice", name, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("StridedSlice", name, new + { + input, + begin, + end, + strides, + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "StridedSlice", name, null, input, begin, end, strides, "begin_mask", begin_mask, "end_mask", end_mask, "ellipsis_mask", ellipsis_mask, "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("StridedSlice", name, new - { - input, - begin, - end, - strides, - begin_mask, - end_mask, - ellipsis_mask, - new_axis_mask, - shrink_axis_mask - }); - - return _op.outputs[0]; - } + "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), + input, begin, end, strides); public static Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides, int begin_mask = 0, diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 1c27cb1d..46106851 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -319,21 +319,13 @@ namespace Tensorflow /// Specifically, y = 1 / (1 + exp(-x)). /// public static Tensor sigmoid(Tensor x, string name = "Sigmoid") - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sigmoid", name, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Sigmoid", name, null, - x); - - return results[0]; - } - - var op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }); - - return op.output; - } + x).FirstOrDefault(), + x); /// /// Computes the gradient of the sigmoid of x wrt its input. @@ -668,11 +660,13 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `x`. public static Tensor exp(Tensor x, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Exp", name, + null, + x).FirstOrDefault(), + x); /// /// Computes natural logarithm of x element-wise. @@ -698,22 +692,14 @@ namespace Tensorflow } public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate= false, string name= null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, "Cast", name, null, x, - "DstT", DstT, "Truncate", Truncate); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); - - return _op.outputs[0]; - } + "DstT", DstT, "Truncate", Truncate).FirstOrDefault(), + x); public static Tensor neg(Tensor x, string name = null) { @@ -1151,20 +1137,13 @@ namespace Tensorflow /// /// public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Range", name, + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Range", name, null, - start, limit, delta); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }); - - return _op.outputs[0]; - } + start, limit, delta).FirstOrDefault(), + start, limit, delta); /// /// Rounds the values of a tensor to the nearest integer, element-wise. diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index fb74cb89..153039a8 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -225,14 +225,12 @@ namespace Tensorflow public static string name_from_scope_name(string name) { - if (name.EndsWith("/")) - { + if (name == null) + return null; + else if (name.EndsWith("/")) return name.Substring(0, name.Length - 1); - } else - { return name; - } } /// @@ -444,7 +442,12 @@ namespace Tensorflow case NDArray nd: return constant_op.constant(nd, dtype: dtype, name: name); case EagerTensor tensor: - return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name); + if (tf.executing_eagerly()) + return tensor; + else + return tensor.dtype == TF_DataType.TF_RESOURCE + ? tensor.AsPlaceholder(name: name) + : tensor.AsContatnt(name: name); case Tensor tensor: return tensor; case Tensor[] tensors: diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 1fb2b43c..5fb67030 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -48,13 +48,13 @@ namespace Tensorflow public void __enter__() { - _name = _name ?? _default_name; if (tf.Context.executing_eagerly()) { (scope_name, old_scope_name) = enter_eager_name_scope(tf.Context, _name); } else { + _name = _name ?? _default_name; Graph g = null; if (_values is List vList) @@ -72,7 +72,8 @@ namespace Tensorflow private (string, string) enter_eager_name_scope(Context ctx, string name) { - if (name == null) + return (null, null); + /*if (name == null) name = ""; var scope_name = name; @@ -87,7 +88,7 @@ namespace Tensorflow } ctx.ScopeName = scope_name; - return (scope_name, old_name); + return (scope_name, old_name);*/ } [DebuggerHidden]