|
|
@@ -179,9 +179,20 @@ namespace Tensorflow.Keras.Engine |
|
|
|
StepsPerExecution = _steps_per_execution |
|
|
|
}); |
|
|
|
|
|
|
|
Func<DataHandler, OwnedIterator, Dictionary<string, float>> trainStepFunction; |
|
|
|
|
|
|
|
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || |
|
|
|
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) |
|
|
|
{ |
|
|
|
trainStepFunction = train_step_multi_inputs_function; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
trainStepFunction = train_step_function; |
|
|
|
} |
|
|
|
|
|
|
|
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, |
|
|
|
train_step_func: train_step_function); |
|
|
|
train_step_func: trainStepFunction); |
|
|
|
} |
|
|
|
|
|
|
|
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data, |
|
|
|