@@ -24,7 +24,7 @@ namespace Tensorflow | |||||
_ops[op_def.Name] = op_def; | _ops[op_def.Name] = op_def; | ||||
} | } | ||||
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null) | |||||
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null) | |||||
{ | { | ||||
var op_def = _ops[op_type_name]; | var op_def = _ops[op_type_name]; | ||||
@@ -46,9 +46,30 @@ namespace Tensorflow | |||||
var key = attr_def.Name; | var key = attr_def.Name; | ||||
} | } | ||||
foreach(var input_arg in op_def.InputArg) | |||||
var attrs = new Dictionary<string, object>(); | |||||
var inputs = new List<Tensor>(); | |||||
var input_types = new List<DataType>(); | |||||
foreach (var attr in op_def.Attr) | |||||
{ | |||||
if (keywords.ContainsKey(attr.Name)) | |||||
{ | |||||
attrs[attr.Name] = keywords[attr.Name]; | |||||
} | |||||
} | |||||
foreach (var input_arg in op_def.InputArg) | |||||
{ | { | ||||
var input_name = input_arg.Name; | |||||
if (keywords.ContainsKey(input_name)) | |||||
{ | |||||
inputs.Add(keywords[input_name] as Tensor); | |||||
} | |||||
if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||||
{ | |||||
attrs[input_arg.TypeAttr] = DataType.DtFloat; | |||||
} | |||||
} | } | ||||
var attr_protos = new Dictionary<string, AttrValue>(); | var attr_protos = new Dictionary<string, AttrValue>(); | ||||
@@ -60,7 +81,7 @@ namespace Tensorflow | |||||
switch (attr_def.Type) | switch (attr_def.Type) | ||||
{ | { | ||||
case "type": | case "type": | ||||
attr_value.Type = dtype.Value; | |||||
attr_value.Type = (DataType)keywords["dtype"]; | |||||
break; | break; | ||||
case "shape": | case "shape": | ||||
attr_value.Shape = new TensorShapeProto(); | attr_value.Shape = new TensorShapeProto(); | ||||
@@ -84,9 +105,9 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
var op = g.create_op(op_type_name, null, output_types.ToArray(), | |||||
var op = g.create_op(op_type_name, inputs, output_types.ToArray(), | |||||
name: scope, | name: scope, | ||||
input_types: new DataType[] { }, | |||||
input_types: input_types.ToArray(), | |||||
attrs: attr_protos, | attrs: attr_protos, | ||||
op_def: op_def); | op_def: op_def); | ||||
@@ -27,7 +27,10 @@ | |||||
<None Update="tensorflow.dll"> | <None Update="tensorflow.dll"> | ||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
</None> | </None> | ||||
<None Update="Tensorflow\op_list_proto_bytes.bin"> | |||||
<None Update="Tensorflow\op_list_proto_array.bin"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | |||||
</None> | |||||
<None Update="Tensorflow\op_list_proto_math.bin"> | |||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | <CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory> | ||||
</None> | </None> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -8,11 +8,22 @@ namespace Tensorflow | |||||
{ | { | ||||
public static class gen_array_ops | public static class gen_array_ops | ||||
{ | { | ||||
public static OpDefLibrary _op_def_lib => _InitOpDefLibrary(); | |||||
public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); | |||||
public static Tensor placeholder(DataType dtype, TensorShape shape = null) | public static Tensor placeholder(DataType dtype, TensorShape shape = null) | ||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Placeholder", dtype: dtype, shape: shape); | |||||
/*var g = ops.get_default_graph(); | |||||
var op = new Operation(g, "Placeholder", "feed"); | |||||
var tensor = new Tensor(op, 0, dtype); | |||||
return tensor;*/ | |||||
var keywords = new Dictionary<string, object>(); | |||||
keywords.Add("dtype", dtype); | |||||
keywords.Add("shape", shape); | |||||
var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | |||||
var _result = _op.outputs; | var _result = _op.outputs; | ||||
var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
var _attrs = new Dictionary<string, object>(); | var _attrs = new Dictionary<string, object>(); | ||||
@@ -27,7 +38,7 @@ namespace Tensorflow | |||||
private static OpDefLibrary _InitOpDefLibrary() | private static OpDefLibrary _InitOpDefLibrary() | ||||
{ | { | ||||
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | ||||
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_bytes.bin"); | |||||
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); | |||||
var op_list = OpList.Parser.ParseFrom(bytes); | var op_list = OpList.Parser.ParseFrom(bytes); | ||||
var op_def_lib = new OpDefLibrary(); | var op_def_lib = new OpDefLibrary(); | ||||
op_def_lib.add_op_list(op_list); | op_def_lib.add_op_list(op_list); | ||||
@@ -1,15 +1,34 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.IO; | |||||
using System.Text; | using System.Text; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static class gen_math_ops | public static class gen_math_ops | ||||
{ | { | ||||
public static Tensor add(Tensor a, Tensor b, string name = "") | |||||
public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); | |||||
public static Tensor add(Tensor a, Tensor b) | |||||
{ | { | ||||
var keywords = new Dictionary<string, object>(); | |||||
keywords.Add("x", a); | |||||
keywords.Add("y", b); | |||||
var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | |||||
return null; | return null; | ||||
} | } | ||||
private static OpDefLibrary _InitOpDefLibrary() | |||||
{ | |||||
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); | |||||
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin"); | |||||
var op_list = OpList.Parser.ParseFrom(bytes); | |||||
var op_def_lib = new OpDefLibrary(); | |||||
op_def_lib.add_op_list(op_list); | |||||
return op_def_lib; | |||||
} | |||||
} | } | ||||
} | } |
@@ -20,18 +20,11 @@ namespace Tensorflow | |||||
public static unsafe Tensor add(Tensor a, Tensor b) | public static unsafe Tensor add(Tensor a, Tensor b) | ||||
{ | { | ||||
return null; | |||||
return gen_math_ops.add(a, b); | |||||
} | } | ||||
public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) | public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) | ||||
{ | { | ||||
/*var g = ops.get_default_graph(); | |||||
var op = new Operation(g, "Placeholder", "feed"); | |||||
var tensor = new Tensor(op, 0, dtype); | |||||
return tensor;*/ | |||||
return gen_array_ops.placeholder(dtype, shape); | return gen_array_ops.placeholder(dtype, shape); | ||||
} | } | ||||