Browse Source

Merge branch 'master' into support_function_load

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
4c1878bb62
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
20 changed files with 668 additions and 113 deletions
  1. +15
    -3
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +40
    -0
      src/TensorFlowNET.Core/APIs/tf.signal.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  4. +28
    -5
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
  6. +5
    -2
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  7. +26
    -62
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  8. +4
    -2
      src/TensorFlowNET.Core/Operations/math_ops.cs
  9. +1
    -1
      src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
  10. +2
    -2
      src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
  11. +2
    -2
      src/TensorFlowNET.Keras/Callbacks/History.cs
  12. +2
    -2
      src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
  13. +71
    -13
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  14. +152
    -16
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  15. +6
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  16. +6
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  17. +0
    -1
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  18. +202
    -0
      test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs
  19. +103
    -0
      test/TensorFlowNET.Graph.UnitTest/SignalTest.cs
  20. +1
    -0
      test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj

+ 15
- 3
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -1,5 +1,5 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@ namespace Tensorflow

public Tensor tanh(Tensor x, string name = null)
=> math_ops.tanh(x, name: name);
/// <summary>
/// Finds values and indices of the `k` largest entries for the last dimension.
/// </summary>
@@ -93,6 +93,16 @@ namespace Tensorflow
bool binary_output = false)
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
dtype: dtype, name: name, axis: axis, binary_output: binary_output);

public Tensor real(Tensor x, string name = null)
=> gen_ops.real(x, x.dtype.real_dtype(), name);
public Tensor imag(Tensor x, string name = null)
=> gen_ops.imag(x, x.dtype.real_dtype(), name);

public Tensor conj(Tensor x, string name = null)
=> gen_ops.conj(x, name);
public Tensor angle(Tensor x, string name = null)
=> gen_ops.angle(x, x.dtype.real_dtype(), name);
}

public Tensor abs(Tensor x, string name = null)
@@ -537,7 +547,7 @@ namespace Tensorflow
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null,
bool keepdims = false, string name = null)
{
if(keepdims)
if (keepdims)
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name);
else
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices));
@@ -585,5 +595,7 @@ namespace Tensorflow
=> gen_math_ops.square(x, name: name);
public Tensor squared_difference(Tensor x, Tensor y, string name = null)
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
string name = null) => gen_ops.complex(real, imag, dtype, name);
}
}

+ 40
- 0
src/TensorFlowNET.Core/APIs/tf.signal.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2023 Konstantin Balashov All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using Tensorflow.Operations;

namespace Tensorflow
{
public partial class tensorflow
{
public SignalApi signal { get; } = new SignalApi();
public class SignalApi
{
public Tensor fft(Tensor input, string name = null)
=> gen_ops.f_f_t(input, name: name);
public Tensor ifft(Tensor input, string name = null)
=> gen_ops.i_f_f_t(input, name: name);
public Tensor fft2d(Tensor input, string name = null)
=> gen_ops.f_f_t2d(input, name: name);
public Tensor ifft2d(Tensor input, string name = null)
=> gen_ops.i_f_f_t2d(input, name: name);
public Tensor fft3d(Tensor input, string name = null)
=> gen_ops.f_f_t3d(input, name: name);
public Tensor ifft3d(Tensor input, string name = null)
=> gen_ops.i_f_f_t3d(input, name: name);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -840,7 +840,7 @@ namespace Tensorflow.Gradients
/// <param name="x"></param>
/// <param name="y"></param>
/// <returns></returns>
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
{
Tensor sx, sy;
if (x.shape.IsFullyDefined &&


+ 28
- 5
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;
@@ -135,13 +136,35 @@ namespace Tensorflow.Gradients
{
Tensor x = op.inputs[0];
Tensor y = op.inputs[1];
var grad = grads[0];
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
return new Tensor[]
var x_grad = math_ops.scalar_mul(scale, grad) * (x - y);
if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad))
{
x_grad,
-x_grad
};
return new Tensor[] { x_grad, -x_grad };
}
var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad);
Debug.Assert(broadcast_info.Length == 2);
var (sx, rx, must_reduce_x) = broadcast_info[0];
var (sy, ry, must_reduce_y) = broadcast_info[1];
Tensor gx, gy;
if (must_reduce_x)
{
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx);
}
else
{
gx = x_grad;
}
if (must_reduce_y)
{
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy);
}
else
{
gy = -x_grad;
}
return new Tensor[] { gx, gy };
}

/// <summary>


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/ICallback.cs View File

@@ -15,5 +15,5 @@ public interface ICallback
void on_predict_end();
void on_test_begin();
void on_test_batch_begin(long step);
void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs);
void on_test_batch_end(long end_step, Dictionary<string, float> logs);
}

+ 5
- 2
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -22,6 +22,7 @@ public interface IModel : ILayer
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(NDArray val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -34,6 +35,7 @@ public interface IModel : ILayer
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -65,7 +67,8 @@ public interface IModel : ILayer
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false,
bool return_dict = false);
bool return_dict = false,
bool is_val = false);

Tensors predict(Tensors x,
int batch_size = -1,
@@ -79,5 +82,5 @@ public interface IModel : ILayer

IKerasConfig get_config();

void set_stopTraining_true();
bool Stop_training { get;set; }
}

+ 26
- 62
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -730,12 +730,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout }));
}

/// <summary>
@@ -4976,15 +4971,14 @@ namespace Tensorflow.Operations
/// tf.complex(real, imag) ==&amp;gt; [[2.25 + 4.75j], [3.25 + 5.75j]]
/// </code>
/// </remarks>
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex")
{
var dict = new Dictionary<string, object>();
dict["real"] = real;
dict["imag"] = imag;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict);
return op.output;
TF_DataType Tin = real.GetDataType();
if (a_Tout is null)
{
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64;
}
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout }));
}

/// <summary>
@@ -5008,12 +5002,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
{
var dict = new Dictionary<string, object>();
dict["x"] = x;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout }));
}

/// <summary>
@@ -5313,10 +5302,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor conj(Tensor input, string name = "Conj")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input }));
}

/// <summary>
@@ -10489,10 +10475,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor f_f_t(Tensor input, string name = "FFT")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("FFT", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("FFT", name, new ExecuteOpArgs(input));
}

/// <summary>
@@ -10519,10 +10502,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor f_f_t2d(Tensor input, string name = "FFT2D")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("FFT2D", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("FFT2D", name, new ExecuteOpArgs(input));
}

/// <summary>
@@ -10549,10 +10529,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor f_f_t3d(Tensor input, string name = "FFT3D")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("FFT3D", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("FFT3D", name, new ExecuteOpArgs(input));
}

/// <summary>
@@ -12875,10 +12852,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor i_f_f_t(Tensor input, string name = "IFFT")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("IFFT", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("IFFT", name, new ExecuteOpArgs(input));
}

/// <summary>
@@ -12905,10 +12879,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor i_f_f_t2d(Tensor input, string name = "IFFT2D")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("IFFT2D", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("IFFT2D", name, new ExecuteOpArgs(input));
}

/// <summary>
@@ -12935,10 +12906,7 @@ namespace Tensorflow.Operations
/// </remarks>
public static Tensor i_f_f_t3d(Tensor input, string name = "IFFT3D")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
var op = tf.OpDefLib._apply_op_helper("IFFT3D", name: name, keywords: dict);
return op.output;
return tf.Context.ExecuteOp("IFFT3D", name, new ExecuteOpArgs(input));
}

/// <summary>
@@ -13325,14 +13293,12 @@ namespace Tensorflow.Operations
/// tf.imag(input) ==&amp;gt; [4.75, 5.75]
/// </code>
/// </remarks>
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict);
return op.output;
TF_DataType Tin = input.GetDataType();
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));

// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
}

/// <summary>
@@ -23863,14 +23829,12 @@ namespace Tensorflow.Operations
/// tf.real(input) ==&amp;gt; [-2.25, 3.25]
/// </code>
/// </remarks>
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real")
{
var dict = new Dictionary<string, object>();
dict["input"] = input;
if (Tout.HasValue)
dict["Tout"] = Tout.Value;
var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict);
return op.output;
TF_DataType Tin = input.GetDataType();
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));

// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
}

/// <summary>


+ 4
- 2
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -20,6 +20,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -35,8 +36,9 @@ namespace Tensorflow
name = scope;
x = ops.convert_to_tensor(x, name: "x");
if (x.dtype.is_complex())
throw new NotImplementedException("math_ops.abs for dtype.is_complex");
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name);
{
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name);
}
return gen_math_ops._abs(x, name: name);
});
}


+ 1
- 1
src/TensorFlowNET.Keras/Callbacks/CallbackList.cs View File

@@ -69,7 +69,7 @@ public class CallbackList
{
callbacks.ForEach(x => x.on_test_batch_begin(step));
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
{
callbacks.ForEach(x => x.on_test_batch_end(end_step, logs));
}


+ 2
- 2
src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs View File

@@ -95,7 +95,7 @@ public class EarlyStopping: ICallback
if (_wait >= _paitence && epoch > 0)
{
_stopped_epoch = epoch;
_parameters.Model.set_stopTraining_true();
_parameters.Model.Stop_training = true;
if (_restore_best_weights && _best_weights != null)
{
if (_verbose > 0)
@@ -121,7 +121,7 @@ public class EarlyStopping: ICallback
public void on_predict_end() { }
public void on_test_begin() { }
public void on_test_batch_begin(long step) { }
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }
public void on_test_batch_end(long end_step, Dictionary<string, float> logs) { }

float get_monitor_value(Dictionary<string, float> logs)
{


+ 2
- 2
src/TensorFlowNET.Keras/Callbacks/History.cs View File

@@ -48,7 +48,7 @@ public class History : ICallback
{
history[log.Key] = new List<float>();
}
history[log.Key].Add((float)log.Value);
history[log.Key].Add(log.Value);
}
}

@@ -78,7 +78,7 @@ public class History : ICallback

}

public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
{
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs View File

@@ -105,11 +105,11 @@ namespace Tensorflow.Keras.Callbacks
{
_sw.Restart();
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
public void on_test_batch_end(long end_step, Dictionary<string, float> logs)
{
_sw.Stop();
var elapse = _sw.ElapsedMilliseconds;
var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}"));
var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {x.Value:F6}"));

Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}");
if (!Console.IsOutputRedirected)


+ 71
- 13
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -26,6 +26,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="workers"></param>
/// <param name="use_multiprocessing"></param>
/// <param name="return_dict"></param>
/// <param name="is_val"></param>
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
@@ -33,7 +34,9 @@ namespace Tensorflow.Keras.Engine
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false,
bool return_dict = false)
bool return_dict = false,
bool is_val = false
)
{
if (x.dims[0] != y.dims[0])
{
@@ -63,11 +66,11 @@ namespace Tensorflow.Keras.Engine
});
callbacks.on_test_begin();

IEnumerable<(string, Tensor)> logs = null;
//Dictionary<string, float>? logs = null;
var logs = new Dictionary<string, float>();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
@@ -75,19 +78,64 @@ namespace Tensorflow.Keras.Engine
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
callbacks.on_test_batch_end(end_step, logs);
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Item1] = (float)log.Item2;
results[log.Key] = log.Value;
}
return results;
}

public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1)
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(x),
Y = y,
Model = this,
StepsPerExecution = _steps_per_execution
});

var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
callbacks.on_test_begin();

Dictionary<string, float> logs = null;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();

foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
logs = test_step_multi_inputs_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Key] = log.Value;
}
return results;
}


public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -104,7 +152,7 @@ namespace Tensorflow.Keras.Engine
});
callbacks.on_test_begin();

IEnumerable<(string, Tensor)> logs = null;
Dictionary<string, float> logs = null;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
@@ -113,28 +161,38 @@ namespace Tensorflow.Keras.Engine

foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
callbacks.on_test_batch_begin(step);
logs = test_function(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
if (is_val == false)
callbacks.on_test_batch_end(end_step, logs);
}
}

var results = new Dictionary<string, float>();
foreach (var log in logs)
{
results[log.Item1] = (float)log.Item2;
results[log.Key] = log.Value;
}
return results;
}

IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator)
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}

List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y)
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
@@ -142,7 +200,7 @@ namespace Tensorflow.Keras.Engine

compiled_metrics.update_state(y, y_pred);

return metrics.Select(x => (x.Name, x.result())).ToList();
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
}
}
}

+ 152
- 16
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="callbacks"></param>
/// <param name="verbose"></param>
/// <param name="validation_split"></param>
/// <param name="validation_data"></param>
/// <param name="shuffle"></param>
public ICallback fit(NDArray x, NDArray y,
int batch_size = -1,
@@ -29,6 +30,7 @@ namespace Tensorflow.Keras.Engine
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(NDArray val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -40,11 +42,17 @@ namespace Tensorflow.Keras.Engine
throw new InvalidArgumentError(
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
var val_x = x[new Slice(train_count)];
var val_y = y[new Slice(train_count)];

var train_x = x;
var train_y = y;

if (validation_split != 0f && validation_data == null)
{
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
train_x = x[new Slice(0, train_count)];
train_y = y[new Slice(0, train_count)];
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]);
}

var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -61,7 +69,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}

@@ -71,6 +79,7 @@ namespace Tensorflow.Keras.Engine
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
@@ -85,12 +94,19 @@ namespace Tensorflow.Keras.Engine
$"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
}
}
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
var train_x = x.Select(x => x[new Slice(0, train_count)] as Tensor);
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as Tensor);
var val_y = y[new Slice(train_count)];

var train_x = x;
var train_y = y;
if (validation_split != 0f && validation_data == null)
{
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
validation_data = (val_x, val_y);
}


var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -110,29 +126,29 @@ namespace Tensorflow.Keras.Engine
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
train_step_func: train_step_multi_inputs_function);
}
else
{
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}
}

public History fit(IDatasetV2 dataset,
IDatasetV2 validation_data = null,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
IDatasetV2 validation_data = null,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
@@ -147,6 +163,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});


return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}
@@ -178,11 +195,13 @@ namespace Tensorflow.Keras.Engine
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
var logs = new Dictionary<string, float>();
long End_step = 0;
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
}

@@ -193,6 +212,123 @@ namespace Tensorflow.Keras.Engine
{
logs["val_" + log.Key] = log.Value;
}
callbacks.on_train_batch_end(End_step, logs);
}


callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
}

return callbacks.History;
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
_train_counter.assign(0);
var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = epochs,
Steps = data_handler.Inferredsteps
});

if (callbackList != null)
{
foreach (var callback in callbackList)
callbacks.callbacks.add(callback);
}

callbacks.on_train_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
var logs = new Dictionary<string, float>();
long End_step = 0;
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
}

if (validation_data != null)
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
}
// because after evaluate, logs add some new log which we need to print
callbacks.on_train_batch_end(End_step, logs);
}

callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
}

return callbacks.History;
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
_train_counter.assign(0);
var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = epochs,
Steps = data_handler.Inferredsteps
});

if (callbackList != null)
{
foreach (var callback in callbackList)
callbacks.callbacks.add(callback);
}

callbacks.on_train_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
var logs = new Dictionary<string, float>();
long End_step = 0;
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
}

if (validation_data != null)
{
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
callbacks.on_train_batch_end(End_step, logs);
}
}

callbacks.on_epoch_end(epoch, logs);


+ 6
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -46,6 +46,12 @@ namespace Tensorflow.Keras.Engine
set => optimizer = value;
}

public bool Stop_training
{
get => stop_training;
set => stop_training = value;
}

public Model(ModelArgs args)
: base(args)
{


+ 6
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -58,6 +58,12 @@ namespace Tensorflow.Keras
Name = name
});

public Sequential Sequential(params ILayer[] layers)
=> new Sequential(new SequentialArgs
{
Layers = layers.ToList()
});

/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>


+ 0
- 1
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -72,7 +72,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<ItemGroup>
<PackageReference Include="HDF5-CSharp" Version="1.16.3" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="SharpZipLib" Version="1.4.2" />
</ItemGroup>



+ 202
- 0
test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs View File

@@ -0,0 +1,202 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;
using TensorFlowNET.Keras.UnitTest;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class ComplexTest : EagerModeTestBase
{
// Tests for Complex128

[TestMethod]
public void complex128_basic()
{
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };

Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_real_result = tf.math.real(t_complex);
Tensor t_imag_result = tf.math.imag(t_complex);

NDArray n_real_result = t_real_result.numpy();
NDArray n_imag_result = t_imag_result.numpy();

double[] d_real_result =n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void complex128_abs()
{
tf.enable_eager_execution();

double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };

double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_abs_result = tf.abs(t_complex);

double[] d_abs_result = t_abs_result.numpy().ToArray<double>();
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
}
[TestMethod]
public void complex128_conj()
{
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };

double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 };
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);

Tensor t_result = tf.math.conj(t_complex);

NDArray n_real_result = tf.math.real(t_result).numpy();
NDArray n_imag_result = tf.math.imag(t_result).numpy();

double[] d_real_result = n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
}
[TestMethod]
public void complex128_angle()
{
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 };
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 };

double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);

Tensor t_result = tf.math.angle(t_complex);

NDArray n_result = t_result.numpy();

double[] d_result = n_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_result, d_expected));
}

// Tests for Complex64
[TestMethod]
public void complex64_basic()
{
tf.init_scope();
float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_real_result = tf.math.real(t_complex);
Tensor t_imag_result = tf.math.imag(t_complex);

// Convert the EagerTensors to NumPy arrays directly
float[] d_real_result = t_real_result.numpy().ToArray<float>();
float[] d_imag_result = t_imag_result.numpy().ToArray<float>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void complex64_abs()
{
tf.enable_eager_execution();
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };

float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);

Tensor t_abs_result = tf.abs(t_complex);

NDArray n_abs_result = t_abs_result.numpy();

float[] d_abs_result = n_abs_result.ToArray<float>();
Assert.IsTrue(base.Equal(d_abs_result, d_abs));

}
[TestMethod]
public void complex64_conj()
{
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };

float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);

Tensor t_result = tf.math.conj(t_complex);

NDArray n_real_result = tf.math.real(t_result).numpy();
NDArray n_imag_result = tf.math.imag(t_result).numpy();

float[] d_real_result = n_real_result.ToArray<float>();
float[] d_imag_result = n_imag_result.ToArray<float>();

Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));

}
[TestMethod]
public void complex64_angle()
{
float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f };
float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f };

float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f };
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);

Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);

Tensor t_result = tf.math.angle(t_complex);

NDArray n_result = t_result.numpy();

float[] d_result = n_result.ToArray<float>();

Assert.IsTrue(base.Equal(d_result, d_expected));
}
}
}

+ 103
- 0
test/TensorFlowNET.Graph.UnitTest/SignalTest.cs View File

@@ -0,0 +1,103 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;
using TensorFlowNET.Keras.UnitTest;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class SignalTest : EagerModeTestBase
{
[TestMethod]
public void fft()
{
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_frequency_domain = tf.signal.fft(t_complex);
Tensor f_time_domain = tf.signal.ifft(t_frequency_domain);

Tensor t_real_result = tf.math.real(f_time_domain);
Tensor t_imag_result = tf.math.imag(f_time_domain);

NDArray n_real_result = t_real_result.numpy();
NDArray n_imag_result = t_imag_result.numpy();

double[] d_real_result = n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void fft2d()
{
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_complex_2d = tf.reshape(t_complex,new int[] { 2, 2 });

Tensor t_frequency_domain_2d = tf.signal.fft2d(t_complex_2d);
Tensor t_time_domain_2d = tf.signal.ifft2d(t_frequency_domain_2d);

Tensor t_time_domain = tf.reshape(t_time_domain_2d, new int[] { 4 });

Tensor t_real_result = tf.math.real(t_time_domain);
Tensor t_imag_result = tf.math.imag(t_time_domain);

NDArray n_real_result = t_real_result.numpy();
NDArray n_imag_result = t_imag_result.numpy();

double[] d_real_result = n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
[TestMethod]
public void fft3d()
{
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0, -3.0, -2.0, -1.0, -4.0 };
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0, 6.0, 4.0, 2.0, 0.0};

Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);

Tensor t_complex = tf.complex(t_real, t_imag);

Tensor t_complex_3d = tf.reshape(t_complex, new int[] { 2, 2, 2 });

Tensor t_frequency_domain_3d = tf.signal.fft2d(t_complex_3d);
Tensor t_time_domain_3d = tf.signal.ifft2d(t_frequency_domain_3d);

Tensor t_time_domain = tf.reshape(t_time_domain_3d, new int[] { 8 });

Tensor t_real_result = tf.math.real(t_time_domain);
Tensor t_imag_result = tf.math.imag(t_time_domain);

NDArray n_real_result = t_real_result.numpy();
NDArray n_imag_result = t_imag_result.numpy();

double[] d_real_result = n_real_result.ToArray<double>();
double[] d_imag_result = n_imag_result.ToArray<double>();

Assert.IsTrue(base.Equal(d_real_result, d_real));
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
}
}
}

+ 1
- 0
test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj View File

@@ -36,6 +36,7 @@

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
<ProjectReference Include="..\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj" />
</ItemGroup>

</Project>

Loading…
Cancel
Save