@@ -589,23 +589,17 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
{ | { | ||||
return "<unprintable>"; | return "<unprintable>"; | ||||
} | } | ||||
else if (dtype == TF_DataType.TF_RESOURCE) | |||||
{ | |||||
return "<unprintable>"; | |||||
} | |||||
var nd = tensor.numpy(); | var nd = tensor.numpy(); | ||||
if (nd.size == 0) | if (nd.size == 0) | ||||
return "[]"; | return "[]"; | ||||
switch (dtype) | |||||
{ | |||||
case TF_DataType.TF_STRING: | |||||
return string.Join(string.Empty, nd.ToArray<byte>() | |||||
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | |||||
case TF_DataType.TF_VARIANT: | |||||
case TF_DataType.TF_RESOURCE: | |||||
return "<unprintable>"; | |||||
default: | |||||
return nd.ToString(); | |||||
} | |||||
return nd.ToString(); | |||||
} | } | ||||
public static ParsedSliceArgs ParseSlices(Slice[] slices) | public static ParsedSliceArgs ParseSlices(Slice[] slices) | ||||
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -74,11 +75,15 @@ namespace Tensorflow.Keras.Engine | |||||
metric_obj = keras.metrics.sparse_categorical_accuracy; | metric_obj = keras.metrics.sparse_categorical_accuracy; | ||||
else | else | ||||
metric_obj = keras.metrics.categorical_accuracy; | metric_obj = keras.metrics.categorical_accuracy; | ||||
return new MeanMetricWrapper(metric_obj, metric); | |||||
} | } | ||||
else if(metric == "mean_absolute_error" || metric == "mae") | |||||
{ | |||||
metric_obj = keras.metrics.mean_absolute_error; | |||||
} | |||||
else | |||||
throw new NotImplementedException(""); | |||||
throw new NotImplementedException(""); | |||||
return new MeanMetricWrapper(metric_obj, metric); | |||||
} | } | ||||
public IEnumerable<Metric> metrics | public IEnumerable<Metric> metrics | ||||
@@ -40,5 +40,11 @@ namespace Tensorflow.Keras.Metrics | |||||
return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); | return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT); | ||||
} | } | ||||
public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) | |||||
{ | |||||
y_true = math_ops.cast(y_true, y_pred.dtype); | |||||
return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1); | |||||
} | |||||
} | } | ||||
} | } |