Browse Source

add data type when determine which property should be taken out. #115

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
3265a38c08
6 changed files with 80 additions and 9 deletions
  1. +27
    -0
      src/TensorFlowNET.Core/Eager/Tape.cs
  2. +13
    -1
      src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs
  3. +13
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  4. +4
    -3
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +20
    -2
      src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs

+ 27
- 0
src/TensorFlowNET.Core/Eager/Tape.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Eager
{
public class Tape
{
public static bool IsDtypeTrainable(DataType dtype)
{
switch (dtype)
{
case DataType.DtHalf:
case DataType.DtBfloat16:
case DataType.DtFloat:
case DataType.DtDouble:
case DataType.DtComplex64:
case DataType.DtComplex128:
case DataType.DtResource:
case DataType.DtVariant:
return true;
default:
return false;
}
}
}
}

+ 13
- 1
src/TensorFlowNET.Core/Eager/pywrap_tfe_src.cs View File

@@ -12,7 +12,19 @@ namespace Tensorflow.Eager
{
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "")
{
var input_ids = inputs.Select(x => x.Id).ToArray();
var input_dtypes = inputs.Select(x => x.dtype).ToArray();

bool should_record = false;
foreach (var input_dtype in input_dtypes)
{
if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum()))
{
should_record = true;
break;
}
}
if (!should_record) return;
}
}
}

+ 13
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

@@ -126,11 +127,11 @@ namespace Tensorflow
Graph._add_op(this);
}

public object get_attr(string name)
public object get_attr<T>(string name)
{
AttrValue x = null;

var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" };
var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" };

using (var buf = new Buffer())
{
@@ -141,12 +142,21 @@ namespace Tensorflow

switch (name)
{
case "T":
case "dtype":
return x.Type;
case "shape":
return x.Shape;
default:
throw new NotImplementedException($"{name}");
switch (typeof(T).Name)
{
case "Boolean":
return x.B;
case "String":
return x.S;
default:
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
}
}
}



+ 4
- 3
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -21,12 +21,13 @@ namespace Tensorflow
var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords);
var _result = _op.outputs;
var _inputs_flat = _op.inputs;
var _attrs = new Dictionary<string, object>();

_attrs["dtype"] = _op.get_attr("dtype");
_attrs["shape"] = _op.get_attr("shape");
var _attrs = new Dictionary<string, object>();
_attrs["dtype"] = _op.get_attr<DataType>("dtype");
_attrs["shape"] = _op.get_attr<int[]>("shape");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);

return new Tensor(_op, 0, dtype);
}



+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -16,6 +16,9 @@ namespace Tensorflow
{
private readonly IntPtr _handle;

private int _id;
public int Id => _id;

public Graph Graph => op.Graph;
public Operation op { get; }



+ 20
- 2
src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs View File

@@ -2,12 +2,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow
{
public class gen_state_ops
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
public static Execute _execute = new Execute();

/// <summary>
/// Holds state in the form of a tensor that persists across steps.
@@ -32,6 +34,14 @@ namespace Tensorflow
var _result = _op.outputs;
var _inputs_flat = _op.inputs;

var _attrs = new Dictionary<string, object>();
_attrs["dtype"] = _op.get_attr<DataType>("dtype");
_attrs["shape"] = _op.get_attr<int[]>("shape");
_attrs["container"] = _op.get_attr<string>("container");
_attrs["shared_name"] = _op.get_attr<string>("shared_name");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);

return new Tensor(_op, 0, dtype);
}

@@ -56,9 +66,17 @@ namespace Tensorflow

var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords);

var _result = _op.outputs[0];
var _result = _op.outputs;
var _inputs_flat = _op.inputs;
return _result;

var _attrs = new Dictionary<string, object>();
_attrs["T"] = _op.get_attr<DataType>("T");
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape");
_attrs["use_locking"] = _op.get_attr<bool>("use_locking");

_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name);

return _result[0];
}
}
}

Loading…
Cancel
Save