diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 103d7cff..324d7e83 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -19,6 +19,8 @@ namespace Tensorflow
public TensorSpec[] structure { get; set; }
+ public int FirstInputTensorCount { get; set; } = 1;
+
public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();
public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
@@ -131,6 +133,7 @@ namespace Tensorflow
// (4) Apply stats aggregator options
+ dataset.FirstInputTensorCount = this.FirstInputTensorCount;
return dataset;
}
@@ -142,7 +145,7 @@ namespace Tensorflow
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
$"len: {length}";
- public IEnumerator<(Tensor, Tensor)> GetEnumerator()
+ public IEnumerator<(Tensors, Tensors)> GetEnumerator()
{
using var ownedIterator = new OwnedIterator(this);
@@ -158,7 +161,8 @@ namespace Tensorflow
break;
}
- yield return (results[0], results.Length == 1 ? null : results[1]);
+ yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
+ null : new Tensors(results.Skip(FirstInputTensorCount)));
}
}
diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
index 5cfeb27c..320cbe34 100644
--- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -4,7 +4,7 @@ using Tensorflow.Framework.Models;
namespace Tensorflow
{
- public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
+ public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)>
{
string[] class_names { get; set; }
@@ -18,6 +18,8 @@ namespace Tensorflow
TensorSpec[] structure { get; set; }
+ int FirstInputTensorCount { get; set; }
+
///
/// Caches the elements in this dataset.
///
diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
index eb91272c..1dafc87e 100644
--- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs
+++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs
@@ -27,7 +27,8 @@ namespace Tensorflow
_dataset = dataset;
_element_spec = dataset.element_spec;
// _flat_output_types =
- (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
+ _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
+ // TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource);
}
@@ -48,7 +49,7 @@ namespace Tensorflow
public void Dispose()
{
- tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
+ //tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null);
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
index 8ce1ec65..78882e82 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataAdapterArgs: IKerasConfig
{
- public Tensor X { get; set; }
- public Tensor Y { get; set; }
+ public Tensors X { get; set; }
+ public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
index fd603a85..82530e95 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class DataHandlerArgs: IKerasConfig
{
- public Tensor X { get; set; }
- public Tensor Y { get; set; }
+ public Tensors X { get; set; }
+ public Tensors Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
index 8bcfcbbb..e02642dc 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
@@ -24,6 +24,17 @@ public interface IModel : ILayer
int workers = 1,
bool use_multiprocessing = false);
+ ICallback fit(IEnumerable x, NDArray y,
+ int batch_size = -1,
+ int epochs = 1,
+ int verbose = 1,
+ float validation_split = 0f,
+ bool shuffle = true,
+ int initial_epoch = 0,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false);
+
void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
index 53401a44..fd4f93fc 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
@@ -14,7 +14,76 @@ namespace Tensorflow.NumPy
red = data[2];
}
- public static implicit operator NDArray(Array array)
+ public static implicit operator NDArray(int[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(byte[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(float[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(double[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(long[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(bool[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(uint[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(ulong[] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(int[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(byte[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(float[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(double[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(long[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(bool[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(uint[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(ulong[,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(int[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(byte[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(float[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(double[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(long[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(bool[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(uint[,,] array)
+ => new NDArray(array);
+
+ public static implicit operator NDArray(ulong[,,] array)
=> new NDArray(array);
public unsafe static implicit operator bool(NDArray nd)
diff --git a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
index 6e81216e..ba7868fa 100644
--- a/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
+++ b/src/TensorFlowNET.Core/NumPy/Persistence/NpzDictionaryArray.cs
@@ -25,7 +25,7 @@ public class NpzDictionary
return array;
using var s = entry.Open();
- return LoadMatrix(s);
+ return (NDArray)LoadMatrix(s);
}
public Array LoadMatrix(Stream stream)
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs
index 3a2cb3ee..6e4c6b32 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs
@@ -49,5 +49,8 @@ namespace Tensorflow.NumPy
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
+
+ public static explicit operator NDArray(Array array)
+ => new NDArray(array);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs
index 9407fd5a..c7e62777 100644
--- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs
@@ -1,6 +1,9 @@
using System;
+using Tensorflow.Contexts;
+using Tensorflow.Eager;
using Tensorflow.Framework.Models;
using Tensorflow.Functions;
+using Tensorflow.Operations;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -220,6 +223,37 @@ namespace Tensorflow
return (results[0], results[1]);
}
+ public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null)
+ {
+ var ctx = tf.Context;
+ Dictionary attrs = new();
+ attrs["output_types"] = output_types;
+ attrs["output_shapes"] = output_shapes;
+ if (ctx.executing_eagerly())
+ {
+ try
+ {
+ var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name)
+ {
+ attrs = attrs
+ });
+ return result[0];
+ }
+ catch (Exception)
+ {
+ return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx);
+ }
+ }
+ return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0];
+ }
+
+ public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx)
+ {
+ object[] attrs = new object[] { output_types, output_shapes };
+ var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name);
+ return result[0];
+ }
+
///
/// Makes a new iterator from the given `dataset` and stores it in `iterator`.
///
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
index 3314f5c4..6c7d53b2 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
protected DataAdapterArgs args;
protected IDatasetV2 dataset;
- public virtual bool CanHandle(Tensor x, Tensor y = null)
+ public virtual bool CanHandle(Tensors x, Tensors y = null)
=> throw new NotImplementedException();
public virtual IDatasetV2 GetDataset()
@@ -19,12 +19,18 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public virtual int GetSize()
=> throw new NotImplementedException("");
- public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
+ public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y)
{
- if (x.shape.ndim == 1)
- x = array_ops.expand_dims(x, axis: -1);
- if (y.shape.ndim == 1)
- y = array_ops.expand_dims(y, axis: -1);
+ for(int i = 0; i < x.Length; i++)
+ {
+ if (x[i].shape.ndim == 1)
+ x[i] = array_ops.expand_dims(x[i], axis: -1);
+ }
+ for (int i = 0; i < y.Length; i++)
+ {
+ if (y[i].shape.ndim == 1)
+ y[i] = array_ops.expand_dims(y[i], axis: -1);
+ }
return (x, y);
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
index 1ddddd11..4723222f 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
@@ -93,11 +93,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
{
+ var data_iterator = new OwnedIterator(_dataset);
foreach (var epoch in range(_initial_epoch, _epochs))
{
if (_insufficient_data)
break;
- using var data_iterator = new OwnedIterator(_dataset);
+ if (_adapter.ShouldRecreateIterator())
+ {
+ data_iterator = new OwnedIterator(_dataset);
+ }
yield return (epoch, data_iterator);
}
// _adapter.on_epoch_end()
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
index df414b9f..4bdc4979 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
@@ -13,10 +13,10 @@
/// input features
/// target labels
///
- bool CanHandle(Tensor x, Tensor y = null);
+ bool CanHandle(Tensors x, Tensors y = null);
IDatasetV2 GetDataset();
int GetSize();
- (Tensor, Tensor) Expand1d(Tensor x, Tensor y);
+ (Tensors, Tensors) Expand1d(Tensors x, Tensors y);
bool ShouldRecreateIterator();
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index fc61aa71..f53c67c4 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -1,4 +1,5 @@
using System;
+using System.Diagnostics;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
@@ -20,7 +21,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
{
this.args = args;
_process_tensorlike();
- num_samples = (int)args.X.shape[0];
+ num_samples = (int)args.X[0].shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f)));
@@ -33,10 +34,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters
indices_dataset = indices_dataset.flat_map(slice_batch_indices);
var inputs = new Tensors();
if (args.X != null)
- inputs.Add(args.X);
+ inputs.AddRange(args.X);
if (args.Y != null)
- inputs.Add(args.Y);
+ inputs.AddRange(args.Y);
dataset = slice_inputs(indices_dataset, inputs);
+ dataset.FirstInputTensorCount = args.X.Length;
}
Tensors permutation(Tensors tensor)
@@ -87,8 +89,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters
return dataset.with_options(new DatasetOptions { });
}
- public override int GetSize()
- => _size;
+ public override int GetSize() => _size;
+
+ public override bool ShouldRecreateIterator() => false;
void _process_tensorlike()
{
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
index 1ebd56d3..39004183 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
@@ -59,7 +59,62 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});
- return FitInternal(data_handler, epochs, verbose);
+ return FitInternal(data_handler, epochs, verbose, validation_data: null,
+ train_step_func: train_step_function);
+ }
+
+ public ICallback fit(IEnumerable x, NDArray y,
+ int batch_size = -1,
+ int epochs = 1,
+ int verbose = 1,
+ float validation_split = 0f,
+ bool shuffle = true,
+ int initial_epoch = 0,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false)
+ {
+ foreach(var tx in x)
+ {
+ if (tx.dims[0] != y.dims[0])
+ {
+ throw new InvalidArgumentError(
+ $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
+ }
+ }
+ int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
+
+ var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor);
+ var train_y = y[new Slice(0, train_count)];
+ var val_x = x.Select(x => x[new Slice(train_count)] as Tensor);
+ var val_y = y[new Slice(train_count)];
+
+ var data_handler = new DataHandler(new DataHandlerArgs
+ {
+ X = new Tensors(train_x),
+ Y = train_y,
+ BatchSize = batch_size,
+ InitialEpoch = initial_epoch,
+ Epochs = epochs,
+ Shuffle = shuffle,
+ MaxQueueSize = max_queue_size,
+ Workers = workers,
+ UseMultiprocessing = use_multiprocessing,
+ Model = this,
+ StepsPerExecution = _steps_per_execution
+ });
+
+ if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
+ data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
+ {
+ return FitInternal(data_handler, epochs, verbose, validation_data: null,
+ train_step_func: train_step_multi_inputs_function);
+ }
+ else
+ {
+ return FitInternal(data_handler, epochs, verbose, validation_data: null,
+ train_step_func: train_step_function);
+ }
}
public History fit(IDatasetV2 dataset,
@@ -88,10 +143,12 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});
- return FitInternal(data_handler, epochs, verbose, validation_data: validation_data);
+ return FitInternal(data_handler, epochs, verbose, validation_data: validation_data,
+ train_step_func: train_step_function);
}
- History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null)
+ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data,
+ Func> train_step_func)
{
stop_training = false;
_train_counter.assign(0);
@@ -113,7 +170,7 @@ namespace Tensorflow.Keras.Engine
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
- logs = train_step_function(data_handler, iterator);
+ logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
callbacks.on_train_batch_end(end_step, logs);
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
index 8d85d70d..d8171e2a 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
@@ -17,12 +17,21 @@ namespace Tensorflow.Keras.Engine
return outputs;
}
+ Dictionary train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
+ {
+ var data = iterator.next();
+ var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
+ var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
+ tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
+ return outputs;
+ }
+
///
/// The logic for one training step.
///
///
///
- Dictionary train_step(DataHandler data_handler, Tensor x, Tensor y)
+ Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
diff --git a/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs
new file mode 100644
index 00000000..e145ce58
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/Helpers/RandomDataset.cs
@@ -0,0 +1,30 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.UnitTest.Helpers
+{
+ public class RandomDataSet : DataSetBase
+ {
+ private Shape _shape;
+
+ public RandomDataSet(Shape shape, int count)
+ {
+ _shape = shape;
+ Debug.Assert(_shape.ndim == 3);
+ long[] dims = new long[4];
+ dims[0] = count;
+ for (int i = 1; i < 4; i++)
+ {
+ dims[i] = _shape[i - 1];
+ }
+ Shape s = new Shape(dims);
+ Data = np.random.normal(0, 2, s);
+ Labels = np.random.uniform(0, 1, (count, 1));
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
new file mode 100644
index 00000000..490178bc
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs
@@ -0,0 +1,69 @@
+using Microsoft.VisualStudio.TestPlatform.Utilities;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using System.Xml.Linq;
+using Tensorflow.Operations;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+using Tensorflow.NumPy;
+using Microsoft.VisualBasic;
+using static HDF.PInvoke.H5T;
+using Tensorflow.Keras.UnitTest.Helpers;
+using Tensorflow.Keras.Optimizers;
+
+namespace Tensorflow.Keras.UnitTest
+{
+ [TestClass]
+ public class MultiInputModelTest
+ {
+ [TestMethod]
+ public void SimpleModel()
+ {
+ var inputs = keras.Input((28, 28, 1));
+ var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
+ var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
+ var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
+ var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
+ var flat1 = keras.layers.Flatten().Apply(pool2);
+
+ var inputs_2 = keras.Input((28, 28, 1));
+ var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
+ var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
+ var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
+ var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
+ var flat1_2 = keras.layers.Flatten().Apply(pool2_2);
+
+ var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
+ var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
+ var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
+ var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
+ var output = keras.layers.Softmax(-1).Apply(dense3);
+
+ var model = keras.Model((inputs, inputs_2), output);
+ model.summary();
+
+ var data_loader = new MnistModelLoader();
+
+ var dataset = data_loader.LoadAsync(new ModelLoadSetting
+ {
+ TrainDir = "mnist",
+ OneHot = false,
+ ValidationSize = 59000,
+ }).Result;
+
+ var loss = keras.losses.SparseCategoricalCrossentropy();
+ var optimizer = new Adam(0.001f);
+ model.compile(optimizer, loss, new string[] { "accuracy" });
+
+ NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
+ NDArray x2 = x1;
+
+ var x = new NDArray[] { x1, x2 };
+ model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
index e778a5a4..385ec0f7 100644
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
@@ -13,6 +13,7 @@ using Tensorflow;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
+using Tensorflow.Keras.UnitTest.Helpers;
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave;
namespace TensorFlowNET.Keras.UnitTest.SaveModel;
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
index 15823b9f..1cf68d3b 100644
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
@@ -6,7 +6,7 @@ using Tensorflow.Keras;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
-using Tensorflow.NumPy;
+using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -175,24 +175,4 @@ public class SequentialModelSave
// )
#endregion
}
-
- public class RandomDataSet : DataSetBase
- {
- private Shape _shape;
-
- public RandomDataSet(Shape shape, int count)
- {
- _shape = shape;
- Debug.Assert(_shape.ndim == 3);
- long[] dims = new long[4];
- dims[0] = count;
- for (int i = 1; i < 4; i++)
- {
- dims[i] = _shape[i - 1];
- }
- Shape s = new Shape(dims);
- Data = np.random.normal(0, 2, s);
- Labels = np.random.uniform(0, 1, (count, 1));
- }
- }
}
\ No newline at end of file
diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
index 8317346e..01f35a41 100644
--- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
+++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
@@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(iStep, step);
iStep++;
- Assert.AreEqual(value, (long)item.Item1);
+ Assert.AreEqual(value, (long)item.Item1[0]);
value++;
}
}
@@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(iStep, step);
iStep++;
- Assert.AreEqual(value, (long)item.Item1);
+ Assert.AreEqual(value, (long)item.Item1[0]);
value += 2;
}
}
@@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset
int n = 0;
foreach (var (item_x, item_y) in dataset)
{
- print($"x:{item_x.numpy()},y:{item_y.numpy()}");
+ print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}");
n += 1;
}
Assert.AreEqual(5, n);
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset
int n = 0;
foreach (var x in dataset)
{
- Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray()));
+ Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray()));
n += 1;
}
Assert.AreEqual(1, n);
@@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset
foreach (var item in dataset2)
{
- Assert.AreEqual(value, (long)item.Item1);
+ Assert.AreEqual(value, (long)item.Item1[0]);
value += 3;
}
@@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset
var dataset3 = dataset1.shard(num_shards: 3, index: 1);
foreach (var item in dataset3)
{
- Assert.AreEqual(value, (long)item.Item1);
+ Assert.AreEqual(value, (long)item.Item1[0]);
value += 3;
}
}
@@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset
foreach (var item in dataset)
{
- Assert.AreEqual(value, (long)item.Item1);
+ Assert.AreEqual(value, (long)item.Item1[0]);
value++;
}
}
@@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset
foreach (var item in dataset)
{
- Assert.AreEqual(value + 10, (long)item.Item1);
+ Assert.AreEqual(value + 10, (long)item.Item1[0]);
value++;
}
}
@@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset
foreach (var item in dataset)
{
- Assert.AreEqual(value, (long)item.Item1);
+ Assert.AreEqual(value, (long)item.Item1[0]);
value++;
}
}