using Tensorflow.NumPy; using System; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using System.Diagnostics; using Tensorflow.Keras.Callbacks; using System.Data; namespace Tensorflow.Keras.Engine { public partial class Model { /// /// Trains the model for a fixed number of epochs (iterations on a dataset). /// /// /// /// /// /// /// /// /// /// public ICallback fit(NDArray x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, List callbacks = null, float validation_split = 0f, (NDArray val_x, NDArray val_y)? validation_data = null, bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, bool use_multiprocessing = false) { if (x.dims[0] != y.dims[0]) { throw new InvalidArgumentError( $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); } var train_x = x; var train_y = y; if (validation_split != 0f && validation_data == null) { int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); train_x = x[new Slice(0, train_count)]; train_y = y[new Slice(0, train_count)]; validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); } var data_handler = new DataHandler(new DataHandlerArgs { X = train_x, Y = train_y, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, Model = this, StepsPerExecution = _steps_per_execution }); return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, train_step_func: train_step_function); } public ICallback fit(IEnumerable x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, List callbacks = null, float validation_split = 0f, (IEnumerable val_x, NDArray val_y)? validation_data = null, bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, bool use_multiprocessing = false) { foreach(var tx in x) { if (tx.dims[0] != y.dims[0]) { throw new InvalidArgumentError( $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); } } var train_x = x; var train_y = y; if (validation_split != 0f && validation_data == null) { int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); train_y = y[new Slice(0, train_count)]; var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); var val_y = y[new Slice(train_count)]; validation_data = (val_x, val_y); } var data_handler = new DataHandler(new DataHandlerArgs { X = new Tensors(train_x.ToArray()), Y = train_y, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, Model = this, StepsPerExecution = _steps_per_execution }); if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) { return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, train_step_func: train_step_multi_inputs_function); } else { return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, train_step_func: train_step_function); } } public History fit(IDatasetV2 dataset, int batch_size = -1, int epochs = 1, int verbose = 1, List callbacks = null, IDatasetV2 validation_data = null, int validation_step = 10, // 间隔多少次会进行一次验证 bool shuffle = true, int initial_epoch = 0, int max_queue_size = 10, int workers = 1, bool use_multiprocessing = false) { var data_handler = new DataHandler(new DataHandlerArgs { Dataset = dataset, BatchSize = batch_size, InitialEpoch = initial_epoch, Epochs = epochs, Shuffle = shuffle, MaxQueueSize = max_queue_size, Workers = workers, UseMultiprocessing = use_multiprocessing, Model = this, StepsPerExecution = _steps_per_execution }); return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data, train_step_func: train_step_function); } History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List callbackList, IDatasetV2 validation_data, Func> train_step_func) { stop_training = false; _train_counter.assign(0); var callbacks = new CallbackList(new CallbackParams { Model = this, Verbose = verbose, Epochs = epochs, Steps = data_handler.Inferredsteps }); if (callbackList != null) { foreach(var callback in callbackList) callbacks.callbacks.add(callback); } callbacks.on_train_begin(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); callbacks.on_epoch_begin(epoch); // data_handler.catch_stop_iteration(); var logs = new Dictionary(); long End_step = 0; foreach (var step in data_handler.steps()) { callbacks.on_train_batch_begin(step); logs = train_step_func(data_handler, iterator); var end_step = step + data_handler.StepIncrement; End_step = end_step; callbacks.on_train_batch_end(end_step, logs); } if (validation_data != null) { if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) continue; var val_logs = evaluate(validation_data); foreach(var log in val_logs) { logs["val_" + log.Key] = log.Value; } callbacks.on_train_batch_end(End_step, logs); } callbacks.on_epoch_end(epoch, logs); GC.Collect(); GC.WaitForPendingFinalizers(); if (stop_training) { break; } } return callbacks.History; } History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (NDArray, NDArray)? validation_data, Func> train_step_func) { stop_training = false; _train_counter.assign(0); var callbacks = new CallbackList(new CallbackParams { Model = this, Verbose = verbose, Epochs = epochs, Steps = data_handler.Inferredsteps }); if (callbackList != null) { foreach (var callback in callbackList) callbacks.callbacks.add(callback); } callbacks.on_train_begin(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); callbacks.on_epoch_begin(epoch); // data_handler.catch_stop_iteration(); var logs = new Dictionary(); long End_step = 0; foreach (var step in data_handler.steps()) { callbacks.on_train_batch_begin(step); logs = train_step_func(data_handler, iterator); var end_step = step + data_handler.StepIncrement; End_step = end_step; callbacks.on_train_batch_end(end_step, logs); } if (validation_data != null) { // Because evaluate calls call_test_batch_end, this interferes with our output on the screen // so we need to pass a is_val parameter to stop on_test_batch_end var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); foreach (var log in val_logs) { logs["val_" + log.Key] = log.Value; } // because after evaluate, logs add some new log which we need to print callbacks.on_train_batch_end(End_step, logs); } callbacks.on_epoch_end(epoch, logs); GC.Collect(); GC.WaitForPendingFinalizers(); if (stop_training) { break; } } return callbacks.History; } History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (IEnumerable, NDArray)? validation_data, Func> train_step_func) { stop_training = false; _train_counter.assign(0); var callbacks = new CallbackList(new CallbackParams { Model = this, Verbose = verbose, Epochs = epochs, Steps = data_handler.Inferredsteps }); if (callbackList != null) { foreach (var callback in callbackList) callbacks.callbacks.add(callback); } callbacks.on_train_begin(); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); callbacks.on_epoch_begin(epoch); // data_handler.catch_stop_iteration(); var logs = new Dictionary(); long End_step = 0; foreach (var step in data_handler.steps()) { callbacks.on_train_batch_begin(step); logs = train_step_func(data_handler, iterator); var end_step = step + data_handler.StepIncrement; End_step = end_step; callbacks.on_train_batch_end(end_step, logs); } if (validation_data != null) { var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); foreach (var log in val_logs) { logs["val_" + log.Key] = log.Value; callbacks.on_train_batch_end(End_step, logs); } } callbacks.on_epoch_end(epoch, logs); GC.Collect(); GC.WaitForPendingFinalizers(); if (stop_training) { break; } } return callbacks.History; } } }