|
|
@@ -123,19 +123,27 @@ namespace Tensorflow.Keras |
|
|
|
var sampling_rate_tensor = math_ops.cast(sampling_rate, dtype: index_dtype); |
|
|
|
|
|
|
|
var start_positions_tensor = tf.constant(start_positions); |
|
|
|
var positions_ds = tf.data.Dataset.from_tensor(start_positions_tensor).repeat(); |
|
|
|
var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); |
|
|
|
var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds); |
|
|
|
var indices = z.map(m => |
|
|
|
{ |
|
|
|
var (i, positions) = (m[0], m[1]); |
|
|
|
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); |
|
|
|
}, num_parallel_calls: -1); |
|
|
|
return null; |
|
|
|
var dataset = sequences_from_indices(data, indices, start_index, end_index); |
|
|
|
|
|
|
|
if (shuffle) |
|
|
|
dataset = dataset.shuffle(buffer_size: batch_size * 8, seed: seed); |
|
|
|
dataset = dataset.batch(batch_size); |
|
|
|
return dataset; |
|
|
|
} |
|
|
|
|
|
|
|
IDatasetV2 sequences_from_indices(Tensor array, Tensor indices_ds, Tensor start_index, Tensor end_index) |
|
|
|
IDatasetV2 sequences_from_indices(Tensor array, IDatasetV2 indices_ds, int start_index, int? end_index) |
|
|
|
{ |
|
|
|
return null; |
|
|
|
var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); |
|
|
|
dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) |
|
|
|
.map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1); |
|
|
|
return dataset; |
|
|
|
} |
|
|
|
} |
|
|
|
} |