Browse Source

Fix timeseries_dataset_from_array.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
a7df67b13c
3 changed files with 16 additions and 8 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Data/DatasetManager.cs
  2. +2
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  3. +12
    -4
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs

+ 2
- 2
src/TensorFlowNET.Core/Data/DatasetManager.cs View File

@@ -14,10 +14,10 @@ namespace Tensorflow
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public IDatasetV2 from_tensor(NDArray tensors)
public IDatasetV2 from_tensors(NDArray tensors)
=> new TensorDataset(tensors);

public IDatasetV2 from_tensor(Tensors tensors)
public IDatasetV2 from_tensors(Tensors tensors)
=> new TensorDataset(tensors);

public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)


+ 2
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -63,7 +63,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
var array = array_ops.slice(indices,
new[] { constant_op.constant(num_in_full_batch)},
new[] { constant_op.constant(_partial_batch_size)});
var index_remainder = tf.data.Dataset.from_tensor(array);
var index_remainder = tf.data.Dataset.from_tensors(array);
flat_dataset = flat_dataset.concatenate(index_remainder);
}
@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters

IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
{
var dataset = tf.data.Dataset.from_tensor(elements).repeat();
var dataset = tf.data.Dataset.from_tensors(elements).repeat();
dataset = tf.data.Dataset.zip(indices_dataset, dataset);

dataset = dataset.map(inputs =>


+ 12
- 4
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -123,19 +123,27 @@ namespace Tensorflow.Keras
var sampling_rate_tensor = math_ops.cast(sampling_rate, dtype: index_dtype);

var start_positions_tensor = tf.constant(start_positions);
var positions_ds = tf.data.Dataset.from_tensor(start_positions_tensor).repeat();
var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat();
var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds);
var indices = z.map(m =>
{
var (i, positions) = (m[0], m[1]);
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
}, num_parallel_calls: -1);
return null;
var dataset = sequences_from_indices(data, indices, start_index, end_index);

if (shuffle)
dataset = dataset.shuffle(buffer_size: batch_size * 8, seed: seed);
dataset = dataset.batch(batch_size);
return dataset;
}

IDatasetV2 sequences_from_indices(Tensor array, Tensor indices_ds, Tensor start_index, Tensor end_index)
IDatasetV2 sequences_from_indices(Tensor array, IDatasetV2 indices_ds, int start_index, int? end_index)
{
return null;
var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]);
dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds)
.map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1);
return dataset;
}
}
}

Loading…
Cancel
Save