Browse Source

Fix input dtype for MapDataset. #666

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
d8afa8cefb
2 changed files with 9 additions and 1 deletions
  1. +7
    -1
      src/TensorFlowNET.Core/Data/MapDataset.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

+ 7
- 1
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -1,5 +1,6 @@
using System;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -14,7 +15,12 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input");
var output = map_func(input);
func.ToGraph(input, output);

structure = func.OutputStructure;

variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
func,


+ 2
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -109,6 +109,8 @@ namespace Tensorflow.Functions
inputs,
outputs,
null);

OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
}

public Tensors Invoke(Tensors inputs)


Loading…
Cancel
Save