Browse Source

add data type when determine which property should be taken out. #115

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
3265a38c08
6 changed files with 80 additions and 9 deletions
  1. +27
    -0
      src/TensorFlowNET.Core/Eager/Tape.cs
  2. +13
    -1
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
  3. +13
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +4
    -3
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +20
    -2
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

+ 27
- 0
src/TensorFlowNET.Core/Eager/Tape.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public class Tape
{
public static bool IsDtypeTrainable(DataType dtype)
{
switch (dtype)
{
case DataType.DtHalf:
case DataType.DtBfloat16:
case DataType.DtFloat:
case DataType.DtDouble:
case DataType.DtComplex64:
case DataType.DtComplex128:
case DataType.DtResource:
case DataType.DtVariant:
return true;
default:
return false;
}
}
}
}

+ 13
- 1
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -12,7 +12,19 @@ namespace Tensorflow.Eager
{ {
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
{ {
var input_ids = inputs.Select(x => x.Id).ToArray();
var input_dtypes = inputs.Select(x => x.dtype).ToArray();

bool should_record = false;
foreach (var input_dtype in input_dtypes)
{
if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum()))
{
should_record = true;
break;
}
}
if (!should_record) return;
} }
} }
} }

+ 13
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;


@@ -126,11 +127,11 @@ namespace Tensorflow
Graph._add_op(this); Graph._add_op(this);
} }


public object get_attr(string name)
public object get_attr<T>(string name)
{ {
AttrValue x = null; AttrValue x = null;


var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };
var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" };


using (var buf = new Buffer()) using (var buf = new Buffer())
{ {
@@ -141,12 +142,21 @@ namespace Tensorflow


switch (name) switch (name)
{ {
case "T":
case "dtype": case "dtype":
return x.Type; return x.Type;
case "shape": case "shape":
return x.Shape; return x.Shape;
default: default:
throw new NotImplementedException($"{name}");
switch (typeof(T).Name)
{
case "Boolean":
return x.B;
case "String":
return x.S;
default:
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
}
} }
} }




+ 4
- 3
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -21,12 +21,13 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords);
var _result = _op.outputs; var _result = _op.outputs;
var _inputs_flat = _op.inputs; var _inputs_flat = _op.inputs;
var _attrs = new Dictionary<string, object>();


_attrs["dtype"] = _op.get_attr("dtype");
_attrs["shape"] = _op.get_attr("shape");
var _attrs = new Dictionary<string, object>();
_attrs["dtype"] = _op.get_attr<DataType>("dtype");
_attrs["shape"] = _op.get_attr<int[]>("shape");


_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);

return new Tensor(_op, 0, dtype); return new Tensor(_op, 0, dtype);
} }




+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -16,6 +16,9 @@ namespace Tensorflow
{ {
private readonly IntPtr _handle; private readonly IntPtr _handle;


private int _id;
public int Id => _id;

public Graph Graph => op.Graph; public Graph Graph => op.Graph;
public Operation op { get; } public Operation op { get; }




+ 20
- 2
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -2,12 +2,14 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Eager;


namespace Tensorflow namespace Tensorflow
{ {
public class gen_state_ops public class gen_state_ops
{ {
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); public static OpDefLibrary _op_def_lib = new OpDefLibrary();
public static Execute _execute = new Execute();


/// <summary> /// <summary>
/// Holds state in the form of a tensor that persists across steps. /// Holds state in the form of a tensor that persists across steps.
@@ -32,6 +34,14 @@ namespace Tensorflow
var _result = _op.outputs; var _result = _op.outputs;
var _inputs_flat = _op.inputs; var _inputs_flat = _op.inputs;


var _attrs = new Dictionary<string, object>();
_attrs["dtype"] = _op.get_attr<DataType>("dtype");
_attrs["shape"] = _op.get_attr<int[]>("shape");
_attrs["container"] = _op.get_attr<string>("container");
_attrs["shared_name"] = _op.get_attr<string>("shared_name");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);

return new Tensor(_op, 0, dtype); return new Tensor(_op, 0, dtype);
} }


@@ -56,9 +66,17 @@ namespace Tensorflow


var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords);


var _result = _op.outputs[0];
var _result = _op.outputs;
var _inputs_flat = _op.inputs; var _inputs_flat = _op.inputs;
return _result;

var _attrs = new Dictionary<string, object>();
_attrs["T"] = _op.get_attr<DataType>("T");
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape");
_attrs["use_locking"] = _op.get_attr<bool>("use_locking");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);

return _result[0];
} }
} }
} }

Loading…
Cancel
Save