Browse Source

reset_metrics for every epoch.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
f0030ca9bb
5 changed files with 19 additions and 7 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/tensorflow.cs
  2. +1
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  3. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  4. +9
    -4
      src/TensorFlowNET.Keras/Engine/Model.Metrics.cs
  5. +7
    -0
      src/TensorFlowNET.Keras/Metrics/Metric.cs

+ 1
- 1
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow
public tensorflow()
{
Logger = new LoggerConfiguration()
.MinimumLevel.Debug()
.MinimumLevel.Error()
.WriteTo.Console()
.CreateLogger();



+ 1
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
num_samples = args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));
num_full_batches = num_samples / batch_size;
_partial_batch_size = num_samples % batch_size;



+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -89,7 +89,7 @@ namespace Tensorflow.Keras.Engine
_train_counter.assign(0);
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
foreach (var step in data_handler.steps())


+ 9
- 4
src/TensorFlowNET.Keras/Engine/Model.Metrics.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow.Keras.Engine
get
{
var _metrics = new List<Metric>();

if (_is_compiled)
{
if (compiled_loss != null)
@@ -18,13 +19,17 @@ namespace Tensorflow.Keras.Engine
_metrics.add(compiled_metrics.metrics);
}

foreach (var layer in _flatten_layers())
{
// _metrics.extend(layer.metrics);
}
/*foreach (var layer in _flatten_layers())
_metrics.extend(layer.metrics);*/

return _metrics;
}
}

void reset_metrics()
{
foreach (var metric in metrics)
metric.reset_states();
}
}
}

+ 7
- 0
src/TensorFlowNET.Keras/Metrics/Metric.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Metrics
{
@@ -53,6 +54,12 @@ namespace Tensorflow.Keras.Metrics
public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
=> throw new NotImplementedException("");

public virtual void reset_states()
{
foreach (var v in weights)
v.assign(0);
}

public virtual Tensor result()
=> throw new NotImplementedException("");



Loading…
Cancel
Save