diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index 923cf581..fd751322 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -6,6 +6,10 @@ namespace Tensorflow
{
public static partial class tf
{
- public static unsafe Tensor matmul(Tensor a, Tensor b) => gen_math_ops.mat_mul(a, b);
+ public static Tensor diag(Tensor diagonal, string name = null)
+ => gen_array_ops.diag(diagonal, name: name);
+
+ public static Tensor matmul(Tensor a, Tensor b)
+ => gen_math_ops.mat_mul(a, b);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 07525318..1f369b9a 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -26,6 +26,44 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Returns a diagonal tensor with a given diagonal values.
+ ///
+ ///
+ /// Rank k tensor where k is at most 1.
+ ///
+ ///
+ /// If specified, the created operation in the graph will be this one, otherwise it will be named 'Diag'.
+ ///
+ ///
+ /// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result.
+ ///
+ ///
+ /// Given a diagonal, this operation returns a tensor with the diagonal and
+ /// everything else padded with zeros. The diagonal is computed as follows:
+ ///
+ /// Assume diagonal has dimensions [D1,..., Dk], then the output is a tensor of
+ /// rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where:
+ ///
+ /// output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik] and 0 everywhere else.
+ ///
+ /// For example:
+ ///
+ ///
+ /// # 'diagonal' is [1, 2, 3, 4]
+ /// tf.diag(diagonal) ==> [[1, 0, 0, 0]
+ /// [0, 2, 0, 0]
+ /// [0, 0, 3, 0]
+ /// [0, 0, 0, 4]]
+ ///
+ ///
+ public static Tensor diag(Tensor diagonal, string name = null)
+ {
+ var op = _op_def_lib._apply_op_helper("Diag", name: name, args: new { diagonal });
+
+ return op.output;
+ }
+
public static Tensor expand_dims(Tensor input, int axis, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis });