Browse Source

tf.data.Dataset.from_tensor #446

tags/v0.20
Oceania2018 5 years ago
parent
commit
436afe9703
5 changed files with 95 additions and 1 deletions
  1. +16
    -1
      src/TensorFlowNET.Core/Data/DatasetManager.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Data/GeneratorDataset.cs
  3. +33
    -0
      src/TensorFlowNET.Core/Data/TensorDataset.cs
  4. +18
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  5. +17
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 16
- 1
src/TensorFlowNET.Core/Data/DatasetManager.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Data;
@@ -7,6 +8,20 @@ namespace Tensorflow
{
public class DatasetManager
{
public IDatasetV2 from_generator<T>(IEnumerable<T> generator, TF_DataType[] output_types, TensorShape[] output_shapes)
=> new GeneratorDataset();

/// <summary>
/// Creates a `Dataset` with a single element, comprising the given tensors.
/// </summary>
/// <param name="tensors"></param>
/// <returns></returns>
public IDatasetV2 from_tensor(NDArray tensors)
=> new TensorDataset(tensors);

public IDatasetV2 from_tensor(Tensor tensors)
=> new TensorDataset(tensors);

public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
=> new TensorSliceDataset(features, labels);



+ 11
- 0
src/TensorFlowNET.Core/Data/GeneratorDataset.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Data
{
public class GeneratorDataset : DatasetSource
{

}
}

+ 33
- 0
src/TensorFlowNET.Core/Data/TensorDataset.cs View File

@@ -0,0 +1,33 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` with a single element.
/// </summary>
public class TensorDataset : DatasetSource
{
public TensorDataset(Tensor element)
{
_tensors = new[] { element };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();

variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
}

public TensorDataset(NDArray element)
{
_tensors = new[] { tf.convert_to_tensor(element) };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.ToArray();

variant_tensor = ops.tensor_dataset(_tensors, output_shapes);
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -8,6 +8,24 @@ namespace Tensorflow
{
public class dataset_ops
{
public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"TensorDataset", name,
null,
new object[]
{
components,
"output_shapes", output_shapes
});
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Creates a dataset that emits each dim-0 slice of `components` once.
/// </summary>


+ 17
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -1,7 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

@@ -62,6 +64,21 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(5, n);
}

[TestMethod]
public void FromTensor()
{
var X = new[] { 2013, 2014, 2015, 2016, 2017 };

var dataset = tf.data.Dataset.from_tensor(X);
int n = 0;
foreach (var x in dataset)
{
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
n += 1;
}
Assert.AreEqual(1, n);
}

[TestMethod]
public void Shard()
{


Loading…
Cancel
Save