|
|
@@ -1,6 +1,7 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using Tensorflow.Keras.Losses; |
|
|
|
using Tensorflow.Keras.Metrics; |
|
|
|
using static Tensorflow.KerasApi; |
|
|
|
|
|
|
@@ -74,11 +75,15 @@ namespace Tensorflow.Keras.Engine |
|
|
|
metric_obj = keras.metrics.sparse_categorical_accuracy; |
|
|
|
else |
|
|
|
metric_obj = keras.metrics.categorical_accuracy; |
|
|
|
|
|
|
|
return new MeanMetricWrapper(metric_obj, metric); |
|
|
|
} |
|
|
|
else if(metric == "mean_absolute_error" || metric == "mae") |
|
|
|
{ |
|
|
|
metric_obj = keras.metrics.mean_absolute_error; |
|
|
|
} |
|
|
|
else |
|
|
|
throw new NotImplementedException(""); |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
return new MeanMetricWrapper(metric_obj, metric); |
|
|
|
} |
|
|
|
|
|
|
|
public IEnumerable<Metric> metrics |
|
|
|