Browse Source

fix: Added support for training a multi-input model using a dataset.

pull/1260/head
Aleksej Solomatin 1 year ago
parent
commit
93dda17944
2 changed files with 25 additions and 2 deletions
  1. +13
    -1
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  2. +12
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 13
- 1
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -112,7 +112,19 @@ namespace Tensorflow.Keras.Engine
Steps = data_handler.Inferredsteps
});

return evaluate(data_handler, callbacks, is_val, test_function);
Func<DataHandler, OwnedIterator, Dictionary<string, float>> testFunction;

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
testFunction = test_step_multi_inputs_function;
}
else
{
testFunction = test_function;
}

return evaluate(data_handler, callbacks, is_val, testFunction);
}

/// <summary>


+ 12
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -179,9 +179,20 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

Func<DataHandler, OwnedIterator, Dictionary<string, float>> trainStepFunction;

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
trainStepFunction = train_step_multi_inputs_function;
}
else
{
trainStepFunction = train_step_function;
}

return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
train_step_func: trainStepFunction);
}

History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,


Loading…
Cancel
Save