Browse Source

Fix MapDataset.

tags/keras_v0.3.0
Oceania2018 Haiping 4 years ago
parent
commit
400cde2ce6
6 changed files with 29 additions and 5 deletions
  1. +18
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +2
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  3. +3
    -3
      src/TensorFlowNET.Core/Data/MapDataset.cs
  4. +3
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  6. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs

+ 18
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -3,6 +3,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -98,6 +99,20 @@ namespace Tensorflow
return dataset;
}

public Tensor dataset_cardinality(string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DatasetCardinality", name,
null,
variant_tensor);
return results[0];
}

throw new NotImplementedException("");
}

public override string ToString()
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";

@@ -117,7 +132,9 @@ namespace Tensorflow
break;
}

yield return (results[0], results.Length == 1 ? null : results[1]);
yield return results.Length == 2
? (results[0], results[1])
: (null, results[0]);
}
}



+ 2
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -74,5 +74,7 @@ namespace Tensorflow
/// </summary>
/// <returns></returns>
IDatasetV2 apply_options();

Tensor dataset_cardinality(string name = null);
}
}

+ 3
- 3
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -15,11 +15,11 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
var input = tf.placeholder(input_dataset.element_spec[0].dtype, name: "input");
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
var output = map_func(input);
func.ToGraph(input, output);
structure = func.OutputStructure;

variant_tensor = ops.map_dataset(input_dataset.variant_tensor,


+ 3
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -130,6 +130,9 @@ namespace Tensorflow.Functions
return new ForwardBackwardCall(functions, args, tape_watching: true);
}

public override string ToString()
=> Name;

public void Dispose()
{
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);


src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -2,10 +2,11 @@

namespace Tensorflow.Keras.ArgsDefinition
{
public class TensorLikeDataAdapterArgs
public class DataAdapterArgs
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
public int Epochs { get; set; }

+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
public int InitialEpoch { get; set; } = 0;


Loading…
Cancel
Save