Browse Source

Fix input dtype for MapDataset. #666

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
d8afa8cefb
2 changed files with 9 additions and 1 deletions
  1. +7
    -1
      src/TensorFlowNET.Core/Data/MapDataset.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs

+ 7
- 1
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using Tensorflow.Functions; using Tensorflow.Functions;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
@@ -14,7 +15,12 @@ namespace Tensorflow
bool preserve_cardinality = false, bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset) bool use_legacy_function = false) : base(input_dataset)
{ {
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input");
var output = map_func(input);
func.ToGraph(input, output);

structure = func.OutputStructure;


variant_tensor = ops.map_dataset(input_dataset.variant_tensor, variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
func, func,


+ 2
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -109,6 +109,8 @@ namespace Tensorflow.Functions
inputs, inputs,
outputs, outputs,
null); null);

OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
} }


public Tensors Invoke(Tensors inputs) public Tensors Invoke(Tensors inputs)


Loading…
Cancel
Save