@@ -3,6 +3,7 @@ using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -98,6 +99,20 @@ namespace Tensorflow | |||||
return dataset; | return dataset; | ||||
} | } | ||||
public Tensor dataset_cardinality(string name = null) | |||||
{ | |||||
if (tf.Context.executing_eagerly()) | |||||
{ | |||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"DatasetCardinality", name, | |||||
null, | |||||
variant_tensor); | |||||
return results[0]; | |||||
} | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public override string ToString() | public override string ToString() | ||||
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; | => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; | ||||
@@ -117,7 +132,9 @@ namespace Tensorflow | |||||
break; | break; | ||||
} | } | ||||
yield return (results[0], results.Length == 1 ? null : results[1]); | |||||
yield return results.Length == 2 | |||||
? (results[0], results[1]) | |||||
: (null, results[0]); | |||||
} | } | ||||
} | } | ||||
@@ -74,5 +74,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
IDatasetV2 apply_options(); | IDatasetV2 apply_options(); | ||||
Tensor dataset_cardinality(string name = null); | |||||
} | } | ||||
} | } |
@@ -15,11 +15,11 @@ namespace Tensorflow | |||||
bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
bool use_legacy_function = false) : base(input_dataset) | bool use_legacy_function = false) : base(input_dataset) | ||||
{ | { | ||||
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||||
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input"); | |||||
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||||
var input = tf.placeholder(input_dataset.element_spec[0].dtype); | |||||
var output = map_func(input); | var output = map_func(input); | ||||
func.ToGraph(input, output); | func.ToGraph(input, output); | ||||
structure = func.OutputStructure; | structure = func.OutputStructure; | ||||
variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | ||||
@@ -130,6 +130,9 @@ namespace Tensorflow.Functions | |||||
return new ForwardBackwardCall(functions, args, tape_watching: true); | return new ForwardBackwardCall(functions, args, tape_watching: true); | ||||
} | } | ||||
public override string ToString() | |||||
=> Name; | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | ||||
@@ -2,10 +2,11 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
public class TensorLikeDataAdapterArgs | |||||
public class DataAdapterArgs | |||||
{ | { | ||||
public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
public IDatasetV2 Dataset { get; set; } | |||||
public int BatchSize { get; set; } = 32; | public int BatchSize { get; set; } = 32; | ||||
public int Steps { get; set; } | public int Steps { get; set; } | ||||
public int Epochs { get; set; } | public int Epochs { get; set; } |
@@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
public IDatasetV2 Dataset { get; set; } | |||||
public int BatchSize { get; set; } = 32; | public int BatchSize { get; set; } = 32; | ||||
public int StepsPerEpoch { get; set; } = -1; | public int StepsPerEpoch { get; set; } = -1; | ||||
public int InitialEpoch { get; set; } = 0; | public int InitialEpoch { get; set; } = 0; | ||||