@@ -8,7 +8,7 @@ TensorFlow.NET is a member project of SciSharp stack. | |||||
### How to use | ### How to use | ||||
```cs | ```cs | ||||
using tf = TensorFlowNET.Core.Tensorflow; | |||||
using TensorFlowNET.Core; | |||||
namespace TensorFlowNET.Examples | namespace TensorFlowNET.Examples | ||||
{ | { | ||||
@@ -0,0 +1,30 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.Core | |||||
{ | |||||
public class Buffer | |||||
{ | |||||
private IntPtr _handle; | |||||
public IntPtr Handle => _handle; | |||||
//public TF_Buffer buffer => Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
public unsafe Buffer() | |||||
{ | |||||
_handle = Marshal.AllocHGlobal(sizeof(TF_Buffer)); | |||||
} | |||||
public byte[] GetBuffer() | |||||
{ | |||||
var buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | |||||
var data = Marshal.AllocHGlobal(buffer.length); | |||||
//var bytes = c_api.TF_GetBuffer(buffer.data); | |||||
return null; | |||||
} | |||||
} | |||||
} |
@@ -32,7 +32,9 @@ namespace TensorFlowNET.Core | |||||
_names_in_use = new Dictionary<string, int>(); | _names_in_use = new Dictionary<string, int>(); | ||||
} | } | ||||
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, Dictionary<string, AttrValue> attrs = null, string name = "Const") | |||||
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, | |||||
TF_DataType[] input_types = null, string name = "", | |||||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | |||||
{ | { | ||||
if (String.IsNullOrEmpty(name)) | if (String.IsNullOrEmpty(name)) | ||||
{ | { | ||||
@@ -0,0 +1,95 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.IO; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
using static Tensorflow.OpDef.Types; | |||||
namespace TensorFlowNET.Core | |||||
{ | |||||
public class OpDefLibrary | |||||
{ | |||||
public Dictionary<string, OpDef> _ops = new Dictionary<string, OpDef>(); | |||||
public void add_op_list(OpList op_list) | |||||
{ | |||||
foreach(var op_def in op_list.Op) | |||||
{ | |||||
add_op(op_def); | |||||
} | |||||
} | |||||
public void add_op(OpDef op_def) | |||||
{ | |||||
_ops[op_def.Name] = op_def; | |||||
} | |||||
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null) | |||||
{ | |||||
var op_def = _ops[op_type_name]; | |||||
var status = new Status(); | |||||
var buffer = new Buffer(); | |||||
var g = ops.get_default_graph(); | |||||
if (String.IsNullOrEmpty(name)) | |||||
{ | |||||
name = op_type_name; | |||||
} | |||||
foreach(var attr_def in op_def.Attr) | |||||
{ | |||||
if (attr_def.Type != "type") continue; | |||||
var key = attr_def.Name; | |||||
} | |||||
foreach(var input_arg in op_def.InputArg) | |||||
{ | |||||
} | |||||
var attr_protos = new Dictionary<string, AttrValue>(); | |||||
foreach (var attr_def in op_def.Attr) | |||||
{ | |||||
var key = attr_def.Name; | |||||
var attr_value = new AttrValue(); | |||||
switch (attr_def.Type) | |||||
{ | |||||
case "type": | |||||
attr_value.Type = dtype.Value; | |||||
break; | |||||
case "shape": | |||||
attr_value.Shape = new TensorShapeProto(); | |||||
break; | |||||
} | |||||
attr_protos[key] = attr_value; | |||||
} | |||||
var output_types = new List<DataType>(); | |||||
foreach (var arg in op_def.OutputArg) | |||||
{ | |||||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||||
{ | |||||
} | |||||
else if (!String.IsNullOrEmpty(arg.TypeAttr)) | |||||
{ | |||||
output_types.Add(attr_protos[arg.TypeAttr].Type); | |||||
} | |||||
} | |||||
var op = g.create_op(op_type_name, null, output_types.ToArray(), | |||||
name: "Placeholder_1/", | |||||
input_types: new DataType[] { }, | |||||
attrs: null, | |||||
op_def: null); | |||||
return op; | |||||
} | |||||
} | |||||
} |
@@ -25,6 +25,9 @@ | |||||
<None Update="tensorflow.dll"> | <None Update="tensorflow.dll"> | ||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
</None> | </None> | ||||
<None Update="Tensorflow\op_list_proto_bytes.bin"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
</ItemGroup> | </ItemGroup> | ||||
</Project> | </Project> |
@@ -2,7 +2,6 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | using System.Text; | ||||
using size_t = System.IntPtr; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -10,7 +9,7 @@ namespace Tensorflow | |||||
public struct TF_Buffer | public struct TF_Buffer | ||||
{ | { | ||||
public IntPtr data; | public IntPtr data; | ||||
public size_t length; | |||||
public int length; | |||||
public IntPtr data_deallocator; | public IntPtr data_deallocator; | ||||
} | } | ||||
} | } |
@@ -27,9 +27,15 @@ namespace TensorFlowNET.Core | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern string TF_GetBuffer(IntPtr buffer); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe TF_Code TF_GetCode(TF_Status s); | public static extern unsafe TF_Code TF_GetCode(TF_Status s); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TF_GraphGetOpDef(TF_Graph graph, string op_name, IntPtr output_op_def, TF_Status status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe string TF_Message(TF_Status s); | public static extern unsafe string TF_Message(TF_Status s); | ||||
@@ -1,10 +1,31 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | |||||
namespace TensorFlowNET.Core | namespace TensorFlowNET.Core | ||||
{ | { | ||||
public static class gen_array_ops | public static class gen_array_ops | ||||
{ | { | ||||
public static OpDefLibrary _op_def_lib => _InitOpDefLibrary(); | |||||
public static Tensor placeholder(DataType dtype, TensorShape shape = null) | |||||
{ | |||||
var op = _op_def_lib._apply_op_helper("Placeholder", dtype: dtype, shape: shape); | |||||
return null; | |||||
} | |||||
private static OpDefLibrary _InitOpDefLibrary() | |||||
{ | |||||
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | |||||
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_bytes.bin"); | |||||
var op_list = OpList.Parser.ParseFrom(bytes); | |||||
var op_def_lib = new OpDefLibrary(); | |||||
op_def_lib.add_op_list(op_list); | |||||
return op_def_lib; | |||||
} | |||||
} | } | ||||
} | } |
@@ -3,7 +3,6 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
using np = NumSharp.Core.NumPy; | |||||
using tensor_pb2 = Tensorflow; | using tensor_pb2 = Tensorflow; | ||||
namespace TensorFlowNET.Core | namespace TensorFlowNET.Core | ||||
@@ -10,14 +10,13 @@ namespace TensorFlowNET.Core | |||||
{ | { | ||||
public static class tf | public static class tf | ||||
{ | { | ||||
public static Type float32 = typeof(float); | |||||
public static DataType float32 = DataType.DtFloat; | |||||
public delegate void Deallocator(IntPtr data, IntPtr size, IntPtr deallocatorData); | public delegate void Deallocator(IntPtr data, IntPtr size, IntPtr deallocatorData); | ||||
public static unsafe Tensor placeholder(Type dtype, TensorShape shape = null) | |||||
public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) | |||||
{ | { | ||||
return null; | |||||
return gen_array_ops.placeholder(dtype, shape); | |||||
} | } | ||||
public static unsafe Tensor constant(object value) | public static unsafe Tensor constant(object value) | ||||