Browse Source

Change TensorDataset construct.

tags/yolov3
Oceania2018 4 years ago
parent
commit
2cc629532b
4 changed files with 15 additions and 19 deletions
  1. +1
    -4
      src/TensorFlowNET.Core/Data/DatasetManager.cs
  2. +2
    -9
      src/TensorFlowNET.Core/Data/TensorDataset.cs
  3. +9
    -4
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  4. +3
    -2
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs

+ 1
- 4
src/TensorFlowNET.Core/Data/DatasetManager.cs View File

@@ -17,10 +17,7 @@ namespace Tensorflow
public IDatasetV2 from_tensor(NDArray tensors)
=> new TensorDataset(tensors);

public IDatasetV2 from_tensor(Tensor features, Tensor labels)
=> new TensorDataset(features, labels);

public IDatasetV2 from_tensor(Tensor tensors)
public IDatasetV2 from_tensor(Tensors tensors)
=> new TensorDataset(tensors);

public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)


+ 2
- 9
src/TensorFlowNET.Core/Data/TensorDataset.cs View File

@@ -9,16 +9,9 @@ namespace Tensorflow
/// </summary>
public class TensorDataset : DatasetSource
{
public TensorDataset(Tensor feature, Tensor label)
public TensorDataset(Tensors elements)
{
_tensors = new[] { feature, label };
structure = _tensors.Select(x => x.ToTensorSpec()).ToArray();

variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
}
public TensorDataset(Tensor element)
{
_tensors = new[] { element };
_tensors = elements;
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();



+ 9
- 4
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -26,10 +26,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters
var _partial_batch_size = num_samples % batch_size;

var indices_dataset = tf.data.Dataset.range(1);
indices_dataset = indices_dataset.repeat();
indices_dataset = indices_dataset.repeat(args.Epochs);
indices_dataset = indices_dataset.map(permutation).prefetch(1);
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
dataset = slice_inputs(indices_dataset, args.X, args.Y);
var elements = new Tensors();
if (args.X != null)
elements.Add(args.X);
if (args.Y != null)
elements.Add(args.Y);
dataset = slice_inputs(indices_dataset, elements);
}

Tensor permutation(Tensor tensor)
@@ -54,9 +59,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters
return flat_dataset;
}

IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensor x, Tensor y)
IDatasetV2 slice_inputs(IDatasetV2 indices_dataset, Tensors elements)
{
var dataset2 = tf.data.Dataset.from_tensor(x, y).repeat();
var dataset2 = tf.data.Dataset.from_tensor(elements).repeat();
var dataset = tf.data.Dataset.zip(indices_dataset, dataset2);

dataset = dataset.map((batch, data) =>


+ 3
- 2
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;

@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="use_multiprocessing"></param>
/// <returns></returns>
public Tensor predict(Tensor x,
int batch_size = 32,
int batch_size = -1,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,


Loading…
Cancel
Save