Browse Source

Add mean_absolute_error to metrics.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
a994a86794
3 changed files with 19 additions and 14 deletions
  1. +5
    -11
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  2. +8
    -3
      src/TensorFlowNET.Keras/Engine/MetricsContainer.cs
  3. +6
    -0
      src/TensorFlowNET.Keras/Metrics/MetricsApi.cs

+ 5
- 11
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -589,23 +589,17 @@ would not be rank 1.", tensor.op.get_attr("axis")));
{ {
return "<unprintable>"; return "<unprintable>";
} }
else if (dtype == TF_DataType.TF_RESOURCE)
{
return "<unprintable>";
}


var nd = tensor.numpy(); var nd = tensor.numpy();


if (nd.size == 0) if (nd.size == 0)
return "[]"; return "[]";


switch (dtype)
{
case TF_DataType.TF_STRING:
return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE:
return "<unprintable>";
default:
return nd.ToString();
}
return nd.ToString();
} }


public static ParsedSliceArgs ParseSlices(Slice[] slices) public static ParsedSliceArgs ParseSlices(Slice[] slices)


+ 8
- 3
src/TensorFlowNET.Keras/Engine/MetricsContainer.cs View File

@@ -1,6 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Metrics;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;


@@ -74,11 +75,15 @@ namespace Tensorflow.Keras.Engine
metric_obj = keras.metrics.sparse_categorical_accuracy; metric_obj = keras.metrics.sparse_categorical_accuracy;
else else
metric_obj = keras.metrics.categorical_accuracy; metric_obj = keras.metrics.categorical_accuracy;

return new MeanMetricWrapper(metric_obj, metric);
} }
else if(metric == "mean_absolute_error" || metric == "mae")
{
metric_obj = keras.metrics.mean_absolute_error;
}
else
throw new NotImplementedException("");


throw new NotImplementedException("");
return new MeanMetricWrapper(metric_obj, metric);
} }


public IEnumerable<Metric> metrics public IEnumerable<Metric> metrics


+ 6
- 0
src/TensorFlowNET.Keras/Metrics/MetricsApi.cs View File

@@ -40,5 +40,11 @@ namespace Tensorflow.Keras.Metrics


return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT);
} }

public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred)
{
y_true = math_ops.cast(y_true, y_pred.dtype);
return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1);
}
} }
} }

Loading…
Cancel
Save