|
@@ -1,5 +1,6 @@ |
|
|
using System; |
|
|
using System; |
|
|
using Tensorflow.Functions; |
|
|
using Tensorflow.Functions; |
|
|
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
namespace Tensorflow |
|
|
{ |
|
|
{ |
|
@@ -14,7 +15,12 @@ namespace Tensorflow |
|
|
bool preserve_cardinality = false, |
|
|
bool preserve_cardinality = false, |
|
|
bool use_legacy_function = false) : base(input_dataset) |
|
|
bool use_legacy_function = false) : base(input_dataset) |
|
|
{ |
|
|
{ |
|
|
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); |
|
|
|
|
|
|
|
|
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); |
|
|
|
|
|
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input"); |
|
|
|
|
|
var output = map_func(input); |
|
|
|
|
|
func.ToGraph(input, output); |
|
|
|
|
|
|
|
|
|
|
|
structure = func.OutputStructure; |
|
|
|
|
|
|
|
|
variant_tensor = ops.map_dataset(input_dataset.variant_tensor, |
|
|
variant_tensor = ops.map_dataset(input_dataset.variant_tensor, |
|
|
func, |
|
|
func, |
|
|