diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs new file mode 100644 index 00000000..ccfbbc2b --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static unsafe Tensor matmul(Tensor a, Tensor b) + { + return gen_math_ops.mat_mul(a, b); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index aac58828..66fa8089 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -78,6 +78,9 @@ namespace Tensorflow case "type": attr_value.Type = _MakeType((TF_DataType)value, attr_def); break; + case "bool": + attr_value.B = (bool)value; + break; case "shape": attr_value.Shape = new TensorShapeProto(); break; diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 6e95bc70..25aa2d8a 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -30,5 +30,18 @@ namespace Tensorflow return new Tensor(_op, 0, _op.OutputType(0)); } + + public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) + { + var keywords = new Dictionary(); + keywords.Add("a", a); + keywords.Add("b", b); + keywords.Add("transpose_a", transpose_a); + keywords.Add("transpose_b", transpose_b); + + var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); + + return new Tensor(_op, 0, _op.OutputType(0)); + } } } diff --git a/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs b/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs index da7294cd..2a9db62d 100644 --- a/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs +++ b/test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs @@ -1,6 +1,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; +using System.Runtime.InteropServices; using System.Text; using Tensorflow; @@ -63,8 +64,9 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void String() { - //var desc = init("string"); - //c_api.TF_SetAttrString(desc, "v", "bunny", 5); + var desc = init("string"); + var handle = Marshal.StringToHGlobalAnsi("bunny"); + c_api.TF_SetAttrString(desc, "v", handle, 5); //var oper = c_api.TF_FinishOperation(desc, s_); //ASSERT_EQ(TF_Code.TF_OK, s_.Code); diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index 802d4c0c..f227365e 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -23,6 +23,7 @@ +