Implement tf.matmul #94tags/v0.1.0-Tensor
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public static partial class tf | |||||
{ | |||||
public static unsafe Tensor matmul(Tensor a, Tensor b) | |||||
{ | |||||
return gen_math_ops.mat_mul(a, b); | |||||
} | |||||
} | |||||
} |
@@ -78,6 +78,9 @@ namespace Tensorflow | |||||
case "type": | case "type": | ||||
attr_value.Type = _MakeType((TF_DataType)value, attr_def); | attr_value.Type = _MakeType((TF_DataType)value, attr_def); | ||||
break; | break; | ||||
case "bool": | |||||
attr_value.B = (bool)value; | |||||
break; | |||||
case "shape": | case "shape": | ||||
attr_value.Shape = new TensorShapeProto(); | attr_value.Shape = new TensorShapeProto(); | ||||
break; | break; | ||||
@@ -30,5 +30,18 @@ namespace Tensorflow | |||||
return new Tensor(_op, 0, _op.OutputType(0)); | return new Tensor(_op, 0, _op.OutputType(0)); | ||||
} | } | ||||
public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) | |||||
{ | |||||
var keywords = new Dictionary<string, object>(); | |||||
keywords.Add("a", a); | |||||
keywords.Add("b", b); | |||||
keywords.Add("transpose_a", transpose_a); | |||||
keywords.Add("transpose_b", transpose_b); | |||||
var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); | |||||
return new Tensor(_op, 0, _op.OutputType(0)); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Runtime.InteropServices; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
@@ -63,8 +64,9 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void String() | public void String() | ||||
{ | { | ||||
//var desc = init("string"); | |||||
//c_api.TF_SetAttrString(desc, "v", "bunny", 5); | |||||
var desc = init("string"); | |||||
var handle = Marshal.StringToHGlobalAnsi("bunny"); | |||||
c_api.TF_SetAttrString(desc, "v", handle, 5); | |||||
//var oper = c_api.TF_FinishOperation(desc, s_); | //var oper = c_api.TF_FinishOperation(desc, s_); | ||||
//ASSERT_EQ(TF_Code.TF_OK, s_.Code); | //ASSERT_EQ(TF_Code.TF_OK, s_.Code); | ||||
@@ -23,6 +23,7 @@ | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||