@@ -17,10 +17,7 @@ namespace Tensorflow | |||||
public IDatasetV2 from_tensor(NDArray tensors) | public IDatasetV2 from_tensor(NDArray tensors) | ||||
=> new TensorDataset(tensors); | => new TensorDataset(tensors); | ||||
public IDatasetV2 from_tensor(Tensor features, Tensor labels) | |||||
=> new TensorDataset(features, labels); | |||||
public IDatasetV2 from_tensor(Tensor tensors) | |||||
public IDatasetV2 from_tensor(Tensors tensors) | |||||
=> new TensorDataset(tensors); | => new TensorDataset(tensors); | ||||
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | ||||
@@ -9,16 +9,9 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public class TensorDataset : DatasetSource | public class TensorDataset : DatasetSource | ||||
{ | { | ||||
public TensorDataset(Tensor feature, Tensor label) | |||||
public TensorDataset(Tensors elements) | |||||
{ | { | ||||
_tensors = new[] { feature, label }; | |||||
structure = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | |||||
} | |||||
public TensorDataset(Tensor element) | |||||
{ | |||||
_tensors = new[] { element }; | |||||
_tensors = elements; | |||||
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | ||||
structure = batched_spec.Select(x => x._unbatch()).ToArray(); | structure = batched_spec.Select(x => x._unbatch()).ToArray(); | ||||
@@ -26,10 +26,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
var _partial_batch_size = num_samples % batch_size; | var _partial_batch_size = num_samples % batch_size; | ||||
var indices_dataset = tf.data.Dataset.range(1); | var indices_dataset = tf.data.Dataset.range(1); | ||||
indices_dataset = indices_dataset.repeat(); | |||||
indices_dataset = indices_dataset.repeat(args.Epochs); | |||||
indices_dataset = indices_dataset.map(permutation).prefetch(1); | indices_dataset = indices_dataset.map(permutation).prefetch(1); | ||||
indices_dataset = indices_dataset.flat_map(slice_batch_indices); | indices_dataset = indices_dataset.flat_map(slice_batch_indices); | ||||
dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||||
var elements = new Tensors(); | |||||
if (args.X != null) | |||||
elements.Add(args.X); | |||||
if (args.Y != null) | |||||
elements.Add(args.Y); | |||||
dataset = slice_inputs(indices_dataset, elements); | |||||
} | } | ||||
Tensor permutation(Tensor tensor) | Tensor permutation(Tensor tensor) | ||||
@@ -54,9 +59,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
return flat_dataset; | return flat_dataset; | ||||
} | } | ||||
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y) | |||||
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements) | |||||
{ | { | ||||
var dataset2 = tf.data.Dataset.from_tensor(x, y).repeat(); | |||||
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); | |||||
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); | var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); | ||||
dataset = dataset.map((batch, data) => | dataset = dataset.map((batch, data) => | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp; | |||||
using System; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="use_multiprocessing"></param> | /// <param name="use_multiprocessing"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor predict(Tensor x, | public Tensor predict(Tensor x, | ||||
int batch_size = 32, | |||||
int batch_size = -1, | |||||
int verbose = 0, | int verbose = 0, | ||||
int steps = -1, | int steps = -1, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||