diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 2bd25da0..abfb6840 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -589,23 +589,17 @@ would not be rank 1.", tensor.op.get_attr("axis"))); { return ""; } + else if (dtype == TF_DataType.TF_RESOURCE) + { + return ""; + } var nd = tensor.numpy(); if (nd.size == 0) return "[]"; - switch (dtype) - { - case TF_DataType.TF_STRING: - return string.Join(string.Empty, nd.ToArray() - .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); - case TF_DataType.TF_VARIANT: - case TF_DataType.TF_RESOURCE: - return ""; - default: - return nd.ToString(); - } + return nd.ToString(); } public static ParsedSliceArgs ParseSlices(Slice[] slices) diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 39ba2a27..6fed2bf3 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; using static Tensorflow.KerasApi; @@ -74,11 +75,15 @@ namespace Tensorflow.Keras.Engine metric_obj = keras.metrics.sparse_categorical_accuracy; else metric_obj = keras.metrics.categorical_accuracy; - - return new MeanMetricWrapper(metric_obj, metric); } + else if(metric == "mean_absolute_error" || metric == "mae") + { + metric_obj = keras.metrics.mean_absolute_error; + } + else + throw new NotImplementedException(""); - throw new NotImplementedException(""); + return new MeanMetricWrapper(metric_obj, metric); } public IEnumerable metrics diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index f165a347..64723a22 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -40,5 +40,11 @@ namespace Tensorflow.Keras.Metrics return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); } + + public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) + { + y_true = math_ops.cast(y_true, y_pred.dtype); + return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1); + } } }