@@ -48,7 +48,7 @@ namespace Tensorflow | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
Logger = new LoggerConfiguration() | Logger = new LoggerConfiguration() | ||||
.MinimumLevel.Debug() | |||||
.MinimumLevel.Error() | |||||
.WriteTo.Console() | .WriteTo.Console() | ||||
.CreateLogger(); | .CreateLogger(); | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
num_samples = args.X.shape[0]; | num_samples = args.X.shape[0]; | ||||
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | ||||
_batch_size = batch_size; | _batch_size = batch_size; | ||||
_size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size; | |||||
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); | |||||
num_full_batches = num_samples / batch_size; | num_full_batches = num_samples / batch_size; | ||||
_partial_batch_size = num_samples % batch_size; | _partial_batch_size = num_samples % batch_size; | ||||
@@ -89,7 +89,7 @@ namespace Tensorflow.Keras.Engine | |||||
_train_counter.assign(0); | _train_counter.assign(0); | ||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
{ | { | ||||
// reset_metrics(); | |||||
reset_metrics(); | |||||
// callbacks.on_epoch_begin(epoch) | // callbacks.on_epoch_begin(epoch) | ||||
// data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
@@ -10,6 +10,7 @@ namespace Tensorflow.Keras.Engine | |||||
get | get | ||||
{ | { | ||||
var _metrics = new List<Metric>(); | var _metrics = new List<Metric>(); | ||||
if (_is_compiled) | if (_is_compiled) | ||||
{ | { | ||||
if (compiled_loss != null) | if (compiled_loss != null) | ||||
@@ -18,13 +19,17 @@ namespace Tensorflow.Keras.Engine | |||||
_metrics.add(compiled_metrics.metrics); | _metrics.add(compiled_metrics.metrics); | ||||
} | } | ||||
foreach (var layer in _flatten_layers()) | |||||
{ | |||||
// _metrics.extend(layer.metrics); | |||||
} | |||||
/*foreach (var layer in _flatten_layers()) | |||||
_metrics.extend(layer.metrics);*/ | |||||
return _metrics; | return _metrics; | ||||
} | } | ||||
} | } | ||||
void reset_metrics() | |||||
{ | |||||
foreach (var metric in metrics) | |||||
metric.reset_states(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,6 +2,7 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
{ | { | ||||
@@ -53,6 +54,12 @@ namespace Tensorflow.Keras.Metrics | |||||
public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | ||||
=> throw new NotImplementedException(""); | => throw new NotImplementedException(""); | ||||
public virtual void reset_states() | |||||
{ | |||||
foreach (var v in weights) | |||||
v.assign(0); | |||||
} | |||||
public virtual Tensor result() | public virtual Tensor result() | ||||
=> throw new NotImplementedException(""); | => throw new NotImplementedException(""); | ||||