@@ -1,5 +1,5 @@ | |||||
/***************************************************************************** | /***************************************************************************** | ||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | Licensed under the Apache License, Version 2.0 (the "License"); | ||||
you may not use this file except in compliance with the License. | you may not use this file except in compliance with the License. | ||||
@@ -57,7 +57,7 @@ namespace Tensorflow | |||||
public Tensor tanh(Tensor x, string name = null) | public Tensor tanh(Tensor x, string name = null) | ||||
=> math_ops.tanh(x, name: name); | => math_ops.tanh(x, name: name); | ||||
/// <summary> | /// <summary> | ||||
/// Finds values and indices of the `k` largest entries for the last dimension. | /// Finds values and indices of the `k` largest entries for the last dimension. | ||||
/// </summary> | /// </summary> | ||||
@@ -93,6 +93,16 @@ namespace Tensorflow | |||||
bool binary_output = false) | bool binary_output = false) | ||||
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength, | => math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength, | ||||
dtype: dtype, name: name, axis: axis, binary_output: binary_output); | dtype: dtype, name: name, axis: axis, binary_output: binary_output); | ||||
public Tensor real(Tensor x, string name = null) | |||||
=> gen_ops.real(x, x.dtype.real_dtype(), name); | |||||
public Tensor imag(Tensor x, string name = null) | |||||
=> gen_ops.imag(x, x.dtype.real_dtype(), name); | |||||
public Tensor conj(Tensor x, string name = null) | |||||
=> gen_ops.conj(x, name); | |||||
public Tensor angle(Tensor x, string name = null) | |||||
=> gen_ops.angle(x, x.dtype.real_dtype(), name); | |||||
} | } | ||||
public Tensor abs(Tensor x, string name = null) | public Tensor abs(Tensor x, string name = null) | ||||
@@ -537,7 +547,7 @@ namespace Tensorflow | |||||
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null, | public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null, | ||||
bool keepdims = false, string name = null) | bool keepdims = false, string name = null) | ||||
{ | { | ||||
if(keepdims) | |||||
if (keepdims) | |||||
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name); | return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name); | ||||
else | else | ||||
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices)); | return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices)); | ||||
@@ -585,5 +595,7 @@ namespace Tensorflow | |||||
=> gen_math_ops.square(x, name: name); | => gen_math_ops.square(x, name: name); | ||||
public Tensor squared_difference(Tensor x, Tensor y, string name = null) | public Tensor squared_difference(Tensor x, Tensor y, string name = null) | ||||
=> gen_math_ops.squared_difference(x: x, y: y, name: name); | => gen_math_ops.squared_difference(x: x, y: y, name: name); | ||||
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null, | |||||
string name = null) => gen_ops.complex(real, imag, dtype, name); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,40 @@ | |||||
/***************************************************************************** | |||||
Copyright 2023 Konstantin Balashov All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public SignalApi signal { get; } = new SignalApi(); | |||||
public class SignalApi | |||||
{ | |||||
public Tensor fft(Tensor input, string name = null) | |||||
=> gen_ops.f_f_t(input, name: name); | |||||
public Tensor ifft(Tensor input, string name = null) | |||||
=> gen_ops.i_f_f_t(input, name: name); | |||||
public Tensor fft2d(Tensor input, string name = null) | |||||
=> gen_ops.f_f_t2d(input, name: name); | |||||
public Tensor ifft2d(Tensor input, string name = null) | |||||
=> gen_ops.i_f_f_t2d(input, name: name); | |||||
public Tensor fft3d(Tensor input, string name = null) | |||||
=> gen_ops.f_f_t3d(input, name: name); | |||||
public Tensor ifft3d(Tensor input, string name = null) | |||||
=> gen_ops.i_f_f_t3d(input, name: name); | |||||
} | |||||
} | |||||
} |
@@ -840,7 +840,7 @@ namespace Tensorflow.Gradients | |||||
/// <param name="x"></param> | /// <param name="x"></param> | ||||
/// <param name="y"></param> | /// <param name="y"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) | |||||
public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad) | |||||
{ | { | ||||
Tensor sx, sy; | Tensor sx, sy; | ||||
if (x.shape.IsFullyDefined && | if (x.shape.IsFullyDefined && | ||||
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -135,13 +136,35 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
Tensor x = op.inputs[0]; | Tensor x = op.inputs[0]; | ||||
Tensor y = op.inputs[1]; | Tensor y = op.inputs[1]; | ||||
var grad = grads[0]; | |||||
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype); | var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype); | ||||
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y); | |||||
return new Tensor[] | |||||
var x_grad = math_ops.scalar_mul(scale, grad) * (x - y); | |||||
if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad)) | |||||
{ | { | ||||
x_grad, | |||||
-x_grad | |||||
}; | |||||
return new Tensor[] { x_grad, -x_grad }; | |||||
} | |||||
var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad); | |||||
Debug.Assert(broadcast_info.Length == 2); | |||||
var (sx, rx, must_reduce_x) = broadcast_info[0]; | |||||
var (sy, ry, must_reduce_y) = broadcast_info[1]; | |||||
Tensor gx, gy; | |||||
if (must_reduce_x) | |||||
{ | |||||
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx); | |||||
} | |||||
else | |||||
{ | |||||
gx = x_grad; | |||||
} | |||||
if (must_reduce_y) | |||||
{ | |||||
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy); | |||||
} | |||||
else | |||||
{ | |||||
gy = -x_grad; | |||||
} | |||||
return new Tensor[] { gx, gy }; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -15,5 +15,5 @@ public interface ICallback | |||||
void on_predict_end(); | void on_predict_end(); | ||||
void on_test_begin(); | void on_test_begin(); | ||||
void on_test_batch_begin(long step); | void on_test_batch_begin(long step); | ||||
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs); | |||||
void on_test_batch_end(long end_step, Dictionary<string, float> logs); | |||||
} | } |
@@ -22,6 +22,7 @@ public interface IModel : ILayer | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
@@ -34,6 +35,7 @@ public interface IModel : ILayer | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
@@ -65,7 +67,8 @@ public interface IModel : ILayer | |||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false, | bool use_multiprocessing = false, | ||||
bool return_dict = false); | |||||
bool return_dict = false, | |||||
bool is_val = false); | |||||
Tensors predict(Tensors x, | Tensors predict(Tensors x, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
@@ -79,5 +82,5 @@ public interface IModel : ILayer | |||||
IKerasConfig get_config(); | IKerasConfig get_config(); | ||||
void set_stopTraining_true(); | |||||
bool Stop_training { get;set; } | |||||
} | } |
@@ -730,12 +730,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle") | public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
if (Tout.HasValue) | |||||
dict["Tout"] = Tout.Value; | |||||
var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout })); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -4976,15 +4971,14 @@ namespace Tensorflow.Operations | |||||
/// tf.complex(real, imag) ==&gt; [[2.25 + 4.75j], [3.25 + 5.75j]] | /// tf.complex(real, imag) ==&gt; [[2.25 + 4.75j], [3.25 + 5.75j]] | ||||
/// </code> | /// </code> | ||||
/// </remarks> | /// </remarks> | ||||
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex") | |||||
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex") | |||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["real"] = real; | |||||
dict["imag"] = imag; | |||||
if (Tout.HasValue) | |||||
dict["Tout"] = Tout.Value; | |||||
var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict); | |||||
return op.output; | |||||
TF_DataType Tin = real.GetDataType(); | |||||
if (a_Tout is null) | |||||
{ | |||||
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64; | |||||
} | |||||
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout })); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -5008,12 +5002,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs") | public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["x"] = x; | |||||
if (Tout.HasValue) | |||||
dict["Tout"] = Tout.Value; | |||||
var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout })); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -5313,10 +5302,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor conj(Tensor input, string name = "Conj") | public static Tensor conj(Tensor input, string name = "Conj") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input })); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -10489,10 +10475,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor f_f_t(Tensor input, string name = "FFT") | public static Tensor f_f_t(Tensor input, string name = "FFT") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("FFT", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("FFT", name, new ExecuteOpArgs(input)); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -10519,10 +10502,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor f_f_t2d(Tensor input, string name = "FFT2D") | public static Tensor f_f_t2d(Tensor input, string name = "FFT2D") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("FFT2D", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("FFT2D", name, new ExecuteOpArgs(input)); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -10549,10 +10529,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor f_f_t3d(Tensor input, string name = "FFT3D") | public static Tensor f_f_t3d(Tensor input, string name = "FFT3D") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("FFT3D", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("FFT3D", name, new ExecuteOpArgs(input)); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -12875,10 +12852,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor i_f_f_t(Tensor input, string name = "IFFT") | public static Tensor i_f_f_t(Tensor input, string name = "IFFT") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("IFFT", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("IFFT", name, new ExecuteOpArgs(input)); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -12905,10 +12879,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor i_f_f_t2d(Tensor input, string name = "IFFT2D") | public static Tensor i_f_f_t2d(Tensor input, string name = "IFFT2D") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("IFFT2D", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("IFFT2D", name, new ExecuteOpArgs(input)); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -12935,10 +12906,7 @@ namespace Tensorflow.Operations | |||||
/// </remarks> | /// </remarks> | ||||
public static Tensor i_f_f_t3d(Tensor input, string name = "IFFT3D") | public static Tensor i_f_f_t3d(Tensor input, string name = "IFFT3D") | ||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
var op = tf.OpDefLib._apply_op_helper("IFFT3D", name: name, keywords: dict); | |||||
return op.output; | |||||
return tf.Context.ExecuteOp("IFFT3D", name, new ExecuteOpArgs(input)); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -13325,14 +13293,12 @@ namespace Tensorflow.Operations | |||||
/// tf.imag(input) ==&gt; [4.75, 5.75] | /// tf.imag(input) ==&gt; [4.75, 5.75] | ||||
/// </code> | /// </code> | ||||
/// </remarks> | /// </remarks> | ||||
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag") | |||||
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag") | |||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
if (Tout.HasValue) | |||||
dict["Tout"] = Tout.Value; | |||||
var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict); | |||||
return op.output; | |||||
TF_DataType Tin = input.GetDataType(); | |||||
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); | |||||
// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input })); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -23863,14 +23829,12 @@ namespace Tensorflow.Operations | |||||
/// tf.real(input) ==&gt; [-2.25, 3.25] | /// tf.real(input) ==&gt; [-2.25, 3.25] | ||||
/// </code> | /// </code> | ||||
/// </remarks> | /// </remarks> | ||||
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real") | |||||
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real") | |||||
{ | { | ||||
var dict = new Dictionary<string, object>(); | |||||
dict["input"] = input; | |||||
if (Tout.HasValue) | |||||
dict["Tout"] = Tout.Value; | |||||
var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict); | |||||
return op.output; | |||||
TF_DataType Tin = input.GetDataType(); | |||||
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout })); | |||||
// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input})); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -20,6 +20,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -35,8 +36,9 @@ namespace Tensorflow | |||||
name = scope; | name = scope; | ||||
x = ops.convert_to_tensor(x, name: "x"); | x = ops.convert_to_tensor(x, name: "x"); | ||||
if (x.dtype.is_complex()) | if (x.dtype.is_complex()) | ||||
throw new NotImplementedException("math_ops.abs for dtype.is_complex"); | |||||
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name); | |||||
{ | |||||
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name); | |||||
} | |||||
return gen_math_ops._abs(x, name: name); | return gen_math_ops._abs(x, name: name); | ||||
}); | }); | ||||
} | } | ||||
@@ -69,7 +69,7 @@ public class CallbackList | |||||
{ | { | ||||
callbacks.ForEach(x => x.on_test_batch_begin(step)); | callbacks.ForEach(x => x.on_test_batch_begin(step)); | ||||
} | } | ||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||||
{ | { | ||||
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | ||||
} | } | ||||
@@ -95,7 +95,7 @@ public class EarlyStopping: ICallback | |||||
if (_wait >= _paitence && epoch > 0) | if (_wait >= _paitence && epoch > 0) | ||||
{ | { | ||||
_stopped_epoch = epoch; | _stopped_epoch = epoch; | ||||
_parameters.Model.set_stopTraining_true(); | |||||
_parameters.Model.Stop_training = true; | |||||
if (_restore_best_weights && _best_weights != null) | if (_restore_best_weights && _best_weights != null) | ||||
{ | { | ||||
if (_verbose > 0) | if (_verbose > 0) | ||||
@@ -121,7 +121,7 @@ public class EarlyStopping: ICallback | |||||
public void on_predict_end() { } | public void on_predict_end() { } | ||||
public void on_test_begin() { } | public void on_test_begin() { } | ||||
public void on_test_batch_begin(long step) { } | public void on_test_batch_begin(long step) { } | ||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { } | |||||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) { } | |||||
float get_monitor_value(Dictionary<string, float> logs) | float get_monitor_value(Dictionary<string, float> logs) | ||||
{ | { | ||||
@@ -48,7 +48,7 @@ public class History : ICallback | |||||
{ | { | ||||
history[log.Key] = new List<float>(); | history[log.Key] = new List<float>(); | ||||
} | } | ||||
history[log.Key].Add((float)log.Value); | |||||
history[log.Key].Add(log.Value); | |||||
} | } | ||||
} | } | ||||
@@ -78,7 +78,7 @@ public class History : ICallback | |||||
} | } | ||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||||
{ | { | ||||
} | } | ||||
} | } |
@@ -105,11 +105,11 @@ namespace Tensorflow.Keras.Callbacks | |||||
{ | { | ||||
_sw.Restart(); | _sw.Restart(); | ||||
} | } | ||||
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) | |||||
{ | { | ||||
_sw.Stop(); | _sw.Stop(); | ||||
var elapse = _sw.ElapsedMilliseconds; | var elapse = _sw.ElapsedMilliseconds; | ||||
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}")); | |||||
var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}")); | |||||
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | ||||
if (!Console.IsOutputRedirected) | if (!Console.IsOutputRedirected) | ||||
@@ -26,6 +26,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="workers"></param> | /// <param name="workers"></param> | ||||
/// <param name="use_multiprocessing"></param> | /// <param name="use_multiprocessing"></param> | ||||
/// <param name="return_dict"></param> | /// <param name="return_dict"></param> | ||||
/// <param name="is_val"></param> | |||||
public Dictionary<string, float> evaluate(NDArray x, NDArray y, | public Dictionary<string, float> evaluate(NDArray x, NDArray y, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
int verbose = 1, | int verbose = 1, | ||||
@@ -33,7 +34,9 @@ namespace Tensorflow.Keras.Engine | |||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false, | bool use_multiprocessing = false, | ||||
bool return_dict = false) | |||||
bool return_dict = false, | |||||
bool is_val = false | |||||
) | |||||
{ | { | ||||
if (x.dims[0] != y.dims[0]) | if (x.dims[0] != y.dims[0]) | ||||
{ | { | ||||
@@ -63,11 +66,11 @@ namespace Tensorflow.Keras.Engine | |||||
}); | }); | ||||
callbacks.on_test_begin(); | callbacks.on_test_begin(); | ||||
IEnumerable<(string, Tensor)> logs = null; | |||||
//Dictionary<string, float>? logs = null; | |||||
var logs = new Dictionary<string, float>(); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
{ | { | ||||
reset_metrics(); | reset_metrics(); | ||||
callbacks.on_epoch_begin(epoch); | |||||
// data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
@@ -75,19 +78,64 @@ namespace Tensorflow.Keras.Engine | |||||
callbacks.on_test_batch_begin(step); | callbacks.on_test_batch_begin(step); | ||||
logs = test_function(data_handler, iterator); | logs = test_function(data_handler, iterator); | ||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
callbacks.on_test_batch_end(end_step, logs); | |||||
if (is_val == false) | |||||
callbacks.on_test_batch_end(end_step, logs); | |||||
} | } | ||||
} | } | ||||
var results = new Dictionary<string, float>(); | var results = new Dictionary<string, float>(); | ||||
foreach (var log in logs) | foreach (var log in logs) | ||||
{ | { | ||||
results[log.Item1] = (float)log.Item2; | |||||
results[log.Key] = log.Value; | |||||
} | } | ||||
return results; | return results; | ||||
} | } | ||||
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1) | |||||
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false) | |||||
{ | |||||
var data_handler = new DataHandler(new DataHandlerArgs | |||||
{ | |||||
X = new Tensors(x), | |||||
Y = y, | |||||
Model = this, | |||||
StepsPerExecution = _steps_per_execution | |||||
}); | |||||
var callbacks = new CallbackList(new CallbackParams | |||||
{ | |||||
Model = this, | |||||
Verbose = verbose, | |||||
Steps = data_handler.Inferredsteps | |||||
}); | |||||
callbacks.on_test_begin(); | |||||
Dictionary<string, float> logs = null; | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
reset_metrics(); | |||||
callbacks.on_epoch_begin(epoch); | |||||
// data_handler.catch_stop_iteration(); | |||||
foreach (var step in data_handler.steps()) | |||||
{ | |||||
callbacks.on_test_batch_begin(step); | |||||
logs = test_step_multi_inputs_function(data_handler, iterator); | |||||
var end_step = step + data_handler.StepIncrement; | |||||
if (is_val == false) | |||||
callbacks.on_test_batch_end(end_step, logs); | |||||
} | |||||
} | |||||
var results = new Dictionary<string, float>(); | |||||
foreach (var log in logs) | |||||
{ | |||||
results[log.Key] = log.Value; | |||||
} | |||||
return results; | |||||
} | |||||
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false) | |||||
{ | { | ||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
@@ -104,7 +152,7 @@ namespace Tensorflow.Keras.Engine | |||||
}); | }); | ||||
callbacks.on_test_begin(); | callbacks.on_test_begin(); | ||||
IEnumerable<(string, Tensor)> logs = null; | |||||
Dictionary<string, float> logs = null; | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
{ | { | ||||
reset_metrics(); | reset_metrics(); | ||||
@@ -113,28 +161,38 @@ namespace Tensorflow.Keras.Engine | |||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
{ | { | ||||
// callbacks.on_train_batch_begin(step) | |||||
callbacks.on_test_batch_begin(step); | |||||
logs = test_function(data_handler, iterator); | logs = test_function(data_handler, iterator); | ||||
var end_step = step + data_handler.StepIncrement; | |||||
if (is_val == false) | |||||
callbacks.on_test_batch_end(end_step, logs); | |||||
} | } | ||||
} | } | ||||
var results = new Dictionary<string, float>(); | var results = new Dictionary<string, float>(); | ||||
foreach (var log in logs) | foreach (var log in logs) | ||||
{ | { | ||||
results[log.Item1] = (float)log.Item2; | |||||
results[log.Key] = log.Value; | |||||
} | } | ||||
return results; | return results; | ||||
} | } | ||||
IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator) | |||||
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var outputs = test_step(data_handler, data[0], data[1]); | var outputs = test_step(data_handler, data[0], data[1]); | ||||
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||||
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator) | |||||
{ | |||||
var data = iterator.next(); | |||||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||||
return outputs; | |||||
} | |||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||||
{ | { | ||||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | (x, y) = data_handler.DataAdapter.Expand1d(x, y); | ||||
var y_pred = Apply(x, training: false); | var y_pred = Apply(x, training: false); | ||||
@@ -142,7 +200,7 @@ namespace Tensorflow.Keras.Engine | |||||
compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
return metrics.Select(x => (x.Name, x.result())).ToList(); | |||||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="callbacks"></param> | /// <param name="callbacks"></param> | ||||
/// <param name="verbose"></param> | /// <param name="verbose"></param> | ||||
/// <param name="validation_split"></param> | /// <param name="validation_split"></param> | ||||
/// <param name="validation_data"></param> | |||||
/// <param name="shuffle"></param> | /// <param name="shuffle"></param> | ||||
public ICallback fit(NDArray x, NDArray y, | public ICallback fit(NDArray x, NDArray y, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
@@ -29,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
@@ -40,11 +42,17 @@ namespace Tensorflow.Keras.Engine | |||||
throw new InvalidArgumentError( | throw new InvalidArgumentError( | ||||
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | ||||
} | } | ||||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||||
var train_x = x[new Slice(0, train_count)]; | |||||
var train_y = y[new Slice(0, train_count)]; | |||||
var val_x = x[new Slice(train_count)]; | |||||
var val_y = y[new Slice(train_count)]; | |||||
var train_x = x; | |||||
var train_y = y; | |||||
if (validation_split != 0f && validation_data == null) | |||||
{ | |||||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||||
train_x = x[new Slice(0, train_count)]; | |||||
train_y = y[new Slice(0, train_count)]; | |||||
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||||
} | |||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
@@ -61,7 +69,7 @@ namespace Tensorflow.Keras.Engine | |||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
}); | }); | ||||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||||
train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
} | } | ||||
@@ -71,6 +79,7 @@ namespace Tensorflow.Keras.Engine | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
@@ -85,12 +94,19 @@ namespace Tensorflow.Keras.Engine | |||||
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); | $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}"); | ||||
} | } | ||||
} | } | ||||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor); | |||||
var train_y = y[new Slice(0, train_count)]; | |||||
var val_x = x.Select(x => x[new Slice(train_count)] as Tensor); | |||||
var val_y = y[new Slice(train_count)]; | |||||
var train_x = x; | |||||
var train_y = y; | |||||
if (validation_split != 0f && validation_data == null) | |||||
{ | |||||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||||
train_y = y[new Slice(0, train_count)]; | |||||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||||
var val_y = y[new Slice(train_count)]; | |||||
validation_data = (val_x, val_y); | |||||
} | |||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
@@ -110,29 +126,29 @@ namespace Tensorflow.Keras.Engine | |||||
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | if (data_handler.DataAdapter.GetDataset().structure.Length > 2 || | ||||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1) | ||||
{ | { | ||||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||||
train_step_func: train_step_multi_inputs_function); | train_step_func: train_step_multi_inputs_function); | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null, | |||||
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data, | |||||
train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
} | } | ||||
} | } | ||||
public History fit(IDatasetV2 dataset, | public History fit(IDatasetV2 dataset, | ||||
IDatasetV2 validation_data = null, | |||||
int batch_size = -1, | int batch_size = -1, | ||||
int epochs = 1, | int epochs = 1, | ||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | |||||
IDatasetV2 validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false) | bool use_multiprocessing = false) | ||||
{ | { | ||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
Dataset = dataset, | Dataset = dataset, | ||||
@@ -147,6 +163,7 @@ namespace Tensorflow.Keras.Engine | |||||
StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
}); | }); | ||||
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data, | ||||
train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
} | } | ||||
@@ -178,11 +195,13 @@ namespace Tensorflow.Keras.Engine | |||||
callbacks.on_epoch_begin(epoch); | callbacks.on_epoch_begin(epoch); | ||||
// data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
var logs = new Dictionary<string, float>(); | var logs = new Dictionary<string, float>(); | ||||
long End_step = 0; | |||||
foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
{ | { | ||||
callbacks.on_train_batch_begin(step); | callbacks.on_train_batch_begin(step); | ||||
logs = train_step_func(data_handler, iterator); | logs = train_step_func(data_handler, iterator); | ||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
End_step = end_step; | |||||
callbacks.on_train_batch_end(end_step, logs); | callbacks.on_train_batch_end(end_step, logs); | ||||
} | } | ||||
@@ -193,6 +212,123 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
logs["val_" + log.Key] = log.Value; | logs["val_" + log.Key] = log.Value; | ||||
} | } | ||||
callbacks.on_train_batch_end(End_step, logs); | |||||
} | |||||
callbacks.on_epoch_end(epoch, logs); | |||||
GC.Collect(); | |||||
GC.WaitForPendingFinalizers(); | |||||
} | |||||
return callbacks.History; | |||||
} | |||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||||
{ | |||||
stop_training = false; | |||||
_train_counter.assign(0); | |||||
var callbacks = new CallbackList(new CallbackParams | |||||
{ | |||||
Model = this, | |||||
Verbose = verbose, | |||||
Epochs = epochs, | |||||
Steps = data_handler.Inferredsteps | |||||
}); | |||||
if (callbackList != null) | |||||
{ | |||||
foreach (var callback in callbackList) | |||||
callbacks.callbacks.add(callback); | |||||
} | |||||
callbacks.on_train_begin(); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
reset_metrics(); | |||||
callbacks.on_epoch_begin(epoch); | |||||
// data_handler.catch_stop_iteration(); | |||||
var logs = new Dictionary<string, float>(); | |||||
long End_step = 0; | |||||
foreach (var step in data_handler.steps()) | |||||
{ | |||||
callbacks.on_train_batch_begin(step); | |||||
logs = train_step_func(data_handler, iterator); | |||||
var end_step = step + data_handler.StepIncrement; | |||||
End_step = end_step; | |||||
callbacks.on_train_batch_end(end_step, logs); | |||||
} | |||||
if (validation_data != null) | |||||
{ | |||||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||||
foreach (var log in val_logs) | |||||
{ | |||||
logs["val_" + log.Key] = log.Value; | |||||
} | |||||
// because after evaluate, logs add some new log which we need to print | |||||
callbacks.on_train_batch_end(End_step, logs); | |||||
} | |||||
callbacks.on_epoch_end(epoch, logs); | |||||
GC.Collect(); | |||||
GC.WaitForPendingFinalizers(); | |||||
} | |||||
return callbacks.History; | |||||
} | |||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||||
{ | |||||
stop_training = false; | |||||
_train_counter.assign(0); | |||||
var callbacks = new CallbackList(new CallbackParams | |||||
{ | |||||
Model = this, | |||||
Verbose = verbose, | |||||
Epochs = epochs, | |||||
Steps = data_handler.Inferredsteps | |||||
}); | |||||
if (callbackList != null) | |||||
{ | |||||
foreach (var callback in callbackList) | |||||
callbacks.callbacks.add(callback); | |||||
} | |||||
callbacks.on_train_begin(); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
reset_metrics(); | |||||
callbacks.on_epoch_begin(epoch); | |||||
// data_handler.catch_stop_iteration(); | |||||
var logs = new Dictionary<string, float>(); | |||||
long End_step = 0; | |||||
foreach (var step in data_handler.steps()) | |||||
{ | |||||
callbacks.on_train_batch_begin(step); | |||||
logs = train_step_func(data_handler, iterator); | |||||
var end_step = step + data_handler.StepIncrement; | |||||
End_step = end_step; | |||||
callbacks.on_train_batch_end(end_step, logs); | |||||
} | |||||
if (validation_data != null) | |||||
{ | |||||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||||
foreach (var log in val_logs) | |||||
{ | |||||
logs["val_" + log.Key] = log.Value; | |||||
callbacks.on_train_batch_end(End_step, logs); | |||||
} | |||||
} | } | ||||
callbacks.on_epoch_end(epoch, logs); | callbacks.on_epoch_end(epoch, logs); | ||||
@@ -46,6 +46,12 @@ namespace Tensorflow.Keras.Engine | |||||
set => optimizer = value; | set => optimizer = value; | ||||
} | } | ||||
public bool Stop_training | |||||
{ | |||||
get => stop_training; | |||||
set => stop_training = value; | |||||
} | |||||
public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
: base(args) | : base(args) | ||||
{ | { | ||||
@@ -58,6 +58,12 @@ namespace Tensorflow.Keras | |||||
Name = name | Name = name | ||||
}); | }); | ||||
public Sequential Sequential(params ILayer[] layers) | |||||
=> new Sequential(new SequentialArgs | |||||
{ | |||||
Layers = layers.ToList() | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// `Model` groups layers into an object with training and inference features. | /// `Model` groups layers into an object with training and inference features. | ||||
/// </summary> | /// </summary> | ||||
@@ -72,7 +72,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="HDF5-CSharp" Version="1.16.3" /> | <PackageReference Include="HDF5-CSharp" Version="1.16.3" /> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||||
<PackageReference Include="SharpZipLib" Version="1.4.2" /> | <PackageReference Include="SharpZipLib" Version="1.4.2" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -0,0 +1,202 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow.NumPy; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow; | |||||
using static Tensorflow.Binding; | |||||
using Buffer = Tensorflow.Buffer; | |||||
using TensorFlowNET.Keras.UnitTest; | |||||
namespace TensorFlowNET.UnitTest.Basics | |||||
{ | |||||
[TestClass] | |||||
public class ComplexTest : EagerModeTestBase | |||||
{ | |||||
// Tests for Complex128 | |||||
[TestMethod] | |||||
public void complex128_basic() | |||||
{ | |||||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; | |||||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; | |||||
Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag); | |||||
Tensor t_real_result = tf.math.real(t_complex); | |||||
Tensor t_imag_result = tf.math.imag(t_complex); | |||||
NDArray n_real_result = t_real_result.numpy(); | |||||
NDArray n_imag_result = t_imag_result.numpy(); | |||||
double[] d_real_result =n_real_result.ToArray<double>(); | |||||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||||
} | |||||
[TestMethod] | |||||
public void complex128_abs() | |||||
{ | |||||
tf.enable_eager_execution(); | |||||
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; | |||||
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; | |||||
double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag); | |||||
Tensor t_abs_result = tf.abs(t_complex); | |||||
double[] d_abs_result = t_abs_result.numpy().ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_abs_result, d_abs)); | |||||
} | |||||
[TestMethod] | |||||
public void complex128_conj() | |||||
{ | |||||
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 }; | |||||
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 }; | |||||
double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 }; | |||||
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); | |||||
Tensor t_result = tf.math.conj(t_complex); | |||||
NDArray n_real_result = tf.math.real(t_result).numpy(); | |||||
NDArray n_imag_result = tf.math.imag(t_result).numpy(); | |||||
double[] d_real_result = n_real_result.ToArray<double>(); | |||||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); | |||||
} | |||||
[TestMethod] | |||||
public void complex128_angle() | |||||
{ | |||||
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 }; | |||||
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 }; | |||||
double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128); | |||||
Tensor t_result = tf.math.angle(t_complex); | |||||
NDArray n_result = t_result.numpy(); | |||||
double[] d_result = n_result.ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_result, d_expected)); | |||||
} | |||||
// Tests for Complex64 | |||||
[TestMethod] | |||||
public void complex64_basic() | |||||
{ | |||||
tf.init_scope(); | |||||
float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | |||||
float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_complex = tf.complex(t_real, t_imag); | |||||
Tensor t_real_result = tf.math.real(t_complex); | |||||
Tensor t_imag_result = tf.math.imag(t_complex); | |||||
// Convert the EagerTensors to NumPy arrays directly | |||||
float[] d_real_result = t_real_result.numpy().ToArray<float>(); | |||||
float[] d_imag_result = t_imag_result.numpy().ToArray<float>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||||
} | |||||
[TestMethod] | |||||
public void complex64_abs() | |||||
{ | |||||
tf.enable_eager_execution(); | |||||
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; | |||||
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f }; | |||||
float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); | |||||
Tensor t_abs_result = tf.abs(t_complex); | |||||
NDArray n_abs_result = t_abs_result.numpy(); | |||||
float[] d_abs_result = n_abs_result.ToArray<float>(); | |||||
Assert.IsTrue(base.Equal(d_abs_result, d_abs)); | |||||
} | |||||
[TestMethod] | |||||
public void complex64_conj() | |||||
{ | |||||
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; | |||||
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f }; | |||||
float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f }; | |||||
float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); | |||||
Tensor t_result = tf.math.conj(t_complex); | |||||
NDArray n_real_result = tf.math.real(t_result).numpy(); | |||||
NDArray n_imag_result = tf.math.imag(t_result).numpy(); | |||||
float[] d_real_result = n_real_result.ToArray<float>(); | |||||
float[] d_imag_result = n_imag_result.ToArray<float>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real_expected)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected)); | |||||
} | |||||
[TestMethod] | |||||
public void complex64_angle() | |||||
{ | |||||
float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f }; | |||||
float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f }; | |||||
float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT); | |||||
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64); | |||||
Tensor t_result = tf.math.angle(t_complex); | |||||
NDArray n_result = t_result.numpy(); | |||||
float[] d_result = n_result.ToArray<float>(); | |||||
Assert.IsTrue(base.Equal(d_result, d_expected)); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,103 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow.NumPy; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Tensorflow; | |||||
using static Tensorflow.Binding; | |||||
using Buffer = Tensorflow.Buffer; | |||||
using TensorFlowNET.Keras.UnitTest; | |||||
namespace TensorFlowNET.UnitTest.Basics | |||||
{ | |||||
[TestClass] | |||||
public class SignalTest : EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void fft() | |||||
{ | |||||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; | |||||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag); | |||||
Tensor t_frequency_domain = tf.signal.fft(t_complex); | |||||
Tensor f_time_domain = tf.signal.ifft(t_frequency_domain); | |||||
Tensor t_real_result = tf.math.real(f_time_domain); | |||||
Tensor t_imag_result = tf.math.imag(f_time_domain); | |||||
NDArray n_real_result = t_real_result.numpy(); | |||||
NDArray n_imag_result = t_imag_result.numpy(); | |||||
double[] d_real_result = n_real_result.ToArray<double>(); | |||||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||||
} | |||||
[TestMethod] | |||||
public void fft2d() | |||||
{ | |||||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 }; | |||||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 }; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag); | |||||
Tensor t_complex_2d = tf.reshape(t_complex,new int[] { 2, 2 }); | |||||
Tensor t_frequency_domain_2d = tf.signal.fft2d(t_complex_2d); | |||||
Tensor t_time_domain_2d = tf.signal.ifft2d(t_frequency_domain_2d); | |||||
Tensor t_time_domain = tf.reshape(t_time_domain_2d, new int[] { 4 }); | |||||
Tensor t_real_result = tf.math.real(t_time_domain); | |||||
Tensor t_imag_result = tf.math.imag(t_time_domain); | |||||
NDArray n_real_result = t_real_result.numpy(); | |||||
NDArray n_imag_result = t_imag_result.numpy(); | |||||
double[] d_real_result = n_real_result.ToArray<double>(); | |||||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||||
} | |||||
[TestMethod] | |||||
public void fft3d() | |||||
{ | |||||
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0, -3.0, -2.0, -1.0, -4.0 }; | |||||
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0, 6.0, 4.0, 2.0, 0.0}; | |||||
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE); | |||||
Tensor t_complex = tf.complex(t_real, t_imag); | |||||
Tensor t_complex_3d = tf.reshape(t_complex, new int[] { 2, 2, 2 }); | |||||
Tensor t_frequency_domain_3d = tf.signal.fft2d(t_complex_3d); | |||||
Tensor t_time_domain_3d = tf.signal.ifft2d(t_frequency_domain_3d); | |||||
Tensor t_time_domain = tf.reshape(t_time_domain_3d, new int[] { 8 }); | |||||
Tensor t_real_result = tf.math.real(t_time_domain); | |||||
Tensor t_imag_result = tf.math.imag(t_time_domain); | |||||
NDArray n_real_result = t_real_result.numpy(); | |||||
NDArray n_imag_result = t_imag_result.numpy(); | |||||
double[] d_real_result = n_real_result.ToArray<double>(); | |||||
double[] d_imag_result = n_imag_result.ToArray<double>(); | |||||
Assert.IsTrue(base.Equal(d_real_result, d_real)); | |||||
Assert.IsTrue(base.Equal(d_imag_result, d_imag)); | |||||
} | |||||
} | |||||
} |
@@ -36,6 +36,7 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | ||||
<ProjectReference Include="..\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
</Project> | </Project> |