Browse Source

Support the multiple inputs of keras model.fit.

tags/v0.100.5-BERT-load
Yaohui Liu Haiping 2 years ago
parent
commit
404c803ce2
21 changed files with 343 additions and 60 deletions
  1. +6
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  2. +3
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  3. +3
    -2
      src/TensorFlowNET.Core/Data/OwnedIterator.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
  6. +11
    -0
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  7. +70
    -1
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  8. +1
    -1
      src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
  9. +3
    -0
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  10. +34
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  11. +12
    -6
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
  12. +5
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  13. +2
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
  14. +8
    -5
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  15. +61
    -4
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  16. +10
    -1
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  17. +30
    -0
      test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs
  18. +69
    -0
      test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
  19. +1
    -0
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  20. +1
    -21
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
  21. +9
    -9
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 6
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -19,6 +19,8 @@ namespace Tensorflow

public TensorSpec[] structure { get; set; }

public int FirstInputTensorCount { get; set; } = 1;

public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();

public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
@@ -131,6 +133,7 @@ namespace Tensorflow

// (4) Apply stats aggregator options

dataset.FirstInputTensorCount = this.FirstInputTensorCount;
return dataset;
}

@@ -142,7 +145,7 @@ namespace Tensorflow
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
$"len: {length}";

public IEnumerator<(Tensor, Tensor)> GetEnumerator()
public IEnumerator<(Tensors, Tensors)> GetEnumerator()
{
using var ownedIterator = new OwnedIterator(this);

@@ -158,7 +161,8 @@ namespace Tensorflow
break;
}

yield return (results[0], results.Length == 1 ? null : results[1]);
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
}
}



+ 3
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -4,7 +4,7 @@ using Tensorflow.Framework.Models;

namespace Tensorflow
{
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)>
{
string[] class_names { get; set; }

@@ -18,6 +18,8 @@ namespace Tensorflow

TensorSpec[] structure { get; set; }

int FirstInputTensorCount { get; set; }

/// <summary>
/// Caches the elements in this dataset.
/// </summary>


+ 3
- 2
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -27,7 +27,8 @@ namespace Tensorflow
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
// TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}

@@ -48,7 +49,7 @@ namespace Tensorflow

public void Dispose()
{
tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
//tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataAdapterArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public Tensors X { get; set; }
public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }


+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataHandlerArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public Tensors X { get; set; }
public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;


+ 11
- 0
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -24,6 +24,17 @@ public interface IModel : ILayer
int workers = 1,
bool use_multiprocessing = false);

ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,


+ 70
- 1
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -14,7 +14,76 @@ namespace Tensorflow.NumPy
red = data[2];
}

public static implicit operator NDArray(Array array)
public static implicit operator NDArray(int[] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[] array)
=> new NDArray(array);

public static implicit operator NDArray(float[] array)
=> new NDArray(array);

public static implicit operator NDArray(double[] array)
=> new NDArray(array);

public static implicit operator NDArray(long[] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[] array)
=> new NDArray(array);

public static implicit operator NDArray(int[,] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[,] array)
=> new NDArray(array);

public static implicit operator NDArray(float[,] array)
=> new NDArray(array);

public static implicit operator NDArray(double[,] array)
=> new NDArray(array);

public static implicit operator NDArray(long[,] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[,] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[,] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[,] array)
=> new NDArray(array);

public static implicit operator NDArray(int[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(byte[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(float[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(double[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(long[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(bool[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(uint[,,] array)
=> new NDArray(array);

public static implicit operator NDArray(ulong[,,] array)
=> new NDArray(array);

public unsafe static implicit operator bool(NDArray nd)


+ 1
- 1
src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs View File

@@ -25,7 +25,7 @@ public class NpzDictionary
return array;

using var s = entry.Open();
return LoadMatrix(s);
return (NDArray)LoadMatrix(s);
}

public Array LoadMatrix(Stream stream)


+ 3
- 0
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -49,5 +49,8 @@ namespace Tensorflow.NumPy

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public static explicit operator NDArray(Array array)
=> new NDArray(array);
}
}

+ 34
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -1,6 +1,9 @@
using System;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -220,6 +223,37 @@ namespace Tensorflow
return (results[0], results[1]);
}

public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null)
{
var ctx = tf.Context;
Dictionary<string, object> attrs = new();
attrs["output_types"] = output_types;
attrs["output_shapes"] = output_shapes;
if (ctx.executing_eagerly())
{
try
{
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name)
{
attrs = attrs
});
return result[0];
}
catch (Exception)
{
return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx);
}
}
return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0];
}

public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx)
{
object[] attrs = new object[] { output_types, output_shapes };
var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name);
return result[0];
}

/// <summary>
/// Makes a new iterator from the given `dataset` and stores it in `iterator`.
/// </summary>


+ 12
- 6
src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
protected DataAdapterArgs args;
protected IDatasetV2 dataset;

public virtual bool CanHandle(Tensor x, Tensor y = null)
public virtual bool CanHandle(Tensors x, Tensors y = null)
=> throw new NotImplementedException();

public virtual IDatasetV2 GetDataset()
@@ -19,12 +19,18 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public virtual int GetSize()
=> throw new NotImplementedException("");

public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
{
if (x.shape.ndim == 1)
x = array_ops.expand_dims(x, axis: -1);
if (y.shape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
for(int i = 0; i < x.Length; i++)
{
if (x[i].shape.ndim == 1)
x[i] = array_ops.expand_dims(x[i], axis: -1);
}
for (int i = 0; i < y.Length; i++)
{
if (y[i].shape.ndim == 1)
y[i] = array_ops.expand_dims(y[i], axis: -1);
}
return (x, y);
}



+ 5
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -93,11 +93,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters

public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
{
var data_iterator = new OwnedIterator(_dataset);
foreach (var epoch in range(_initial_epoch, _epochs))
{
if (_insufficient_data)
break;
using var data_iterator = new OwnedIterator(_dataset);
if (_adapter.ShouldRecreateIterator())
{
data_iterator = new OwnedIterator(_dataset);
}
yield return (epoch, data_iterator);
}
// _adapter.on_epoch_end()


+ 2
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs View File

@@ -13,10 +13,10 @@
/// <param name="x">input features</param>
/// <param name="y">target labels</param>
/// <returns></returns>
bool CanHandle(Tensor x, Tensor y = null);
bool CanHandle(Tensors x, Tensors y = null);
IDatasetV2 GetDataset();
int GetSize();
(Tensor, Tensor) Expand1d(Tensor x, Tensor y);
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
bool ShouldRecreateIterator();
}
}

+ 8
- 5
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
@@ -20,7 +21,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
{
this.args = args;
_process_tensorlike();
num_samples = (int)args.X.shape[0];
num_samples = (int)args.X[0].shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));
@@ -33,10 +34,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
var inputs = new Tensors();
if (args.X != null)
inputs.Add(args.X);
inputs.AddRange(args.X);
if (args.Y != null)
inputs.Add(args.Y);
inputs.AddRange(args.Y);
dataset = slice_inputs(indices_dataset, inputs);
dataset.FirstInputTensorCount = args.X.Length;
}

Tensors permutation(Tensors tensor)
@@ -87,8 +89,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters
return dataset.with_options(new DatasetOptions { });
}

public override int GetSize()
=> _size;
public override int GetSize() => _size;

public override bool ShouldRecreateIterator() => false;

void _process_tensorlike()
{


+ 61
- 4
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -59,7 +59,62 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose);
return FitInternal(data_handler, epochs, verbose, validation_data: null,
train_step_func: train_step_function);
}

public ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
foreach(var tx in x)
{
if (tx.dims[0] != y.dims[0])
{
throw new InvalidArgumentError(
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
}
}
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor);
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as Tensor);
var val_y = y[new Slice(train_count)];

var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(train_x),
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});

if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
return FitInternal(data_handler, epochs, verbose, validation_data: null,
train_step_func: train_step_multi_inputs_function);
}
else
{
return FitInternal(data_handler, epochs, verbose, validation_data: null,
train_step_func: train_step_function);
}
}

public History fit(IDatasetV2 dataset,
@@ -88,10 +143,12 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose, validation_data: validation_data);
return FitInternal(data_handler, epochs, verbose, validation_data: validation_data,
train_step_func: train_step_function);
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null)
History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
_train_counter.assign(0);
@@ -113,7 +170,7 @@ namespace Tensorflow.Keras.Engine
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
logs = train_step_function(data_handler, iterator);
logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
callbacks.on_train_batch_end(end_step, logs);
}


+ 10
- 1
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -17,12 +17,21 @@ namespace Tensorflow.Keras.Engine
return outputs;
}

Dictionary<string, float> train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}

/// <summary>
/// The logic for one training step.
/// </summary>
/// <param name="data"></param>
/// <returns></returns>
Dictionary<string, float> train_step(DataHandler data_handler, Tensor x, Tensor y)
Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();


+ 30
- 0
test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs View File

@@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.UnitTest.Helpers
{
public class RandomDataSet : DataSetBase
{
private Shape _shape;

public RandomDataSet(Shape shape, int count)
{
_shape = shape;
Debug.Assert(_shape.ndim == 3);
long[] dims = new long[4];
dims[0] = count;
for (int i = 1; i < 4; i++)
{
dims[i] = _shape[i - 1];
}
Shape s = new Shape(dims);
Data = np.random.normal(0, 2, s);
Labels = np.random.uniform(0, 1, (count, 1));
}
}
}

+ 69
- 0
test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs View File

@@ -0,0 +1,69 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
using Microsoft.VisualBasic;
using static HDF.PInvoke.H5T;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.Keras.Optimizers;

namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class MultiInputModelTest
{
[TestMethod]
public void SimpleModel()
{
var inputs = keras.Input((28, 28, 1));
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
var flat1 = keras.layers.Flatten().Apply(pool2);

var inputs_2 = keras.Input((28, 28, 1));
var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
var flat1_2 = keras.layers.Flatten().Apply(pool2_2);

var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var output = keras.layers.Softmax(-1).Apply(dense3);

var model = keras.Model((inputs, inputs_2), output);
model.summary();

var data_loader = new MnistModelLoader();

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 59000,
}).Result;

var loss = keras.losses.SparseCategoricalCrossentropy();
var optimizer = new Adam(0.001f);
model.compile(optimizer, loss, new string[] { "accuracy" });

NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
NDArray x2 = x1;

var x = new NDArray[] { x1, x2 };
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
}
}
}

+ 1
- 0
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -13,6 +13,7 @@ using Tensorflow;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
using Tensorflow.Keras.UnitTest.Helpers;
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;


+ 1
- 21
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -6,7 +6,7 @@ using Tensorflow.Keras;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using Tensorflow.NumPy;
using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

@@ -175,24 +175,4 @@ public class SequentialModelSave
// )
#endregion
}

public class RandomDataSet : DataSetBase
{
private Shape _shape;

public RandomDataSet(Shape shape, int count)
{
_shape = shape;
Debug.Assert(_shape.ndim == 3);
long[] dims = new long[4];
dims[0] = count;
for (int i = 1; i < 4; i++)
{
dims[i] = _shape[i - 1];
}
Shape s = new Shape(dims);
Data = np.random.normal(0, 2, s);
Labels = np.random.uniform(0, 1, (count, 1));
}
}
}

+ 9
- 9
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(iStep, step);
iStep++;

Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value, (long)item.Item1[0]);
value++;
}
}
@@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(iStep, step);
iStep++;

Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value, (long)item.Item1[0]);
value += 2;
}
}
@@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset
int n = 0;
foreach (var (item_x, item_y) in dataset)
{
print($"x:{item_x.numpy()},y:{item_y.numpy()}");
print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}");
n += 1;
}
Assert.AreEqual(5, n);
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset
int n = 0;
foreach (var x in dataset)
{
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray<int>()));
n += 1;
}
Assert.AreEqual(1, n);
@@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset2)
{
Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value, (long)item.Item1[0]);
value += 3;
}

@@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset
var dataset3 = dataset1.shard(num_shards: 3, index: 1);
foreach (var item in dataset3)
{
Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value, (long)item.Item1[0]);
value += 3;
}
}
@@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value, (long)item.Item1[0]);
value++;
}
}
@@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset)
{
Assert.AreEqual(value + 10, (long)item.Item1);
Assert.AreEqual(value + 10, (long)item.Item1[0]);
value++;
}
}
@@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1);
Assert.AreEqual(value, (long)item.Item1[0]);
value++;
}
}


Loading…
Cancel
Save