From 1876cc982f7286bc187550f4a57890777a6c93f9 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 20 Jan 2019 22:13:45 -0600 Subject: [PATCH] Fix Operation.get_attr #115 --- README.md | 6 +++--- src/TensorFlowNET.Core/Eager/Execute.cs | 17 +++++++++++++++ .../Eager/pywrap_tfe_src.cs | 18 ++++++++++++++++ .../Operations/Operation.cs | 21 ++++++++++++------- .../Operations/gen_array_ops.cs | 5 ++++- 5 files changed, 55 insertions(+), 12 deletions(-) create mode 100644 src/TensorFlowNET.Core/Eager/Execute.cs create mode 100644 src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs diff --git a/README.md b/README.md index e740b0ed..81804ccb 100644 --- a/README.md +++ b/README.md @@ -12,15 +12,15 @@ TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) st ![tensors_flowing](docs/assets/tensors_flowing.gif) ### How to use -Download the pre-compiled dll [here](tensorflow.so) and place it in the working folder. -This is only need for Linux and Mac OS, and already packed for Windows. - Install TensorFlow.NET through NuGet. ```sh PM> Install-Package TensorFlow.NET ``` +If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflow.so) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows. + Import tensorflow.net. + ```cs using Tensorflow; ``` diff --git a/src/TensorFlowNET.Core/Eager/Execute.cs b/src/TensorFlowNET.Core/Eager/Execute.cs new file mode 100644 index 00000000..ed8c7839 --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/Execute.cs @@ -0,0 +1,17 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Eager +{ + public class Execute + { + public void record_gradient(string op_name, Tensor[] inputs, Dictionary attrs, Tensor[] results, string name = "") + { + if (inputs == null) + inputs = new Tensor[0]; + + pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs new file mode 100644 index 00000000..b972e4ec --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs @@ -0,0 +1,18 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.Eager +{ + /// + /// python\eager\pywrap_tfe_src.cc + /// + public class pywrap_tfe_src + { + public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary attrs, Tensor[] results, string name = "") + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 783a8ef7..81743336 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -113,7 +113,7 @@ namespace Tensorflow op_def = g.GetOpDef(node_def.Op); _handle = ops._create_c_op(g, node_def, inputs); - + _outputs = new Tensor[NumOutputs]; output_types = new TF_DataType[NumOutputs]; @@ -128,21 +128,26 @@ namespace Tensorflow public object get_attr(string name) { - object ret = null; + AttrValue x = null; var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; + using (var buf = new Buffer()) + { + c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); + status.Check(true); + x = AttrValue.Parser.ParseFrom(buf); + } + switch (name) { case "dtype": - ret = _outputs[0]; - break; + return x.Type; case "shape": - ret = new TensorShapeProto(); - break; + return x.Shape; + default: + throw new NotImplementedException($"{name}"); } - - return ret; } public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 560f6173..80fd60a9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -3,14 +3,16 @@ using System.Collections.Generic; using System.IO; using System.Text; using Tensorflow; +using Tensorflow.Eager; namespace Tensorflow { public static class gen_array_ops { public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + public static Execute _execute = new Execute(); - public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) + public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = "") { var keywords = new Dictionary(); keywords.Add("dtype", dtype); @@ -24,6 +26,7 @@ namespace Tensorflow _attrs["dtype"] = _op.get_attr("dtype"); _attrs["shape"] = _op.get_attr("shape"); + _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); return new Tensor(_op, 0, dtype); }