@@ -1,5 +1,5 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
@@ -57,7 +57,7 @@ namespace Tensorflow | |||
public Tensor tanh(Tensor x, string name = null) | |||
=> math_ops.tanh(x, name: name); | |||
/// <summary> | |||
/// Finds values and indices of the `k` largest entries for the last dimension. | |||
/// </summary> | |||
@@ -93,6 +93,16 @@ namespace Tensorflow | |||
bool binary_output = false) | |||
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength, | |||
dtype: dtype, name: name, axis: axis, binary_output: binary_output); | |||
public Tensor real(Tensor x, string name = null) | |||
=> gen_ops.real(x, x.dtype.real_dtype(), name); | |||
public Tensor imag(Tensor x, string name = null) | |||
=> gen_ops.imag(x, x.dtype.real_dtype(), name); | |||
public Tensor conj(Tensor x, string name = null) | |||
=> gen_ops.conj(x, name); | |||
public Tensor angle(Tensor x, string name = null) | |||
=> gen_ops.angle(x, x.dtype.real_dtype(), name); | |||
} | |||
public Tensor abs(Tensor x, string name = null) | |||
@@ -537,7 +547,7 @@ namespace Tensorflow | |||
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null, | |||
bool keepdims = false, string name = null) | |||
{ | |||
if(keepdims) | |||
if (keepdims) | |||
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name); | |||
else | |||
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices)); | |||
@@ -585,5 +595,7 @@ namespace Tensorflow | |||
=> gen_math_ops.square(x, name: name); | |||
public Tensor squared_difference(Tensor x, Tensor y, string name = null) | |||
=> gen_math_ops.squared_difference(x: x, y: y, name: name); | |||
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | |||
string name = null) => gen_ops.complex(real, imag, dtype, name); | |||
} | |||
} |
@@ -0,0 +1,40 @@ | |||
/***************************************************************************** | |||
Copyright 2023 Konstantin Balashov All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public SignalApi signal { get; } = new SignalApi(); | |||
public class SignalApi | |||
{ | |||
public Tensor fft(Tensor input, string name = null) | |||
=> gen_ops.f_f_t(input, name: name); | |||
public Tensor ifft(Tensor input, string name = null) | |||
=> gen_ops.i_f_f_t(input, name: name); | |||
public Tensor fft2d(Tensor input, string name = null) | |||
=> gen_ops.f_f_t2d(input, name: name); | |||
public Tensor ifft2d(Tensor input, string name = null) | |||
=> gen_ops.i_f_f_t2d(input, name: name); | |||
public Tensor fft3d(Tensor input, string name = null) | |||
=> gen_ops.f_f_t3d(input, name: name); | |||
public Tensor ifft3d(Tensor input, string name = null) | |||
=> gen_ops.i_f_f_t3d(input, name: name); | |||
} | |||
} | |||
} |
@@ -840,7 +840,7 @@ namespace Tensorflow.Gradients | |||
/// <param name="x"></param> | |||
/// <param name="y"></param> | |||
/// <returns></returns> | |||
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) | |||
public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) | |||
{ | |||
Tensor sx, sy; | |||
if (x.shape.IsFullyDefined && | |||
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
@@ -135,13 +136,35 @@ namespace Tensorflow.Gradients | |||
{ | |||
Tensor x = op.inputs[0]; | |||
Tensor y = op.inputs[1]; | |||
var grad = grads[0]; | |||
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype); | |||
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y); | |||
return new Tensor[] | |||
var x_grad = math_ops.scalar_mul(scale, grad) * (x - y); | |||
if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad)) | |||
{ | |||
x_grad, | |||
-x_grad | |||
}; | |||
return new Tensor[] { x_grad, -x_grad }; | |||
} | |||
var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad); | |||
Debug.Assert(broadcast_info.Length == 2); | |||
var (sx, rx, must_reduce_x) = broadcast_info[0]; | |||
var (sy, ry, must_reduce_y) = broadcast_info[1]; | |||
Tensor gx, gy; | |||
if (must_reduce_x) | |||
{ | |||
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx); | |||
} | |||
else | |||
{ | |||
gx = x_grad; | |||
} | |||
if (must_reduce_y) | |||
{ | |||
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy); | |||
} | |||
else | |||
{ | |||
gy = -x_grad; | |||
} | |||
return new Tensor[] { gx, gy }; | |||
} | |||
/// <summary> | |||
@@ -15,5 +15,5 @@ public interface ICallback | |||
void on_predict_end(); | |||
void on_test_begin(); | |||
void on_test_batch_begin(long step); | |||
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs); | |||
void on_test_batch_end(long end_step, Dictionary<string, float> logs); | |||
} |
@@ -22,6 +22,7 @@ public interface IModel : ILayer | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
@@ -34,6 +35,7 @@ public interface IModel : ILayer | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
@@ -65,7 +67,8 @@ public interface IModel : ILayer | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false, | |||
bool return_dict = false); | |||
bool return_dict = false, | |||
bool is_val = false); | |||
Tensors predict(Tensors x, | |||
int batch_size = -1, | |||
@@ -79,5 +82,5 @@ public interface IModel : ILayer | |||
IKerasConfig get_config(); | |||
void set_stopTraining_true(); | |||
bool Stop_training { get;set; } | |||
} |
@@ -730,12 +730,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
if (Tout.HasValue) | |||
dict["Tout"] = Tout.Value; | |||
var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout })); | |||
} | |||
/// <summary> | |||
@@ -4976,15 +4971,14 @@ namespace Tensorflow.Operations | |||
/// tf.complex(real, imag) ==&gt; [[2.25 + 4.75j], [3.25 + 5.75j]] | |||
/// </code> | |||
/// </remarks> | |||
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex") | |||
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["real"] = real; | |||
dict["imag"] = imag; | |||
if (Tout.HasValue) | |||
dict["Tout"] = Tout.Value; | |||
var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict); | |||
return op.output; | |||
TF_DataType Tin = real.GetDataType(); | |||
if (a_Tout is null) | |||
{ | |||
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64; | |||
} | |||
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout })); | |||
} | |||
/// <summary> | |||
@@ -5008,12 +5002,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["x"] = x; | |||
if (Tout.HasValue) | |||
dict["Tout"] = Tout.Value; | |||
var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout })); | |||
} | |||
/// <summary> | |||
@@ -5313,10 +5302,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor conj(Tensor input, string name = "Conj") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input })); | |||
} | |||
/// <summary> | |||
@@ -10489,10 +10475,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor f_f_t(Tensor input, string name = "FFT") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("FFT", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("FFT", name, new ExecuteOpArgs(input)); | |||
} | |||
/// <summary> | |||
@@ -10519,10 +10502,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor f_f_t2d(Tensor input, string name = "FFT2D") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("FFT2D", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("FFT2D", name, new ExecuteOpArgs(input)); | |||
} | |||
/// <summary> | |||
@@ -10549,10 +10529,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor f_f_t3d(Tensor input, string name = "FFT3D") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("FFT3D", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("FFT3D", name, new ExecuteOpArgs(input)); | |||
} | |||
/// <summary> | |||
@@ -12875,10 +12852,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor i_f_f_t(Tensor input, string name = "IFFT") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("IFFT", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("IFFT", name, new ExecuteOpArgs(input)); | |||
} | |||
/// <summary> | |||
@@ -12905,10 +12879,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor i_f_f_t2d(Tensor input, string name = "IFFT2D") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("IFFT2D", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("IFFT2D", name, new ExecuteOpArgs(input)); | |||
} | |||
/// <summary> | |||
@@ -12935,10 +12906,7 @@ namespace Tensorflow.Operations | |||
/// </remarks> | |||
public static Tensor i_f_f_t3d(Tensor input, string name = "IFFT3D") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
var op = tf.OpDefLib._apply_op_helper("IFFT3D", name: name, keywords: dict); | |||
return op.output; | |||
return tf.Context.ExecuteOp("IFFT3D", name, new ExecuteOpArgs(input)); | |||
} | |||
/// <summary> | |||
@@ -13325,14 +13293,12 @@ namespace Tensorflow.Operations | |||
/// tf.imag(input) ==&gt; [4.75, 5.75] | |||
/// </code> | |||
/// </remarks> | |||
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag") | |||
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
if (Tout.HasValue) | |||
dict["Tout"] = Tout.Value; | |||
var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict); | |||
return op.output; | |||
TF_DataType Tin = input.GetDataType(); | |||
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); | |||
// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input })); | |||
} | |||
/// <summary> | |||
@@ -23863,14 +23829,12 @@ namespace Tensorflow.Operations | |||
/// tf.real(input) ==&gt; [-2.25, 3.25] | |||
/// </code> | |||
/// </remarks> | |||
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real") | |||
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real") | |||
{ | |||
var dict = new Dictionary<string, object>(); | |||
dict["input"] = input; | |||
if (Tout.HasValue) | |||
dict["Tout"] = Tout.Value; | |||
var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict); | |||
return op.output; | |||
TF_DataType Tin = input.GetDataType(); | |||
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); | |||
// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input})); | |||
} | |||
/// <summary> | |||
@@ -20,6 +20,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
@@ -35,8 +36,9 @@ namespace Tensorflow | |||
name = scope; | |||
x = ops.convert_to_tensor(x, name: "x"); | |||
if (x.dtype.is_complex()) | |||
throw new NotImplementedException("math_ops.abs for dtype.is_complex"); | |||
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name); | |||
{ | |||
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name); | |||
} | |||
return gen_math_ops._abs(x, name: name); | |||
}); | |||
} | |||
@@ -69,7 +69,7 @@ public class CallbackList | |||
{ | |||
callbacks.ForEach(x => x.on_test_batch_begin(step)); | |||
} | |||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||
{ | |||
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | |||
} | |||
@@ -95,7 +95,7 @@ public class EarlyStopping: ICallback | |||
if (_wait >= _paitence && epoch > 0) | |||
{ | |||
_stopped_epoch = epoch; | |||
_parameters.Model.set_stopTraining_true(); | |||
_parameters.Model.Stop_training = true; | |||
if (_restore_best_weights && _best_weights != null) | |||
{ | |||
if (_verbose > 0) | |||
@@ -121,7 +121,7 @@ public class EarlyStopping: ICallback | |||
public void on_predict_end() { } | |||
public void on_test_begin() { } | |||
public void on_test_batch_begin(long step) { } | |||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } | |||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) { } | |||
float get_monitor_value(Dictionary<string, float> logs) | |||
{ | |||
@@ -48,7 +48,7 @@ public class History : ICallback | |||
{ | |||
history[log.Key] = new List<float>(); | |||
} | |||
history[log.Key].Add((float)log.Value); | |||
history[log.Key].Add(log.Value); | |||
} | |||
} | |||
@@ -78,7 +78,7 @@ public class History : ICallback | |||
} | |||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||
{ | |||
} | |||
} |
@@ -105,11 +105,11 @@ namespace Tensorflow.Keras.Callbacks | |||
{ | |||
_sw.Restart(); | |||
} | |||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||
{ | |||
_sw.Stop(); | |||
var elapse = _sw.ElapsedMilliseconds; | |||
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}")); | |||
var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}")); | |||
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | |||
if (!Console.IsOutputRedirected) | |||
@@ -26,6 +26,7 @@ namespace Tensorflow.Keras.Engine | |||
/// <param name="workers"></param> | |||
/// <param name="use_multiprocessing"></param> | |||
/// <param name="return_dict"></param> | |||
/// <param name="is_val"></param> | |||
public Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int verbose = 1, | |||
@@ -33,7 +34,9 @@ namespace Tensorflow.Keras.Engine | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false, | |||
bool return_dict = false) | |||
bool return_dict = false, | |||
bool is_val = false | |||
) | |||
{ | |||
if (x.dims[0] != y.dims[0]) | |||
{ | |||
@@ -63,11 +66,11 @@ namespace Tensorflow.Keras.Engine | |||
}); | |||
callbacks.on_test_begin(); | |||
IEnumerable<(string, Tensor)> logs = null; | |||
//Dictionary<string, float>? logs = null; | |||
var logs = new Dictionary<string, float>(); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
reset_metrics(); | |||
callbacks.on_epoch_begin(epoch); | |||
// data_handler.catch_stop_iteration(); | |||
foreach (var step in data_handler.steps()) | |||
@@ -75,19 +78,64 @@ namespace Tensorflow.Keras.Engine | |||
callbacks.on_test_batch_begin(step); | |||
logs = test_function(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
callbacks.on_test_batch_end(end_step, logs); | |||
if (is_val == false) | |||
callbacks.on_test_batch_end(end_step, logs); | |||
} | |||
} | |||
var results = new Dictionary<string, float>(); | |||
foreach (var log in logs) | |||
{ | |||
results[log.Item1] = (float)log.Item2; | |||
results[log.Key] = log.Value; | |||
} | |||
return results; | |||
} | |||
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1) | |||
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false) | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = new Tensors(x), | |||
Y = y, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
var callbacks = new CallbackList(new CallbackParams | |||
{ | |||
Model = this, | |||
Verbose = verbose, | |||
Steps = data_handler.Inferredsteps | |||
}); | |||
callbacks.on_test_begin(); | |||
Dictionary<string, float> logs = null; | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
reset_metrics(); | |||
callbacks.on_epoch_begin(epoch); | |||
// data_handler.catch_stop_iteration(); | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
callbacks.on_test_batch_begin(step); | |||
logs = test_step_multi_inputs_function(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
if (is_val == false) | |||
callbacks.on_test_batch_end(end_step, logs); | |||
} | |||
} | |||
var results = new Dictionary<string, float>(); | |||
foreach (var log in logs) | |||
{ | |||
results[log.Key] = log.Value; | |||
} | |||
return results; | |||
} | |||
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
@@ -104,7 +152,7 @@ namespace Tensorflow.Keras.Engine | |||
}); | |||
callbacks.on_test_begin(); | |||
IEnumerable<(string, Tensor)> logs = null; | |||
Dictionary<string, float> logs = null; | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
reset_metrics(); | |||
@@ -113,28 +161,38 @@ namespace Tensorflow.Keras.Engine | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
// callbacks.on_train_batch_begin(step) | |||
callbacks.on_test_batch_begin(step); | |||
logs = test_function(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
if (is_val == false) | |||
callbacks.on_test_batch_end(end_step, logs); | |||
} | |||
} | |||
var results = new Dictionary<string, float>(); | |||
foreach (var log in logs) | |||
{ | |||
results[log.Item1] = (float)log.Item2; | |||
results[log.Key] = log.Value; | |||
} | |||
return results; | |||
} | |||
IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator) | |||
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
var outputs = test_step(data_handler, data[0], data[1]); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||
{ | |||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
var y_pred = Apply(x, training: false); | |||
@@ -142,7 +200,7 @@ namespace Tensorflow.Keras.Engine | |||
compiled_metrics.update_state(y, y_pred); | |||
return metrics.Select(x => (x.Name, x.result())).ToList(); | |||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); | |||
} | |||
} | |||
} |
@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Engine | |||
/// <param name="callbacks"></param> | |||
/// <param name="verbose"></param> | |||
/// <param name="validation_split"></param> | |||
/// <param name="validation_data"></param> | |||
/// <param name="shuffle"></param> | |||
public ICallback fit(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
@@ -29,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
@@ -40,11 +42,17 @@ namespace Tensorflow.Keras.Engine | |||
throw new InvalidArgumentError( | |||
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||
} | |||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
var train_x = x[new Slice(0, train_count)]; | |||
var train_y = y[new Slice(0, train_count)]; | |||
var val_x = x[new Slice(train_count)]; | |||
var val_y = y[new Slice(train_count)]; | |||
var train_x = x; | |||
var train_y = y; | |||
if (validation_split != 0f && validation_data == null) | |||
{ | |||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
train_x = x[new Slice(0, train_count)]; | |||
train_y = y[new Slice(0, train_count)]; | |||
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||
} | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
@@ -61,7 +69,7 @@ namespace Tensorflow.Keras.Engine | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||
train_step_func: train_step_function); | |||
} | |||
@@ -71,6 +79,7 @@ namespace Tensorflow.Keras.Engine | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
@@ -85,12 +94,19 @@ namespace Tensorflow.Keras.Engine | |||
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); | |||
} | |||
} | |||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); | |||
var train_y = y[new Slice(0, train_count)]; | |||
var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); | |||
var val_y = y[new Slice(train_count)]; | |||
var train_x = x; | |||
var train_y = y; | |||
if (validation_split != 0f && validation_data == null) | |||
{ | |||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
train_y = y[new Slice(0, train_count)]; | |||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
var val_y = y[new Slice(train_count)]; | |||
validation_data = (val_x, val_y); | |||
} | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
@@ -110,29 +126,29 @@ namespace Tensorflow.Keras.Engine | |||
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | |||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | |||
{ | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||
train_step_func: train_step_multi_inputs_function); | |||
} | |||
else | |||
{ | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||
train_step_func: train_step_function); | |||
} | |||
} | |||
public History fit(IDatasetV2 dataset, | |||
IDatasetV2 validation_data = null, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
IDatasetV2 validation_data = null, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
Dataset = dataset, | |||
@@ -147,6 +163,7 @@ namespace Tensorflow.Keras.Engine | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | |||
train_step_func: train_step_function); | |||
} | |||
@@ -178,11 +195,13 @@ namespace Tensorflow.Keras.Engine | |||
callbacks.on_epoch_begin(epoch); | |||
// data_handler.catch_stop_iteration(); | |||
var logs = new Dictionary<string, float>(); | |||
long End_step = 0; | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
callbacks.on_train_batch_begin(step); | |||
logs = train_step_func(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
End_step = end_step; | |||
callbacks.on_train_batch_end(end_step, logs); | |||
} | |||
@@ -193,6 +212,123 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
logs["val_" + log.Key] = log.Value; | |||
} | |||
callbacks.on_train_batch_end(End_step, logs); | |||
} | |||
callbacks.on_epoch_end(epoch, logs); | |||
GC.Collect(); | |||
GC.WaitForPendingFinalizers(); | |||
} | |||
return callbacks.History; | |||
} | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
{ | |||
stop_training = false; | |||
_train_counter.assign(0); | |||
var callbacks = new CallbackList(new CallbackParams | |||
{ | |||
Model = this, | |||
Verbose = verbose, | |||
Epochs = epochs, | |||
Steps = data_handler.Inferredsteps | |||
}); | |||
if (callbackList != null) | |||
{ | |||
foreach (var callback in callbackList) | |||
callbacks.callbacks.add(callback); | |||
} | |||
callbacks.on_train_begin(); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
reset_metrics(); | |||
callbacks.on_epoch_begin(epoch); | |||
// data_handler.catch_stop_iteration(); | |||
var logs = new Dictionary<string, float>(); | |||
long End_step = 0; | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
callbacks.on_train_batch_begin(step); | |||
logs = train_step_func(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
End_step = end_step; | |||
callbacks.on_train_batch_end(end_step, logs); | |||
} | |||
if (validation_data != null) | |||
{ | |||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||
foreach (var log in val_logs) | |||
{ | |||
logs["val_" + log.Key] = log.Value; | |||
} | |||
// because after evaluate, logs add some new log which we need to print | |||
callbacks.on_train_batch_end(End_step, logs); | |||
} | |||
callbacks.on_epoch_end(epoch, logs); | |||
GC.Collect(); | |||
GC.WaitForPendingFinalizers(); | |||
} | |||
return callbacks.History; | |||
} | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
{ | |||
stop_training = false; | |||
_train_counter.assign(0); | |||
var callbacks = new CallbackList(new CallbackParams | |||
{ | |||
Model = this, | |||
Verbose = verbose, | |||
Epochs = epochs, | |||
Steps = data_handler.Inferredsteps | |||
}); | |||
if (callbackList != null) | |||
{ | |||
foreach (var callback in callbackList) | |||
callbacks.callbacks.add(callback); | |||
} | |||
callbacks.on_train_begin(); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
reset_metrics(); | |||
callbacks.on_epoch_begin(epoch); | |||
// data_handler.catch_stop_iteration(); | |||
var logs = new Dictionary<string, float>(); | |||
long End_step = 0; | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
callbacks.on_train_batch_begin(step); | |||
logs = train_step_func(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
End_step = end_step; | |||
callbacks.on_train_batch_end(end_step, logs); | |||
} | |||
if (validation_data != null) | |||
{ | |||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||
foreach (var log in val_logs) | |||
{ | |||
logs["val_" + log.Key] = log.Value; | |||
callbacks.on_train_batch_end(End_step, logs); | |||
} | |||
} | |||
callbacks.on_epoch_end(epoch, logs); | |||
@@ -46,6 +46,12 @@ namespace Tensorflow.Keras.Engine | |||
set => optimizer = value; | |||
} | |||
public bool Stop_training | |||
{ | |||
get => stop_training; | |||
set => stop_training = value; | |||
} | |||
public Model(ModelArgs args) | |||
: base(args) | |||
{ | |||
@@ -58,6 +58,12 @@ namespace Tensorflow.Keras | |||
Name = name | |||
}); | |||
public Sequential Sequential(params ILayer[] layers) | |||
=> new Sequential(new SequentialArgs | |||
{ | |||
Layers = layers.ToList() | |||
}); | |||
/// <summary> | |||
/// `Model` groups layers into an object with training and inference features. | |||
/// </summary> | |||
@@ -72,7 +72,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<ItemGroup> | |||
<PackageReference Include="HDF5-CSharp" Version="1.16.3" /> | |||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||
<PackageReference Include="SharpZipLib" Version="1.4.2" /> | |||
</ItemGroup> | |||
@@ -0,0 +1,202 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
using Buffer = Tensorflow.Buffer; | |||
using TensorFlowNET.Keras.UnitTest; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
[TestClass] | |||
public class ComplexTest : EagerModeTestBase | |||
{ | |||
// Tests for Complex128 | |||
[TestMethod] | |||
public void complex128_basic() | |||
{ | |||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; | |||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; | |||
Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag); | |||
Tensor t_real_result = tf.math.real(t_complex); | |||
Tensor t_imag_result = tf.math.imag(t_complex); | |||
NDArray n_real_result = t_real_result.numpy(); | |||
NDArray n_imag_result = t_imag_result.numpy(); | |||
double[] d_real_result =n_real_result.ToArray<double>(); | |||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||
} | |||
[TestMethod] | |||
public void complex128_abs() | |||
{ | |||
tf.enable_eager_execution(); | |||
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; | |||
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; | |||
double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag); | |||
Tensor t_abs_result = tf.abs(t_complex); | |||
double[] d_abs_result = t_abs_result.numpy().ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_abs_result, d_abs)); | |||
} | |||
[TestMethod] | |||
public void complex128_conj() | |||
{ | |||
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; | |||
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; | |||
double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 }; | |||
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); | |||
Tensor t_result = tf.math.conj(t_complex); | |||
NDArray n_real_result = tf.math.real(t_result).numpy(); | |||
NDArray n_imag_result = tf.math.imag(t_result).numpy(); | |||
double[] d_real_result = n_real_result.ToArray<double>(); | |||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); | |||
} | |||
[TestMethod] | |||
public void complex128_angle() | |||
{ | |||
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 }; | |||
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 }; | |||
double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); | |||
Tensor t_result = tf.math.angle(t_complex); | |||
NDArray n_result = t_result.numpy(); | |||
double[] d_result = n_result.ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_result, d_expected)); | |||
} | |||
// Tests for Complex64 | |||
[TestMethod] | |||
public void complex64_basic() | |||
{ | |||
tf.init_scope(); | |||
float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | |||
float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_complex = tf.complex(t_real, t_imag); | |||
Tensor t_real_result = tf.math.real(t_complex); | |||
Tensor t_imag_result = tf.math.imag(t_complex); | |||
// Convert the EagerTensors to NumPy arrays directly | |||
float[] d_real_result = t_real_result.numpy().ToArray<float>(); | |||
float[] d_imag_result = t_imag_result.numpy().ToArray<float>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||
} | |||
[TestMethod] | |||
public void complex64_abs() | |||
{ | |||
tf.enable_eager_execution(); | |||
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; | |||
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f }; | |||
float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); | |||
Tensor t_abs_result = tf.abs(t_complex); | |||
NDArray n_abs_result = t_abs_result.numpy(); | |||
float[] d_abs_result = n_abs_result.ToArray<float>(); | |||
Assert.IsTrue(base.Equal(d_abs_result, d_abs)); | |||
} | |||
[TestMethod] | |||
public void complex64_conj() | |||
{ | |||
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; | |||
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f }; | |||
float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; | |||
float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); | |||
Tensor t_result = tf.math.conj(t_complex); | |||
NDArray n_real_result = tf.math.real(t_result).numpy(); | |||
NDArray n_imag_result = tf.math.imag(t_result).numpy(); | |||
float[] d_real_result = n_real_result.ToArray<float>(); | |||
float[] d_imag_result = n_imag_result.ToArray<float>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); | |||
} | |||
[TestMethod] | |||
public void complex64_angle() | |||
{ | |||
float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f }; | |||
float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f }; | |||
float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); | |||
Tensor t_result = tf.math.angle(t_complex); | |||
NDArray n_result = t_result.numpy(); | |||
float[] d_result = n_result.ToArray<float>(); | |||
Assert.IsTrue(base.Equal(d_result, d_expected)); | |||
} | |||
} | |||
} |
@@ -0,0 +1,103 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
using Buffer = Tensorflow.Buffer; | |||
using TensorFlowNET.Keras.UnitTest; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
[TestClass] | |||
public class SignalTest : EagerModeTestBase | |||
{ | |||
[TestMethod] | |||
public void fft() | |||
{ | |||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; | |||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag); | |||
Tensor t_frequency_domain = tf.signal.fft(t_complex); | |||
Tensor f_time_domain = tf.signal.ifft(t_frequency_domain); | |||
Tensor t_real_result = tf.math.real(f_time_domain); | |||
Tensor t_imag_result = tf.math.imag(f_time_domain); | |||
NDArray n_real_result = t_real_result.numpy(); | |||
NDArray n_imag_result = t_imag_result.numpy(); | |||
double[] d_real_result = n_real_result.ToArray<double>(); | |||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||
} | |||
[TestMethod] | |||
public void fft2d() | |||
{ | |||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; | |||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag); | |||
Tensor t_complex_2d = tf.reshape(t_complex,new int[] { 2, 2 }); | |||
Tensor t_frequency_domain_2d = tf.signal.fft2d(t_complex_2d); | |||
Tensor t_time_domain_2d = tf.signal.ifft2d(t_frequency_domain_2d); | |||
Tensor t_time_domain = tf.reshape(t_time_domain_2d, new int[] { 4 }); | |||
Tensor t_real_result = tf.math.real(t_time_domain); | |||
Tensor t_imag_result = tf.math.imag(t_time_domain); | |||
NDArray n_real_result = t_real_result.numpy(); | |||
NDArray n_imag_result = t_imag_result.numpy(); | |||
double[] d_real_result = n_real_result.ToArray<double>(); | |||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||
} | |||
[TestMethod] | |||
public void fft3d() | |||
{ | |||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0, -3.0, -2.0, -1.0, -4.0 }; | |||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0, 6.0, 4.0, 2.0, 0.0}; | |||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||
Tensor t_complex = tf.complex(t_real, t_imag); | |||
Tensor t_complex_3d = tf.reshape(t_complex, new int[] { 2, 2, 2 }); | |||
Tensor t_frequency_domain_3d = tf.signal.fft2d(t_complex_3d); | |||
Tensor t_time_domain_3d = tf.signal.ifft2d(t_frequency_domain_3d); | |||
Tensor t_time_domain = tf.reshape(t_time_domain_3d, new int[] { 8 }); | |||
Tensor t_real_result = tf.math.real(t_time_domain); | |||
Tensor t_imag_result = tf.math.imag(t_time_domain); | |||
NDArray n_real_result = t_real_result.numpy(); | |||
NDArray n_imag_result = t_imag_result.numpy(); | |||
double[] d_real_result = n_real_result.ToArray<double>(); | |||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||
} | |||
} | |||
} |
@@ -36,6 +36,7 @@ | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
<ProjectReference Include="..\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj" /> | |||
</ItemGroup> | |||
</Project> |