diff --git a/src/TensorFlowNET.Keras/Metrics/BinaryCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/BinaryCrossentropy.cs index 5d26058f..14ef73b9 100644 --- a/src/TensorFlowNET.Keras/Metrics/BinaryCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Metrics/BinaryCrossentropy.cs @@ -4,7 +4,16 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class BinaryCrossentropy + public class BinaryCrossentropy : MeanMetricWrapper { + public BinaryCrossentropy(string name = "binary_crossentropy", string dtype = null, bool from_logits = false, float label_smoothing = 0) + : base(Fn, name, dtype) + { + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Losses.Loss.binary_crossentropy(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs index 977d5368..c83bb5d5 100644 --- a/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs @@ -4,7 +4,16 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class CategoricalCrossentropy + public class CategoricalCrossentropy : MeanMetricWrapper { + public CategoricalCrossentropy(string name = "categorical_crossentropy", string dtype = null, bool from_logits = false, float label_smoothing = 0) + : base(Fn, name, dtype) + { + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Losses.Loss.categorical_crossentropy(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/KLDivergence.cs b/src/TensorFlowNET.Keras/Metrics/KLDivergence.cs index c6447d1e..814b14ce 100644 --- a/src/TensorFlowNET.Keras/Metrics/KLDivergence.cs +++ b/src/TensorFlowNET.Keras/Metrics/KLDivergence.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class KLDivergence + public class KLDivergence : MeanMetricWrapper { + public KLDivergence(string name = "kullback_leibler_divergence", string dtype = null) + : base(Losses.Loss.logcosh, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanIoU.cs b/src/TensorFlowNET.Keras/Metrics/MeanIoU.cs index eda95a34..d8975218 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanIoU.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanIoU.cs @@ -1,10 +1,34 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanIoU + public class MeanIoU : Metric { + public MeanIoU(int num_classes, string name, string dtype) : base(name, dtype) + { + } + + public override void reset_states() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanTensor.cs b/src/TensorFlowNET.Keras/Metrics/MeanTensor.cs index 9bcab008..283b516f 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanTensor.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanTensor.cs @@ -4,7 +4,44 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanTensor + public class MeanTensor : Metric { + public int total + { + get + { + throw new NotImplementedException(); + } + } + + public int count + { + get + { + throw new NotImplementedException(); + } + } + + public MeanTensor(int num_classes, string name = "mean_tensor", string dtype) : base(name, dtype) + { + } + + + private void _build(TensorShape shape) => throw new NotImplementedException(); + + public override void reset_states() + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs index 7001a11b..b2513fd8 100644 --- a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs @@ -4,7 +4,16 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class SparseCategoricalCrossentropy + public class SparseCategoricalCrossentropy : MeanMetricWrapper { + public SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", string dtype = null, bool from_logits = false, int axis = -1) + : base(Fn, name, dtype) + { + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Losses.Loss.sparse_categorical_crossentropy(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SumOverBatchSize.cs b/src/TensorFlowNET.Keras/Metrics/SumOverBatchSize.cs index 5faa76f8..d25654c5 100644 --- a/src/TensorFlowNET.Keras/Metrics/SumOverBatchSize.cs +++ b/src/TensorFlowNET.Keras/Metrics/SumOverBatchSize.cs @@ -4,7 +4,10 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class SumOverBatchSize + public class SumOverBatchSize : Reduce { + public SumOverBatchSize(string name = "sum_over_batch_size", string dtype = null) : base(Reduction.SUM_OVER_BATCH_SIZE, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SumOverBatchSizeMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/SumOverBatchSizeMetricWrapper.cs index 03fe2668..ff1c0497 100644 --- a/src/TensorFlowNET.Keras/Metrics/SumOverBatchSizeMetricWrapper.cs +++ b/src/TensorFlowNET.Keras/Metrics/SumOverBatchSizeMetricWrapper.cs @@ -1,10 +1,25 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class SumOverBatchSizeMetricWrapper + public class SumOverBatchSizeMetricWrapper : SumOverBatchSize { + public SumOverBatchSizeMetricWrapper(Func fn, string name, string dtype = null) + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } }