diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 00c55fae..60b22f71 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -48,7 +48,7 @@ namespace Tensorflow public tensorflow() { Logger = new LoggerConfiguration() - .MinimumLevel.Debug() + .MinimumLevel.Error() .WriteTo.Console() .CreateLogger(); diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 98fd4741..6633ce19 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters num_samples = args.X.shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; - _size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size; + _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); num_full_batches = num_samples / batch_size; _partial_batch_size = num_samples % batch_size; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index ad58efa1..939cd1c9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -89,7 +89,7 @@ namespace Tensorflow.Keras.Engine _train_counter.assign(0); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { - // reset_metrics(); + reset_metrics(); // callbacks.on_epoch_begin(epoch) // data_handler.catch_stop_iteration(); foreach (var step in data_handler.steps()) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs index 821cf781..214b9934 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Metrics.cs @@ -10,6 +10,7 @@ namespace Tensorflow.Keras.Engine get { var _metrics = new List(); + if (_is_compiled) { if (compiled_loss != null) @@ -18,13 +19,17 @@ namespace Tensorflow.Keras.Engine _metrics.add(compiled_metrics.metrics); } - foreach (var layer in _flatten_layers()) - { - // _metrics.extend(layer.metrics); - } + /*foreach (var layer in _flatten_layers()) + _metrics.extend(layer.metrics);*/ return _metrics; } } + + void reset_metrics() + { + foreach (var metric in metrics) + metric.reset_states(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs index 9cbaaeb7..2a34ef53 100644 --- a/src/TensorFlowNET.Keras/Metrics/Metric.cs +++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs @@ -2,6 +2,7 @@ using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Metrics { @@ -53,6 +54,12 @@ namespace Tensorflow.Keras.Metrics public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) => throw new NotImplementedException(""); + public virtual void reset_states() + { + foreach (var v in weights) + v.assign(0); + } + public virtual Tensor result() => throw new NotImplementedException("");