Browse Source

Fix Operation.get_attr #115

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
1876cc982f
5 changed files with 55 additions and 12 deletions
  1. +3
    -3
      README.md
  2. +17
    -0
      src/TensorFlowNET.Core/Eager/Execute.cs
  3. +18
    -0
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
  4. +13
    -8
      src/TensorFlowNET.Core/Operations/Operation.cs
  5. +4
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs

+ 3
- 3
README.md View File

@@ -12,15 +12,15 @@ TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) st
![tensors_flowing](docs/assets/tensors_flowing.gif)

### How to use
Download the pre-compiled dll [here](tensorflow.so) and place it in the working folder.
This is only need for Linux and Mac OS, and already packed for Windows.

Install TensorFlow.NET through NuGet.
```sh
PM> Install-Package TensorFlow.NET
```

If you are using Linux or Mac OS, please download the pre-compiled dll [here](tensorflow.so) and place it in the working folder. This is only need for Linux and Mac OS, and already packed into NuGet for Windows.

Import tensorflow.net.

```cs
using Tensorflow;
```


+ 17
- 0
src/TensorFlowNET.Core/Eager/Execute.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public class Execute
{
public void record_gradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
{
if (inputs == null)
inputs = new Tensor[0];

pywrap_tfe_src.RecordGradient(op_name, inputs, attrs, results, name);
}
}
}

+ 18
- 0
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -0,0 +1,18 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Eager
{
/// <summary>
/// python\eager\pywrap_tfe_src.cc
/// </summary>
public class pywrap_tfe_src
{
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
{
}
}
}

+ 13
- 8
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -113,7 +113,7 @@ namespace Tensorflow
op_def = g.GetOpDef(node_def.Op);

_handle = ops._create_c_op(g, node_def, inputs);
_outputs = new Tensor[NumOutputs];
output_types = new TF_DataType[NumOutputs];

@@ -128,21 +128,26 @@ namespace Tensorflow

public object get_attr(string name)
{
object ret = null;
AttrValue x = null;

var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };

using (var buf = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true);
x = AttrValue.Parser.ParseFrom(buf);
}

switch (name)
{
case "dtype":
ret = _outputs[0];
break;
return x.Type;
case "shape":
ret = new TensorShapeProto();
break;
return x.Shape;
default:
throw new NotImplementedException($"{name}");
}

return ret;
}

public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)


+ 4
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -3,14 +3,16 @@ using System.Collections.Generic;
using System.IO;
using System.Text;
using Tensorflow;
using Tensorflow.Eager;

namespace Tensorflow
{
public static class gen_array_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
public static Execute _execute = new Execute();

public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = "")
{
var keywords = new Dictionary<string, object>();
keywords.Add("dtype", dtype);
@@ -24,6 +26,7 @@ namespace Tensorflow
_attrs["dtype"] = _op.get_attr("dtype");
_attrs["shape"] = _op.get_attr("shape");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);
return new Tensor(_op, 0, dtype);
}



Loading…
Cancel
Save