diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs index 879a38e4..7876a990 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Text; using static Tensorflow.Binding; +using System.Linq; namespace Tensorflow { @@ -62,5 +63,8 @@ namespace Tensorflow }); } } + + public Tensor this[params string[] slices] + => this[slices.Select(x => new Slice(x)).ToArray()]; } } diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 037703c8..790221f8 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -77,9 +77,9 @@ namespace Tensorflow.Keras.Engine metric_obj = keras.metrics.categorical_accuracy; } else if(metric == "mean_absolute_error" || metric == "mae") - { metric_obj = keras.metrics.mean_absolute_error; - } + else if (metric == "mean_absolute_percentage_error" || metric == "mape") + metric_obj = keras.metrics.mean_absolute_percentage_error; else throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs index 71bd2f38..7b051f1d 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -42,9 +42,10 @@ namespace Tensorflow.Keras.Engine _ => throw new NotImplementedException("") }; - var _loss = loss switch + ILossFunc _loss = loss switch { "mse" => new MeanSquaredError(), + "mae" => new MeanAbsoluteError(), _ => throw new NotImplementedException("") }; diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index c8d54fc9..3d614e02 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -46,5 +46,12 @@ namespace Tensorflow.Keras.Metrics y_true = math_ops.cast(y_true, y_pred.dtype); return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1); } + + public Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred) + { + y_true = math_ops.cast(y_true, y_pred.dtype); + var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon()); + return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1); + } } }