@@ -39,6 +39,9 @@ namespace Tensorflow | |||||
public Tensor matmul(Tensor a, Tensor b) | public Tensor matmul(Tensor a, Tensor b) | ||||
=> math_ops.matmul(a, b); | => math_ops.matmul(a, b); | ||||
public Tensor norm(Tensor a, string ord = "euclidean", Axis axis = null, string name = null) | |||||
=> ops.norm(a, ord: ord, axis: axis, name: name); | |||||
public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) | public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) | ||||
=> math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); | => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); | ||||
@@ -166,6 +166,8 @@ namespace Tensorflow | |||||
return arr.Count; | return arr.Count; | ||||
case IEnumerable enumerable: | case IEnumerable enumerable: | ||||
return enumerable.OfType<object>().Count(); | return enumerable.OfType<object>().Count(); | ||||
case Axis axis: | |||||
return axis.size; | |||||
case Shape arr: | case Shape arr: | ||||
return arr.ndim; | return arr.ndim; | ||||
} | } | ||||
@@ -1,14 +1,26 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.NumPy | namespace Tensorflow.NumPy | ||||
{ | { | ||||
public class LinearAlgebraImpl | public class LinearAlgebraImpl | ||||
{ | { | ||||
[AutoNumPy] | |||||
public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn") | public NDArray lstsq(NDArray a, NDArray b, string rcond = "warn") | ||||
=> new NDArray(tf.linalg.lstsq(a, b)); | |||||
[AutoNumPy] | |||||
public NDArray norm(NDArray a, Axis axis = null) | |||||
{ | { | ||||
return a; | |||||
if (a.dtype.is_integer()) | |||||
{ | |||||
var float_a = math_ops.cast(a, dtype: tf.float32); | |||||
return new NDArray(tf.linalg.norm(float_a, axis: axis)); | |||||
} | |||||
return new NDArray(tf.linalg.norm(a, axis: axis)); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -52,6 +52,21 @@ namespace Tensorflow | |||||
return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer); | return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer); | ||||
} | } | ||||
public Tensor norm(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null, bool keepdims = true) | |||||
{ | |||||
var is_matrix_norm = axis != null && len(axis) == 2; | |||||
return tf_with(ops.name_scope(name, default_name: "norm", tensor), scope => | |||||
{ | |||||
if (is_matrix_norm) | |||||
throw new NotImplementedException(""); | |||||
var result = math_ops.sqrt(math_ops.reduce_sum(tensor * math_ops.conj(tensor), axis, keepdims: true)); | |||||
if(!keepdims) | |||||
result = array_ops.squeeze(result, axis); | |||||
return result; | |||||
}); | |||||
} | |||||
Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) | Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) | ||||
{ | { | ||||
Shape matrix_shape = matrix.shape.dims.Skip(matrix.shape.ndim - 2).ToArray(); | Shape matrix_shape = matrix.shape.dims.Skip(matrix.shape.ndim - 2).ToArray(); | ||||
@@ -109,6 +109,16 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor normalize(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "normalize", tensor), scope => | |||||
{ | |||||
var norm = tf.linalg.norm(tensor, ord: ord, axis: axis, name: name); | |||||
var normalized = tensor / norm; | |||||
return normalized; | |||||
}); | |||||
} | |||||
public static Tensor batch_normalization(Tensor x, | public static Tensor batch_normalization(Tensor x, | ||||
Tensor mean, | Tensor mean, | ||||
Tensor variance, | Tensor variance, | ||||
@@ -19,5 +19,13 @@ namespace TensorFlowNET.UnitTest.NumPy | |||||
{ | { | ||||
} | } | ||||
[TestMethod] | |||||
public void norm() | |||||
{ | |||||
var x = np.arange(9) - 4; | |||||
var y = x.reshape((3, 3)); | |||||
var norm = np.linalg.norm(y); | |||||
} | |||||
} | } | ||||
} | } |