From 404c803ce23aa55bb400a118a7ab202a16fbffba Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sat, 4 Mar 2023 21:41:25 +0800 Subject: [PATCH] Support the multiple inputs of keras model.fit. --- src/TensorFlowNET.Core/Data/DatasetV2.cs | 8 ++- src/TensorFlowNET.Core/Data/IDatasetV2.cs | 4 +- src/TensorFlowNET.Core/Data/OwnedIterator.cs | 5 +- .../Keras/ArgsDefinition/DataAdapterArgs.cs | 4 +- .../Keras/ArgsDefinition/DataHandlerArgs.cs | 4 +- src/TensorFlowNET.Core/Keras/Engine/IModel.cs | 11 +++ .../NumPy/NDArray.Implicit.cs | 71 ++++++++++++++++++- .../NumPy/Persistence/NpzDictionaryArray.cs | 2 +- src/TensorFlowNET.Core/Numpy/NDArray.cs | 3 + .../Operations/dataset_ops.cs | 34 +++++++++ .../Engine/DataAdapters/DataAdapter.cs | 18 +++-- .../Engine/DataAdapters/DataHandler.cs | 6 +- .../Engine/DataAdapters/IDataAdapter.cs | 4 +- .../DataAdapters/TensorLikeDataAdapter.cs | 13 ++-- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 65 +++++++++++++++-- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 11 ++- .../Helpers/RandomDataset.cs | 30 ++++++++ .../MultiInputModelTest.cs | 69 ++++++++++++++++++ .../SaveModel/SequentialModelLoad.cs | 1 + .../SaveModel/SequentialModelSave.cs | 22 +----- .../Dataset/DatasetTest.cs | 18 ++--- 21 files changed, 343 insertions(+), 60 deletions(-) create mode 100644 test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs create mode 100644 test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 103d7cff..324d7e83 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -19,6 +19,8 @@ namespace Tensorflow public TensorSpec[] structure { get; set; } + public int FirstInputTensorCount { get; set; } = 1; + public Shape[] output_shapes => structure.Select(x => x.shape).ToArray(); public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); @@ -131,6 +133,7 @@ namespace Tensorflow // (4) Apply stats aggregator options + dataset.FirstInputTensorCount = this.FirstInputTensorCount; return dataset; } @@ -142,7 +145,7 @@ namespace Tensorflow $"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " + $"len: {length}"; - public IEnumerator<(Tensor, Tensor)> GetEnumerator() + public IEnumerator<(Tensors, Tensors)> GetEnumerator() { using var ownedIterator = new OwnedIterator(this); @@ -158,7 +161,8 @@ namespace Tensorflow break; } - yield return (results[0], results.Length == 1 ? null : results[1]); + yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? + null : new Tensors(results.Skip(FirstInputTensorCount))); } } diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 5cfeb27c..320cbe34 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -4,7 +4,7 @@ using Tensorflow.Framework.Models; namespace Tensorflow { - public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> + public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)> { string[] class_names { get; set; } @@ -18,6 +18,8 @@ namespace Tensorflow TensorSpec[] structure { get; set; } + int FirstInputTensorCount { get; set; } + /// /// Caches the elements in this dataset. /// diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs index eb91272c..1dafc87e 100644 --- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -27,7 +27,8 @@ namespace Tensorflow _dataset = dataset; _element_spec = dataset.element_spec; // _flat_output_types = - (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); + _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes); + // TODO(Rinne): deal with graph mode. ops.make_iterator(dataset.variant_tensor, _iterator_resource); } @@ -48,7 +49,7 @@ namespace Tensorflow public void Dispose() { - tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); + //tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index 8ce1ec65..78882e82 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition { public class DataAdapterArgs: IKerasConfig { - public Tensor X { get; set; } - public Tensor Y { get; set; } + public Tensors X { get; set; } + public Tensors Y { get; set; } public IDatasetV2 Dataset { get; set; } public int BatchSize { get; set; } = 32; public int Steps { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index fd603a85..82530e95 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition { public class DataHandlerArgs: IKerasConfig { - public Tensor X { get; set; } - public Tensor Y { get; set; } + public Tensors X { get; set; } + public Tensors Y { get; set; } public IDatasetV2 Dataset { get; set; } public int BatchSize { get; set; } = 32; public int StepsPerEpoch { get; set; } = -1; diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs index 8bcfcbbb..e02642dc 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs @@ -24,6 +24,17 @@ public interface IModel : ILayer int workers = 1, bool use_multiprocessing = false); + ICallback fit(IEnumerable x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + float validation_split = 0f, + bool shuffle = true, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false); + void save(string filepath, bool overwrite = true, bool include_optimizer = true, diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index 53401a44..fd4f93fc 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -14,7 +14,76 @@ namespace Tensorflow.NumPy red = data[2]; } - public static implicit operator NDArray(Array array) + public static implicit operator NDArray(int[] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[] array) + => new NDArray(array); + + public static implicit operator NDArray(float[] array) + => new NDArray(array); + + public static implicit operator NDArray(double[] array) + => new NDArray(array); + + public static implicit operator NDArray(long[] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[] array) + => new NDArray(array); + + public static implicit operator NDArray(int[,] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[,] array) + => new NDArray(array); + + public static implicit operator NDArray(float[,] array) + => new NDArray(array); + + public static implicit operator NDArray(double[,] array) + => new NDArray(array); + + public static implicit operator NDArray(long[,] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[,] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[,] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[,] array) + => new NDArray(array); + + public static implicit operator NDArray(int[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(byte[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(float[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(double[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(long[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(bool[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(uint[,,] array) + => new NDArray(array); + + public static implicit operator NDArray(ulong[,,] array) => new NDArray(array); public unsafe static implicit operator bool(NDArray nd) diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs index 6e81216e..ba7868fa 100644 --- a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs +++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs @@ -25,7 +25,7 @@ public class NpzDictionary return array; using var s = entry.Open(); - return LoadMatrix(s); + return (NDArray)LoadMatrix(s); } public Array LoadMatrix(Stream stream) diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 3a2cb3ee..6e4c6b32 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -49,5 +49,8 @@ namespace Tensorflow.NumPy IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + public static explicit operator NDArray(Array array) + => new NDArray(array); } } diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 9407fd5a..c7e62777 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -1,6 +1,9 @@ using System; +using Tensorflow.Contexts; +using Tensorflow.Eager; using Tensorflow.Framework.Models; using Tensorflow.Functions; +using Tensorflow.Operations; using static Tensorflow.Binding; namespace Tensorflow @@ -220,6 +223,37 @@ namespace Tensorflow return (results[0], results[1]); } + public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null) + { + var ctx = tf.Context; + Dictionary attrs = new(); + attrs["output_types"] = output_types; + attrs["output_shapes"] = output_shapes; + if (ctx.executing_eagerly()) + { + try + { + var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name) + { + attrs = attrs + }); + return result[0]; + } + catch (Exception) + { + return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx); + } + } + return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0]; + } + + public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx) + { + object[] attrs = new object[] { output_types, output_shapes }; + var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name); + return result[0]; + } + /// /// Makes a new iterator from the given `dataset` and stores it in `iterator`. /// diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs index 3314f5c4..6c7d53b2 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters protected DataAdapterArgs args; protected IDatasetV2 dataset; - public virtual bool CanHandle(Tensor x, Tensor y = null) + public virtual bool CanHandle(Tensors x, Tensors y = null) => throw new NotImplementedException(); public virtual IDatasetV2 GetDataset() @@ -19,12 +19,18 @@ namespace Tensorflow.Keras.Engine.DataAdapters public virtual int GetSize() => throw new NotImplementedException(""); - public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y) + public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y) { - if (x.shape.ndim == 1) - x = array_ops.expand_dims(x, axis: -1); - if (y.shape.ndim == 1) - y = array_ops.expand_dims(y, axis: -1); + for(int i = 0; i < x.Length; i++) + { + if (x[i].shape.ndim == 1) + x[i] = array_ops.expand_dims(x[i], axis: -1); + } + for (int i = 0; i < y.Length; i++) + { + if (y[i].shape.ndim == 1) + y[i] = array_ops.expand_dims(y[i], axis: -1); + } return (x, y); } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 1ddddd11..4723222f 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -93,11 +93,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters public IEnumerable<(int, OwnedIterator)> enumerate_epochs() { + var data_iterator = new OwnedIterator(_dataset); foreach (var epoch in range(_initial_epoch, _epochs)) { if (_insufficient_data) break; - using var data_iterator = new OwnedIterator(_dataset); + if (_adapter.ShouldRecreateIterator()) + { + data_iterator = new OwnedIterator(_dataset); + } yield return (epoch, data_iterator); } // _adapter.on_epoch_end() diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs index df414b9f..4bdc4979 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs @@ -13,10 +13,10 @@ /// input features /// target labels /// - bool CanHandle(Tensor x, Tensor y = null); + bool CanHandle(Tensors x, Tensors y = null); IDatasetV2 GetDataset(); int GetSize(); - (Tensor, Tensor) Expand1d(Tensor x, Tensor y); + (Tensors, Tensors) Expand1d(Tensors x, Tensors y); bool ShouldRecreateIterator(); } } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index fc61aa71..f53c67c4 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Linq; using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; @@ -20,7 +21,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters { this.args = args; _process_tensorlike(); - num_samples = (int)args.X.shape[0]; + num_samples = (int)args.X[0].shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); @@ -33,10 +34,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters indices_dataset = indices_dataset.flat_map(slice_batch_indices); var inputs = new Tensors(); if (args.X != null) - inputs.Add(args.X); + inputs.AddRange(args.X); if (args.Y != null) - inputs.Add(args.Y); + inputs.AddRange(args.Y); dataset = slice_inputs(indices_dataset, inputs); + dataset.FirstInputTensorCount = args.X.Length; } Tensors permutation(Tensors tensor) @@ -87,8 +89,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters return dataset.with_options(new DatasetOptions { }); } - public override int GetSize() - => _size; + public override int GetSize() => _size; + + public override bool ShouldRecreateIterator() => false; void _process_tensorlike() { diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 1ebd56d3..39004183 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -59,7 +59,62 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - return FitInternal(data_handler, epochs, verbose); + return FitInternal(data_handler, epochs, verbose, validation_data: null, + train_step_func: train_step_function); + } + + public ICallback fit(IEnumerable x, NDArray y, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + float validation_split = 0f, + bool shuffle = true, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + foreach(var tx in x) + { + if (tx.dims[0] != y.dims[0]) + { + throw new InvalidArgumentError( + $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); + } + } + int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); + + var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); + var train_y = y[new Slice(0, train_count)]; + var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); + var val_y = y[new Slice(train_count)]; + + var data_handler = new DataHandler(new DataHandlerArgs + { + X = new Tensors(train_x), + Y = train_y, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || + data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) + { + return FitInternal(data_handler, epochs, verbose, validation_data: null, + train_step_func: train_step_multi_inputs_function); + } + else + { + return FitInternal(data_handler, epochs, verbose, validation_data: null, + train_step_func: train_step_function); + } } public History fit(IDatasetV2 dataset, @@ -88,10 +143,12 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - return FitInternal(data_handler, epochs, verbose, validation_data: validation_data); + return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, + train_step_func: train_step_function); } - History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null) + History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, + Func> train_step_func) { stop_training = false; _train_counter.assign(0); @@ -113,7 +170,7 @@ namespace Tensorflow.Keras.Engine foreach (var step in data_handler.steps()) { callbacks.on_train_batch_begin(step); - logs = train_step_function(data_handler, iterator); + logs = train_step_func(data_handler, iterator); var end_step = step + data_handler.StepIncrement; callbacks.on_train_batch_end(end_step, logs); } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 8d85d70d..d8171e2a 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -17,12 +17,21 @@ namespace Tensorflow.Keras.Engine return outputs; } + Dictionary train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) + { + var data = iterator.next(); + var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; + var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); + tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); + return outputs; + } + /// /// The logic for one training step. /// /// /// - Dictionary train_step(DataHandler data_handler, Tensor x, Tensor y) + Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); diff --git a/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs new file mode 100644 index 00000000..e145ce58 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow.NumPy; + +namespace Tensorflow.Keras.UnitTest.Helpers +{ + public class RandomDataSet : DataSetBase + { + private Shape _shape; + + public RandomDataSet(Shape shape, int count) + { + _shape = shape; + Debug.Assert(_shape.ndim == 3); + long[] dims = new long[4]; + dims[0] = count; + for (int i = 1; i < 4; i++) + { + dims[i] = _shape[i - 1]; + } + Shape s = new Shape(dims); + Data = np.random.normal(0, 2, s); + Labels = np.random.uniform(0, 1, (count, 1)); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs new file mode 100644 index 00000000..490178bc --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs @@ -0,0 +1,69 @@ +using Microsoft.VisualStudio.TestPlatform.Utilities; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using System.Xml.Linq; +using Tensorflow.Operations; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.NumPy; +using Microsoft.VisualBasic; +using static HDF.PInvoke.H5T; +using Tensorflow.Keras.UnitTest.Helpers; +using Tensorflow.Keras.Optimizers; + +namespace Tensorflow.Keras.UnitTest +{ + [TestClass] + public class MultiInputModelTest + { + [TestMethod] + public void SimpleModel() + { + var inputs = keras.Input((28, 28, 1)); + var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); + var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1); + var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1); + var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2); + var flat1 = keras.layers.Flatten().Apply(pool2); + + var inputs_2 = keras.Input((28, 28, 1)); + var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2); + var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2); + var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2); + var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2); + var flat1_2 = keras.layers.Flatten().Apply(pool2_2); + + var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); + var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); + var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); + var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); + var output = keras.layers.Softmax(-1).Apply(dense3); + + var model = keras.Model((inputs, inputs_2), output); + model.summary(); + + var data_loader = new MnistModelLoader(); + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 59000, + }).Result; + + var loss = keras.losses.SparseCategoricalCrossentropy(); + var optimizer = new Adam(0.001f); + model.compile(optimizer, loss, new string[] { "accuracy" }); + + NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); + NDArray x2 = x1; + + var x = new NDArray[] { x1, x2 }; + model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3); + } + } +} diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index e778a5a4..385ec0f7 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -13,6 +13,7 @@ using Tensorflow; using Tensorflow.Keras.Optimizers; using static Tensorflow.KerasApi; using Tensorflow.NumPy; +using Tensorflow.Keras.UnitTest.Helpers; using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; namespace TensorFlowNET.Keras.UnitTest.SaveModel; diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs index 15823b9f..1cf68d3b 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs @@ -6,7 +6,7 @@ using Tensorflow.Keras; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; -using Tensorflow.NumPy; +using Tensorflow.Keras.UnitTest.Helpers; using static Tensorflow.Binding; using static Tensorflow.KerasApi; @@ -175,24 +175,4 @@ public class SequentialModelSave // ) #endregion } - - public class RandomDataSet : DataSetBase - { - private Shape _shape; - - public RandomDataSet(Shape shape, int count) - { - _shape = shape; - Debug.Assert(_shape.ndim == 3); - long[] dims = new long[4]; - dims[0] = count; - for (int i = 1; i < 4; i++) - { - dims[i] = _shape[i - 1]; - } - Shape s = new Shape(dims); - Data = np.random.normal(0, 2, s); - Labels = np.random.uniform(0, 1, (count, 1)); - } - } } \ No newline at end of file diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 8317346e..01f35a41 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value++; } } @@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value += 2; } } @@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset int n = 0; foreach (var (item_x, item_y) in dataset) { - print($"x:{item_x.numpy()},y:{item_y.numpy()}"); + print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); n += 1; } Assert.AreEqual(5, n); @@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset int n = 0; foreach (var x in dataset) { - Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray())); + Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray())); n += 1; } Assert.AreEqual(1, n); @@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset2) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value += 3; } @@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset var dataset3 = dataset1.shard(num_shards: 3, index: 1); foreach (var item in dataset3) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value += 3; } } @@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value++; } } @@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset) { - Assert.AreEqual(value + 10, (long)item.Item1); + Assert.AreEqual(value + 10, (long)item.Item1[0]); value++; } } @@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1); + Assert.AreEqual(value, (long)item.Item1[0]); value++; } }