diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index c7aa4670..83653c8b 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -1,5 +1,5 @@
/*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+ Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@ namespace Tensorflow
public Tensor tanh(Tensor x, string name = null)
=> math_ops.tanh(x, name: name);
-
+
///
/// Finds values and indices of the `k` largest entries for the last dimension.
///
@@ -93,6 +93,16 @@ namespace Tensorflow
bool binary_output = false)
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
dtype: dtype, name: name, axis: axis, binary_output: binary_output);
+
+ public Tensor real(Tensor x, string name = null)
+ => gen_ops.real(x, x.dtype.real_dtype(), name);
+ public Tensor imag(Tensor x, string name = null)
+ => gen_ops.imag(x, x.dtype.real_dtype(), name);
+
+ public Tensor conj(Tensor x, string name = null)
+ => gen_ops.conj(x, name);
+ public Tensor angle(Tensor x, string name = null)
+ => gen_ops.angle(x, x.dtype.real_dtype(), name);
}
public Tensor abs(Tensor x, string name = null)
@@ -537,7 +547,7 @@ namespace Tensorflow
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null,
bool keepdims = false, string name = null)
{
- if(keepdims)
+ if (keepdims)
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name);
else
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices));
@@ -585,5 +595,7 @@ namespace Tensorflow
=> gen_math_ops.square(x, name: name);
public Tensor squared_difference(Tensor x, Tensor y, string name = null)
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
+ public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
+ string name = null) => gen_ops.complex(real, imag, dtype, name);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.signal.cs b/src/TensorFlowNET.Core/APIs/tf.signal.cs
new file mode 100644
index 00000000..2471124c
--- /dev/null
+++ b/src/TensorFlowNET.Core/APIs/tf.signal.cs
@@ -0,0 +1,40 @@
+/*****************************************************************************
+ Copyright 2023 Konstantin Balashov All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using Tensorflow.Operations;
+
+namespace Tensorflow
+{
+ public partial class tensorflow
+ {
+ public SignalApi signal { get; } = new SignalApi();
+ public class SignalApi
+ {
+ public Tensor fft(Tensor input, string name = null)
+ => gen_ops.f_f_t(input, name: name);
+ public Tensor ifft(Tensor input, string name = null)
+ => gen_ops.i_f_f_t(input, name: name);
+ public Tensor fft2d(Tensor input, string name = null)
+ => gen_ops.f_f_t2d(input, name: name);
+ public Tensor ifft2d(Tensor input, string name = null)
+ => gen_ops.i_f_f_t2d(input, name: name);
+ public Tensor fft3d(Tensor input, string name = null)
+ => gen_ops.f_f_t3d(input, name: name);
+ public Tensor ifft3d(Tensor input, string name = null)
+ => gen_ops.i_f_f_t3d(input, name: name);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index 22d3c641..89699d6b 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -840,7 +840,7 @@ namespace Tensorflow.Gradients
///
///
///
- private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
+ public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
{
Tensor sx, sy;
if (x.shape.IsFullyDefined &&
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index 15b72f55..e9516393 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using System.Diagnostics;
using System.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;
@@ -135,13 +136,35 @@ namespace Tensorflow.Gradients
{
Tensor x = op.inputs[0];
Tensor y = op.inputs[1];
+ var grad = grads[0];
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
- var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
- return new Tensor[]
+ var x_grad = math_ops.scalar_mul(scale, grad) * (x - y);
+ if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad))
{
- x_grad,
- -x_grad
- };
+ return new Tensor[] { x_grad, -x_grad };
+ }
+ var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad);
+ Debug.Assert(broadcast_info.Length == 2);
+ var (sx, rx, must_reduce_x) = broadcast_info[0];
+ var (sy, ry, must_reduce_y) = broadcast_info[1];
+ Tensor gx, gy;
+ if (must_reduce_x)
+ {
+ gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx);
+ }
+ else
+ {
+ gx = x_grad;
+ }
+ if (must_reduce_y)
+ {
+ gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy);
+ }
+ else
+ {
+ gy = -x_grad;
+ }
+ return new Tensor[] { gx, gy };
}
///
diff --git a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
index 530a9368..096dbd2e 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
@@ -15,5 +15,5 @@ public interface ICallback
void on_predict_end();
void on_test_begin();
void on_test_batch_begin(long step);
- void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs);
+ void on_test_batch_end(long end_step, Dictionary logs);
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
index 3928ef5f..19f3df9b 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
@@ -22,6 +22,7 @@ public interface IModel : ILayer
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
+ (NDArray val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -34,6 +35,7 @@ public interface IModel : ILayer
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
+ (IEnumerable val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -65,7 +67,8 @@ public interface IModel : ILayer
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false,
- bool return_dict = false);
+ bool return_dict = false,
+ bool is_val = false);
Tensors predict(Tensors x,
int batch_size = -1,
@@ -79,5 +82,5 @@ public interface IModel : ILayer
IKerasConfig get_config();
- void set_stopTraining_true();
+ bool Stop_training { get;set; }
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs
index 26a9b5be..bf178b60 100644
--- a/src/TensorFlowNET.Core/Operations/gen_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs
@@ -730,12 +730,7 @@ namespace Tensorflow.Operations
///
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
{
- var dict = new Dictionary();
- dict["input"] = input;
- if (Tout.HasValue)
- dict["Tout"] = Tout.Value;
- var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout }));
}
///
@@ -4976,15 +4971,14 @@ namespace Tensorflow.Operations
/// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
///
///
- public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
+ public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex")
{
- var dict = new Dictionary();
- dict["real"] = real;
- dict["imag"] = imag;
- if (Tout.HasValue)
- dict["Tout"] = Tout.Value;
- var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict);
- return op.output;
+ TF_DataType Tin = real.GetDataType();
+ if (a_Tout is null)
+ {
+ a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64;
+ }
+ return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout }));
}
///
@@ -5008,12 +5002,7 @@ namespace Tensorflow.Operations
///
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
{
- var dict = new Dictionary();
- dict["x"] = x;
- if (Tout.HasValue)
- dict["Tout"] = Tout.Value;
- var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout }));
}
///
@@ -5313,10 +5302,7 @@ namespace Tensorflow.Operations
///
public static Tensor conj(Tensor input, string name = "Conj")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input }));
}
///
@@ -10489,10 +10475,7 @@ namespace Tensorflow.Operations
///
public static Tensor f_f_t(Tensor input, string name = "FFT")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("FFT", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("FFT", name, new ExecuteOpArgs(input));
}
///
@@ -10519,10 +10502,7 @@ namespace Tensorflow.Operations
///
public static Tensor f_f_t2d(Tensor input, string name = "FFT2D")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("FFT2D", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("FFT2D", name, new ExecuteOpArgs(input));
}
///
@@ -10549,10 +10529,7 @@ namespace Tensorflow.Operations
///
public static Tensor f_f_t3d(Tensor input, string name = "FFT3D")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("FFT3D", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("FFT3D", name, new ExecuteOpArgs(input));
}
///
@@ -12875,10 +12852,7 @@ namespace Tensorflow.Operations
///
public static Tensor i_f_f_t(Tensor input, string name = "IFFT")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("IFFT", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("IFFT", name, new ExecuteOpArgs(input));
}
///
@@ -12905,10 +12879,7 @@ namespace Tensorflow.Operations
///
public static Tensor i_f_f_t2d(Tensor input, string name = "IFFT2D")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("IFFT2D", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("IFFT2D", name, new ExecuteOpArgs(input));
}
///
@@ -12935,10 +12906,7 @@ namespace Tensorflow.Operations
///
public static Tensor i_f_f_t3d(Tensor input, string name = "IFFT3D")
{
- var dict = new Dictionary();
- dict["input"] = input;
- var op = tf.OpDefLib._apply_op_helper("IFFT3D", name: name, keywords: dict);
- return op.output;
+ return tf.Context.ExecuteOp("IFFT3D", name, new ExecuteOpArgs(input));
}
///
@@ -13325,14 +13293,12 @@ namespace Tensorflow.Operations
/// tf.imag(input) ==> [4.75, 5.75]
///
///
- public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
+ public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag")
{
- var dict = new Dictionary();
- dict["input"] = input;
- if (Tout.HasValue)
- dict["Tout"] = Tout.Value;
- var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict);
- return op.output;
+ TF_DataType Tin = input.GetDataType();
+ return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));
+
+ // return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
}
///
@@ -23863,14 +23829,12 @@ namespace Tensorflow.Operations
/// tf.real(input) ==> [-2.25, 3.25]
///
///
- public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
+ public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real")
{
- var dict = new Dictionary();
- dict["input"] = input;
- if (Tout.HasValue)
- dict["Tout"] = Tout.Value;
- var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict);
- return op.output;
+ TF_DataType Tin = input.GetDataType();
+ return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));
+
+// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
}
///
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 36f7db79..a89e7a22 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -20,6 +20,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -35,8 +36,9 @@ namespace Tensorflow
name = scope;
x = ops.convert_to_tensor(x, name: "x");
if (x.dtype.is_complex())
- throw new NotImplementedException("math_ops.abs for dtype.is_complex");
- //return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name);
+ {
+ return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name);
+ }
return gen_math_ops._abs(x, name: name);
});
}
diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
index a2847798..362f2280 100644
--- a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
+++ b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
@@ -69,7 +69,7 @@ public class CallbackList
{
callbacks.ForEach(x => x.on_test_batch_begin(step));
}
- public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
+ public void on_test_batch_end(long end_step, Dictionary logs)
{
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
}
diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
index 1e0418dc..0aa5006c 100644
--- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
+++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
@@ -95,7 +95,7 @@ public class EarlyStopping: ICallback
if (_wait >= _paitence && epoch > 0)
{
_stopped_epoch = epoch;
- _parameters.Model.set_stopTraining_true();
+ _parameters.Model.Stop_training = true;
if (_restore_best_weights && _best_weights != null)
{
if (_verbose > 0)
@@ -121,7 +121,7 @@ public class EarlyStopping: ICallback
public void on_predict_end() { }
public void on_test_begin() { }
public void on_test_batch_begin(long step) { }
- public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }
+ public void on_test_batch_end(long end_step, Dictionary logs) { }
float get_monitor_value(Dictionary logs)
{
diff --git a/src/TensorFlowNET.Keras/Callbacks/History.cs b/src/TensorFlowNET.Keras/Callbacks/History.cs
index d6113261..c34f253d 100644
--- a/src/TensorFlowNET.Keras/Callbacks/History.cs
+++ b/src/TensorFlowNET.Keras/Callbacks/History.cs
@@ -48,7 +48,7 @@ public class History : ICallback
{
history[log.Key] = new List();
}
- history[log.Key].Add((float)log.Value);
+ history[log.Key].Add(log.Value);
}
}
@@ -78,7 +78,7 @@ public class History : ICallback
}
- public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
+ public void on_test_batch_end(long end_step, Dictionary logs)
{
}
}
diff --git a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
index d22c779f..9f2b1eb3 100644
--- a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
+++ b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
@@ -105,11 +105,11 @@ namespace Tensorflow.Keras.Callbacks
{
_sw.Restart();
}
- public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
+ public void on_test_batch_end(long end_step, Dictionary logs)
{
_sw.Stop();
var elapse = _sw.ElapsedMilliseconds;
- var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));
+ var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}"));
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
if (!Console.IsOutputRedirected)
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
index a4b59439..185de4f4 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
@@ -26,6 +26,7 @@ namespace Tensorflow.Keras.Engine
///
///
///
+ ///
public Dictionary evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
@@ -33,7 +34,9 @@ namespace Tensorflow.Keras.Engine
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false,
- bool return_dict = false)
+ bool return_dict = false,
+ bool is_val = false
+ )
{
if (x.dims[0] != y.dims[0])
{
@@ -63,11 +66,11 @@ namespace Tensorflow.Keras.Engine
});
callbacks.on_test_begin();
- IEnumerable<(string, Tensor)> logs = null;
+ //Dictionary? logs = null;
+ var logs = new Dictionary();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
- callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
foreach (var step in data_handler.steps())
@@ -75,19 +78,64 @@ namespace Tensorflow.Keras.Engine
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
- callbacks.on_test_batch_end(end_step, logs);
+ if (is_val == false)
+ callbacks.on_test_batch_end(end_step, logs);
}
}
var results = new Dictionary();
foreach (var log in logs)
{
- results[log.Item1] = (float)log.Item2;
+ results[log.Key] = log.Value;
}
return results;
}
- public Dictionary evaluate(IDatasetV2 x, int verbose = 1)
+ public Dictionary evaluate(IEnumerable x, NDArray y, int verbose = 1, bool is_val = false)
+ {
+ var data_handler = new DataHandler(new DataHandlerArgs
+ {
+ X = new Tensors(x),
+ Y = y,
+ Model = this,
+ StepsPerExecution = _steps_per_execution
+ });
+
+ var callbacks = new CallbackList(new CallbackParams
+ {
+ Model = this,
+ Verbose = verbose,
+ Steps = data_handler.Inferredsteps
+ });
+ callbacks.on_test_begin();
+
+ Dictionary logs = null;
+ foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
+ {
+ reset_metrics();
+ callbacks.on_epoch_begin(epoch);
+ // data_handler.catch_stop_iteration();
+
+ foreach (var step in data_handler.steps())
+ {
+ callbacks.on_test_batch_begin(step);
+ logs = test_step_multi_inputs_function(data_handler, iterator);
+ var end_step = step + data_handler.StepIncrement;
+ if (is_val == false)
+ callbacks.on_test_batch_end(end_step, logs);
+ }
+ }
+
+ var results = new Dictionary();
+ foreach (var log in logs)
+ {
+ results[log.Key] = log.Value;
+ }
+ return results;
+ }
+
+
+ public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -104,7 +152,7 @@ namespace Tensorflow.Keras.Engine
});
callbacks.on_test_begin();
- IEnumerable<(string, Tensor)> logs = null;
+ Dictionary logs = null;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
@@ -113,28 +161,38 @@ namespace Tensorflow.Keras.Engine
foreach (var step in data_handler.steps())
{
- // callbacks.on_train_batch_begin(step)
+ callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
+ var end_step = step + data_handler.StepIncrement;
+ if (is_val == false)
+ callbacks.on_test_batch_end(end_step, logs);
}
}
var results = new Dictionary();
foreach (var log in logs)
{
- results[log.Item1] = (float)log.Item2;
+ results[log.Key] = log.Value;
}
return results;
}
- IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator)
+ Dictionary test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
-
- List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y)
+ Dictionary test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
+ {
+ var data = iterator.next();
+ var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
+ var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
+ tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
+ return outputs;
+ }
+ Dictionary test_step(DataHandler data_handler, Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
@@ -142,7 +200,7 @@ namespace Tensorflow.Keras.Engine
compiled_metrics.update_state(y, y_pred);
- return metrics.Select(x => (x.Name, x.result())).ToList();
+ return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
index 7ad4d3ef..bb8e18cc 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Engine
///
///
///
+ ///
///
public ICallback fit(NDArray x, NDArray y,
int batch_size = -1,
@@ -29,6 +30,7 @@ namespace Tensorflow.Keras.Engine
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
+ (NDArray val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -40,11 +42,17 @@ namespace Tensorflow.Keras.Engine
throw new InvalidArgumentError(
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}
- int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
- var train_x = x[new Slice(0, train_count)];
- var train_y = y[new Slice(0, train_count)];
- var val_x = x[new Slice(train_count)];
- var val_y = y[new Slice(train_count)];
+
+ var train_x = x;
+ var train_y = y;
+
+ if (validation_split != 0f && validation_data == null)
+ {
+ int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
+ train_x = x[new Slice(0, train_count)];
+ train_y = y[new Slice(0, train_count)];
+ validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]);
+ }
var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -61,7 +69,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});
- return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
+ return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}
@@ -71,6 +79,7 @@ namespace Tensorflow.Keras.Engine
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
+ (IEnumerable val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -85,12 +94,19 @@ namespace Tensorflow.Keras.Engine
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
}
}
- int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
-
- var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor);
- var train_y = y[new Slice(0, train_count)];
- var val_x = x.Select(x => x[new Slice(train_count)] as Tensor);
- var val_y = y[new Slice(train_count)];
+
+ var train_x = x;
+ var train_y = y;
+ if (validation_split != 0f && validation_data == null)
+ {
+ int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
+ train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
+ train_y = y[new Slice(0, train_count)];
+ var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
+ var val_y = y[new Slice(train_count)];
+ validation_data = (val_x, val_y);
+ }
+
var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -110,29 +126,29 @@ namespace Tensorflow.Keras.Engine
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
- return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
+ return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
train_step_func: train_step_multi_inputs_function);
}
else
{
- return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
+ return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}
}
public History fit(IDatasetV2 dataset,
- IDatasetV2 validation_data = null,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List callbacks = null,
- float validation_split = 0f,
+ IDatasetV2 validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
+
var data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
@@ -147,6 +163,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});
+
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}
@@ -178,11 +195,13 @@ namespace Tensorflow.Keras.Engine
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
var logs = new Dictionary();
+ long End_step = 0;
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
+ End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
}
@@ -193,6 +212,123 @@ namespace Tensorflow.Keras.Engine
{
logs["val_" + log.Key] = log.Value;
}
+ callbacks.on_train_batch_end(End_step, logs);
+ }
+
+
+ callbacks.on_epoch_end(epoch, logs);
+
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+ }
+
+ return callbacks.History;
+ }
+
+ History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (NDArray, NDArray)? validation_data,
+ Func> train_step_func)
+ {
+ stop_training = false;
+ _train_counter.assign(0);
+ var callbacks = new CallbackList(new CallbackParams
+ {
+ Model = this,
+ Verbose = verbose,
+ Epochs = epochs,
+ Steps = data_handler.Inferredsteps
+ });
+
+ if (callbackList != null)
+ {
+ foreach (var callback in callbackList)
+ callbacks.callbacks.add(callback);
+ }
+
+ callbacks.on_train_begin();
+
+ foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
+ {
+ reset_metrics();
+ callbacks.on_epoch_begin(epoch);
+ // data_handler.catch_stop_iteration();
+ var logs = new Dictionary();
+ long End_step = 0;
+ foreach (var step in data_handler.steps())
+ {
+ callbacks.on_train_batch_begin(step);
+ logs = train_step_func(data_handler, iterator);
+ var end_step = step + data_handler.StepIncrement;
+ End_step = end_step;
+ callbacks.on_train_batch_end(end_step, logs);
+ }
+
+ if (validation_data != null)
+ {
+ // Because evaluate calls call_test_batch_end, this interferes with our output on the screen
+ // so we need to pass a is_val parameter to stop on_test_batch_end
+ var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
+ foreach (var log in val_logs)
+ {
+ logs["val_" + log.Key] = log.Value;
+ }
+ // because after evaluate, logs add some new log which we need to print
+ callbacks.on_train_batch_end(End_step, logs);
+ }
+
+ callbacks.on_epoch_end(epoch, logs);
+
+ GC.Collect();
+ GC.WaitForPendingFinalizers();
+ }
+
+ return callbacks.History;
+ }
+
+ History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (IEnumerable, NDArray)? validation_data,
+ Func> train_step_func)
+ {
+ stop_training = false;
+ _train_counter.assign(0);
+ var callbacks = new CallbackList(new CallbackParams
+ {
+ Model = this,
+ Verbose = verbose,
+ Epochs = epochs,
+ Steps = data_handler.Inferredsteps
+ });
+
+ if (callbackList != null)
+ {
+ foreach (var callback in callbackList)
+ callbacks.callbacks.add(callback);
+ }
+
+ callbacks.on_train_begin();
+
+ foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
+ {
+ reset_metrics();
+ callbacks.on_epoch_begin(epoch);
+ // data_handler.catch_stop_iteration();
+ var logs = new Dictionary();
+ long End_step = 0;
+ foreach (var step in data_handler.steps())
+ {
+ callbacks.on_train_batch_begin(step);
+ logs = train_step_func(data_handler, iterator);
+ var end_step = step + data_handler.StepIncrement;
+ End_step = end_step;
+ callbacks.on_train_batch_end(end_step, logs);
+ }
+
+ if (validation_data != null)
+ {
+ var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2);
+ foreach (var log in val_logs)
+ {
+ logs["val_" + log.Key] = log.Value;
+ callbacks.on_train_batch_end(End_step, logs);
+ }
}
callbacks.on_epoch_end(epoch, logs);
diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs
index a3676007..1d9e9f06 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.cs
@@ -46,6 +46,12 @@ namespace Tensorflow.Keras.Engine
set => optimizer = value;
}
+ public bool Stop_training
+ {
+ get => stop_training;
+ set => stop_training = value;
+ }
+
public Model(ModelArgs args)
: base(args)
{
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index f7980706..7c6a692e 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -58,6 +58,12 @@ namespace Tensorflow.Keras
Name = name
});
+ public Sequential Sequential(params ILayer[] layers)
+ => new Sequential(new SequentialArgs
+ {
+ Layers = layers.ToList()
+ });
+
///
/// `Model` groups layers into an object with training and inference features.
///
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 1bbb3442..adb7be0c 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -72,7 +72,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
-
diff --git a/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
new file mode 100644
index 00000000..a57ec929
--- /dev/null
+++ b/test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
@@ -0,0 +1,202 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow.NumPy;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Tensorflow;
+using static Tensorflow.Binding;
+using Buffer = Tensorflow.Buffer;
+using TensorFlowNET.Keras.UnitTest;
+
+namespace TensorFlowNET.UnitTest.Basics
+{
+ [TestClass]
+ public class ComplexTest : EagerModeTestBase
+ {
+ // Tests for Complex128
+
+ [TestMethod]
+ public void complex128_basic()
+ {
+ double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
+ double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };
+
+ Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag);
+
+ Tensor t_real_result = tf.math.real(t_complex);
+ Tensor t_imag_result = tf.math.imag(t_complex);
+
+ NDArray n_real_result = t_real_result.numpy();
+ NDArray n_imag_result = t_imag_result.numpy();
+
+ double[] d_real_result =n_real_result.ToArray();
+ double[] d_imag_result = n_imag_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag));
+ }
+ [TestMethod]
+ public void complex128_abs()
+ {
+ tf.enable_eager_execution();
+
+ double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
+ double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };
+
+ double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag);
+
+ Tensor t_abs_result = tf.abs(t_complex);
+
+ double[] d_abs_result = t_abs_result.numpy().ToArray();
+ Assert.IsTrue(base.Equal(d_abs_result, d_abs));
+ }
+ [TestMethod]
+ public void complex128_conj()
+ {
+ double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
+ double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };
+
+ double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 };
+ double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
+
+ Tensor t_result = tf.math.conj(t_complex);
+
+ NDArray n_real_result = tf.math.real(t_result).numpy();
+ NDArray n_imag_result = tf.math.imag(t_result).numpy();
+
+ double[] d_real_result = n_real_result.ToArray();
+ double[] d_imag_result = n_imag_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
+ }
+ [TestMethod]
+ public void complex128_angle()
+ {
+ double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 };
+ double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 };
+
+ double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
+
+ Tensor t_result = tf.math.angle(t_complex);
+
+ NDArray n_result = t_result.numpy();
+
+ double[] d_result = n_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_result, d_expected));
+ }
+
+ // Tests for Complex64
+ [TestMethod]
+ public void complex64_basic()
+ {
+ tf.init_scope();
+ float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
+ float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
+
+ Tensor t_complex = tf.complex(t_real, t_imag);
+
+ Tensor t_real_result = tf.math.real(t_complex);
+ Tensor t_imag_result = tf.math.imag(t_complex);
+
+ // Convert the EagerTensors to NumPy arrays directly
+ float[] d_real_result = t_real_result.numpy().ToArray();
+ float[] d_imag_result = t_imag_result.numpy().ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag));
+ }
+ [TestMethod]
+ public void complex64_abs()
+ {
+ tf.enable_eager_execution();
+
+ float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
+ float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };
+
+ float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
+
+ Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
+
+ Tensor t_abs_result = tf.abs(t_complex);
+
+ NDArray n_abs_result = t_abs_result.numpy();
+
+ float[] d_abs_result = n_abs_result.ToArray();
+ Assert.IsTrue(base.Equal(d_abs_result, d_abs));
+
+ }
+ [TestMethod]
+ public void complex64_conj()
+ {
+ float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
+ float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };
+
+ float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
+ float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
+
+ Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
+
+ Tensor t_result = tf.math.conj(t_complex);
+
+ NDArray n_real_result = tf.math.real(t_result).numpy();
+ NDArray n_imag_result = tf.math.imag(t_result).numpy();
+
+ float[] d_real_result = n_real_result.ToArray();
+ float[] d_imag_result = n_imag_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
+
+ }
+ [TestMethod]
+ public void complex64_angle()
+ {
+ float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f };
+ float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f };
+
+ float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
+
+ Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
+
+ Tensor t_result = tf.math.angle(t_complex);
+
+ NDArray n_result = t_result.numpy();
+
+ float[] d_result = n_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_result, d_expected));
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs
new file mode 100644
index 00000000..01014a10
--- /dev/null
+++ b/test/TensorFlowNET.Graph.UnitTest/SignalTest.cs
@@ -0,0 +1,103 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow.NumPy;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Tensorflow;
+using static Tensorflow.Binding;
+using Buffer = Tensorflow.Buffer;
+using TensorFlowNET.Keras.UnitTest;
+
+namespace TensorFlowNET.UnitTest.Basics
+{
+ [TestClass]
+ public class SignalTest : EagerModeTestBase
+ {
+ [TestMethod]
+ public void fft()
+ {
+ double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
+ double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag);
+
+ Tensor t_frequency_domain = tf.signal.fft(t_complex);
+ Tensor f_time_domain = tf.signal.ifft(t_frequency_domain);
+
+ Tensor t_real_result = tf.math.real(f_time_domain);
+ Tensor t_imag_result = tf.math.imag(f_time_domain);
+
+ NDArray n_real_result = t_real_result.numpy();
+ NDArray n_imag_result = t_imag_result.numpy();
+
+ double[] d_real_result = n_real_result.ToArray();
+ double[] d_imag_result = n_imag_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag));
+ }
+ [TestMethod]
+ public void fft2d()
+ {
+ double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
+ double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag);
+
+ Tensor t_complex_2d = tf.reshape(t_complex,new int[] { 2, 2 });
+
+ Tensor t_frequency_domain_2d = tf.signal.fft2d(t_complex_2d);
+ Tensor t_time_domain_2d = tf.signal.ifft2d(t_frequency_domain_2d);
+
+ Tensor t_time_domain = tf.reshape(t_time_domain_2d, new int[] { 4 });
+
+ Tensor t_real_result = tf.math.real(t_time_domain);
+ Tensor t_imag_result = tf.math.imag(t_time_domain);
+
+ NDArray n_real_result = t_real_result.numpy();
+ NDArray n_imag_result = t_imag_result.numpy();
+
+ double[] d_real_result = n_real_result.ToArray();
+ double[] d_imag_result = n_imag_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag));
+ }
+ [TestMethod]
+ public void fft3d()
+ {
+ double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0, -3.0, -2.0, -1.0, -4.0 };
+ double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0, 6.0, 4.0, 2.0, 0.0};
+
+ Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
+ Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
+
+ Tensor t_complex = tf.complex(t_real, t_imag);
+
+ Tensor t_complex_3d = tf.reshape(t_complex, new int[] { 2, 2, 2 });
+
+ Tensor t_frequency_domain_3d = tf.signal.fft2d(t_complex_3d);
+ Tensor t_time_domain_3d = tf.signal.ifft2d(t_frequency_domain_3d);
+
+ Tensor t_time_domain = tf.reshape(t_time_domain_3d, new int[] { 8 });
+
+ Tensor t_real_result = tf.math.real(t_time_domain);
+ Tensor t_imag_result = tf.math.imag(t_time_domain);
+
+ NDArray n_real_result = t_real_result.numpy();
+ NDArray n_imag_result = t_imag_result.numpy();
+
+ double[] d_real_result = n_real_result.ToArray();
+ double[] d_imag_result = n_imag_result.ToArray();
+
+ Assert.IsTrue(base.Equal(d_real_result, d_real));
+ Assert.IsTrue(base.Equal(d_imag_result, d_imag));
+ }
+ }
+}
\ No newline at end of file
diff --git a/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj
index 7f6f3c67..6762e603 100644
--- a/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj
+++ b/test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj
@@ -36,6 +36,7 @@
+