From bc13c9a932207d02016ab8838d319f62ff3c59d0 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 19 Dec 2018 07:32:02 -0600 Subject: [PATCH] add op_list_proto_math --- src/TensorFlowNET.Core/OpDefLibrary.cs | 31 +++++++++++++++--- .../TensorFlowNET.Core.csproj | 5 ++- ...roto_bytes.bin => op_list_proto_array.bin} | Bin .../Tensorflow/op_list_proto_math.bin | Bin 0 -> 11722 bytes src/TensorFlowNET.Core/ops/gen_array_ops.cs | 17 ++++++++-- src/TensorFlowNET.Core/ops/gen_math_ops.cs | 21 +++++++++++- src/TensorFlowNET.Core/tf.cs | 9 +---- 7 files changed, 65 insertions(+), 18 deletions(-) rename src/TensorFlowNET.Core/Tensorflow/{op_list_proto_bytes.bin => op_list_proto_array.bin} (100%) create mode 100644 src/TensorFlowNET.Core/Tensorflow/op_list_proto_math.bin diff --git a/src/TensorFlowNET.Core/OpDefLibrary.cs b/src/TensorFlowNET.Core/OpDefLibrary.cs index e891e58b..7d0d6e12 100644 --- a/src/TensorFlowNET.Core/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/OpDefLibrary.cs @@ -24,7 +24,7 @@ namespace Tensorflow _ops[op_def.Name] = op_def; } - public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null) + public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary keywords = null) { var op_def = _ops[op_type_name]; @@ -46,9 +46,30 @@ namespace Tensorflow var key = attr_def.Name; } - foreach(var input_arg in op_def.InputArg) + var attrs = new Dictionary(); + var inputs = new List(); + var input_types = new List(); + + foreach (var attr in op_def.Attr) + { + if (keywords.ContainsKey(attr.Name)) + { + attrs[attr.Name] = keywords[attr.Name]; + } + } + + foreach (var input_arg in op_def.InputArg) { + var input_name = input_arg.Name; + if (keywords.ContainsKey(input_name)) + { + inputs.Add(keywords[input_name] as Tensor); + } + if (!String.IsNullOrEmpty(input_arg.TypeAttr)) + { + attrs[input_arg.TypeAttr] = DataType.DtFloat; + } } var attr_protos = new Dictionary(); @@ -60,7 +81,7 @@ namespace Tensorflow switch (attr_def.Type) { case "type": - attr_value.Type = dtype.Value; + attr_value.Type = (DataType)keywords["dtype"]; break; case "shape": attr_value.Shape = new TensorShapeProto(); @@ -84,9 +105,9 @@ namespace Tensorflow } } - var op = g.create_op(op_type_name, null, output_types.ToArray(), + var op = g.create_op(op_type_name, inputs, output_types.ToArray(), name: scope, - input_types: new DataType[] { }, + input_types: input_types.ToArray(), attrs: attr_protos, op_def: op_def); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 6f976409..bf591c6c 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -27,7 +27,10 @@ PreserveNewest - + + PreserveNewest + + PreserveNewest diff --git a/src/TensorFlowNET.Core/Tensorflow/op_list_proto_bytes.bin b/src/TensorFlowNET.Core/Tensorflow/op_list_proto_array.bin similarity index 100% rename from src/TensorFlowNET.Core/Tensorflow/op_list_proto_bytes.bin rename to src/TensorFlowNET.Core/Tensorflow/op_list_proto_array.bin diff --git a/src/TensorFlowNET.Core/Tensorflow/op_list_proto_math.bin b/src/TensorFlowNET.Core/Tensorflow/op_list_proto_math.bin new file mode 100644 index 0000000000000000000000000000000000000000..c94552c83b513ca8a1120349f429291d1fb8b984 GIT binary patch literal 11722 zcmdT~+jHA!6-V*?I7t(^X;Z_{T?0eeOG(Ujm(6Y|NpECmLy66nVTSR@mg;CLORglR zP8kMfc!gnkX7~?y=YAM>3w@lP5d-RoMx=96IaFmH7%-;e3sv6AZU*@0X@QS(ka&*ob ztu~kWoP4p6tCZyuJq2;9dXQNNZ$_J4q03D#1>s>IRst~WQmK0g_Xs1iG(yYeF z^0znwUxp!9s37#SN%A#Xw>5Rtb9Kw?=%%XoG^deo>*}~pRJ2Ae)VN8|eC4^O4Lhnn zaO8Z~vW)uDjW0DeNlf)@qQ111+bL4M$u`jjv1Z@UWb)5q`N1x{9dQC4V)lf!Wp8a+ z7Q5nexU1xSk(P+*H&0#%{W`7f4u`fi)(0@^`!7a{k&=4><^b88ZpGKA;#!7gD`pRt z{LHWvw_fabd~pRDrZL(}8AujckPSt9o+murFpHV5XiL z*D1bLDDq0E$bdTCyol6sByd2Q?-F(2QyX(50TyJA1bl$2oZ*1?VE9-dyp)4RoHP+U zYLSu7b^4V0YAe zuDMtoH~3gI%uQX{xJJ|k4XR@6nj=>Y-EnUSSN|3rQ2wssxQJWGr}>uMYl9A?l>9x% z<(w#nJm2Ft9`)ZYQ@L%A%pPJ;Oeb$+KTQwS9$Nh*N5ZsBs@^s9;oiw7iZK!jaapi6 zTJ2%DbIhZifo@J2R7Cbq7}Kf24v~NTt~D4cwzg}kd%8PRde7sML}IMCN4Dl1S%!*3 zW0K9M{FdMxoLK4SRN`F?Z44>qZA~$lUOXZ`_j#8pyrvdmB8oP)+7|0Uc~B)<|Bnz^#PLBL*jr(m0D*b7MIT|Oo|u-T)+UE6X5 zbc*{~9bSTL$3OtWOaUAZlT=pMfyw9(5_B;d@bn-cz{&E;SPzf`@E+>|8v!?HeDGA$ngcgJLk>{w`I$MGtL7-DQzwUHKlByp;(wa@u(<%I_l;CF1?L>z$jF z+c(3Q5IWaPbo(JlH;K&@^+Qm)%T`E2bZ7tCGY^(2=nB9D{xW1Y05CDp{P1e(1Okww@G7$ZD%@aZrN+46k<|~R657H{ zbrvvtJ9GT}IuYjUJ;eZOlOgIcnggDee8F?D2UhcE9ayeE%hx4+4WjpU_>nzxgCP~8 zk|l6R2Ko~K_a#9VMdihzm+IASz6C&dL$Gyqw8@Vpz;xw zK#e|#0hJ({<5vW;X&TK{Sp2_}DF0YPj*m4}=5=uX-^>WvPBaESXW1fvEhgJrFq1jS zdD6Ucm0At{O1H6WN&z0|axz$X2+BiktWgB`11;^;GXwm6D#P>UhPN!B?R}~|v5b>2 z&;_0t1iV>E^?k}cv0er-D2&Vjw<)rv8hU|9E^N3@xL&v*hWi(UvLBBWldJrD)=LvH zuXSYW=JP`aFY=|JXIX5EKB>Gy74g7z@}?AfRmPiAT@K5-GRhv&lS8YM{1%;95`X74 zfIGq7UO&A;rDg$r?@Mb)aN#x{7}|+G(%&Rn4k%Z-M_uAvz0}_3hZ8(B8F%5#4o(q| z$~b()=t(q+M?_u3oDx^K=<}i}Ay&S^2Y$p#Jfeo|^8)rj_2o4^>MmHM8Bcv1d7sN(F&3 z1Z$}P3G6KbIM$L!2^dkpc=Q9yG%iH#J-0gR2h3~9Lij&42hS)wG)1VXSv>F7?!5vRn8Xng2S z{cr_f4_kBl4rw^JQCY>^BFIUP(<@r{SeW7Rm#CEZvjHG?JoKPtn5bRT{k=JjlHW_T zkvd%11>B&I-WZ3vQwd;-{5*tf_^oqM^iPT2NKy1ucaIK617AQ++(%C*?klv6rc%d` z)=p!&jo099aoORo1c{wO%|^}GUyz(&rG>0@a3yuQxE`csaat6QQBya|Rw?k!~zEC>bMgj&7 zx1&ZrGOp-h&%~b(5wFD(^`8FhnPzMJPbIA4h7kWIZV+be!nHkaFxjshfu4uEqf zO( _InitOpDefLibrary(); + public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); public static Tensor placeholder(DataType dtype, TensorShape shape = null) { - var _op = _op_def_lib._apply_op_helper("Placeholder", dtype: dtype, shape: shape); + /*var g = ops.get_default_graph(); + var op = new Operation(g, "Placeholder", "feed"); + + var tensor = new Tensor(op, 0, dtype); + + return tensor;*/ + + var keywords = new Dictionary(); + keywords.Add("dtype", dtype); + keywords.Add("shape", shape); + + var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); var _result = _op.outputs; var _inputs_flat = _op.inputs; var _attrs = new Dictionary(); @@ -27,7 +38,7 @@ namespace Tensorflow private static OpDefLibrary _InitOpDefLibrary() { // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); - var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_bytes.bin"); + var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin"); var op_list = OpList.Parser.ParseFrom(bytes); var op_def_lib = new OpDefLibrary(); op_def_lib.add_op_list(op_list); diff --git a/src/TensorFlowNET.Core/ops/gen_math_ops.cs b/src/TensorFlowNET.Core/ops/gen_math_ops.cs index c48f4737..eece48e0 100644 --- a/src/TensorFlowNET.Core/ops/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/ops/gen_math_ops.cs @@ -1,15 +1,34 @@ using System; using System.Collections.Generic; +using System.IO; using System.Text; namespace Tensorflow { public static class gen_math_ops { - public static Tensor add(Tensor a, Tensor b, string name = "") + public static OpDefLibrary _op_def_lib = _InitOpDefLibrary(); + + public static Tensor add(Tensor a, Tensor b) { + var keywords = new Dictionary(); + keywords.Add("x", a); + keywords.Add("y", b); + + var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); return null; } + + private static OpDefLibrary _InitOpDefLibrary() + { + // c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle); + var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin"); + var op_list = OpList.Parser.ParseFrom(bytes); + var op_def_lib = new OpDefLibrary(); + op_def_lib.add_op_list(op_list); + + return op_def_lib; + } } } diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 582f7b5a..2c035bd7 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -20,18 +20,11 @@ namespace Tensorflow public static unsafe Tensor add(Tensor a, Tensor b) { - return null; + return gen_math_ops.add(a, b); } public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null) { - /*var g = ops.get_default_graph(); - var op = new Operation(g, "Placeholder", "feed"); - - var tensor = new Tensor(op, 0, dtype); - - return tensor;*/ - return gen_array_ops.placeholder(dtype, shape); }