diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs
index 02fe38fd..be36287b 100644
--- a/src/TensorFlowNET.Core/Data/DatasetManager.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs
@@ -14,10 +14,10 @@ namespace Tensorflow
///
///
///
- public IDatasetV2 from_tensor(NDArray tensors)
+ public IDatasetV2 from_tensors(NDArray tensors)
=> new TensorDataset(tensors);
- public IDatasetV2 from_tensor(Tensors tensors)
+ public IDatasetV2 from_tensors(Tensors tensors)
=> new TensorDataset(tensors);
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index 6633ce19..d73dc8b1 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -63,7 +63,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
var array = array_ops.slice(indices,
new[] { constant_op.constant(num_in_full_batch)},
new[] { constant_op.constant(_partial_batch_size)});
- var index_remainder = tf.data.Dataset.from_tensor(array);
+ var index_remainder = tf.data.Dataset.from_tensors(array);
flat_dataset = flat_dataset.concatenate(index_remainder);
}
@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
{
- var dataset = tf.data.Dataset.from_tensor(elements).repeat();
+ var dataset = tf.data.Dataset.from_tensors(elements).repeat();
dataset = tf.data.Dataset.zip(indices_dataset, dataset);
dataset = dataset.map(inputs =>
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
index 1fbcb3bd..f820da9a 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
@@ -123,19 +123,27 @@ namespace Tensorflow.Keras
var sampling_rate_tensor = math_ops.cast(sampling_rate, dtype: index_dtype);
var start_positions_tensor = tf.constant(start_positions);
- var positions_ds = tf.data.Dataset.from_tensor(start_positions_tensor).repeat();
+ var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat();
var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds);
var indices = z.map(m =>
{
var (i, positions) = (m[0], m[1]);
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
}, num_parallel_calls: -1);
- return null;
+ var dataset = sequences_from_indices(data, indices, start_index, end_index);
+
+ if (shuffle)
+ dataset = dataset.shuffle(buffer_size: batch_size * 8, seed: seed);
+ dataset = dataset.batch(batch_size);
+ return dataset;
}
- IDatasetV2 sequences_from_indices(Tensor array, Tensor indices_ds, Tensor start_index, Tensor end_index)
+ IDatasetV2 sequences_from_indices(Tensor array, IDatasetV2 indices_ds, int start_index, int? end_index)
{
- return null;
+ var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]);
+ dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds)
+ .map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1);
+ return dataset;
}
}
}