|
|
@@ -52,6 +52,21 @@ namespace Tensorflow |
|
|
|
return _composite_impl(matrix, rhs, l2_regularizer: l2_regularizer); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor norm(Tensor tensor, string ord = "euclidean", Axis axis = null, string name = null, bool keepdims = true) |
|
|
|
{ |
|
|
|
var is_matrix_norm = axis != null && len(axis) == 2; |
|
|
|
return tf_with(ops.name_scope(name, default_name: "norm", tensor), scope => |
|
|
|
{ |
|
|
|
if (is_matrix_norm) |
|
|
|
throw new NotImplementedException(""); |
|
|
|
var result = math_ops.sqrt(math_ops.reduce_sum(tensor * math_ops.conj(tensor), axis, keepdims: true)); |
|
|
|
|
|
|
|
if(!keepdims) |
|
|
|
result = array_ops.squeeze(result, axis); |
|
|
|
return result; |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
Tensor _composite_impl(Tensor matrix, Tensor rhs, Tensor l2_regularizer = null) |
|
|
|
{ |
|
|
|
Shape matrix_shape = matrix.shape.dims.Skip(matrix.shape.ndim - 2).ToArray(); |
|
|
|