@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Protobuf.Text; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
@@ -12,18 +12,36 @@ namespace Tensorflow.Eager | |||||
return HasGradientTape(); | return HasGradientTape(); | ||||
} | } | ||||
private bool ShouldRecord(Tensor[] inputs) | |||||
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) | |||||
{ | { | ||||
bool should_record = false; | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
var tape_set = tf.GetTapeSet(); | |||||
var input_ids = MakeTensorIDList(tensors); | |||||
var input_dtypes = MakeTensorDtypeList(tensors); | |||||
bool some_tape_watching = false; | |||||
if (tape_set is not null && tape_set.Count > 0) | |||||
{ | { | ||||
if (tape.ShouldRecord(inputs)) | |||||
foreach (var tape in tape_set) | |||||
{ | { | ||||
should_record = true; | |||||
break; | |||||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||||
{ | |||||
if (tape.Persistent || some_tape_watching) | |||||
{ | |||||
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||||
} | |||||
some_tape_watching = true; | |||||
} | |||||
} | } | ||||
} | } | ||||
return should_record; | |||||
// skip the forward_accumulators. | |||||
if (some_tape_watching) | |||||
{ | |||||
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||||
} | |||||
else | |||||
{ | |||||
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -13,7 +13,17 @@ namespace Tensorflow.Eager | |||||
Tensor[] results, | Tensor[] results, | ||||
BackwardFunction backwardFunction = null) | BackwardFunction backwardFunction = null) | ||||
{ | { | ||||
bool should_record = ShouldRecord(inputs); | |||||
var input_ids = MakeTensorIDList(inputs); | |||||
var input_dtypes = MakeTensorDtypeList(inputs); | |||||
bool should_record = false; | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
{ | |||||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||||
{ | |||||
should_record = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!should_record) | if (!should_record) | ||||
{ | { | ||||
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager | |||||
op_inputs = inputs;*/ | op_inputs = inputs;*/ | ||||
backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | ||||
TapeSetRecordOperation(op_name, inputs, results, backwardFunction); | |||||
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
return HasGradientTape(); | return HasGradientTape(); | ||||
} | } | ||||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||||
{ | |||||
return tensors.Select(x => x.dtype).ToArray(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,6 +1,8 @@ | |||||
using System; | |||||
using OneOf.Types; | |||||
using System; | |||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager | |||||
/// </summary> | /// </summary> | ||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="tape"></param> | |||||
/// <param name="target"></param> | |||||
/// <param name="sources"></param> | |||||
/// <param name="output_gradients"></param> | |||||
/// <param name="unconnected_gradients">determines the value returned if the target and | |||||
/// sources are unconnected.When 'none' the value returned is None wheras when | |||||
/// 'zero' a zero tensor in the same shape as the sources is returned.</param> | |||||
/// <returns></returns> | |||||
/// <exception cref="RuntimeError"></exception> | |||||
public Tensor[] TFE_TapeGradient(ITape tape, | public Tensor[] TFE_TapeGradient(ITape tape, | ||||
Tensor[] target, | Tensor[] target, | ||||
Tensor[] sources, | Tensor[] sources, | ||||
Tensor[] output_gradients) | |||||
List<Tensor> output_gradients, | |||||
Tensor[] sources_raw, | |||||
string unconnected_gradients) | |||||
{ | { | ||||
var target_vec = target; | |||||
var sources_vec = sources; | |||||
var sources_set = sources_vec; | |||||
if (!tape.Persistent) | |||||
{ | |||||
var tape_set = tf.GetTapeSet(); | |||||
if (tape_set.Contains(tape)) | |||||
{ | |||||
throw new RuntimeError("gradient() cannot be invoked within the " + | |||||
"GradientTape context (i.e., while operations are being " + | |||||
"recorded). Either move the call to gradient() to be " + | |||||
"outside the 'with tf.GradientTape' block, or " + | |||||
"use a persistent tape: " + | |||||
"'with tf.GradientTape(persistent=true)'"); | |||||
} | |||||
} | |||||
var target_vec = MakeTensorIDList(target); | |||||
var sources_vec = MakeTensorIDList(sources); | |||||
HashSet<long> sources_set = new HashSet<long>(sources_vec); | |||||
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||||
int len = target.Length; | |||||
for(int i = 0; i < len; i++) | |||||
{ | |||||
var target_id = target_vec[i]; | |||||
if (sources_set.Contains(target_id)) | |||||
{ | |||||
var tensor = target[i]; | |||||
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); | |||||
} | |||||
} | |||||
List<Tensor> outgrad_vec = new(); | |||||
if(output_gradients is not null) | |||||
{ | |||||
outgrad_vec = output_gradients.ToList(); | |||||
} | |||||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||||
var seq_array = target; | |||||
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||||
for (int i = 0; i < target.Length; ++i) | |||||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||||
Tensor[] sources_obj = null; | |||||
if (unconnected_gradients_zero) | |||||
{ | { | ||||
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||||
sources_obj = MakeTensorList(sources_raw); | |||||
} | } | ||||
if (output_gradients != null) | |||||
if (result.Length > 0) | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
for(int i = 0; i < result.Length; i++) | |||||
{ | |||||
if (result[i] is null && unconnected_gradients_zero) | |||||
{ | |||||
var dtype = sources_obj[i].dtype; | |||||
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); | |||||
} | |||||
} | |||||
} | } | ||||
else | |||||
return result; | |||||
} | |||||
Tensor[] MakeTensorList(IEnumerable<Tensor> tensors) | |||||
{ | |||||
return tensors.ToArray(); | |||||
} | |||||
long[] MakeTensorIDList(Tensor[] tensors) | |||||
{ | |||||
int len = tensors.Length; | |||||
long[] ids = new long[len]; | |||||
for(int i = 0; i < len; i++) | |||||
{ | |||||
var tensor = tensors[i]; | |||||
ids[i] = tensor.Id; | |||||
} | |||||
return ids; | |||||
} | |||||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||||
{ | |||||
int len = tensors.Length; | |||||
TF_DataType[] dtypes = new TF_DataType[len]; | |||||
for (int i = 0; i < len; i++) | |||||
{ | { | ||||
output_gradients = new Tensor[0]; | |||||
var tensor = tensors[i]; | |||||
dtypes[i] = tensor.dtype; | |||||
} | } | ||||
return dtypes; | |||||
} | |||||
var outgrad_vec = MakeTensorList(output_gradients); | |||||
TapeTensor TapeTensorFromTensor(Tensor tensor) | |||||
{ | |||||
long id = tensor.Id; | |||||
var dtype = tensor.dtype; | |||||
if (tensor is EagerTensor) | |||||
{ | |||||
var handle = tensor.EagerTensorHandle; | |||||
if (DTypeNeedsHandleData(dtype)) | |||||
{ | |||||
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); | |||||
} | |||||
Status status = new(); | |||||
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); | |||||
long[] dims = new long[num_dims]; | |||||
for(int i = 0; i < num_dims; i++) | |||||
{ | |||||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||||
} | |||||
Shape tensor_shape = new(dims); | |||||
if(status.Code != TF_Code.TF_OK) | |||||
{ | |||||
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); | |||||
} | |||||
else | |||||
{ | |||||
return new TapeTensor(id, dtype, tensor_shape); | |||||
} | |||||
} | |||||
var shape_tuple = tensor.shape.dims; | |||||
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) | |||||
{ | |||||
return new TapeTensor(id, dtype, tensor); | |||||
} | |||||
long[] l = new long[shape_tuple.Length]; | |||||
for(int i = 0; i < shape_tuple.Length; i++) | |||||
{ | |||||
if (shape_tuple[i] < 0) | |||||
{ | |||||
l[i] = 0; | |||||
} | |||||
else | |||||
{ | |||||
l[i] = shape_tuple[i]; | |||||
} | |||||
} | |||||
return new TapeTensor(id, dtype, new Shape(l)); | |||||
} | |||||
return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); | |||||
bool DTypeNeedsHandleData(TF_DataType dtype) | |||||
{ | |||||
return dtype == dtypes.variant || dtype == dtypes.resource; | |||||
} | } | ||||
Tensor[] MakeTensorList(Tensor[] tensors) | |||||
bool ListContainNone(long[] list) | |||||
{ | { | ||||
return tensors; | |||||
int len = list.Length; | |||||
if(len == 0) | |||||
{ | |||||
return true; | |||||
} | |||||
for(int i = 0; i < len; i++) | |||||
{ | |||||
if (list[i] == -1) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -7,8 +7,9 @@ namespace Tensorflow.Eager | |||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
void TapeSetRecordBackprop(string op_type, | void TapeSetRecordBackprop(string op_type, | ||||
Tensor[] input_tensors, | |||||
TapeTensor[] output_tensors, | |||||
TapeTensor[] output_info, | |||||
long[] input_ids, | |||||
TF_DataType[] input_detyps, | |||||
BackwardFunction backward_function) | BackwardFunction backward_function) | ||||
{ | { | ||||
if (!CouldBackprop()) | if (!CouldBackprop()) | ||||
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||||
foreach (var tape in tf.GetTapeSet()) | foreach (var tape in tf.GetTapeSet()) | ||||
{ | { | ||||
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); | |||||
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -10,18 +10,28 @@ namespace Tensorflow.Eager | |||||
public bool TapeSetRecordOperation(string op_type, | public bool TapeSetRecordOperation(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
Tensor[] output_tensors, | Tensor[] output_tensors, | ||||
long[] input_ids, | |||||
TF_DataType[] input_dtypes, | |||||
BackwardFunction backward_function) | BackwardFunction backward_function) | ||||
{ | { | ||||
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||||
var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); | |||||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | ||||
backward_function)) | backward_function)) | ||||
return false; | return false; | ||||
TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||||
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, | |||||
backward_function); | backward_function); | ||||
return true; | return true; | ||||
} | } | ||||
public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||||
Tensor[] input_tensors, BackwardFunction backward_function) | |||||
{ | |||||
var input_ids = MakeTensorIDList(input_tensors); | |||||
var input_dtypes = MakeTensorDtypeList(input_tensors); | |||||
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||||
backward_function); | |||||
} | |||||
} | } | ||||
} | } |
@@ -29,7 +29,14 @@ namespace Tensorflow.Eager | |||||
Tensor[] TFE_TapeGradient(ITape tape, | Tensor[] TFE_TapeGradient(ITape tape, | ||||
Tensor[] target, | Tensor[] target, | ||||
Tensor[] sources, | Tensor[] sources, | ||||
Tensor[] output_gradients); | |||||
List<Tensor> output_gradients, | |||||
Tensor[] sources_raw, | |||||
string unconnected_gradients); | |||||
void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||||
Tensor[] input_tensors, BackwardFunction backward_function); | |||||
int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); | |||||
bool RecordGradient(string op_name, | bool RecordGradient(string op_name, | ||||
Tensor[] inputs, | Tensor[] inputs, | ||||
@@ -18,12 +18,13 @@ namespace Tensorflow.Functions | |||||
public class ConcreteFunction: Trackable | public class ConcreteFunction: Trackable | ||||
{ | { | ||||
protected IEnumerable<Tensor> _captured_inputs; | protected IEnumerable<Tensor> _captured_inputs; | ||||
internal FuncGraph func_graph; | |||||
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | ||||
protected Dictionary<string, AttrValue> _attrs; | protected Dictionary<string, AttrValue> _attrs; | ||||
protected FunctionSpec _function_spec; | protected FunctionSpec _function_spec; | ||||
protected FunctionSpec _pre_initialized_function_spec = null; | protected FunctionSpec _pre_initialized_function_spec = null; | ||||
protected EagerDefinedFunction _inference_function; | protected EagerDefinedFunction _inference_function; | ||||
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new(); | |||||
internal FuncGraph func_graph; | |||||
internal ForwardBackwardCall forward_backward; | internal ForwardBackwardCall forward_backward; | ||||
public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
@@ -156,6 +157,17 @@ namespace Tensorflow.Functions | |||||
{ | { | ||||
var executing_eagerly = tf.Context.executing_eagerly(); | var executing_eagerly = tf.Context.executing_eagerly(); | ||||
var default_graph = ops.get_default_graph(); | var default_graph = ops.get_default_graph(); | ||||
// TODO(Rinne): deal with `default_graph.building_function` | |||||
var tempvv = func_graph.Variables; | |||||
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) | |||||
{ | |||||
foreach(var v in this.func_graph.Variables) | |||||
{ | |||||
resource_variable_ops.variable_accessed(v); | |||||
} | |||||
} | |||||
var tensor_inputs = new Tensors(); | var tensor_inputs = new Tensors(); | ||||
foreach (var (i, arg) in enumerate(args)) | foreach (var (i, arg) in enumerate(args)) | ||||
{ | { | ||||
@@ -223,11 +235,16 @@ namespace Tensorflow.Functions | |||||
{ | { | ||||
input_tangents = new TangentInfo(); | input_tangents = new TangentInfo(); | ||||
} | } | ||||
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient()) | |||||
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||||
{ | { | ||||
if(input_tangents.Indices is not null || executing_eagerly) | if(input_tangents.Indices is not null || executing_eagerly) | ||||
{ | { | ||||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
string cache_key = "first_order"; | |||||
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) | |||||
{ | |||||
functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
_tape_functions_cache[cache_key] = functions; | |||||
} | |||||
return new ForwardBackwardCall(functions, args, tape_watching: true); | return new ForwardBackwardCall(functions, args, tape_watching: true); | ||||
} | } | ||||
else | else | ||||
@@ -241,7 +258,7 @@ namespace Tensorflow.Functions | |||||
} | } | ||||
// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | ||||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); | |||||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||||
} | } | ||||
internal void set_variables(IEnumerable<IVariableV1> variables) | internal void set_variables(IEnumerable<IVariableV1> variables) | ||||
@@ -124,17 +124,16 @@ namespace Tensorflow.Functions | |||||
// TODO(Rinne): Add arg `CancellationManager`. | // TODO(Rinne): Add arg `CancellationManager`. | ||||
// TODO(Rinne): Check the arg length. | // TODO(Rinne): Check the arg length. | ||||
var function_call_options = tf.Context.FunctionCallOptions; | var function_call_options = tf.Context.FunctionCallOptions; | ||||
string config; | |||||
if (function_call_options.config_proto_serialized().Length == 0) | |||||
{ | |||||
config = function_utils.get_disabled_rewriter_config().ToString(); | |||||
} | |||||
else | |||||
{ | |||||
config = function_call_options.config_proto_serialized().ToString(); | |||||
} | |||||
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. | |||||
config = ""; // TODO(Rinne): revise it. | |||||
//if (function_call_options.config_proto_serialized().Length == 0) | |||||
//{ | |||||
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||||
//} | |||||
//else | |||||
//{ | |||||
// config = function_call_options.config_proto_serialized().ToStringUtf8(); | |||||
//} | |||||
string executor_type = function_call_options.ExecutorType ?? ""; | string executor_type = function_call_options.ExecutorType ?? ""; | ||||
var executing_eagerly = tf.Context.executing_eagerly(); | var executing_eagerly = tf.Context.executing_eagerly(); | ||||
@@ -14,12 +14,11 @@ namespace Tensorflow.Functions | |||||
} | } | ||||
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||||
ForwardAndBackwardFunctions(Tensors inference_args) | |||||
{ | { | ||||
var outputs = _func_graph.Outputs; | |||||
(_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) | |||||
= BuildFunctionsForOutputs(outputs, inference_args); | |||||
return _forward_function; | |||||
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); | |||||
return BuildFunctionsForOutputs(outputs, inference_args); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -14,7 +14,6 @@ namespace Tensorflow | |||||
protected ConcreteFunction _concrete_variable_creation_fn; | protected ConcreteFunction _concrete_variable_creation_fn; | ||||
protected bool _autograph; | protected bool _autograph; | ||||
protected TracingCompiler _variable_creation_fn; | protected TracingCompiler _variable_creation_fn; | ||||
protected bool _has_initialized; | |||||
public string Name { get; set; } | public string Name { get; set; } | ||||
public Function(Func<Tensor[], Tensor[]> csharp_function, | public Function(Func<Tensor[], Tensor[]> csharp_function, | ||||
string name, bool auto_graph = true) | string name, bool auto_graph = true) | ||||
@@ -22,7 +21,6 @@ namespace Tensorflow | |||||
_csharp_function = csharp_function; | _csharp_function = csharp_function; | ||||
Name = name; | Name = name; | ||||
_autograph = auto_graph; | _autograph = auto_graph; | ||||
_has_initialized = false; | |||||
} | } | ||||
public virtual Tensors Apply(Tensors inputs) | public virtual Tensors Apply(Tensors inputs) | ||||
@@ -38,10 +36,11 @@ namespace Tensorflow | |||||
protected virtual Tensors _call(Tensors inputs) | protected virtual Tensors _call(Tensors inputs) | ||||
{ | { | ||||
if (!_has_initialized) | |||||
if(_variable_creation_fn is not null) | |||||
{ | { | ||||
_initialize(inputs); | |||||
return _variable_creation_fn.Apply(inputs); | |||||
} | } | ||||
_initialize(inputs); | |||||
return _concrete_variable_creation_fn.CallFlat(inputs, | return _concrete_variable_creation_fn.CallFlat(inputs, | ||||
_concrete_variable_creation_fn.CapturedInputs); | _concrete_variable_creation_fn.CapturedInputs); | ||||
@@ -63,7 +62,6 @@ namespace Tensorflow | |||||
_variable_creation_fn = _compiler(_csharp_function); | _variable_creation_fn = _compiler(_csharp_function); | ||||
_variable_creation_fn._name = this.Name; | _variable_creation_fn._name = this.Name; | ||||
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | ||||
_has_initialized = true; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -24,23 +24,40 @@ namespace Tensorflow.Functions | |||||
protected string _INFERENCE_PREFIX = "__inference_"; | protected string _INFERENCE_PREFIX = "__inference_"; | ||||
protected FuncGraph _func_graph; | protected FuncGraph _func_graph; | ||||
protected EagerDefinedFunction _forward_function; | |||||
protected EagerDefinedFunction _forward; | |||||
protected FuncGraph _forward_graph; | protected FuncGraph _forward_graph; | ||||
protected List<int> _forwardprop_input_indices; | |||||
protected List<int> _forwardprop_output_indices; | protected List<int> _forwardprop_output_indices; | ||||
protected int _num_forwardprop_outputs; | protected int _num_forwardprop_outputs; | ||||
protected ConcreteFunction _backward_function; | |||||
protected int _num_inference_outputs; | |||||
protected int _num_outputs; | |||||
protected int _num_trainable_inference_outputs; | |||||
protected ConcreteFunction _backward; | |||||
BackwardFunction _backward_function_wrapper; | BackwardFunction _backward_function_wrapper; | ||||
public TapeGradientFunctions(FuncGraph func_graph, | public TapeGradientFunctions(FuncGraph func_graph, | ||||
bool need_gradients_for_jvps) | bool need_gradients_for_jvps) | ||||
{ | { | ||||
_func_graph = func_graph; | _func_graph = func_graph; | ||||
_forward_graph = null; | |||||
_forward = null; | |||||
_backward = null; | |||||
_num_outputs = func_graph.Outputs.Length; | |||||
_forwardprop_output_indices = null; | |||||
_num_forwardprop_outputs = 0; | |||||
_num_inference_outputs = func_graph.Outputs.Length; | |||||
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); | |||||
} | } | ||||
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | ||||
{ | { | ||||
// TODO(Rinne): add input_tangents arg. | // TODO(Rinne): add input_tangents arg. | ||||
return ForwardAndBackwardFunctions(inference_args); | |||||
if(_forward is null) | |||||
{ | |||||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||||
= ForwardAndBackwardFunctions(inference_args); | |||||
} | |||||
return _forward; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -51,9 +68,13 @@ namespace Tensorflow.Functions | |||||
public virtual void Record(Tensors flat_outputs, Tensors inference_args) | public virtual void Record(Tensors flat_outputs, Tensors inference_args) | ||||
{ | { | ||||
// TODO(Rinne): add arg `input_tagents`. | // TODO(Rinne): add arg `input_tagents`. | ||||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); | |||||
tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, | |||||
getBackwardFunction: backward_function); | |||||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||||
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) | |||||
{ | |||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | |||||
} | |||||
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -65,66 +86,95 @@ namespace Tensorflow.Functions | |||||
/// <returns></returns> | /// <returns></returns> | ||||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | ||||
{ | { | ||||
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) | |||||
.ToDictionary(x => x.Item1, x => x.Item2); | |||||
var captured_inputs = backward.CapturedInputs; | |||||
var remapped_captures = captured_inputs.Select(c => | |||||
{ | |||||
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) | |||||
{ | |||||
return value; | |||||
} | |||||
else | |||||
{ | |||||
return c; | |||||
} | |||||
}).ToArray(); | |||||
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) | |||||
{ | |||||
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); | |||||
throw new RuntimeError($"Failed to map all backward graph captures to " + | |||||
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); | |||||
} | |||||
Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>(); | |||||
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | ||||
var recorded_outputs = new Tensors(); | var recorded_outputs = new Tensors(); | ||||
var trainable_recorded_outputs = 0; | |||||
foreach (var (output_index, output) in enumerate(outputs)) | |||||
int trainable_recorded_outputs = 0; | |||||
var skip_positions = new HashSet<int>(); | |||||
var relevant_outputs = outputs; | |||||
foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||||
{ | { | ||||
if (trainable_recorded_outputs < backward_function_inputs) | if (trainable_recorded_outputs < backward_function_inputs) | ||||
recorded_outputs.Add(output); | recorded_outputs.Add(output); | ||||
if (gradients_util.IsTrainable(output)) | |||||
trainable_recorded_outputs += 1; | |||||
if (backprop_util.IsTrainable(output)) | |||||
trainable_recorded_outputs++; | |||||
else | |||||
skip_positions.Add(output_index); | |||||
if (output.dtype == dtypes.variant) | |||||
variant_zeros_like[output_index] = default_gradient.zeros_like(output); | |||||
} | } | ||||
if(_backward_function_wrapper == null) | |||||
_backward_function_wrapper = (args, unneeded_gradients) => | |||||
{ | { | ||||
var capture_mapping = new Dictionary<long, Tensor>(); | |||||
foreach (var (i, output) in enumerate(outputs)) | |||||
capture_mapping[forward_graph.Outputs[i].Id] = output; | |||||
var remapped_captures = new Tensors(); | |||||
foreach (var capture in backward.CapturedInputs) | |||||
{ | |||||
if (capture_mapping.ContainsKey(capture.Id)) | |||||
remapped_captures.Add(capture_mapping[capture.Id]); | |||||
} | |||||
var skip_positions = new List<int>(); | |||||
foreach (var (output_index, output) in enumerate(outputs)) | |||||
if(backward.Outputs is null || backward.Outputs.Length == 0) | |||||
{ | { | ||||
if (!gradients_util.IsTrainable(output)) | |||||
skip_positions.Add(output_index); | |||||
return backward.FlatStructuredOutputs; | |||||
} | } | ||||
_backward_function_wrapper = (args, unneeded_gradients) => | |||||
var processed_args = new Tensors(); | |||||
int input_index = 0; | |||||
foreach (var (output_index, arg) in enumerate(args)) | |||||
{ | { | ||||
var processed_args = new Tensors(); | |||||
var input_index = 0; | |||||
foreach (var (output_index, arg) in enumerate(args)) | |||||
if (skip_positions.Contains(output_index)) | |||||
continue; | |||||
if (arg is null) | |||||
{ | |||||
var input_placeholder = backward.Inputs[input_index]; | |||||
Tensor variant_arg; | |||||
if (input_placeholder.dtype == dtypes.variant) | |||||
{ | |||||
variant_arg = variant_zeros_like[output_index]; | |||||
} | |||||
else | |||||
{ | |||||
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); | |||||
variant_arg = array_ops.zeros(shape, type); | |||||
} | |||||
processed_args.Add(variant_arg); | |||||
} | |||||
else | |||||
{ | { | ||||
if (skip_positions.Contains(output_index)) | |||||
continue; | |||||
if (arg == null) | |||||
throw new NotImplementedException(""); | |||||
processed_args.Add(arg); | processed_args.Add(arg); | ||||
input_index += 1; | |||||
if (input_index >= backward_function_inputs) | |||||
break; | |||||
} | } | ||||
input_index++; | |||||
if (input_index >= backward_function_inputs) | |||||
break; | |||||
} | |||||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||||
{ | |||||
var index = Convert.ToInt32(unneeded_gradient_index); | |||||
if (gradients.Length <= index) | |||||
gradients.Insert(index, null); | |||||
} | |||||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||||
{ | |||||
var index = Convert.ToInt32(unneeded_gradient_index); | |||||
if (gradients.Length <= index) | |||||
gradients.Insert(index, null); | |||||
} | |||||
return gradients; | |||||
}; | |||||
} | |||||
return gradients; | |||||
}; | |||||
return (_backward_function_wrapper, recorded_outputs); | return (_backward_function_wrapper, recorded_outputs); | ||||
} | } | ||||
@@ -143,7 +193,7 @@ namespace Tensorflow.Functions | |||||
} | } | ||||
} | } | ||||
var backwards_graph = new FuncGraph(_func_graph.Name); | |||||
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||||
backwards_graph.as_default(); | backwards_graph.as_default(); | ||||
var gradients_wrt_outputs = new List<Tensor>(); | var gradients_wrt_outputs = new List<Tensor>(); | ||||
foreach (var output in trainable_outputs) | foreach (var output in trainable_outputs) | ||||
@@ -153,6 +203,7 @@ namespace Tensorflow.Functions | |||||
gradients_wrt_outputs.Add(gradient_placeholder); | gradients_wrt_outputs.Add(gradient_placeholder); | ||||
handle_data_util.copy_handle_data(output, gradient_placeholder); | handle_data_util.copy_handle_data(output, gradient_placeholder); | ||||
} | } | ||||
// TODO(Rinne): with ops.device(None) | |||||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | ||||
_func_graph.Inputs, | _func_graph.Inputs, | ||||
grad_ys: gradients_wrt_outputs.ToArray(), | grad_ys: gradients_wrt_outputs.ToArray(), | ||||
@@ -175,7 +226,8 @@ namespace Tensorflow.Functions | |||||
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | ||||
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | ||||
var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||||
var (wrapped_forward_function, wrapped_backward_function) = | |||||
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||||
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | ||||
//var backward_function_attr = new Dictionary<string, string>(); | //var backward_function_attr = new Dictionary<string, string>(); | ||||
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | ||||
@@ -189,10 +241,11 @@ namespace Tensorflow.Functions | |||||
// _func_graph.Inputs, _func_graph.Outputs, | // _func_graph.Inputs, _func_graph.Outputs, | ||||
// monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | // monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | ||||
return (forward_function, _func_graph, backward_function, null, 0); | |||||
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); | |||||
} | } | ||||
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||||
ForwardAndBackwardFunctions(Tensors inference_args) | |||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -73,12 +73,12 @@ namespace Tensorflow.Functions | |||||
private static string male_cache_key(Tensor[] inputs) | private static string male_cache_key(Tensor[] inputs) | ||||
{ | { | ||||
string res = ""; | |||||
foreach (var input in inputs) | |||||
{ | |||||
res += $"{input.name}_{input.Id}"; | |||||
} | |||||
return res; | |||||
//string res = ""; | |||||
//foreach (var input in inputs) | |||||
//{ | |||||
// res += $"{input.name}_{input.Id}"; | |||||
//} | |||||
return inputs.Length.ToString(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -153,7 +153,7 @@ namespace Tensorflow.Functions | |||||
foreach(var tape in tf.GetTapeSet()) | foreach(var tape in tf.GetTapeSet()) | ||||
{ | { | ||||
tape.RecordOperation(_inference_function.Signature.Name, to_record, | tape.RecordOperation(_inference_function.Signature.Name, to_record, | ||||
inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); | |||||
inference_args, backward_function); | |||||
} | } | ||||
} | } | ||||
@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients | |||||
/// Map from tensor to how many references still exist for this tensor in | /// Map from tensor to how many references still exist for this tensor in | ||||
/// the tape. | /// the tape. | ||||
/// </summary> | /// </summary> | ||||
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||||
public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Maps from op ID to how many output tensors of this op still need to have | /// Maps from op ID to how many output tensors of this op still need to have | ||||
/// their gradients computed. | /// their gradients computed. | ||||
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients | |||||
public BackpropInitialState() | public BackpropInitialState() | ||||
{ | { | ||||
op_tape = new OpTape(); | op_tape = new OpTape(); | ||||
tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||||
tensor_usage_counts = new UnorderedMap<long, long>(); | |||||
op_missing_tensor = new UnorderedMap<long, long>(); | op_missing_tensor = new UnorderedMap<long, long>(); | ||||
} | } | ||||
} | } | ||||
@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients | |||||
/// <param name="target"></param> | /// <param name="target"></param> | ||||
/// <param name="source"></param> | /// <param name="source"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor gradient(Tensor target, Tensor source) | |||||
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
if(_tape is null) | |||||
{ | |||||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||||
"compute one set of gradients (or jacobians)."); | |||||
} | |||||
ITape tape = stop_recording(); | ITape tape = stop_recording(); | ||||
var results = tf.Runner.TFE_TapeGradient(tape, | var results = tf.Runner.TFE_TapeGradient(tape, | ||||
new[] { target }, | new[] { target }, | ||||
new[] { source }, | new[] { source }, | ||||
null); | |||||
output_gradients, | |||||
new[] { source }, | |||||
unconnected_gradients); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public Tensor gradient(Tensor target, ResourceVariable source) | |||||
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
var results = gradient(target, new List<IVariableV1> { source }); | |||||
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); | |||||
return (results[0], results[1]); | return (results[0], results[1]); | ||||
} | } | ||||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null, | |||||
string unconnected_gradients = null) | |||||
{ | { | ||||
if (_tape is null) | |||||
{ | |||||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||||
"compute one set of gradients (or jacobians)."); | |||||
} | |||||
var tape = stop_recording(); | var tape = stop_recording(); | ||||
var results = tf.Runner.TFE_TapeGradient(tape, | var results = tf.Runner.TFE_TapeGradient(tape, | ||||
new[] { target }, | new[] { target }, | ||||
sources.Select(x => x.Handle).ToArray(), | sources.Select(x => x.Handle).ToArray(), | ||||
null); | |||||
output_gradients, | |||||
sources.Select(x => x.Handle).ToArray(), | |||||
unconnected_gradients); | |||||
if (!tape.Persistent) | if (!tape.Persistent) | ||||
{ | { | ||||
@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients | |||||
public interface ITape | public interface ITape | ||||
{ | { | ||||
void SetTapeId(int id); | void SetTapeId(int id); | ||||
bool ShouldRecord(Tensor[] tensors); | |||||
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); | |||||
void StartRecord(); | void StartRecord(); | ||||
void StopRecord(); | void StopRecord(); | ||||
bool Persistent { get; } | bool Persistent { get; } | ||||
void RecordOperation(string op_type, | void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | |||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_tensor_id, | |||||
TF_DataType[] input_dtypes, | |||||
BackwardFunction backward_function); | BackwardFunction backward_function); | ||||
void VariableAccessed(ResourceVariable variable); | |||||
void RecordOperation(string op_type, | |||||
Tensor[] outputs, | |||||
Tensor[] inputs, | |||||
BackwardFunction backward_function); | |||||
void VariableAccessed(IVariableV1 variable); | |||||
void Watch(Tensor x); | void Watch(Tensor x); | ||||
ResourceVariable[] WatchedVariables(); | |||||
IVariableV1[] WatchedVariables(); | |||||
Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||||
Tensor[] source_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients); | |||||
Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
long[] source_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
List<Tensor> output_gradients, | |||||
bool build_default_zeros_grads); | |||||
} | } | ||||
} | } |
@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public string op_type { get; set; } | public string op_type { get; set; } | ||||
public TapeTensor[] output_tensor_info { get; set; } | public TapeTensor[] output_tensor_info { get; set; } | ||||
public Tensor[] input_tensor_id { get; set; } | |||||
public long[] input_tensor_id { get; set; } | |||||
public BackwardFunction backward_function { get; set; } | public BackwardFunction backward_function { get; set; } | ||||
public override string ToString() | public override string ToString() | ||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||||
} | } | ||||
} | } |
@@ -2,235 +2,246 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
// int kMinAggregateCount = 4; | |||||
// int kMinAggregateBytes = 128 * 1024 * 1024; | |||||
static readonly int kMinAggregateCount = 4; | |||||
static readonly int kMinAggregateBytes = 128 * 1024 * 1024; | |||||
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap; | |||||
public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||||
Tensor[] source_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients) | |||||
static Tape() | |||||
{ | { | ||||
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||||
// var gradients_size = new UnorderedMap<Tensor, long>(); | |||||
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); | |||||
var state = PrepareBackprop( | |||||
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||||
output_gradients, | |||||
tensor_tape_, | |||||
state.op_tape); | |||||
_functionsAcceptingNoneForIndicesMap = new(); | |||||
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||||
} | |||||
while (!op_stack.empty()) | |||||
public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
long[] source_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
List<Tensor> output_gradients, | |||||
bool build_default_zeros_grads) | |||||
{ | |||||
UnorderedSet<long> sources_set = new(source_tensor_ids); | |||||
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); | |||||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); | |||||
UnorderedMap<long, long> gradients_size = new(); | |||||
while(op_stack.Count > 0) | |||||
{ | { | ||||
var op = op_stack.Dequeue(); | |||||
if (!state.op_tape.find(op, out var trace)) | |||||
long op = op_stack.Dequeue(); | |||||
if(!state.op_tape.TryGetValue(op, out var op_it)) | |||||
{ | |||||
continue; | continue; | ||||
// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||||
} | |||||
var trace = op_it; | |||||
state.op_tape.erase(op); | state.op_tape.erase(op); | ||||
var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||||
var unneeded_gradients = new List<long>(); | |||||
for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||||
List<Tensor> out_gradients = new(); | |||||
List<long> unneeded_gradients = new(); | |||||
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) | |||||
{ | { | ||||
var in_tensor_id = trace.input_tensor_id[i]; | |||||
if (!tensor_tape_.find(in_tensor_id) && | |||||
!sources_set.find(in_tensor_id)) | |||||
long in_tensor_id = trace.input_tensor_id[i]; | |||||
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) | |||||
{ | |||||
unneeded_gradients.Add(i); | unneeded_gradients.Add(i); | ||||
} | |||||
} | } | ||||
bool any_gradient_nonzero = false; | bool any_gradient_nonzero = false; | ||||
var zero_indices = new List<int>(); | |||||
for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||||
List<int> zero_indices = new(); | |||||
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) | |||||
{ | { | ||||
var id = trace.output_tensor_info[i].GetTensor(); | |||||
if (!gradients.find(id, out var grad_it)) | |||||
long id = trace.output_tensor_info[i].GetID(); | |||||
if(!gradients.TryGetValue(id, out var grad_it)) | |||||
{ | { | ||||
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && | |||||
func_name_it.find(i)) | |||||
out_gradients.Add(null); | |||||
if (build_default_zeros_grads) | |||||
{ | { | ||||
out_gradients.Add(null); | |||||
} | |||||
else | |||||
{ | |||||
out_gradients.Add(null); | |||||
zero_indices.Add(i); | |||||
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || | |||||
!func_name_it.find(i)) | |||||
{ | |||||
zero_indices.Add(i); | |||||
} | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
any_gradient_nonzero = true; | any_gradient_nonzero = true; | ||||
var new_gradients = grad_it.Count == 1 ? | |||||
grad_it[0] : | |||||
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||||
Tensor new_gradients; | |||||
if (grad_it.Count == 1) | |||||
{ | |||||
new_gradients = grad_it[0]; | |||||
} | |||||
else | |||||
{ | |||||
new_gradients = AggregateGradients(grad_it); | |||||
} | |||||
if (!sources_set.find(id)) | if (!sources_set.find(id)) | ||||
{ | |||||
gradients.Remove(id); | gradients.Remove(id); | ||||
} | |||||
else | else | ||||
{ | { | ||||
// grad_it.Clear(); | |||||
// grad_it.Add(new_gradients); | |||||
// vspace.MarkAsResult(new_gradients); | |||||
grad_it.Clear(); | |||||
grad_it.Add(new_gradients); | |||||
// MarkAsResult | |||||
} | } | ||||
out_gradients.Add(new_gradients); | out_gradients.Add(new_gradients); | ||||
} | } | ||||
} | } | ||||
Tensor[] in_gradients; | |||||
Tensor[] in_gradients = new Tensor[0]; | |||||
if (any_gradient_nonzero) | if (any_gradient_nonzero) | ||||
{ | { | ||||
// foreach (var i in zero_indices) | |||||
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||||
in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); | |||||
if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) | |||||
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||||
if (!_persistent) | |||||
foreach(var i in zero_indices) | |||||
{ | { | ||||
// trace.backward_function_deleter(trace.backward_function); | |||||
trace.backward_function = null; | |||||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||||
} | } | ||||
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||||
out_gradients.Clear(); | |||||
} | } | ||||
bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; | |||||
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) | |||||
for(int i = 0, end = in_gradients.Length; i < end; i++) | |||||
{ | { | ||||
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; | |||||
var id = trace.input_tensor_id[k]; | |||||
if (in_gradients[i] != null) | |||||
long id = trace.input_tensor_id[i]; | |||||
if (in_gradients[i] is not null) | |||||
{ | { | ||||
var unaggregated_grads = gradients[id]; | |||||
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>()); | |||||
unaggregated_grads.Add(in_gradients[i]); | unaggregated_grads.Add(in_gradients[i]); | ||||
/*if (unaggregated_grads.Count > kMinAggregateCount) | |||||
if(unaggregated_grads.Count > kMinAggregateCount) | |||||
{ | { | ||||
if (!gradients_size.find(id, out var size)) | |||||
if(!gradients_size.TryGetValue(id, out var size)) | |||||
{ | { | ||||
size = (long)unaggregated_grads[0].size; | |||||
size = NumElements(unaggregated_grads[0]); | |||||
gradients_size.emplace(id, size); | gradients_size.emplace(id, size); | ||||
} | } | ||||
if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||||
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||||
{ | { | ||||
throw new NotImplementedException(""); | |||||
Tensor grad = AggregateGradients(unaggregated_grads); | |||||
unaggregated_grads.Clear(); | |||||
unaggregated_grads.Add(grad); | |||||
} | } | ||||
}*/ | |||||
} | |||||
} | } | ||||
if (!state.tensor_usage_counts.find(id)) | |||||
if(!state.tensor_usage_counts.find(id)) | |||||
{ | |||||
continue; | continue; | ||||
} | |||||
state.tensor_usage_counts[id]--; | state.tensor_usage_counts[id]--; | ||||
if (state.tensor_usage_counts[id] > 0) | |||||
if(state.tensor_usage_counts[id] > 0) | |||||
{ | |||||
continue; | continue; | ||||
if (!tensor_tape_.find(id, out var tape_it)) | |||||
} | |||||
if (!tensor_tape_.TryGetValue(id, out var tape_it)) | |||||
{ | { | ||||
if (gradients.find(id, out var grad_it)) | |||||
if (gradients.find(id)) | |||||
{ | { | ||||
// foreach (var g in grad_it) | |||||
// DeleteGradient(g); | |||||
gradients.erase(id); | gradients.erase(id); | ||||
} | } | ||||
continue; | continue; | ||||
} | } | ||||
var op_id = tape_it; | |||||
if (op_id == -1) | |||||
long op_id = tape_it; | |||||
if(op_id == -1) | |||||
{ | |||||
continue; | continue; | ||||
if (state.op_missing_tensor.find(op_id, out var missing_it)) | |||||
} | |||||
if(state.op_missing_tensor.find(op_id)) | |||||
{ | { | ||||
state.op_missing_tensor[op_id]--; | state.op_missing_tensor[op_id]--; | ||||
if (state.op_missing_tensor[op_id] == 0) | |||||
if(state.op_missing_tensor[op_id] == 0) | |||||
{ | |||||
op_stack.Enqueue(op_id); | op_stack.Enqueue(op_id); | ||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (state.op_tape.Count > 0) | |||||
if(state.op_tape.Count > 0) | |||||
{ | |||||
throw new RuntimeError("Invalid tape state."); | throw new RuntimeError("Invalid tape state."); | ||||
var result = new Tensor[source_tensor_ids.Length]; | |||||
var j = 0; | |||||
foreach (var id in source_tensor_ids) | |||||
} | |||||
Tensor[] result = new Tensor[source_tensor_ids.Length]; | |||||
for(int i = 0; i < source_tensor_ids.Length; i++) | |||||
{ | { | ||||
if (gradients.find(id, out var grad_it)) | |||||
long tensor_id = source_tensor_ids[i]; | |||||
if(!gradients.TryGetValue(tensor_id, out var grad_it)) | |||||
{ | { | ||||
if (grad_it.Count > 1) | |||||
result[j] = gen_math_ops.add_n(grad_it.ToArray()); | |||||
else | |||||
result[j] = grad_it[0]; | |||||
result[i] = null; | |||||
} | |||||
else | |||||
{ | |||||
if(grad_it.Count > 1) | |||||
{ | |||||
Tensor grad = AggregateGradients(grad_it); | |||||
grad_it.Clear(); | |||||
grad_it.Add(grad); | |||||
} | |||||
result[i] = grad_it[0]; | |||||
} | } | ||||
j++; | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | ||||
{ | { | ||||
var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||||
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||||
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||||
return m; | |||||
return _functionsAcceptingNoneForIndicesMap; | |||||
} | } | ||||
UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients, | |||||
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
List<Tensor> output_gradients, | |||||
TensorTape tensor_tape, | TensorTape tensor_tape, | ||||
OpTape op_tape) | OpTape op_tape) | ||||
{ | { | ||||
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||||
for (int i = 0; i < target_tensor_ids.Length; ++i) | |||||
var result = new UnorderedMap<long, List<Tensor>>(); | |||||
for(int i = 0, end = target_tensor_ids.Length; i < end; i++) | |||||
{ | { | ||||
var id = target_tensor_ids[i]; | |||||
if (output_gradients.Length == 0 || output_gradients[i] == null) | |||||
long id = target_tensor_ids[i]; | |||||
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) | |||||
{ | { | ||||
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||||
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) | |||||
{ | { | ||||
if (!op_tape.find(tensor_tape[id], out var op_it)) | |||||
if(!op_tape.TryGetValue(tensor_it, out var op_it)) | |||||
{ | |||||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | throw new RuntimeError("Internal state of the gradient tape is invalid: " + | ||||
"failed to find operation producing a tensor"); | |||||
"failed to find operation producing a tensor."); | |||||
} | |||||
bool found = false; | bool found = false; | ||||
for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||||
for(int j = 0; j < op_it.output_tensor_info.Length; j++) | |||||
{ | { | ||||
if (op_it.output_tensor_info[j].GetTensor() == id) | |||||
if (op_it.output_tensor_info[j].GetID() == id) | |||||
{ | { | ||||
found = true; | found = true; | ||||
var ones = op_it.output_tensor_info[j].OnesLike(); | |||||
result[id].Add(ones); | |||||
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); | |||||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||||
break; | break; | ||||
} | } | ||||
} | } | ||||
if (!found) | if (!found) | ||||
{ | { | ||||
throw new ValueError("Internal state of the gradient tape is invalid: " + | |||||
"none of operations outputs match expected tensor"); | |||||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||||
"none of operations outputs match expected tensor."); | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
if (sources_that_are_targets.find(id, out var source_tensor)) | |||||
result[id].Add(source_tensor.OnesLike()); | |||||
if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) | |||||
{ | |||||
Tensor ones_like = BuildOnesLike(source_tensor); | |||||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||||
} | |||||
} | } | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
result[id].Add(output_gradients[i]); | |||||
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]); | |||||
} | } | ||||
} | } | ||||
@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
Tensor BuildOnesLike(TapeTensor t) | |||||
{ | |||||
return t.OnesLike(); | |||||
} | |||||
Tensor AggregateGradients(List<Tensor> gradient_tensors) | |||||
{ | |||||
if(gradient_tensors.Count == 0) | |||||
{ | |||||
return gradient_tensors[0]; | |||||
} | |||||
return tf.add_n(gradient_tensors.ToArray()); | |||||
} | |||||
void DeleteGradient(Tensor gradient) | |||||
{ | |||||
// Do not do anything here. Because GC will collect it when it has no reference. | |||||
} | |||||
long NumElements(Tensor tensor) => 1; | |||||
} | } | ||||
} | } |
@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
public BackpropInitialState PrepareBackprop(Tensor[] target, | |||||
public BackpropInitialState PrepareBackprop(long[] target, | |||||
TensorTape tensor_tape, | TensorTape tensor_tape, | ||||
OpTape op_tape, | OpTape op_tape, | ||||
UnorderedSet<Tensor> sources_set, | |||||
UnorderedSet<long> sources_set, | |||||
bool persistent_tape) | bool persistent_tape) | ||||
{ | { | ||||
Stack<long> tensor_stack = new Stack<long>(); | |||||
foreach(var t in target) | |||||
{ | |||||
tensor_stack.Push(t); | |||||
} | |||||
BackpropInitialState result = new BackpropInitialState(); | BackpropInitialState result = new BackpropInitialState(); | ||||
var tensor_stack = new Queue<Tensor>(target); | |||||
while (tensor_stack.Count > 0) | |||||
while(tensor_stack.Count > 0) | |||||
{ | { | ||||
var tensor_id = tensor_stack.Dequeue(); | |||||
if (!tensor_tape.find(tensor_id, out var op_id)) | |||||
long tensor_id = tensor_stack.Pop(); | |||||
if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) | |||||
{ | |||||
continue; | continue; | ||||
if (op_id == -1 || | |||||
!op_tape.find(op_id, out var op_it) || | |||||
result.op_tape.find(op_id, out var result_op_it)) | |||||
} | |||||
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) | |||||
|| result.op_tape.find(op_id)) | |||||
{ | |||||
continue; | continue; | ||||
} | |||||
result.op_tape.emplace(op_id, op_it); | result.op_tape.emplace(op_id, op_it); | ||||
foreach (var it in op_it.input_tensor_id) | |||||
foreach(var it in op_it.input_tensor_id) | |||||
{ | { | ||||
if (result.tensor_usage_counts.find(it)) | |||||
if(result.tensor_usage_counts.find(it)) | |||||
{ | |||||
result.tensor_usage_counts[it]++; | result.tensor_usage_counts[it]++; | ||||
} | |||||
else | else | ||||
{ | { | ||||
result.tensor_usage_counts[it] = 1; | result.tensor_usage_counts[it] = 1; | ||||
if (tensor_tape.find(it)) | if (tensor_tape.find(it)) | ||||
tensor_stack.Enqueue(it); | |||||
{ | |||||
tensor_stack.Push(it); | |||||
} | |||||
} | } | ||||
} | } | ||||
if (!persistent_tape) | if (!persistent_tape) | ||||
op_tape.Remove(op_id); | |||||
{ | |||||
op_tape.erase(op_id); | |||||
} | |||||
} | } | ||||
foreach (var pair in result.tensor_usage_counts) | |||||
foreach(var pair in result.tensor_usage_counts) | |||||
{ | { | ||||
if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||||
result.op_missing_tensor[it] += 1; | |||||
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) | |||||
{ | |||||
result.op_missing_tensor[it]++; | |||||
} | |||||
} | } | ||||
if (!persistent_tape) | if (!persistent_tape) | ||||
{ | { | ||||
// Call destructors for all unneeded gradient functions and | |||||
// clear the op_tape. We can clear the tape because ownership of | |||||
// backward functions that will be used for gradient computation | |||||
// has been transferred to `result`. | |||||
/*for (const auto&op_pair : *op_tape) { | |||||
op_pair.second.backward_function_deleter( | |||||
op_pair.second.backward_function); | |||||
}*/ | |||||
op_tape.Clear(); | op_tape.Clear(); | ||||
} | } | ||||
return result; | return result; | ||||
} | } | ||||
} | } | ||||
@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients | |||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
long next_op_id_ = 0; | long next_op_id_ = 0; | ||||
UnorderedMap<Tensor, long> tensor_usage_; | |||||
UnorderedMap<long, long> tensor_usage_; | |||||
public void RecordOperation(string op_type, | public void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | |||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_tensor_id, | |||||
TF_DataType[] input_dtypes, | |||||
BackwardFunction backward_function) | BackwardFunction backward_function) | ||||
{ | { | ||||
if (!ShouldRecord(input_tensors)) | |||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||||
return; | return; | ||||
var op_id = next_op_id_++; | |||||
foreach (var i in input_tensors) | |||||
foreach (var i in input_tensor_id) | |||||
{ | |||||
tensor_usage_[i]++; | tensor_usage_[i]++; | ||||
} | |||||
long op_id = next_op_id_++; | |||||
foreach (var o in output_tensors) | foreach (var o in output_tensors) | ||||
{ | { | ||||
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | ||||
tensor_tape_[o.GetTensor()] = op_id; | |||||
tensor_usage_[o.GetTensor()] = 1; | |||||
tensor_tape_[o.GetID()] = op_id; | |||||
tensor_usage_[o.GetID()] = 1; | |||||
} | } | ||||
op_tape_[op_id] = new OpTapeEntry | op_tape_[op_id] = new OpTapeEntry | ||||
{ | { | ||||
op_type = op_type, | op_type = op_type, | ||||
output_tensor_info = output_tensors, | |||||
input_tensor_id = input_tensors, | |||||
output_tensor_info = output_tensors.ToArray(), | |||||
input_tensor_id = input_tensor_id.ToArray(), | |||||
backward_function = backward_function | backward_function = backward_function | ||||
}; | }; | ||||
} | } | ||||
public void RecordOperation(string op_type, | |||||
Tensor[] outputs, | |||||
Tensor[] inputs, | |||||
BackwardFunction backward_function) | |||||
{ | |||||
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,5 +1,6 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients | |||||
_created_eagerly = tf.Context.executing_eagerly(); | _created_eagerly = tf.Context.executing_eagerly(); | ||||
tensor_tape_ = new TensorTape(); | tensor_tape_ = new TensorTape(); | ||||
op_tape_ = new OpTape(); | op_tape_ = new OpTape(); | ||||
tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||||
tensor_usage_ = new UnorderedMap<long, long>(); | |||||
if(_created_eagerly) | if(_created_eagerly) | ||||
tf.Context.start_step(); | tf.Context.start_step(); | ||||
// nesting_id = ++tape_nesting_id_counter; | // nesting_id = ++tape_nesting_id_counter; | ||||
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients | |||||
public void Watch(Tensor x) | public void Watch(Tensor x) | ||||
{ | { | ||||
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | ||||
tensor_tape_.emplace(x, -1); | |||||
tensor_tape_.emplace(x.Id, -1); | |||||
} | } | ||||
public bool ShouldRecord(Tensor[] tensors) | |||||
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) | |||||
{ | { | ||||
var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||||
for (int i = 0; i < tensors.Length; ++i) | |||||
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); | |||||
for (int i = 0; i < tensor_ids.Length; ++i) | |||||
{ | { | ||||
if (tensor_tape_.find(tensors[i])) | |||||
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) | |||||
{ | { | ||||
if (IsDtypeTrainable(dtypes[i])) | |||||
return true; | |||||
return true; | |||||
} | } | ||||
} | } | ||||
return false; | return false; | ||||
} | } | ||||
public void VariableAccessed(ResourceVariable variable) | |||||
public void VariableAccessed(IVariableV1 variable) | |||||
{ | { | ||||
Watch(variable.Handle); | Watch(variable.Handle); | ||||
} | } | ||||
public ResourceVariable[] WatchedVariables() | |||||
public IVariableV1[] WatchedVariables() | |||||
{ | { | ||||
return null; | return null; | ||||
} | } | ||||
@@ -1,27 +1,63 @@ | |||||
using static Tensorflow.Binding; | |||||
using OneOf; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public class TapeTensor | public class TapeTensor | ||||
{ | { | ||||
Tensor tensor; | |||||
long id => tensor.Id; | |||||
TF_DataType dtype => tensor.dtype; | |||||
Shape shape => tensor.shape; | |||||
internal Tensor tensor; | |||||
internal long id; | |||||
internal TF_DataType dtype; | |||||
internal OneOf<Shape, Tensor> shape; | |||||
public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||||
{ | |||||
this.id = id; | |||||
this.dtype = dtype; | |||||
this.shape = shape; | |||||
} | |||||
public TapeTensor(long id, TF_DataType dtype, Tensor shape) | |||||
{ | |||||
this.id = id; | |||||
this.dtype = dtype; | |||||
this.shape = shape; | |||||
} | |||||
public TapeTensor(Tensor tensor) | public TapeTensor(Tensor tensor) | ||||
{ | { | ||||
this.id = tensor.Id; | |||||
this.dtype = tensor.dtype; | |||||
this.shape = tensor.shape; | |||||
this.tensor = tensor; | this.tensor = tensor; | ||||
} | } | ||||
public long GetID() => tensor.Id; | |||||
public Tensor GetTensor() => tensor; | |||||
public long GetID() => id; | |||||
public Tensor ZerosLike() | public Tensor ZerosLike() | ||||
=> tf.zeros(shape: shape, dtype: dtype); | |||||
{ | |||||
if(dtype == dtypes.resource) | |||||
{ | |||||
return null; | |||||
} | |||||
if(shape.Index == 1) | |||||
{ | |||||
return tf.zeros_like(shape.AsT1); | |||||
} | |||||
return tf.zeros(shape.AsT0, dtype); | |||||
} | |||||
public Tensor OnesLike() | public Tensor OnesLike() | ||||
=> tf.ones(shape: shape, dtype: dtype); | |||||
{ | |||||
if (shape.Index == 1) | |||||
{ | |||||
return tf.ones_like(shape.AsT1); | |||||
} | |||||
return tf.ones(shape.AsT0, dtype); | |||||
} | |||||
//public Tensor OnesLike() | |||||
// => tf.ones(shape: shape, dtype: dtype); | |||||
public override string ToString() | public override string ToString() | ||||
=> $"{id}, {shape}, {dtype.as_numpy_name()}"; | => $"{id}, {shape}, {dtype.as_numpy_name()}"; | ||||
@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients | |||||
/// produced this tensor. A value of -1 means that the tensor was directly | /// produced this tensor. A value of -1 means that the tensor was directly | ||||
/// watched and not the result of any operation in the tape. | /// watched and not the result of any operation in the tape. | ||||
/// </summary> | /// </summary> | ||||
public class TensorTape : UnorderedMap<Tensor, long> | |||||
public class TensorTape : UnorderedMap<long, long> | |||||
{ | { | ||||
} | } | ||||
@@ -704,32 +704,7 @@ namespace Tensorflow | |||||
public static int PossibleTapeGradientTypes(Tensor[] tensors) | public static int PossibleTapeGradientTypes(Tensor[] tensors) | ||||
{ | { | ||||
var tape_set = tf.GetTapeSet(); | |||||
bool some_tape_watching = false; | |||||
if(tape_set is not null && tape_set.Count > 0) | |||||
{ | |||||
foreach(var tape in tape_set) | |||||
{ | |||||
if (tape.ShouldRecord(tensors)) | |||||
{ | |||||
if(tape.Persistent || some_tape_watching) | |||||
{ | |||||
return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||||
} | |||||
some_tape_watching = true; | |||||
} | |||||
} | |||||
} | |||||
// skip the forward_accumulators. | |||||
if (some_tape_watching) | |||||
{ | |||||
return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||||
} | |||||
else | |||||
{ | |||||
return POSSIBLE_GRADIENT_TYPES_NONE; | |||||
} | |||||
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable | |||||
return tensor; | return tensor; | ||||
} | } | ||||
public void watch_variable(IVariableV1 v) | |||||
{ | |||||
if (_resource_tensor_inputs.Contains(v.Handle)) | |||||
{ | |||||
return; | |||||
} | |||||
_watched_variables.Add(new WeakReference<IVariableV1>(v)); | |||||
//this = this.outer_graph; | |||||
} | |||||
Tensor capture_eager_tensor(Tensor tensor, string name) | Tensor capture_eager_tensor(Tensor tensor, string name) | ||||
{ | { | ||||
Tensor graph_const = null; | Tensor graph_const = null; | ||||
@@ -4,10 +4,10 @@ public interface IOptimizer | |||||
{ | { | ||||
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | ||||
Tensor[] clip_gradients(Tensor[] grads); | Tensor[] clip_gradients(Tensor[] grads); | ||||
void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||||
void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true); | bool experimental_aggregate_gradients = true); | ||||
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||||
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true); | bool experimental_aggregate_gradients = true); | ||||
} | } |
@@ -208,9 +208,9 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | ||||
//[DllImport(TensorFlowLibName)] | |||||
//public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
//[DllImport(TensorFlowLibName)] | |||||
//public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||||
} | } | ||||
} | } |
@@ -39,7 +39,7 @@ namespace Tensorflow | |||||
if (config is null) | if (config is null) | ||||
{ | { | ||||
config = function_utils.get_disabled_rewriter_config().ToString(); | |||||
config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||||
} | } | ||||
if (executor_type is null) | if (executor_type is null) | ||||
@@ -49,6 +49,8 @@ namespace Tensorflow | |||||
if (executing_eagerly) | if (executing_eagerly) | ||||
{ | { | ||||
// TODO(Rinne): implement it. | |||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -210,7 +211,51 @@ namespace Tensorflow | |||||
/// <param name="name">A name for the operation (optional).</param> | /// <param name="name">A name for the operation (optional).</param> | ||||
/// <returns>A `Tensor`. Has the same type as `value`.</returns> | /// <returns>A `Tensor`. Has the same type as `value`.</returns> | ||||
public static Tensor fill<T>(Tensor dims, T value, string name = null) | public static Tensor fill<T>(Tensor dims, T value, string name = null) | ||||
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||||
{ | |||||
var ctx = tf.Context; | |||||
if (ctx.executing_eagerly()) | |||||
{ | |||||
try | |||||
{ | |||||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); | |||||
return _result[0]; | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
try | |||||
{ | |||||
return fill_eager_fallback(dims, value as Tensor, name, ctx); | |||||
} | |||||
catch (Exception) | |||||
{ | |||||
} | |||||
} | |||||
Dictionary<string, object> attrs = new Dictionary<string, object>(); | |||||
attrs["dims"] = dims; | |||||
attrs["value"] = value; | |||||
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return result.output; | |||||
} | |||||
public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) | |||||
{ | |||||
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; | |||||
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); | |||||
if (execute.must_record_gradient()) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
return _result[0]; | |||||
} | |||||
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||||
/// <summary> | /// <summary> | ||||
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | ||||
@@ -49,8 +49,10 @@ namespace Tensorflow.Operations | |||||
target_t.HandleData = handle_data; | target_t.HandleData = handle_data; | ||||
return; | return; | ||||
} | } | ||||
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||||
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||||
Status status = new(); | |||||
var proto = handle_data.ToByteArray(); | |||||
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||||
status.Check(true); | |||||
} | } | ||||
public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | ||||
@@ -25,6 +25,7 @@ using static Tensorflow.Binding; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using System.Buffers; | using System.Buffers; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -302,5 +303,18 @@ namespace Tensorflow | |||||
// return handle_data_util.get_resource_handle_data(handle); | // return handle_data_util.get_resource_handle_data(handle); | ||||
//} | //} | ||||
} | } | ||||
public static void variable_accessed(IVariableV1 variable) | |||||
{ | |||||
if (ops.get_default_graph() is FuncGraph func_graph) | |||||
{ | |||||
func_graph.watch_variable(variable); | |||||
} | |||||
if (variable.Trainable) | |||||
{ | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
tape.VariableAccessed(variable); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | ||||
<PackageReference Include="OneOf" Version="3.0.223" /> | <PackageReference Include="OneOf" Version="3.0.223" /> | ||||
<PackageReference Include="Protobuf.Text" Version="0.6.2" /> | |||||
<PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||||
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | ||||
public Tensor() | |||||
protected Tensor() | |||||
{ | { | ||||
} | } | ||||
@@ -108,6 +108,7 @@ namespace Tensorflow | |||||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | ||||
{ | { | ||||
_handle = TF_NewTensor(shape, dtype, null); | _handle = TF_NewTensor(shape, dtype, null); | ||||
_id = ops.uid(); | |||||
} | } | ||||
protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | ||||
@@ -116,6 +117,7 @@ namespace Tensorflow | |||||
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | ||||
else | else | ||||
_handle = TF_NewTensor(bytes, shape, dtype); | _handle = TF_NewTensor(bytes, shape, dtype); | ||||
_id = ops.uid(); | |||||
} | } | ||||
protected unsafe void InitTensor(Array array, Shape? shape = null) | protected unsafe void InitTensor(Array array, Shape? shape = null) | ||||
@@ -166,6 +168,8 @@ namespace Tensorflow | |||||
string[] val => StringTensor(val, shape), | string[] val => StringTensor(val, shape), | ||||
_ => throw new NotImplementedException("") | _ => throw new NotImplementedException("") | ||||
}; | }; | ||||
_id = ops.uid(); | |||||
} | } | ||||
unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged | unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged | ||||
@@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
{ | { | ||||
IEnumerable<ConcreteFunction> _concrete_functions; | IEnumerable<ConcreteFunction> _concrete_functions; | ||||
FunctionSpec _function_spec; | FunctionSpec _function_spec; | ||||
public IEnumerable<ConcreteFunction> ConcreteFunctions => _concrete_functions; | |||||
public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec, | public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec, | ||||
IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false) | IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false) | ||||
{ | { | ||||
@@ -25,6 +25,19 @@ namespace Tensorflow.Util | |||||
} | } | ||||
} | } | ||||
public Tv SetDefault(Tk key, Tv default_value) | |||||
{ | |||||
if(TryGetValue(key, out var res)) | |||||
{ | |||||
return res; | |||||
} | |||||
else | |||||
{ | |||||
base[key] = default_value; | |||||
return base[key]; | |||||
} | |||||
} | |||||
public void push_back(Tk key, Tv value) | public void push_back(Tk key, Tv value) | ||||
=> this[key] = value; | => this[key] = value; | ||||
@@ -9,6 +9,7 @@ using System.Diagnostics; | |||||
using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
using Tensorflow.Training.Saving.SavedModel; | using Tensorflow.Training.Saving.SavedModel; | ||||
using OneOf; | using OneOf; | ||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -193,6 +194,10 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
void variable_accessed(BaseResourceVariable variable) | void variable_accessed(BaseResourceVariable variable) | ||||
{ | { | ||||
if(ops.get_default_graph() is FuncGraph func_graph) | |||||
{ | |||||
func_graph.watch_variable(variable as IVariableV1); | |||||
} | |||||
if (variable.Trainable) | if (variable.Trainable) | ||||
{ | { | ||||
foreach (var tape in tf.GetTapeSet()) | foreach (var tape in tf.GetTapeSet()) | ||||
@@ -575,12 +575,8 @@ namespace Tensorflow | |||||
public static HandleData get_resource_handle_data(Tensor graph_op) | public static HandleData get_resource_handle_data(Tensor graph_op) | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
// This implementation hasn't been checked for some reasons. | |||||
// If it throws an exception in the future, please check it. | |||||
//var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||||
//return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||||
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||||
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||||
} | } | ||||
public static void dismantle_graph(Graph graph) | public static void dismantle_graph(Graph graph) | ||||
@@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Training; | using Tensorflow.Training; | ||||
using Tensorflow.Training.Saving.SavedModel; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine | |||||
/// the layer's weights. | /// the layer's weights. | ||||
/// </summary> | /// </summary> | ||||
protected bool built; | protected bool built; | ||||
public bool Built => built; | |||||
public bool Built | |||||
{ | |||||
get | |||||
{ | |||||
return built; | |||||
} | |||||
internal set | |||||
{ | |||||
built = value; | |||||
} | |||||
} | |||||
public bool Trainable => args.Trainable; | public bool Trainable => args.Trainable; | ||||
public TF_DataType DType => args.DType; | public TF_DataType DType => args.DType; | ||||
public bool AutoCast => args.Autocast; | public bool AutoCast => args.Autocast; | ||||
@@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
protected List<ILayer> _self_tracked_trackables; | protected List<ILayer> _self_tracked_trackables; | ||||
/// <summary> | |||||
/// If this value is set, the behavior of layer call will be changed to directly calling this function. | |||||
/// </summary> | |||||
public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null; | |||||
public Layer(LayerArgs args) | public Layer(LayerArgs args) | ||||
{ | { | ||||
Initialize(args); | Initialize(args); | ||||
@@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine | |||||
/// <returns></returns> | /// <returns></returns> | ||||
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
{ | { | ||||
if(ReplacedCall is not null) | |||||
{ | |||||
return ReplacedCall(inputs); | |||||
} | |||||
return inputs; | return inputs; | ||||
} | } | ||||
@@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | (x, y) = data_handler.DataAdapter.Expand1d(x, y); | ||||
using var tape = tf.GradientTape(); | using var tape = tf.GradientTape(); | ||||
//foreach (var variable in TrainableVariables) | |||||
//{ | |||||
// tape.watch(variable.Handle); | |||||
//} | |||||
var y_pred = Apply(x, training: true); | var y_pred = Apply(x, training: true); | ||||
var loss = compiled_loss.Call(y, y_pred); | var loss = compiled_loss.Call(y, y_pred); | ||||
@@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine | |||||
gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); | gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); | ||||
gradients = optimizer.clip_gradients(gradients); | gradients = optimizer.clip_gradients(gradients); | ||||
optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), | |||||
optimizer.apply_gradients(zip(gradients, trainable_variables), | |||||
experimental_aggregate_gradients: false); | experimental_aggregate_gradients: false); | ||||
} | } | ||||
} | } | ||||
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
_set_hyper("decay", args.InitialDecay); | _set_hyper("decay", args.InitialDecay); | ||||
} | } | ||||
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||||
public void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true) | bool experimental_aggregate_gradients = true) | ||||
=> apply_gradients(new[] { grads_and_vars }, | => apply_gradients(new[] { grads_and_vars }, | ||||
@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
/// <param name="grads_and_vars"></param> | /// <param name="grads_and_vars"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <param name="experimental_aggregate_gradients"></param> | /// <param name="experimental_aggregate_gradients"></param> | ||||
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||||
public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||||
string name = null, | string name = null, | ||||
bool experimental_aggregate_gradients = true) | bool experimental_aggregate_gradients = true) | ||||
{ | { | ||||
@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
}); | }); | ||||
} | } | ||||
void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||||
void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||||
{ | { | ||||
_resource_apply_dense(var, grad, apply_state); | _resource_apply_dense(var, grad, apply_state); | ||||
// if var.constraint is not None: | // if var.constraint is not None: | ||||
@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
throw new NotImplementedException("_resource_apply_dense"); | throw new NotImplementedException("_resource_apply_dense"); | ||||
} | } | ||||
void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||||
void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||||
string name, | string name, | ||||
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | ||||
{ | { | ||||
@@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving | |||||
/// <param name="layers"></param> | /// <param name="layers"></param> | ||||
private void _finalize_saved_model_layers(List<Layer> layers) | private void _finalize_saved_model_layers(List<Layer> layers) | ||||
{ | { | ||||
foreach(var layer in layers) | |||||
{ | |||||
layer.Built = true; | |||||
var keras_attr = _get_keras_attr(layer); | |||||
if(keras_attr is not Trackable trackable) | |||||
{ | |||||
continue; | |||||
} | |||||
if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call)) | |||||
{ | |||||
Debug.Assert(layer_call is RestoredFunction); | |||||
var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions; | |||||
if (concrete_functions is not null && concrete_functions.Count() > 0) | |||||
{ | |||||
layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply); | |||||
} | |||||
} | |||||
} | |||||
foreach(var layer in layers) | foreach(var layer in layers) | ||||
{ | { | ||||
// TODO(Rinne): deal with `RevivedNetwork`. | // TODO(Rinne): deal with `RevivedNetwork`. | ||||
@@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving | |||||
} | } | ||||
} | } | ||||
private Func<Tensors, Tensors> use_wrapped_call(Layer layer, Func<Tensors, Tensors> call) | |||||
{ | |||||
// TODO(Rinne): revise it. | |||||
return call; | |||||
} | |||||
private void _restore_layer_unconditional_losses(Layer layer) | private void _restore_layer_unconditional_losses(Layer layer) | ||||
{ | { | ||||
// TODO(Rinne): implement it. | // TODO(Rinne): implement it. | ||||
@@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
return _config; | return _config; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
{ | |||||
if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||||
{ | |||||
return base.Call(inputs, state, training); | |||||
} | |||||
else | |||||
{ | |||||
return (func as Function).Apply(inputs); | |||||
} | |||||
} | |||||
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
//{ | |||||
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||||
// { | |||||
// return base.Call(inputs, state, training); | |||||
// } | |||||
// else | |||||
// { | |||||
// return (func as Function).Apply(inputs); | |||||
// } | |||||
//} | |||||
} | } | ||||
} | } |
@@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
//base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | ||||
// functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | ||||
base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), | base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), | ||||
functions.Concat(new string[] { })) | |||||
functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })) | |||||
{ | { | ||||
} | } | ||||
@@ -64,23 +64,19 @@ public class SequentialModelLoad | |||||
var model = tf.keras.models.load_model(@"Assets/python_func_model"); | var model = tf.keras.models.load_model(@"Assets/python_func_model"); | ||||
model.summary(); | model.summary(); | ||||
var x = tf.random.uniform((8, 784), -1, 1); | |||||
var y = model.Apply(x); | |||||
Console.WriteLine(y); | |||||
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
//model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
//var data_loader = new MnistModelLoader(); | |||||
//var num_epochs = 1; | |||||
//var batch_size = 8; | |||||
var data_loader = new MnistModelLoader(); | |||||
var num_epochs = 1; | |||||
var batch_size = 8; | |||||
//var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
//{ | |||||
// TrainDir = "mnist", | |||||
// OneHot = false, | |||||
// ValidationSize = 58000, | |||||
//}).Result; | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 55000, | |||||
}).Result; | |||||
//model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
} | } | ||||
} | } |