|
|
@@ -26,10 +26,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
var _partial_batch_size = num_samples % batch_size; |
|
|
|
|
|
|
|
var indices_dataset = tf.data.Dataset.range(1); |
|
|
|
indices_dataset = indices_dataset.repeat(); |
|
|
|
indices_dataset = indices_dataset.repeat(args.Epochs); |
|
|
|
indices_dataset = indices_dataset.map(permutation).prefetch(1); |
|
|
|
indices_dataset = indices_dataset.flat_map(slice_batch_indices); |
|
|
|
dataset = slice_inputs(indices_dataset, args.X, args.Y); |
|
|
|
var elements = new Tensors(); |
|
|
|
if (args.X != null) |
|
|
|
elements.Add(args.X); |
|
|
|
if (args.Y != null) |
|
|
|
elements.Add(args.Y); |
|
|
|
dataset = slice_inputs(indices_dataset, elements); |
|
|
|
} |
|
|
|
|
|
|
|
Tensor permutation(Tensor tensor) |
|
|
@@ -54,9 +59,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters |
|
|
|
return flat_dataset; |
|
|
|
} |
|
|
|
|
|
|
|
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) |
|
|
|
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements) |
|
|
|
{ |
|
|
|
var dataset2 = tf.data.Dataset.from_tensor(x, y).repeat(); |
|
|
|
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); |
|
|
|
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); |
|
|
|
|
|
|
|
dataset = dataset.map((batch, data) => |
|
|
|