diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index 75253700..0e53d938 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.NumPy;
using Tensorflow.Operations;
namespace Tensorflow
@@ -42,7 +43,6 @@ namespace Tensorflow
public Tensor multiply(Tensor x, Tensor y, string name = null)
=> math_ops.multiply(x, y, name: name);
-
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
=> math_ops.div_no_nan(a, b);
@@ -452,7 +452,18 @@ namespace Tensorflow
///
public Tensor multiply(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
-
+ ///
+ /// return scalar product
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor dot_prod(Tx x, Ty y, NDArray axes, string name = null)
+ => math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
public Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index c5705930..99ed5c1f 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -486,7 +486,28 @@ namespace Tensorflow
throw new NotImplementedException("");
}
}
-
+ public static NDArray GetFlattenArray(NDArray x)
+ {
+ switch (x.GetDataType())
+ {
+ case TF_DataType.TF_FLOAT:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_DOUBLE:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT16:
+ case TF_DataType.TF_INT32:
+ x = x.ToArray();
+ break;
+ case TF_DataType.TF_INT64:
+ x = x.ToArray();
+ break;
+ default:
+ break;
+ }
+ return x;
+ }
public static TF_DataType GetDataType(this object data)
{
var type = data.GetType();
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
index 19f3df9b..ddc72aee 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
@@ -60,7 +60,7 @@ public interface IModel : ILayer
bool skip_mismatch = false,
object options = null);
- Dictionary evaluate(NDArray x, NDArray y,
+ Dictionary evaluate(Tensor x, Tensor y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
index ea85048f..5bc97952 100644
--- a/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
+++ b/src/TensorFlowNET.Core/NumPy/Numpy.Math.cs
@@ -49,9 +49,30 @@ namespace Tensorflow.NumPy
[AutoNumPy]
public static NDArray prod(params T[] array) where T : unmanaged
=> new NDArray(tf.reduce_prod(new NDArray(array)));
+ [AutoNumPy]
+ public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null)
+ {
+ //if axes mentioned
+ if (axes != null)
+ {
+ return new NDArray(tf.dot_prod(x1, x2, axes, name));
+ }
+ if (x1.shape.ndim > 1)
+ {
+ x1 = GetFlattenArray(x1);
+ }
+ if (x2.shape.ndim > 1)
+ {
+ x2 = GetFlattenArray(x2);
+ }
+ //if axes not mentioned, default 0,0
+ return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name));
+ }
[AutoNumPy]
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y));
+ [AutoNumPy]
+ public static NDArray square(NDArray x) => new NDArray(tf.square(x));
[AutoNumPy]
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x));
diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs
index 38a3e5dc..2838b000 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensors.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs
@@ -226,62 +226,62 @@ namespace Tensorflow
}
#region Explicit Conversions
- public unsafe static explicit operator bool(Tensors tensor)
+ public static explicit operator bool(Tensors tensor)
{
return (bool)tensor.Single;
}
- public unsafe static explicit operator sbyte(Tensors tensor)
+ public static explicit operator sbyte(Tensors tensor)
{
return (sbyte)tensor.Single;
}
- public unsafe static explicit operator byte(Tensors tensor)
+ public static explicit operator byte(Tensors tensor)
{
return (byte)tensor.Single;
}
- public unsafe static explicit operator ushort(Tensors tensor)
+ public static explicit operator ushort(Tensors tensor)
{
return (ushort)tensor.Single;
}
- public unsafe static explicit operator short(Tensors tensor)
+ public static explicit operator short(Tensors tensor)
{
return (short)tensor.Single;
}
- public unsafe static explicit operator int(Tensors tensor)
+ public static explicit operator int(Tensors tensor)
{
return (int)tensor.Single;
}
- public unsafe static explicit operator uint(Tensors tensor)
+ public static explicit operator uint(Tensors tensor)
{
return (uint)tensor.Single;
}
- public unsafe static explicit operator long(Tensors tensor)
+ public static explicit operator long(Tensors tensor)
{
return (long)tensor.Single;
}
- public unsafe static explicit operator ulong(Tensors tensor)
+ public static explicit operator ulong(Tensors tensor)
{
return (ulong)tensor.Single;
}
- public unsafe static explicit operator float(Tensors tensor)
+ public static explicit operator float(Tensors tensor)
{
return (byte)tensor.Single;
}
- public unsafe static explicit operator double(Tensors tensor)
+ public static explicit operator double(Tensors tensor)
{
return (double)tensor.Single;
}
- public unsafe static explicit operator string(Tensors tensor)
+ public static explicit operator string(Tensors tensor)
{
return (string)tensor.Single;
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
index d807b204..eaa9eb23 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
@@ -1,14 +1,14 @@
-using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine.DataAdapters;
-using static Tensorflow.Binding;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
-using Tensorflow;
-using Tensorflow.Keras.Callbacks;
+using Tensorflow.NumPy;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine
{
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Engine
///
///
///
- public Dictionary evaluate(NDArray x, NDArray y,
+ public Dictionary evaluate(Tensor x, Tensor y,
int batch_size = -1,
int verbose = 1,
int steps = -1,
@@ -64,34 +64,11 @@ namespace Tensorflow.Keras.Engine
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
- callbacks.on_test_begin();
-
- //Dictionary? logs = null;
- var logs = new Dictionary();
- foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
- {
- reset_metrics();
- // data_handler.catch_stop_iteration();
-
- foreach (var step in data_handler.steps())
- {
- callbacks.on_test_batch_begin(step);
- logs = test_function(data_handler, iterator);
- var end_step = step + data_handler.StepIncrement;
- if (is_val == false)
- callbacks.on_test_batch_end(end_step, logs);
- }
- }
- var results = new Dictionary();
- foreach (var log in logs)
- {
- results[log.Key] = log.Value;
- }
- return results;
+ return evaluate(data_handler, callbacks, is_val, test_function);
}
- public Dictionary evaluate(IEnumerable x, NDArray y, int verbose = 1, bool is_val = false)
+ public Dictionary evaluate(IEnumerable x, Tensor y, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
{
@@ -107,34 +84,10 @@ namespace Tensorflow.Keras.Engine
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
- callbacks.on_test_begin();
- Dictionary logs = null;
- foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
- {
- reset_metrics();
- callbacks.on_epoch_begin(epoch);
- // data_handler.catch_stop_iteration();
-
- foreach (var step in data_handler.steps())
- {
- callbacks.on_test_batch_begin(step);
- logs = test_step_multi_inputs_function(data_handler, iterator);
- var end_step = step + data_handler.StepIncrement;
- if (is_val == false)
- callbacks.on_test_batch_end(end_step, logs);
- }
- }
-
- var results = new Dictionary();
- foreach (var log in logs)
- {
- results[log.Key] = log.Value;
- }
- return results;
+ return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
}
-
public Dictionary evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
{
var data_handler = new DataHandler(new DataHandlerArgs
@@ -150,9 +103,24 @@ namespace Tensorflow.Keras.Engine
Verbose = verbose,
Steps = data_handler.Inferredsteps
});
+
+ return evaluate(data_handler, callbacks, is_val, test_function);
+ }
+
+ ///
+ /// Internal bare implementation of evaluate function.
+ ///
+ /// Interations handling objects
+ ///
+ /// The function to be called on each batch of data.
+ /// Whether it is validation or test.
+ ///
+ Dictionary evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func> test_func)
+ {
callbacks.on_test_begin();
- Dictionary logs = null;
+ var results = new Dictionary();
+ var logs = results;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
@@ -162,45 +130,47 @@ namespace Tensorflow.Keras.Engine
foreach (var step in data_handler.steps())
{
callbacks.on_test_batch_begin(step);
- logs = test_function(data_handler, iterator);
+
+ logs = test_func(data_handler, iterator.next());
+
+ tf_with(ops.control_dependencies(Array.Empty