using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine.DataAdapters
{
///
/// Handles iterating over epoch-level `tf.data.Iterator` objects.
///
public class DataHandler
{
DataHandlerArgs args;
IDataAdapter _adapter;
public IDataAdapter DataAdapter => _adapter;
IDatasetV2 _dataset;
int _inferred_steps;
int _current_step;
int _step_increment;
bool _insufficient_data;
int _steps_per_execution_value;
int _initial_epoch => args.InitialEpoch;
int _epochs => args.Epochs;
IVariableV1 _steps_per_execution;
public DataHandler(DataHandlerArgs args)
{
this.args = args;
if (args.StepsPerExecution == null)
{
_steps_per_execution = tf.Variable(1);
_steps_per_execution_value = 1;
}
else
{
_steps_per_execution = args.StepsPerExecution;
_steps_per_execution_value = args.StepsPerExecution.numpy();
}
if(args.Dataset == null)
{
_adapter = new TensorLikeDataAdapter(new DataAdapterArgs
{
X = args.X,
Y = args.Y,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}
else
{
_adapter = new DatasetAdapter(new DataAdapterArgs
{
Dataset = args.Dataset,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
_step_increment = args.StepsPerExecution.numpy() - 1;
_insufficient_data = false;
}
int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
{
if (steps_per_epoch > -1)
return steps_per_epoch;
var adapter_steps = _adapter.GetSize();
if (adapter_steps > -1)
return adapter_steps;
var size = dataset.dataset_cardinality();
return size.numpy();
}
public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
{
using var ownedIterator = new OwnedIterator(_dataset);
foreach (var epoch in range(_initial_epoch, _epochs))
{
if (_insufficient_data)
break;
yield return (epoch, ownedIterator);
}
}
public IEnumerable steps()
{
_current_step = 0;
while (_current_step < _inferred_steps)
{
if (_insufficient_data)
break;
bool can_run_full_execution = _steps_per_execution_value == 1
|| _inferred_steps < 0
|| _inferred_steps - _current_step >= _steps_per_execution_value;
if (can_run_full_execution)
{
_step_increment = _steps_per_execution_value - 1;
yield return _current_step;
_current_step += _steps_per_execution_value;
}
else
{
var steps_remaining = _inferred_steps - _current_step;
_steps_per_execution.assign(steps_remaining);
_step_increment = steps_remaining - 1;
yield return _current_step;
_current_step += steps_remaining;
_steps_per_execution.assign(_steps_per_execution_value);
}
}
}
}
}