|
|
@@ -142,6 +142,7 @@ namespace Tensorflow.Keras.Engine |
|
|
|
int verbose = 1, |
|
|
|
List<ICallback> callbacks = null, |
|
|
|
IDatasetV2 validation_data = null, |
|
|
|
int validation_step = 10, // 间隔多少次会进行一次验证 |
|
|
|
bool shuffle = true, |
|
|
|
int initial_epoch = 0, |
|
|
|
int max_queue_size = 10, |
|
|
@@ -164,11 +165,11 @@ namespace Tensorflow.Keras.Engine |
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, |
|
|
|
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, |
|
|
|
train_step_func: train_step_function); |
|
|
|
} |
|
|
|
|
|
|
|
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, |
|
|
|
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, |
|
|
|
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) |
|
|
|
{ |
|
|
|
stop_training = false; |
|
|
@@ -207,6 +208,9 @@ namespace Tensorflow.Keras.Engine |
|
|
|
|
|
|
|
if (validation_data != null) |
|
|
|
{ |
|
|
|
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) |
|
|
|
continue; |
|
|
|
|
|
|
|
var val_logs = evaluate(validation_data); |
|
|
|
foreach(var log in val_logs) |
|
|
|
{ |
|
|
|