Browse Source

Merge pull request #95 from Esther2013/master

Implement tf.matmul #94
tags/v0.1.0-Tensor
Haiping GitHub 6 years ago
parent
commit
af159f849a
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 2 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  3. +13
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +4
    -2
      test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs
  5. +1
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

+ 14
- 0
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static unsafe Tensor matmul(Tensor a, Tensor b)
{
return gen_math_ops.mat_mul(a, b);
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -78,6 +78,9 @@ namespace Tensorflow
case "type": case "type":
attr_value.Type = _MakeType((TF_DataType)value, attr_def); attr_value.Type = _MakeType((TF_DataType)value, attr_def);
break; break;
case "bool":
attr_value.B = (bool)value;
break;
case "shape": case "shape":
attr_value.Shape = new TensorShapeProto(); attr_value.Shape = new TensorShapeProto();
break; break;


+ 13
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -30,5 +30,18 @@ namespace Tensorflow


return new Tensor(_op, 0, _op.OutputType(0)); return new Tensor(_op, 0, _op.OutputType(0));
} }

public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false)
{
var keywords = new Dictionary<string, object>();
keywords.Add("a", a);
keywords.Add("b", b);
keywords.Add("transpose_a", transpose_a);
keywords.Add("transpose_b", transpose_b);

var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords);

return new Tensor(_op, 0, _op.OutputType(0));
}
} }
} }

+ 4
- 2
test/TensorFlowNET.UnitTest/CApiAttributesTestcs.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;


@@ -63,8 +64,9 @@ namespace TensorFlowNET.UnitTest
[TestMethod] [TestMethod]
public void String() public void String()
{ {
//var desc = init("string");
//c_api.TF_SetAttrString(desc, "v", "bunny", 5);
var desc = init("string");
var handle = Marshal.StringToHGlobalAnsi("bunny");
c_api.TF_SetAttrString(desc, "v", handle, 5);


//var oper = c_api.TF_FinishOperation(desc, s_); //var oper = c_api.TF_FinishOperation(desc, s_);
//ASSERT_EQ(TF_Code.TF_OK, s_.Code); //ASSERT_EQ(TF_Code.TF_OK, s_.Code);


+ 1
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -23,6 +23,7 @@
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup> </ItemGroup>




Loading…
Cancel
Save