Browse Source

Add concatenate in dataset.

tags/yolov3
Oceania2018 4 years ago
parent
commit
ceccf40890
4 changed files with 52 additions and 1 deletions
  1. +35
    -0
      src/TensorFlowNET.Core/Data/ConcatenateDataset.cs
  2. +4
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +7
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  4. +6
    -1
      src/TensorFlowNET.Core/Data/OwnedIterator.cs

+ 35
- 0
src/TensorFlowNET.Core/Data/ConcatenateDataset.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

namespace Tensorflow.Data
{
/// <summary>
/// A `Dataset` that concatenates its input with given dataset.
/// </summary>
public class ConcatenateDataset : DatasetV2
{
IDatasetV2 _input_dataset;
IDatasetV2 _dataset_to_concatenate;
public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate)
{
_input_dataset = input_dataset;
_dataset_to_concatenate = dataset_to_concatenate;
var _structure = new List<TensorSpec>();
foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec))
{
var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape);
_structure.Add(new TensorSpec(shape, dtype: spec.dtype));
}
structure = _structure.ToArray();

variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor,
dataset_to_concatenate.variant_tensor,
output_types, output_shapes);
}
}
}

+ 4
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -2,6 +2,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Data;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

@@ -26,6 +27,9 @@ namespace Tensorflow
public IDatasetV2 cache(string filename = "")
=> new CacheDataset(this, filename: filename);

public IDatasetV2 concatenate(IDatasetV2 dataset)
=> new ConcatenateDataset(this, dataset);

public IDatasetV2 take(int count = -1)
=> new TakeDataset(this, count: count);



+ 7
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -23,6 +23,13 @@ namespace Tensorflow
/// <returns></returns>
IDatasetV2 cache(string filename = "");

/// <summary>
/// Creates a `Dataset` by concatenating the given dataset with this dataset.
/// </summary>
/// <param name="dataset"></param>
/// <returns></returns>
IDatasetV2 concatenate(IDatasetV2 dataset);

/// <summary>
///
/// </summary>


+ 6
- 1
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -1,5 +1,7 @@
using System;
using System.Linq;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -36,7 +38,10 @@ namespace Tensorflow
{
try
{
return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
foreach(var (i, tensor) in enumerate(results))
tensor.set_shape(_element_spec[i].shape);
return results;
}
catch (OutOfRangeError ex)
{


Loading…
Cancel
Save