@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Framework; | |||||
using Tensorflow.Framework.Models; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Data | |||||
{ | |||||
/// <summary> | |||||
/// A `Dataset` that concatenates its input with given dataset. | |||||
/// </summary> | |||||
public class ConcatenateDataset : DatasetV2 | |||||
{ | |||||
IDatasetV2 _input_dataset; | |||||
IDatasetV2 _dataset_to_concatenate; | |||||
public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate) | |||||
{ | |||||
_input_dataset = input_dataset; | |||||
_dataset_to_concatenate = dataset_to_concatenate; | |||||
var _structure = new List<TensorSpec>(); | |||||
foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec)) | |||||
{ | |||||
var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape); | |||||
_structure.Add(new TensorSpec(shape, dtype: spec.dtype)); | |||||
} | |||||
structure = _structure.ToArray(); | |||||
variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor, | |||||
dataset_to_concatenate.variant_tensor, | |||||
output_types, output_shapes); | |||||
} | |||||
} | |||||
} |
@@ -2,6 +2,7 @@ | |||||
using System.Collections; | using System.Collections; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Data; | |||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -26,6 +27,9 @@ namespace Tensorflow | |||||
public IDatasetV2 cache(string filename = "") | public IDatasetV2 cache(string filename = "") | ||||
=> new CacheDataset(this, filename: filename); | => new CacheDataset(this, filename: filename); | ||||
public IDatasetV2 concatenate(IDatasetV2 dataset) | |||||
=> new ConcatenateDataset(this, dataset); | |||||
public IDatasetV2 take(int count = -1) | public IDatasetV2 take(int count = -1) | ||||
=> new TakeDataset(this, count: count); | => new TakeDataset(this, count: count); | ||||
@@ -23,6 +23,13 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
IDatasetV2 cache(string filename = ""); | IDatasetV2 cache(string filename = ""); | ||||
/// <summary> | |||||
/// Creates a `Dataset` by concatenating the given dataset with this dataset. | |||||
/// </summary> | |||||
/// <param name="dataset"></param> | |||||
/// <returns></returns> | |||||
IDatasetV2 concatenate(IDatasetV2 dataset); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
@@ -1,5 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | |||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -36,7 +38,10 @@ namespace Tensorflow | |||||
{ | { | ||||
try | try | ||||
{ | { | ||||
return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); | |||||
var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); | |||||
foreach(var (i, tensor) in enumerate(results)) | |||||
tensor.set_shape(_element_spec[i].shape); | |||||
return results; | |||||
} | } | ||||
catch (OutOfRangeError ex) | catch (OutOfRangeError ex) | ||||
{ | { | ||||