Browse Source

add op_list_proto_math

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
bc13c9a932
7 changed files with 65 additions and 18 deletions
  1. +26
    -5
      src/TensorFlowNET.Core/OpDefLibrary.cs
  2. +4
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  3. +0
    -0
      src/TensorFlowNET.Core/Tensorflow/op_list_proto_array.bin
  4. BIN
      src/TensorFlowNET.Core/Tensorflow/op_list_proto_math.bin
  5. +14
    -3
      src/TensorFlowNET.Core/ops/gen_array_ops.cs
  6. +20
    -1
      src/TensorFlowNET.Core/ops/gen_math_ops.cs
  7. +1
    -8
      src/TensorFlowNET.Core/tf.cs

+ 26
- 5
src/TensorFlowNET.Core/OpDefLibrary.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
_ops[op_def.Name] = op_def;
}

public unsafe Operation _apply_op_helper(string op_type_name, string name = "", DataType? dtype = null, TensorShape shape = null)
public unsafe Operation _apply_op_helper(string op_type_name, string name = "", Dictionary<string, object> keywords = null)
{
var op_def = _ops[op_type_name];

@@ -46,9 +46,30 @@ namespace Tensorflow
var key = attr_def.Name;
}

foreach(var input_arg in op_def.InputArg)
var attrs = new Dictionary<string, object>();
var inputs = new List<Tensor>();
var input_types = new List<DataType>();

foreach (var attr in op_def.Attr)
{
if (keywords.ContainsKey(attr.Name))
{
attrs[attr.Name] = keywords[attr.Name];
}
}

foreach (var input_arg in op_def.InputArg)
{
var input_name = input_arg.Name;
if (keywords.ContainsKey(input_name))
{
inputs.Add(keywords[input_name] as Tensor);
}

if (!String.IsNullOrEmpty(input_arg.TypeAttr))
{
attrs[input_arg.TypeAttr] = DataType.DtFloat;
}
}

var attr_protos = new Dictionary<string, AttrValue>();
@@ -60,7 +81,7 @@ namespace Tensorflow
switch (attr_def.Type)
{
case "type":
attr_value.Type = dtype.Value;
attr_value.Type = (DataType)keywords["dtype"];
break;
case "shape":
attr_value.Shape = new TensorShapeProto();
@@ -84,9 +105,9 @@ namespace Tensorflow
}
}

var op = g.create_op(op_type_name, null, output_types.ToArray(),
var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
name: scope,
input_types: new DataType[] { },
input_types: input_types.ToArray(),
attrs: attr_protos,
op_def: op_def);



+ 4
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -27,7 +27,10 @@
<None Update="tensorflow.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Tensorflow\op_list_proto_bytes.bin">
<None Update="Tensorflow\op_list_proto_array.bin">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Tensorflow\op_list_proto_math.bin">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>


src/TensorFlowNET.Core/Tensorflow/op_list_proto_bytes.bin → src/TensorFlowNET.Core/Tensorflow/op_list_proto_array.bin View File


BIN
src/TensorFlowNET.Core/Tensorflow/op_list_proto_math.bin View File


+ 14
- 3
src/TensorFlowNET.Core/ops/gen_array_ops.cs View File

@@ -8,11 +8,22 @@ namespace Tensorflow
{
public static class gen_array_ops
{
public static OpDefLibrary _op_def_lib => _InitOpDefLibrary();
public static OpDefLibrary _op_def_lib = _InitOpDefLibrary();

public static Tensor placeholder(DataType dtype, TensorShape shape = null)
{
var _op = _op_def_lib._apply_op_helper("Placeholder", dtype: dtype, shape: shape);
/*var g = ops.get_default_graph();
var op = new Operation(g, "Placeholder", "feed");

var tensor = new Tensor(op, 0, dtype);

return tensor;*/

var keywords = new Dictionary<string, object>();
keywords.Add("dtype", dtype);
keywords.Add("shape", shape);

var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords);
var _result = _op.outputs;
var _inputs_flat = _op.inputs;
var _attrs = new Dictionary<string, object>();
@@ -27,7 +38,7 @@ namespace Tensorflow
private static OpDefLibrary _InitOpDefLibrary()
{
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle);
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_bytes.bin");
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_array.bin");
var op_list = OpList.Parser.ParseFrom(bytes);
var op_def_lib = new OpDefLibrary();
op_def_lib.add_op_list(op_list);


+ 20
- 1
src/TensorFlowNET.Core/ops/gen_math_ops.cs View File

@@ -1,15 +1,34 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow
{
public static class gen_math_ops
{
public static Tensor add(Tensor a, Tensor b, string name = "")
public static OpDefLibrary _op_def_lib = _InitOpDefLibrary();

public static Tensor add(Tensor a, Tensor b)
{
var keywords = new Dictionary<string, object>();
keywords.Add("x", a);
keywords.Add("y", b);

var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords);

return null;
}

private static OpDefLibrary _InitOpDefLibrary()
{
// c_api.TF_GraphGetOpDef(g.Handle, op_type_name, buffer.Handle, status.Handle);
var bytes = File.ReadAllBytes("Tensorflow/op_list_proto_math.bin");
var op_list = OpList.Parser.ParseFrom(bytes);
var op_def_lib = new OpDefLibrary();
op_def_lib.add_op_list(op_list);

return op_def_lib;
}
}
}

+ 1
- 8
src/TensorFlowNET.Core/tf.cs View File

@@ -20,18 +20,11 @@ namespace Tensorflow

public static unsafe Tensor add(Tensor a, Tensor b)
{
return null;
return gen_math_ops.add(a, b);
}

public static unsafe Tensor placeholder(DataType dtype, TensorShape shape = null)
{
/*var g = ops.get_default_graph();
var op = new Operation(g, "Placeholder", "feed");

var tensor = new Tensor(op, 0, dtype);
return tensor;*/

return gen_array_ops.placeholder(dtype, shape);
}



Loading…
Cancel
Save