@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Data; | using Tensorflow.Data; | ||||
@@ -7,6 +8,20 @@ namespace Tensorflow | |||||
{ | { | ||||
public class DatasetManager | public class DatasetManager | ||||
{ | { | ||||
public IDatasetV2 from_generator<T>(IEnumerable<T> generator, TF_DataType[] output_types, TensorShape[] output_shapes) | |||||
=> new GeneratorDataset(); | |||||
/// <summary> | |||||
/// Creates a `Dataset` with a single element, comprising the given tensors. | |||||
/// </summary> | |||||
/// <param name="tensors"></param> | |||||
/// <returns></returns> | |||||
public IDatasetV2 from_tensor(NDArray tensors) | |||||
=> new TensorDataset(tensors); | |||||
public IDatasetV2 from_tensor(Tensor tensors) | |||||
=> new TensorDataset(tensors); | |||||
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | ||||
=> new TensorSliceDataset(features, labels); | => new TensorSliceDataset(features, labels); | ||||
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Data | |||||
{ | |||||
public class GeneratorDataset : DatasetSource | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,33 @@ | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// A `Dataset` with a single element. | |||||
/// </summary> | |||||
public class TensorDataset : DatasetSource | |||||
{ | |||||
public TensorDataset(Tensor element) | |||||
{ | |||||
_tensors = new[] { element }; | |||||
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | |||||
} | |||||
public TensorDataset(NDArray element) | |||||
{ | |||||
_tensors = new[] { tf.convert_to_tensor(element) }; | |||||
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
structure = batched_spec.ToArray(); | |||||
variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | |||||
} | |||||
} | |||||
} |
@@ -8,6 +8,24 @@ namespace Tensorflow | |||||
{ | { | ||||
public class dataset_ops | public class dataset_ops | ||||
{ | { | ||||
public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"TensorDataset", name, | |||||
null, | |||||
new object[] | |||||
{ | |||||
components, | |||||
"output_shapes", output_shapes | |||||
}); | |||||
return results[0]; | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Creates a dataset that emits each dim-0 slice of `components` once. | /// Creates a dataset that emits each dim-0 slice of `components` once. | ||||
/// </summary> | /// </summary> | ||||
@@ -1,7 +1,9 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras; | |||||
using Tensorflow.UnitTest; | using Tensorflow.UnitTest; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -62,6 +64,21 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
Assert.AreEqual(5, n); | Assert.AreEqual(5, n); | ||||
} | } | ||||
[TestMethod] | |||||
public void FromTensor() | |||||
{ | |||||
var X = new[] { 2013, 2014, 2015, 2016, 2017 }; | |||||
var dataset = tf.data.Dataset.from_tensor(X); | |||||
int n = 0; | |||||
foreach (var x in dataset) | |||||
{ | |||||
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>())); | |||||
n += 1; | |||||
} | |||||
Assert.AreEqual(1, n); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Shard() | public void Shard() | ||||
{ | { | ||||