From cc9f06277fd3ef2604a6583f657235749bc34f1e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 1 Nov 2020 16:07:56 -0600 Subject: [PATCH] Fix SparseCategoricalCrossentropy. --- .../Keras/Engine/DataAdapters/DataHandler.cs | 1 + .../Keras/Engine/DataAdapters/IDataAdapter.cs | 1 + .../DataAdapters/TensorLikeDataAdapter.cs | 7 ++ .../Keras/Engine/LossesContainer.cs | 42 +++++++++- .../Keras/Engine/MetricsContainer.cs | 13 +++ .../Keras/Engine/Model.Compile.cs | 6 +- .../Keras/Engine/Model.Fit.cs | 2 +- .../Keras/Engine/Model.Train.cs | 32 ++++++- src/TensorFlowNET.Core/Keras/Engine/Model.cs | 4 +- .../Keras/Losses/ILossFunc.cs | 2 + src/TensorFlowNET.Core/Keras/Losses/Loss.cs | 15 ++++ .../Keras/Losses/LossFunctionWrapper.cs | 6 +- .../Keras/Losses/ReductionV2.cs | 3 + .../Losses/SparseCategoricalCrossentropy.cs | 19 ++++- .../Keras/Metrics/Reduce.cs | 38 +++++++++ .../Keras/Optimizers/OptimizerV2.cs | 7 +- .../Keras/Utils/losses_utils.cs | 83 +++++++++++++++++++ .../Operations/Losses/losses_impl.py.cs | 3 + .../Operations/array_ops.cs | 3 + .../Operations/gen_math_ops.cs | 11 ++- 20 files changed, 276 insertions(+), 22 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Utils/losses_utils.cs diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs index f988cef4..f744473d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/DataHandler.cs @@ -13,6 +13,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters { DataHandlerArgs args; IDataAdapter _adapter; + public IDataAdapter DataAdapter => _adapter; IDatasetV2 _dataset; int _inferred_steps; int _current_step; diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs index 41253824..b39a9eaf 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/IDataAdapter.cs @@ -20,5 +20,6 @@ namespace Tensorflow.Keras.Engine.DataAdapters bool CanHandle(Tensor x, Tensor y = null); IDatasetV2 GetDataset(); int GetSize(); + (Tensor, Tensor) Expand1d(Tensor x, Tensor y); } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 9713694a..8797ba7c 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -89,5 +89,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters public int GetSize() => _size; + + public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) + { + if (y.TensorShape.ndim == 1) + y = array_ops.expand_dims(y, axis: -1); + return (x, y); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs b/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs index 8596f8f4..4db7c0b3 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs @@ -1,9 +1,11 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Engine { @@ -12,6 +14,10 @@ namespace Tensorflow.Keras.Engine ILossFunc _user_losses; ILossFunc _losses; Mean _loss_metric; + bool _built; + Tensor[] _per_output_metrics; + List loss_values; + List loss_metric_values; public LossesContainer(ILossFunc losses, string[] output_names = null) : base(output_names) @@ -19,6 +25,8 @@ namespace Tensorflow.Keras.Engine _user_losses = losses; _losses = losses; _loss_metric = new Mean(name: "loss"); + loss_values = new List(); + loss_metric_values = new List(); _built = false; } @@ -27,14 +35,44 @@ namespace Tensorflow.Keras.Engine /// /// /// - public void Apply(Tensor y_true, Tensor y_pred) + public Tensor Call(Tensor y_true, Tensor y_pred) { + Build(y_pred); + var loss_value = _losses.Call(y_true, y_pred); + var loss_metric_value = loss_value; + var batch_dim = array_ops.shape(y_true)[0]; + /*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE + || _losses.Reduction == ReductionV2.AUTO) + loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/ + + loss_values.append(loss_value); + loss_metric_values.append(loss_metric_value); + + if(loss_values.Count > 0) + { + var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray()); + _loss_metric.update_state(total_loss_metric_value, batch_dim); + // loss_values = losses_utils.cast_losses_to_common_dtype(loss_values); + var total_loss = math_ops.add_n(loss_values.ToArray()); + return total_loss; + } + else + { + // Ok for a model to have no compiled loss. + return array_ops.zeros(new TensorShape()); + } } - public void Build() + public void Build(Tensor y_pred) { + _create_metrics(); + _built = true; + } + void _create_metrics() + { + // _per_output_metrics = _output_names.Select(x => null); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs index c494ec6f..5e8b4122 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs @@ -16,5 +16,18 @@ namespace Tensorflow.Keras.Engine _metrics = metrics; _built = false; } + + public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) + { + if (!_built) + Build(); + + _built = true; + } + + void Build() + { + + } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs index d9e4a0e2..2b037103 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Compile.cs @@ -10,6 +10,8 @@ namespace Tensorflow.Keras.Engine { public partial class Model { + LossesContainer compiled_loss; + MetricsContainer compiled_metrics; public void compile(string optimizerName, ILossFunc lossName) { throw new NotImplementedException(""); @@ -18,8 +20,8 @@ namespace Tensorflow.Keras.Engine public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) { this.optimizer = optimizer; - var compiled_loss = new LossesContainer(loss, output_names: output_names); - var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); + compiled_loss = new LossesContainer(loss, output_names: output_names); + compiled_metrics = new MetricsContainer(metrics, output_names: output_names); int experimental_steps_per_execution = 1; _configure_steps_per_execution(experimental_steps_per_execution); diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs index a22aa44f..09e66b54 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs @@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Engine var val_x = x[new Slice(train_count)]; var val_y = y[new Slice(train_count)]; - var data_handler = new DataHandler(new DataHandlerArgs + data_handler = new DataHandler(new DataHandlerArgs { X = train_x, Y = train_y, diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs index ebe4f710..db00fabe 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs @@ -1,7 +1,10 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Gradients; +using Tensorflow.Keras.Optimizers; using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine @@ -11,7 +14,8 @@ namespace Tensorflow.Keras.Engine Tensor step_function(OwnedIterator iterator) { var data = iterator.next(); - train_step(data[0], data[1]); + var outputs = train_step(data[0], data[1]); + tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); throw new NotImplementedException(""); } @@ -20,11 +24,33 @@ namespace Tensorflow.Keras.Engine /// /// /// - Tensor train_step(Tensor x, Tensor y) + IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y) { + (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); var y_pred = Apply(x, is_training: true); - throw new NotImplementedException(""); + var loss = compiled_loss.Call(y, y_pred); + + // For custom training steps, users can just write: + // trainable_variables = self.trainable_variables + // gradients = tape.gradient(loss, trainable_variables) + // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) + // The _minimize call does a few extra steps unnecessary in most cases, + // such as loss scaling and gradient clipping. + _minimize(tape, optimizer, loss, trainable_variables); + + compiled_metrics.update_state(y, y_pred); + return new[] { ("loss", loss) }; + } + + void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List trainable_variables) + { + var gradients = tape.gradient(loss, trainable_variables); + gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables)); + gradients = optimizer._clip_gradients(gradients); + + optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), + experimental_aggregate_gradients: false); } } } diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs index 0337c467..bbbc5649 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs @@ -6,6 +6,7 @@ using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; using NumSharp; using System.Collections.Generic; +using System.Data.Common; namespace Tensorflow.Keras.Engine { @@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Engine #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used #pragma warning restore CS0108 // Member hides inherited member; missing new keyword ILossFunc loss; - IOptimizer optimizer; + OptimizerV2 optimizer; IVariableV1 _steps_per_execution; protected bool _is_graph_network; protected Tensors inputs; @@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; + DataHandler data_handler; public Model(ModelArgs args) : base(args) diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs index 391203b3..2ec58125 100644 --- a/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs +++ b/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs @@ -6,5 +6,7 @@ namespace Tensorflow.Keras.Losses { public interface ILossFunc { + string Reduction { get; } + Tensor Call(Tensor y_true, Tensor y_pred); } } diff --git a/src/TensorFlowNET.Core/Keras/Losses/Loss.cs b/src/TensorFlowNET.Core/Keras/Losses/Loss.cs index 445f377b..f29a1ae1 100644 --- a/src/TensorFlowNET.Core/Keras/Losses/Loss.cs +++ b/src/TensorFlowNET.Core/Keras/Losses/Loss.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Utils; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Losses { @@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Losses bool _allow_sum_over_batch_size; string _name_scope; + public string Reduction => reduction; + public Loss(string reduction = ReductionV2.AUTO, string name = null) { this.reduction = reduction; @@ -21,6 +25,17 @@ namespace Tensorflow.Keras.Losses _allow_sum_over_batch_size = false; } + public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) + { + throw new NotImplementedException(""); + } + + public Tensor Call(Tensor y_true, Tensor y_pred) + { + var losses = Apply(y_true, y_pred); + return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); + } + void _set_name_scope() { _name_scope = name; diff --git a/src/TensorFlowNET.Core/Keras/Losses/LossFunctionWrapper.cs b/src/TensorFlowNET.Core/Keras/Losses/LossFunctionWrapper.cs index 672f2e72..eeac4632 100644 --- a/src/TensorFlowNET.Core/Keras/Losses/LossFunctionWrapper.cs +++ b/src/TensorFlowNET.Core/Keras/Losses/LossFunctionWrapper.cs @@ -6,15 +6,11 @@ namespace Tensorflow.Keras.Losses { public class LossFunctionWrapper : Loss { - Action fn; - - public LossFunctionWrapper(Action fn, - string reduction = ReductionV2.AUTO, + public LossFunctionWrapper(string reduction = ReductionV2.AUTO, string name = null) : base(reduction: reduction, name: name) { - this.fn = fn; } } } diff --git a/src/TensorFlowNET.Core/Keras/Losses/ReductionV2.cs b/src/TensorFlowNET.Core/Keras/Losses/ReductionV2.cs index 82bbb7af..da4dca2e 100644 --- a/src/TensorFlowNET.Core/Keras/Losses/ReductionV2.cs +++ b/src/TensorFlowNET.Core/Keras/Losses/ReductionV2.cs @@ -6,6 +6,9 @@ namespace Tensorflow.Keras.Losses { public class ReductionV2 { + public const string NONE = "none"; public const string AUTO = "auto"; + public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; + public const string WEIGHTED_MEAN = "weighted_mean"; } } diff --git a/src/TensorFlowNET.Core/Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Core/Keras/Losses/SparseCategoricalCrossentropy.cs index 72e5555b..c7ec060f 100644 --- a/src/TensorFlowNET.Core/Keras/Losses/SparseCategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Core/Keras/Losses/SparseCategoricalCrossentropy.cs @@ -1,6 +1,8 @@ using System; using System.Collections.Generic; +using System.Data; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Losses { @@ -9,16 +11,27 @@ namespace Tensorflow.Keras.Losses public SparseCategoricalCrossentropy(bool from_logits = false, string reduction = ReductionV2.AUTO, string name = "sparse_categorical_crossentropy") : - base(sparse_categorical_crossentropy, - reduction: reduction, + base(reduction: reduction, name: name) { } - static void sparse_categorical_crossentropy() + public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) { + target = tf.cast(target, dtype: TF_DataType.TF_INT64); + // Try to adjust the shape so that rank of labels = rank of logits - 1. + var output_shape = array_ops.shape_v2(output); + var output_rank = output.TensorShape.ndim; + var target_rank = target.TensorShape.ndim; + var update_shape = target_rank != output_rank - 1; + if (update_shape) + { + target = array_ops.reshape(target, new int[] { -1 }); + output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); + } + return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); } } } diff --git a/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs b/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs index 79613b2d..99d1473e 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Utils; using static Tensorflow.Binding; namespace Tensorflow.Keras.Metrics @@ -13,9 +15,14 @@ namespace Tensorflow.Keras.Metrics { IVariableV1 total; IVariableV1 count; + string _reduction; + TF_DataType _dtype; + public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) : base(name: name, dtype: dtype) { + _reduction = reduction; + _dtype = dtype; total = add_weight("total", initializer: tf.zeros_initializer); if (reduction == Reduction.WEIGHTED_MEAN || @@ -24,5 +31,36 @@ namespace Tensorflow.Keras.Metrics count = add_weight("count", initializer: tf.zeros_initializer); } } + + public Tensor update_state(Tensor values, Tensor sample_weight = null) + { + if(sample_weight != null) + { + (values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( + values, sample_weight: sample_weight); + + sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); + values = math_ops.multiply(values, sample_weight); + } + + Tensor update_total_op = null; + var value_sum = math_ops.reduce_sum(values); + tf_with(ops.control_dependencies(new[] { value_sum }), ctl => + { + var update_total_op = total.assign_add(value_sum); + }); + + Tensor num_values = null; + if (_reduction == ReductionV2.WEIGHTED_MEAN) + { + if (sample_weight == null) + num_values = math_ops.cast(array_ops.size(values), _dtype); + else + num_values = math_ops.reduce_sum(sample_weight); + } + + return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl + => count.assign_add(num_values)); + } } } diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index 54f217e3..81b9f59a 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -111,11 +111,16 @@ namespace Tensorflow.Keras.Optimizers }); } - Tensor[] _aggregate_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars) + public Tensor[] _aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars) { return grads_and_vars.Select(x => x.Item1).ToArray(); } + public Tensor[] _clip_gradients(Tensor[] grads) + { + return grads; + } + protected IVariableV1 get_slot(IVariableV1 var, string slot_name) { var slot_dict = _slots[var.UniqueId]; diff --git a/src/TensorFlowNET.Core/Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/losses_utils.cs new file mode 100644 index 00000000..c0f420e7 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Utils/losses_utils.cs @@ -0,0 +1,83 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Utils +{ + public class losses_utils + { + public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) + { + if (sample_weight == null) + sample_weight = tf.constant(1.0f); + var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); + // Apply reduction function to the individual weighted losses. + var loss = reduce_weighted_loss(weighted_losses, reduction); + // Convert the result back to the input type. + // loss = math_ops.cast(loss, losses.dtype); + return loss; + } + + public static Tensor scale_losses_by_sample_weight(Tensor losses, Tensor sample_weight) + { + // losses = math_ops.cast(losses, dtypes.float32); + // sample_weight = math_ops.cast(sample_weight, dtypes.float32); + // Update dimensions of `sample_weight` to match with `losses` if possible. + // (losses, sample_weight) = squeeze_or_expand_dimensions(losses, sample_weight); + return math_ops.multiply(losses, sample_weight); + } + + public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) + { + var weights_shape = sample_weight.TensorShape; + var weights_rank = weights_shape.ndim; + if (weights_rank == 0) + return (y_pred, sample_weight); + throw new NotImplementedException(""); + } + + public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) + { + if (reduction == ReductionV2.NONE) + return weighted_losses; + else + { + var loss = math_ops.reduce_sum(weighted_losses); + if (reduction == ReductionV2.SUM_OVER_BATCH_SIZE) + loss = _safe_mean(loss, _num_elements(weighted_losses)); + return loss; + } + } + + static Tensor _safe_mean(Tensor losses, Tensor num_present) + { + var total_loss = math_ops.reduce_sum(losses); + return math_ops.div_no_nan(total_loss, num_present, name: "value"); + } + + static Tensor _num_elements(Tensor losses) + { + return tf_with(ops.name_scope("num_elements"), scope => + { + return math_ops.cast(array_ops.size(losses, name: scope), dtype: losses.dtype); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index 783a20da..f4770792 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -24,6 +24,9 @@ namespace Tensorflow public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) { + if (weights == null) + weights = tf.constant(1.0f); + return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate { // Save the `reduction` argument for loss normalization when distributing diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 35d89e6c..5ef191da 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -521,6 +521,9 @@ namespace Tensorflow public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) => shape_internal(input, name, optimize: true, out_type: out_type); + public static Tensor shape_v2(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => shape_internal(input, name, optimize: true, out_type: out_type); + public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) => size_internal(input, name, optimize: optimize, out_type: out_type); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index c48fa3d2..2666f4dd 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -118,10 +118,13 @@ namespace Tensorflow /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) /// public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) - { - var op = tf.OpDefLib._apply_op_helper("DivNoNan", name: name, args: new { x, y }); - return op.output; - } + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "DivNoNan", name, + null, + x, y).FirstOrDefault(), + x, y); /// /// Computes the mean of elements across dimensions of a tensor.