From 400cde2ce6f8bc045c41a11abe836bf75266a83e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 19 Dec 2020 08:32:12 -0600 Subject: [PATCH] Fix MapDataset. --- src/TensorFlowNET.Core/Data/DatasetV2.cs | 19 ++++++++++++++++++- src/TensorFlowNET.Core/Data/IDatasetV2.cs | 2 ++ src/TensorFlowNET.Core/Data/MapDataset.cs | 6 +++--- .../Functions/ConcreteFunction.cs | 3 +++ ...eDataAdapterArgs.cs => DataAdapterArgs.cs} | 3 ++- .../Keras/ArgsDefinition/DataHandlerArgs.cs | 1 + 6 files changed, 29 insertions(+), 5 deletions(-) rename src/TensorFlowNET.Core/Keras/ArgsDefinition/{TensorLikeDataAdapterArgs.cs => DataAdapterArgs.cs} (86%) diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 104789df..763baa31 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -3,6 +3,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using Tensorflow.Framework.Models; +using static Tensorflow.Binding; namespace Tensorflow { @@ -98,6 +99,20 @@ namespace Tensorflow return dataset; } + public Tensor dataset_cardinality(string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "DatasetCardinality", name, + null, + variant_tensor); + return results[0]; + } + + throw new NotImplementedException(""); + } + public override string ToString() => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; @@ -117,7 +132,9 @@ namespace Tensorflow break; } - yield return (results[0], results.Length == 1 ? null : results[1]); + yield return results.Length == 2 + ? (results[0], results[1]) + : (null, results[0]); } } diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index fc47c832..9a31ff51 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -74,5 +74,7 @@ namespace Tensorflow /// /// IDatasetV2 apply_options(); + + Tensor dataset_cardinality(string name = null); } } diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs index 3ea54233..c593322b 100644 --- a/src/TensorFlowNET.Core/Data/MapDataset.cs +++ b/src/TensorFlowNET.Core/Data/MapDataset.cs @@ -15,11 +15,11 @@ namespace Tensorflow bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { - var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); - var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input"); + using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); + var input = tf.placeholder(input_dataset.element_spec[0].dtype); var output = map_func(input); func.ToGraph(input, output); - + structure = func.OutputStructure; variant_tensor = ops.map_dataset(input_dataset.variant_tensor, diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index e4754860..90cb0494 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -130,6 +130,9 @@ namespace Tensorflow.Functions return new ForwardBackwardCall(functions, args, tape_watching: true); } + public override string ToString() + => Name; + public void Dispose() { c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs similarity index 86% rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index 921a4726..f3cca438 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -2,10 +2,11 @@ namespace Tensorflow.Keras.ArgsDefinition { - public class TensorLikeDataAdapterArgs + public class DataAdapterArgs { public Tensor X { get; set; } public Tensor Y { get; set; } + public IDatasetV2 Dataset { get; set; } public int BatchSize { get; set; } = 32; public int Steps { get; set; } public int Epochs { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index 63de54ad..b6e6849b 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition { public Tensor X { get; set; } public Tensor Y { get; set; } + public IDatasetV2 Dataset { get; set; } public int BatchSize { get; set; } = 32; public int StepsPerEpoch { get; set; } = -1; public int InitialEpoch { get; set; } = 0;