diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs index 02fe38fd..be36287b 100644 --- a/src/TensorFlowNET.Core/Data/DatasetManager.cs +++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs @@ -14,10 +14,10 @@ namespace Tensorflow /// /// /// - public IDatasetV2 from_tensor(NDArray tensors) + public IDatasetV2 from_tensors(NDArray tensors) => new TensorDataset(tensors); - public IDatasetV2 from_tensor(Tensors tensors) + public IDatasetV2 from_tensors(Tensors tensors) => new TensorDataset(tensors); public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 6633ce19..d73dc8b1 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -63,7 +63,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters var array = array_ops.slice(indices, new[] { constant_op.constant(num_in_full_batch)}, new[] { constant_op.constant(_partial_batch_size)}); - var index_remainder = tf.data.Dataset.from_tensor(array); + var index_remainder = tf.data.Dataset.from_tensors(array); flat_dataset = flat_dataset.concatenate(index_remainder); } @@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements) { - var dataset = tf.data.Dataset.from_tensor(elements).repeat(); + var dataset = tf.data.Dataset.from_tensors(elements).repeat(); dataset = tf.data.Dataset.zip(indices_dataset, dataset); dataset = dataset.map(inputs => diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index 1fbcb3bd..f820da9a 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -123,19 +123,27 @@ namespace Tensorflow.Keras var sampling_rate_tensor = math_ops.cast(sampling_rate, dtype: index_dtype); var start_positions_tensor = tf.constant(start_positions); - var positions_ds = tf.data.Dataset.from_tensor(start_positions_tensor).repeat(); + var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds); var indices = z.map(m => { var (i, positions) = (m[0], m[1]); return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); }, num_parallel_calls: -1); - return null; + var dataset = sequences_from_indices(data, indices, start_index, end_index); + + if (shuffle) + dataset = dataset.shuffle(buffer_size: batch_size * 8, seed: seed); + dataset = dataset.batch(batch_size); + return dataset; } - IDatasetV2 sequences_from_indices(Tensor array, Tensor indices_ds, Tensor start_index, Tensor end_index) + IDatasetV2 sequences_from_indices(Tensor array, IDatasetV2 indices_ds, int start_index, int? end_index) { - return null; + var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); + dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) + .map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1); + return dataset; } } }