@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public class Tape | |||||
{ | |||||
public static bool IsDtypeTrainable(DataType dtype) | |||||
{ | |||||
switch (dtype) | |||||
{ | |||||
case DataType.DtHalf: | |||||
case DataType.DtBfloat16: | |||||
case DataType.DtFloat: | |||||
case DataType.DtDouble: | |||||
case DataType.DtComplex64: | |||||
case DataType.DtComplex128: | |||||
case DataType.DtResource: | |||||
case DataType.DtVariant: | |||||
return true; | |||||
default: | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -12,7 +12,19 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | public static void RecordGradient(string op_name, Tensor[] inputs, Dictionary<string, object> attrs, Tensor[] results, string name = "") | ||||
{ | { | ||||
var input_ids = inputs.Select(x => x.Id).ToArray(); | |||||
var input_dtypes = inputs.Select(x => x.dtype).ToArray(); | |||||
bool should_record = false; | |||||
foreach (var input_dtype in input_dtypes) | |||||
{ | |||||
if (Tape.IsDtypeTrainable(input_dtype.as_datatype_enum())) | |||||
{ | |||||
should_record = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!should_record) return; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | using System.Text; | ||||
@@ -126,11 +127,11 @@ namespace Tensorflow | |||||
Graph._add_op(this); | Graph._add_op(this); | ||||
} | } | ||||
public object get_attr(string name) | |||||
public object get_attr<T>(string name) | |||||
{ | { | ||||
AttrValue x = null; | AttrValue x = null; | ||||
var fields = new string[] { "s", "i", "f", "b", "type", "shape", "tensor", "func" }; | |||||
var fields = new string[] { "s", "i", "f", "b", "Type", "Shape", "Tensor", "func" }; | |||||
using (var buf = new Buffer()) | using (var buf = new Buffer()) | ||||
{ | { | ||||
@@ -141,12 +142,21 @@ namespace Tensorflow | |||||
switch (name) | switch (name) | ||||
{ | { | ||||
case "T": | |||||
case "dtype": | case "dtype": | ||||
return x.Type; | return x.Type; | ||||
case "shape": | case "shape": | ||||
return x.Shape; | return x.Shape; | ||||
default: | default: | ||||
throw new NotImplementedException($"{name}"); | |||||
switch (typeof(T).Name) | |||||
{ | |||||
case "Boolean": | |||||
return x.B; | |||||
case "String": | |||||
return x.S; | |||||
default: | |||||
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -21,12 +21,13 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Placeholder", keywords: keywords); | ||||
var _result = _op.outputs; | var _result = _op.outputs; | ||||
var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
var _attrs = new Dictionary<string, object>(); | |||||
_attrs["dtype"] = _op.get_attr("dtype"); | |||||
_attrs["shape"] = _op.get_attr("shape"); | |||||
var _attrs = new Dictionary<string, object>(); | |||||
_attrs["dtype"] = _op.get_attr<DataType>("dtype"); | |||||
_attrs["shape"] = _op.get_attr<int[]>("shape"); | |||||
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | _execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | ||||
return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
} | } | ||||
@@ -16,6 +16,9 @@ namespace Tensorflow | |||||
{ | { | ||||
private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
private int _id; | |||||
public int Id => _id; | |||||
public Graph Graph => op.Graph; | public Graph Graph => op.Graph; | ||||
public Operation op { get; } | public Operation op { get; } | ||||
@@ -2,12 +2,14 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class gen_state_ops | public class gen_state_ops | ||||
{ | { | ||||
public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
public static Execute _execute = new Execute(); | |||||
/// <summary> | /// <summary> | ||||
/// Holds state in the form of a tensor that persists across steps. | /// Holds state in the form of a tensor that persists across steps. | ||||
@@ -32,6 +34,14 @@ namespace Tensorflow | |||||
var _result = _op.outputs; | var _result = _op.outputs; | ||||
var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
var _attrs = new Dictionary<string, object>(); | |||||
_attrs["dtype"] = _op.get_attr<DataType>("dtype"); | |||||
_attrs["shape"] = _op.get_attr<int[]>("shape"); | |||||
_attrs["container"] = _op.get_attr<string>("container"); | |||||
_attrs["shared_name"] = _op.get_attr<string>("shared_name"); | |||||
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
} | } | ||||
@@ -56,9 +66,17 @@ namespace Tensorflow | |||||
var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | var _op = _op_def_lib._apply_op_helper("Assign", name: name, keywords: keywords); | ||||
var _result = _op.outputs[0]; | |||||
var _result = _op.outputs; | |||||
var _inputs_flat = _op.inputs; | var _inputs_flat = _op.inputs; | ||||
return _result; | |||||
var _attrs = new Dictionary<string, object>(); | |||||
_attrs["T"] = _op.get_attr<DataType>("T"); | |||||
_attrs["validate_shape"] = _op.get_attr<bool>("validate_shape"); | |||||
_attrs["use_locking"] = _op.get_attr<bool>("use_locking"); | |||||
_execute.record_gradient("Placeholder", _inputs_flat, _attrs, _result, name); | |||||
return _result[0]; | |||||
} | } | ||||
} | } | ||||
} | } |