@@ -19,6 +19,8 @@ namespace Tensorflow | |||||
public TensorSpec[] structure { get; set; } | public TensorSpec[] structure { get; set; } | ||||
public int FirstInputTensorCount { get; set; } = 1; | |||||
public Shape[] output_shapes => structure.Select(x => x.shape).ToArray(); | public Shape[] output_shapes => structure.Select(x => x.shape).ToArray(); | ||||
public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); | public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); | ||||
@@ -131,6 +133,7 @@ namespace Tensorflow | |||||
// (4) Apply stats aggregator options | // (4) Apply stats aggregator options | ||||
dataset.FirstInputTensorCount = this.FirstInputTensorCount; | |||||
return dataset; | return dataset; | ||||
} | } | ||||
@@ -142,7 +145,7 @@ namespace Tensorflow | |||||
$"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " + | $"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " + | ||||
$"len: {length}"; | $"len: {length}"; | ||||
public IEnumerator<(Tensor, Tensor)> GetEnumerator() | |||||
public IEnumerator<(Tensors, Tensors)> GetEnumerator() | |||||
{ | { | ||||
using var ownedIterator = new OwnedIterator(this); | using var ownedIterator = new OwnedIterator(this); | ||||
@@ -158,7 +161,8 @@ namespace Tensorflow | |||||
break; | break; | ||||
} | } | ||||
yield return (results[0], results.Length == 1 ? null : results[1]); | |||||
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||||
null : new Tensors(results.Skip(FirstInputTensorCount))); | |||||
} | } | ||||
} | } | ||||
@@ -4,7 +4,7 @@ using Tensorflow.Framework.Models; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)> | |||||
public interface IDatasetV2 : IEnumerable<(Tensors, Tensors)> | |||||
{ | { | ||||
string[] class_names { get; set; } | string[] class_names { get; set; } | ||||
@@ -18,6 +18,8 @@ namespace Tensorflow | |||||
TensorSpec[] structure { get; set; } | TensorSpec[] structure { get; set; } | ||||
int FirstInputTensorCount { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Caches the elements in this dataset. | /// Caches the elements in this dataset. | ||||
/// </summary> | /// </summary> | ||||
@@ -27,7 +27,8 @@ namespace Tensorflow | |||||
_dataset = dataset; | _dataset = dataset; | ||||
_element_spec = dataset.element_spec; | _element_spec = dataset.element_spec; | ||||
// _flat_output_types = | // _flat_output_types = | ||||
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); | |||||
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes); | |||||
// TODO(Rinne): deal with graph mode. | |||||
ops.make_iterator(dataset.variant_tensor, _iterator_resource); | ops.make_iterator(dataset.variant_tensor, _iterator_resource); | ||||
} | } | ||||
@@ -48,7 +49,7 @@ namespace Tensorflow | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); | |||||
//tf.Runner.Execute(tf.Context, "DeleteIterator", 0, new[] { _iterator_resource, _deleter }, null); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class DataAdapterArgs: IKerasConfig | public class DataAdapterArgs: IKerasConfig | ||||
{ | { | ||||
public Tensor X { get; set; } | |||||
public Tensor Y { get; set; } | |||||
public Tensors X { get; set; } | |||||
public Tensors Y { get; set; } | |||||
public IDatasetV2 Dataset { get; set; } | public IDatasetV2 Dataset { get; set; } | ||||
public int BatchSize { get; set; } = 32; | public int BatchSize { get; set; } = 32; | ||||
public int Steps { get; set; } | public int Steps { get; set; } | ||||
@@ -5,8 +5,8 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class DataHandlerArgs: IKerasConfig | public class DataHandlerArgs: IKerasConfig | ||||
{ | { | ||||
public Tensor X { get; set; } | |||||
public Tensor Y { get; set; } | |||||
public Tensors X { get; set; } | |||||
public Tensors Y { get; set; } | |||||
public IDatasetV2 Dataset { get; set; } | public IDatasetV2 Dataset { get; set; } | ||||
public int BatchSize { get; set; } = 32; | public int BatchSize { get; set; } = 32; | ||||
public int StepsPerEpoch { get; set; } = -1; | public int StepsPerEpoch { get; set; } = -1; | ||||
@@ -24,6 +24,17 @@ public interface IModel : ILayer | |||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false); | bool use_multiprocessing = false); | ||||
ICallback fit(IEnumerable<NDArray> x, NDArray y, | |||||
int batch_size = -1, | |||||
int epochs = 1, | |||||
int verbose = 1, | |||||
float validation_split = 0f, | |||||
bool shuffle = true, | |||||
int initial_epoch = 0, | |||||
int max_queue_size = 10, | |||||
int workers = 1, | |||||
bool use_multiprocessing = false); | |||||
void save(string filepath, | void save(string filepath, | ||||
bool overwrite = true, | bool overwrite = true, | ||||
bool include_optimizer = true, | bool include_optimizer = true, | ||||
@@ -14,7 +14,76 @@ namespace Tensorflow.NumPy | |||||
red = data[2]; | red = data[2]; | ||||
} | } | ||||
public static implicit operator NDArray(Array array) | |||||
public static implicit operator NDArray(int[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(byte[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(float[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(double[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(long[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(bool[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(uint[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(ulong[] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(int[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(byte[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(float[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(double[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(long[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(bool[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(uint[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(ulong[,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(int[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(byte[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(float[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(double[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(long[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(bool[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(uint[,,] array) | |||||
=> new NDArray(array); | |||||
public static implicit operator NDArray(ulong[,,] array) | |||||
=> new NDArray(array); | => new NDArray(array); | ||||
public unsafe static implicit operator bool(NDArray nd) | public unsafe static implicit operator bool(NDArray nd) | ||||
@@ -25,7 +25,7 @@ public class NpzDictionary | |||||
return array; | return array; | ||||
using var s = entry.Open(); | using var s = entry.Open(); | ||||
return LoadMatrix(s); | |||||
return (NDArray)LoadMatrix(s); | |||||
} | } | ||||
public Array LoadMatrix(Stream stream) | public Array LoadMatrix(Stream stream) | ||||
@@ -49,5 +49,8 @@ namespace Tensorflow.NumPy | |||||
IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
=> GetEnumerator(); | => GetEnumerator(); | ||||
public static explicit operator NDArray(Array array) | |||||
=> new NDArray(array); | |||||
} | } | ||||
} | } |
@@ -1,6 +1,9 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Contexts; | |||||
using Tensorflow.Eager; | |||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
using Tensorflow.Operations; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -220,6 +223,37 @@ namespace Tensorflow | |||||
return (results[0], results[1]); | return (results[0], results[1]); | ||||
} | } | ||||
public Tensor anonymous_iterator_v3(TF_DataType[] output_types, Shape[] output_shapes, string name = null) | |||||
{ | |||||
var ctx = tf.Context; | |||||
Dictionary<string, object> attrs = new(); | |||||
attrs["output_types"] = output_types; | |||||
attrs["output_shapes"] = output_shapes; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name) | |||||
{ | |||||
attrs = attrs | |||||
}); | |||||
return result[0]; | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
return anonymous_iterator_v3_eager_fallback(output_types, output_shapes, name, ctx); | |||||
} | |||||
} | |||||
return tf.OpDefLib._apply_op_helper("AnonymousIteratorV3", name, attrs).outputs[0]; | |||||
} | |||||
public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx) | |||||
{ | |||||
object[] attrs = new object[] { output_types, output_shapes }; | |||||
var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name); | |||||
return result[0]; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Makes a new iterator from the given `dataset` and stores it in `iterator`. | /// Makes a new iterator from the given `dataset` and stores it in `iterator`. | ||||
/// </summary> | /// </summary> | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
protected DataAdapterArgs args; | protected DataAdapterArgs args; | ||||
protected IDatasetV2 dataset; | protected IDatasetV2 dataset; | ||||
public virtual bool CanHandle(Tensor x, Tensor y = null) | |||||
public virtual bool CanHandle(Tensors x, Tensors y = null) | |||||
=> throw new NotImplementedException(); | => throw new NotImplementedException(); | ||||
public virtual IDatasetV2 GetDataset() | public virtual IDatasetV2 GetDataset() | ||||
@@ -19,12 +19,18 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public virtual int GetSize() | public virtual int GetSize() | ||||
=> throw new NotImplementedException(""); | => throw new NotImplementedException(""); | ||||
public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
public virtual (Tensors, Tensors) Expand1d(Tensors x, Tensors y) | |||||
{ | { | ||||
if (x.shape.ndim == 1) | |||||
x = array_ops.expand_dims(x, axis: -1); | |||||
if (y.shape.ndim == 1) | |||||
y = array_ops.expand_dims(y, axis: -1); | |||||
for(int i = 0; i < x.Length; i++) | |||||
{ | |||||
if (x[i].shape.ndim == 1) | |||||
x[i] = array_ops.expand_dims(x[i], axis: -1); | |||||
} | |||||
for (int i = 0; i < y.Length; i++) | |||||
{ | |||||
if (y[i].shape.ndim == 1) | |||||
y[i] = array_ops.expand_dims(y[i], axis: -1); | |||||
} | |||||
return (x, y); | return (x, y); | ||||
} | } | ||||
@@ -93,11 +93,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | ||||
{ | { | ||||
var data_iterator = new OwnedIterator(_dataset); | |||||
foreach (var epoch in range(_initial_epoch, _epochs)) | foreach (var epoch in range(_initial_epoch, _epochs)) | ||||
{ | { | ||||
if (_insufficient_data) | if (_insufficient_data) | ||||
break; | break; | ||||
using var data_iterator = new OwnedIterator(_dataset); | |||||
if (_adapter.ShouldRecreateIterator()) | |||||
{ | |||||
data_iterator = new OwnedIterator(_dataset); | |||||
} | |||||
yield return (epoch, data_iterator); | yield return (epoch, data_iterator); | ||||
} | } | ||||
// _adapter.on_epoch_end() | // _adapter.on_epoch_end() | ||||
@@ -13,10 +13,10 @@ | |||||
/// <param name="x">input features</param> | /// <param name="x">input features</param> | ||||
/// <param name="y">target labels</param> | /// <param name="y">target labels</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
bool CanHandle(Tensor x, Tensor y = null); | |||||
bool CanHandle(Tensors x, Tensors y = null); | |||||
IDatasetV2 GetDataset(); | IDatasetV2 GetDataset(); | ||||
int GetSize(); | int GetSize(); | ||||
(Tensor, Tensor) Expand1d(Tensor x, Tensor y); | |||||
(Tensors, Tensors) Expand1d(Tensors x, Tensors y); | |||||
bool ShouldRecreateIterator(); | bool ShouldRecreateIterator(); | ||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using System; | using System; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -20,7 +21,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
_process_tensorlike(); | _process_tensorlike(); | ||||
num_samples = (int)args.X.shape[0]; | |||||
num_samples = (int)args.X[0].shape[0]; | |||||
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | ||||
_batch_size = batch_size; | _batch_size = batch_size; | ||||
_size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); | _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0.0f))); | ||||
@@ -33,10 +34,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
indices_dataset = indices_dataset.flat_map(slice_batch_indices); | indices_dataset = indices_dataset.flat_map(slice_batch_indices); | ||||
var inputs = new Tensors(); | var inputs = new Tensors(); | ||||
if (args.X != null) | if (args.X != null) | ||||
inputs.Add(args.X); | |||||
inputs.AddRange(args.X); | |||||
if (args.Y != null) | if (args.Y != null) | ||||
inputs.Add(args.Y); | |||||
inputs.AddRange(args.Y); | |||||
dataset = slice_inputs(indices_dataset, inputs); | dataset = slice_inputs(indices_dataset, inputs); | ||||
dataset.FirstInputTensorCount = args.X.Length; | |||||
} | } | ||||
Tensors permutation(Tensors tensor) | Tensors permutation(Tensors tensor) | ||||
@@ -87,8 +89,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
return dataset.with_options(new DatasetOptions { }); | return dataset.with_options(new DatasetOptions { }); | ||||
} | } | ||||
public override int GetSize() | |||||
=> _size; | |||||
public override int GetSize() => _size; | |||||
public override bool ShouldRecreateIterator() => false; | |||||
void _process_tensorlike() | void _process_tensorlike() | ||||
{ | { | ||||
@@ -59,7 +59,62 @@ namespace Tensorflow.Keras.Engine | |||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
}); | }); | ||||
return FitInternal(data_handler, epochs, verbose); | |||||
return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||||
train_step_func: train_step_function); | |||||
} | |||||
public ICallback fit(IEnumerable<NDArray> x, NDArray y, | |||||
int batch_size = -1, | |||||
int epochs = 1, | |||||
int verbose = 1, | |||||
float validation_split = 0f, | |||||
bool shuffle = true, | |||||
int initial_epoch = 0, | |||||
int max_queue_size = 10, | |||||
int workers = 1, | |||||
bool use_multiprocessing = false) | |||||
{ | |||||
foreach(var tx in x) | |||||
{ | |||||
if (tx.dims[0] != y.dims[0]) | |||||
{ | |||||
throw new InvalidArgumentError( | |||||
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); | |||||
} | |||||
} | |||||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); | |||||
var train_y = y[new Slice(0, train_count)]; | |||||
var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); | |||||
var val_y = y[new Slice(train_count)]; | |||||
var data_handler = new DataHandler(new DataHandlerArgs | |||||
{ | |||||
X = new Tensors(train_x), | |||||
Y = train_y, | |||||
BatchSize = batch_size, | |||||
InitialEpoch = initial_epoch, | |||||
Epochs = epochs, | |||||
Shuffle = shuffle, | |||||
MaxQueueSize = max_queue_size, | |||||
Workers = workers, | |||||
UseMultiprocessing = use_multiprocessing, | |||||
Model = this, | |||||
StepsPerExecution = _steps_per_execution | |||||
}); | |||||
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | |||||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | |||||
{ | |||||
return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||||
train_step_func: train_step_multi_inputs_function); | |||||
} | |||||
else | |||||
{ | |||||
return FitInternal(data_handler, epochs, verbose, validation_data: null, | |||||
train_step_func: train_step_function); | |||||
} | |||||
} | } | ||||
public History fit(IDatasetV2 dataset, | public History fit(IDatasetV2 dataset, | ||||
@@ -88,10 +143,12 @@ namespace Tensorflow.Keras.Engine | |||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
}); | }); | ||||
return FitInternal(data_handler, epochs, verbose, validation_data: validation_data); | |||||
return FitInternal(data_handler, epochs, verbose, validation_data: validation_data, | |||||
train_step_func: train_step_function); | |||||
} | } | ||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null) | |||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data, | |||||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||||
{ | { | ||||
stop_training = false; | stop_training = false; | ||||
_train_counter.assign(0); | _train_counter.assign(0); | ||||
@@ -113,7 +170,7 @@ namespace Tensorflow.Keras.Engine | |||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
{ | { | ||||
callbacks.on_train_batch_begin(step); | callbacks.on_train_batch_begin(step); | ||||
logs = train_step_function(data_handler, iterator); | |||||
logs = train_step_func(data_handler, iterator); | |||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
callbacks.on_train_batch_end(end_step, logs); | callbacks.on_train_batch_end(end_step, logs); | ||||
} | } | ||||
@@ -17,12 +17,21 @@ namespace Tensorflow.Keras.Engine | |||||
return outputs; | return outputs; | ||||
} | } | ||||
Dictionary<string, float> train_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) | |||||
{ | |||||
var data = iterator.next(); | |||||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||||
return outputs; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// The logic for one training step. | /// The logic for one training step. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="data"></param> | /// <param name="data"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
Dictionary<string, float> train_step(DataHandler data_handler, Tensor x, Tensor y) | |||||
Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y) | |||||
{ | { | ||||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | (x, y) = data_handler.DataAdapter.Expand1d(x, y); | ||||
using var tape = tf.GradientTape(); | using var tape = tf.GradientTape(); | ||||
@@ -0,0 +1,30 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.UnitTest.Helpers | |||||
{ | |||||
public class RandomDataSet : DataSetBase | |||||
{ | |||||
private Shape _shape; | |||||
public RandomDataSet(Shape shape, int count) | |||||
{ | |||||
_shape = shape; | |||||
Debug.Assert(_shape.ndim == 3); | |||||
long[] dims = new long[4]; | |||||
dims[0] = count; | |||||
for (int i = 1; i < 4; i++) | |||||
{ | |||||
dims[i] = _shape[i - 1]; | |||||
} | |||||
Shape s = new Shape(dims); | |||||
Data = np.random.normal(0, 2, s); | |||||
Labels = np.random.uniform(0, 1, (count, 1)); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,69 @@ | |||||
using Microsoft.VisualStudio.TestPlatform.Utilities; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using System.Xml.Linq; | |||||
using Tensorflow.Operations; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
using Tensorflow.NumPy; | |||||
using Microsoft.VisualBasic; | |||||
using static HDF.PInvoke.H5T; | |||||
using Tensorflow.Keras.UnitTest.Helpers; | |||||
using Tensorflow.Keras.Optimizers; | |||||
namespace Tensorflow.Keras.UnitTest | |||||
{ | |||||
[TestClass] | |||||
public class MultiInputModelTest | |||||
{ | |||||
[TestMethod] | |||||
public void SimpleModel() | |||||
{ | |||||
var inputs = keras.Input((28, 28, 1)); | |||||
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); | |||||
var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1); | |||||
var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1); | |||||
var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2); | |||||
var flat1 = keras.layers.Flatten().Apply(pool2); | |||||
var inputs_2 = keras.Input((28, 28, 1)); | |||||
var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2); | |||||
var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2); | |||||
var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2); | |||||
var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2); | |||||
var flat1_2 = keras.layers.Flatten().Apply(pool2_2); | |||||
var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); | |||||
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); | |||||
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); | |||||
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2); | |||||
var output = keras.layers.Softmax(-1).Apply(dense3); | |||||
var model = keras.Model((inputs, inputs_2), output); | |||||
model.summary(); | |||||
var data_loader = new MnistModelLoader(); | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 59000, | |||||
}).Result; | |||||
var loss = keras.losses.SparseCategoricalCrossentropy(); | |||||
var optimizer = new Adam(0.001f); | |||||
model.compile(optimizer, loss, new string[] { "accuracy" }); | |||||
NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1)); | |||||
NDArray x2 = x1; | |||||
var x = new NDArray[] { x1, x2 }; | |||||
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3); | |||||
} | |||||
} | |||||
} |
@@ -13,6 +13,7 @@ using Tensorflow; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Keras.UnitTest.Helpers; | |||||
using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; | using static TensorFlowNET.Keras.UnitTest.SaveModel.SequentialModelSave; | ||||
namespace TensorFlowNET.Keras.UnitTest.SaveModel; | namespace TensorFlowNET.Keras.UnitTest.SaveModel; | ||||
@@ -6,7 +6,7 @@ using Tensorflow.Keras; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.NumPy; | |||||
using Tensorflow.Keras.UnitTest.Helpers; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
@@ -175,24 +175,4 @@ public class SequentialModelSave | |||||
// ) | // ) | ||||
#endregion | #endregion | ||||
} | } | ||||
public class RandomDataSet : DataSetBase | |||||
{ | |||||
private Shape _shape; | |||||
public RandomDataSet(Shape shape, int count) | |||||
{ | |||||
_shape = shape; | |||||
Debug.Assert(_shape.ndim == 3); | |||||
long[] dims = new long[4]; | |||||
dims[0] = count; | |||||
for (int i = 1; i < 4; i++) | |||||
{ | |||||
dims[i] = _shape[i - 1]; | |||||
} | |||||
Shape s = new Shape(dims); | |||||
Data = np.random.normal(0, 2, s); | |||||
Labels = np.random.uniform(0, 1, (count, 1)); | |||||
} | |||||
} | |||||
} | } |
@@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
Assert.AreEqual(iStep, step); | Assert.AreEqual(iStep, step); | ||||
iStep++; | iStep++; | ||||
Assert.AreEqual(value, (long)item.Item1); | |||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||
@@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
Assert.AreEqual(iStep, step); | Assert.AreEqual(iStep, step); | ||||
iStep++; | iStep++; | ||||
Assert.AreEqual(value, (long)item.Item1); | |||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
value += 2; | value += 2; | ||||
} | } | ||||
} | } | ||||
@@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
int n = 0; | int n = 0; | ||||
foreach (var (item_x, item_y) in dataset) | foreach (var (item_x, item_y) in dataset) | ||||
{ | { | ||||
print($"x:{item_x.numpy()},y:{item_y.numpy()}"); | |||||
print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); | |||||
n += 1; | n += 1; | ||||
} | } | ||||
Assert.AreEqual(5, n); | Assert.AreEqual(5, n); | ||||
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
int n = 0; | int n = 0; | ||||
foreach (var x in dataset) | foreach (var x in dataset) | ||||
{ | { | ||||
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>())); | |||||
Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray<int>())); | |||||
n += 1; | n += 1; | ||||
} | } | ||||
Assert.AreEqual(1, n); | Assert.AreEqual(1, n); | ||||
@@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset2) | foreach (var item in dataset2) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1); | |||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
value += 3; | value += 3; | ||||
} | } | ||||
@@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
var dataset3 = dataset1.shard(num_shards: 3, index: 1); | var dataset3 = dataset1.shard(num_shards: 3, index: 1); | ||||
foreach (var item in dataset3) | foreach (var item in dataset3) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1); | |||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
value += 3; | value += 3; | ||||
} | } | ||||
} | } | ||||
@@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset) | foreach (var item in dataset) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1); | |||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||
@@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset) | foreach (var item in dataset) | ||||
{ | { | ||||
Assert.AreEqual(value + 10, (long)item.Item1); | |||||
Assert.AreEqual(value + 10, (long)item.Item1[0]); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||
@@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset) | foreach (var item in dataset) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1); | |||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||