@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Google.Protobuf; | |||
using Protobuf.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Contexts | |||
@@ -12,18 +12,36 @@ namespace Tensorflow.Eager | |||
return HasGradientTape(); | |||
} | |||
private bool ShouldRecord(Tensor[] inputs) | |||
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors) | |||
{ | |||
bool should_record = false; | |||
foreach (var tape in tf.GetTapeSet()) | |||
var tape_set = tf.GetTapeSet(); | |||
var input_ids = MakeTensorIDList(tensors); | |||
var input_dtypes = MakeTensorDtypeList(tensors); | |||
bool some_tape_watching = false; | |||
if (tape_set is not null && tape_set.Count > 0) | |||
{ | |||
if (tape.ShouldRecord(inputs)) | |||
foreach (var tape in tape_set) | |||
{ | |||
should_record = true; | |||
break; | |||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
{ | |||
if (tape.Persistent || some_tape_watching) | |||
{ | |||
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
} | |||
some_tape_watching = true; | |||
} | |||
} | |||
} | |||
return should_record; | |||
// skip the forward_accumulators. | |||
if (some_tape_watching) | |||
{ | |||
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
} | |||
else | |||
{ | |||
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE; | |||
} | |||
} | |||
} | |||
} |
@@ -13,7 +13,17 @@ namespace Tensorflow.Eager | |||
Tensor[] results, | |||
BackwardFunction backwardFunction = null) | |||
{ | |||
bool should_record = ShouldRecord(inputs); | |||
var input_ids = MakeTensorIDList(inputs); | |||
var input_dtypes = MakeTensorDtypeList(inputs); | |||
bool should_record = false; | |||
foreach (var tape in tf.GetTapeSet()) | |||
{ | |||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||
{ | |||
should_record = true; | |||
break; | |||
} | |||
} | |||
if (!should_record) | |||
{ | |||
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager | |||
op_inputs = inputs;*/ | |||
backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); | |||
TapeSetRecordOperation(op_name, inputs, results, backwardFunction); | |||
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction); | |||
return true; | |||
} | |||
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager | |||
{ | |||
return HasGradientTape(); | |||
} | |||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
{ | |||
return tensors.Select(x => x.dtype).ToArray(); | |||
} | |||
} | |||
} |
@@ -1,6 +1,8 @@ | |||
using System; | |||
using OneOf.Types; | |||
using System; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Eager | |||
{ | |||
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager | |||
/// </summary> | |||
public partial class EagerRunner | |||
{ | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="tape"></param> | |||
/// <param name="target"></param> | |||
/// <param name="sources"></param> | |||
/// <param name="output_gradients"></param> | |||
/// <param name="unconnected_gradients">determines the value returned if the target and | |||
/// sources are unconnected.When 'none' the value returned is None wheras when | |||
/// 'zero' a zero tensor in the same shape as the sources is returned.</param> | |||
/// <returns></returns> | |||
/// <exception cref="RuntimeError"></exception> | |||
public Tensor[] TFE_TapeGradient(ITape tape, | |||
Tensor[] target, | |||
Tensor[] sources, | |||
Tensor[] output_gradients) | |||
List<Tensor> output_gradients, | |||
Tensor[] sources_raw, | |||
string unconnected_gradients) | |||
{ | |||
var target_vec = target; | |||
var sources_vec = sources; | |||
var sources_set = sources_vec; | |||
if (!tape.Persistent) | |||
{ | |||
var tape_set = tf.GetTapeSet(); | |||
if (tape_set.Contains(tape)) | |||
{ | |||
throw new RuntimeError("gradient() cannot be invoked within the " + | |||
"GradientTape context (i.e., while operations are being " + | |||
"recorded). Either move the call to gradient() to be " + | |||
"outside the 'with tf.GradientTape' block, or " + | |||
"use a persistent tape: " + | |||
"'with tf.GradientTape(persistent=true)'"); | |||
} | |||
} | |||
var target_vec = MakeTensorIDList(target); | |||
var sources_vec = MakeTensorIDList(sources); | |||
HashSet<long> sources_set = new HashSet<long>(sources_vec); | |||
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||
int len = target.Length; | |||
for(int i = 0; i < len; i++) | |||
{ | |||
var target_id = target_vec[i]; | |||
if (sources_set.Contains(target_id)) | |||
{ | |||
var tensor = target[i]; | |||
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor); | |||
} | |||
} | |||
List<Tensor> outgrad_vec = new(); | |||
if(output_gradients is not null) | |||
{ | |||
outgrad_vec = output_gradients.ToList(); | |||
} | |||
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false); | |||
var seq_array = target; | |||
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||
for (int i = 0; i < target.Length; ++i) | |||
bool unconnected_gradients_zero = unconnected_gradients == "zero"; | |||
Tensor[] sources_obj = null; | |||
if (unconnected_gradients_zero) | |||
{ | |||
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||
sources_obj = MakeTensorList(sources_raw); | |||
} | |||
if (output_gradients != null) | |||
if (result.Length > 0) | |||
{ | |||
throw new NotImplementedException(""); | |||
for(int i = 0; i < result.Length; i++) | |||
{ | |||
if (result[i] is null && unconnected_gradients_zero) | |||
{ | |||
var dtype = sources_obj[i].dtype; | |||
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike(); | |||
} | |||
} | |||
} | |||
else | |||
return result; | |||
} | |||
Tensor[] MakeTensorList(IEnumerable<Tensor> tensors) | |||
{ | |||
return tensors.ToArray(); | |||
} | |||
long[] MakeTensorIDList(Tensor[] tensors) | |||
{ | |||
int len = tensors.Length; | |||
long[] ids = new long[len]; | |||
for(int i = 0; i < len; i++) | |||
{ | |||
var tensor = tensors[i]; | |||
ids[i] = tensor.Id; | |||
} | |||
return ids; | |||
} | |||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | |||
{ | |||
int len = tensors.Length; | |||
TF_DataType[] dtypes = new TF_DataType[len]; | |||
for (int i = 0; i < len; i++) | |||
{ | |||
output_gradients = new Tensor[0]; | |||
var tensor = tensors[i]; | |||
dtypes[i] = tensor.dtype; | |||
} | |||
return dtypes; | |||
} | |||
var outgrad_vec = MakeTensorList(output_gradients); | |||
TapeTensor TapeTensorFromTensor(Tensor tensor) | |||
{ | |||
long id = tensor.Id; | |||
var dtype = tensor.dtype; | |||
if (tensor is EagerTensor) | |||
{ | |||
var handle = tensor.EagerTensorHandle; | |||
if (DTypeNeedsHandleData(dtype)) | |||
{ | |||
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor); | |||
} | |||
Status status = new(); | |||
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status); | |||
long[] dims = new long[num_dims]; | |||
for(int i = 0; i < num_dims; i++) | |||
{ | |||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | |||
} | |||
Shape tensor_shape = new(dims); | |||
if(status.Code != TF_Code.TF_OK) | |||
{ | |||
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null); | |||
} | |||
else | |||
{ | |||
return new TapeTensor(id, dtype, tensor_shape); | |||
} | |||
} | |||
var shape_tuple = tensor.shape.dims; | |||
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype)) | |||
{ | |||
return new TapeTensor(id, dtype, tensor); | |||
} | |||
long[] l = new long[shape_tuple.Length]; | |||
for(int i = 0; i < shape_tuple.Length; i++) | |||
{ | |||
if (shape_tuple[i] < 0) | |||
{ | |||
l[i] = 0; | |||
} | |||
else | |||
{ | |||
l[i] = shape_tuple[i]; | |||
} | |||
} | |||
return new TapeTensor(id, dtype, new Shape(l)); | |||
} | |||
return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec); | |||
bool DTypeNeedsHandleData(TF_DataType dtype) | |||
{ | |||
return dtype == dtypes.variant || dtype == dtypes.resource; | |||
} | |||
Tensor[] MakeTensorList(Tensor[] tensors) | |||
bool ListContainNone(long[] list) | |||
{ | |||
return tensors; | |||
int len = list.Length; | |||
if(len == 0) | |||
{ | |||
return true; | |||
} | |||
for(int i = 0; i < len; i++) | |||
{ | |||
if (list[i] == -1) | |||
{ | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
} | |||
} |
@@ -7,8 +7,9 @@ namespace Tensorflow.Eager | |||
public partial class EagerRunner | |||
{ | |||
void TapeSetRecordBackprop(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
TapeTensor[] output_info, | |||
long[] input_ids, | |||
TF_DataType[] input_detyps, | |||
BackwardFunction backward_function) | |||
{ | |||
if (!CouldBackprop()) | |||
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||
foreach (var tape in tf.GetTapeSet()) | |||
{ | |||
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function); | |||
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function); | |||
} | |||
} | |||
} | |||
@@ -10,18 +10,28 @@ namespace Tensorflow.Eager | |||
public bool TapeSetRecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
Tensor[] output_tensors, | |||
long[] input_ids, | |||
TF_DataType[] input_dtypes, | |||
BackwardFunction backward_function) | |||
{ | |||
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||
var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray(); | |||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | |||
backward_function)) | |||
return false; | |||
TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes, | |||
backward_function); | |||
return true; | |||
} | |||
public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
Tensor[] input_tensors, BackwardFunction backward_function) | |||
{ | |||
var input_ids = MakeTensorIDList(input_tensors); | |||
var input_dtypes = MakeTensorDtypeList(input_tensors); | |||
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes, | |||
backward_function); | |||
} | |||
} | |||
} |
@@ -29,7 +29,14 @@ namespace Tensorflow.Eager | |||
Tensor[] TFE_TapeGradient(ITape tape, | |||
Tensor[] target, | |||
Tensor[] sources, | |||
Tensor[] output_gradients); | |||
List<Tensor> output_gradients, | |||
Tensor[] sources_raw, | |||
string unconnected_gradients); | |||
void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors, | |||
Tensor[] input_tensors, BackwardFunction backward_function); | |||
int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors); | |||
bool RecordGradient(string op_name, | |||
Tensor[] inputs, | |||
@@ -18,12 +18,13 @@ namespace Tensorflow.Functions | |||
public class ConcreteFunction: Trackable | |||
{ | |||
protected IEnumerable<Tensor> _captured_inputs; | |||
internal FuncGraph func_graph; | |||
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||
protected Dictionary<string, AttrValue> _attrs; | |||
protected FunctionSpec _function_spec; | |||
protected FunctionSpec _pre_initialized_function_spec = null; | |||
protected EagerDefinedFunction _inference_function; | |||
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new(); | |||
internal FuncGraph func_graph; | |||
internal ForwardBackwardCall forward_backward; | |||
public Tensor[] Inputs => func_graph.Inputs; | |||
public Tensor[] CapturedInputs => func_graph.external_captures; | |||
@@ -156,6 +157,17 @@ namespace Tensorflow.Functions | |||
{ | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
var default_graph = ops.get_default_graph(); | |||
// TODO(Rinne): deal with `default_graph.building_function` | |||
var tempvv = func_graph.Variables; | |||
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph) | |||
{ | |||
foreach(var v in this.func_graph.Variables) | |||
{ | |||
resource_variable_ops.variable_accessed(v); | |||
} | |||
} | |||
var tensor_inputs = new Tensors(); | |||
foreach (var (i, arg) in enumerate(args)) | |||
{ | |||
@@ -223,11 +235,16 @@ namespace Tensorflow.Functions | |||
{ | |||
input_tangents = new TangentInfo(); | |||
} | |||
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient()) | |||
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER) | |||
{ | |||
if(input_tangents.Indices is not null || executing_eagerly) | |||
{ | |||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
string cache_key = "first_order"; | |||
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions)) | |||
{ | |||
functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
_tape_functions_cache[cache_key] = functions; | |||
} | |||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
} | |||
else | |||
@@ -241,7 +258,7 @@ namespace Tensorflow.Functions | |||
} | |||
// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. | |||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient()); | |||
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false); | |||
} | |||
internal void set_variables(IEnumerable<IVariableV1> variables) | |||
@@ -124,17 +124,16 @@ namespace Tensorflow.Functions | |||
// TODO(Rinne): Add arg `CancellationManager`. | |||
// TODO(Rinne): Check the arg length. | |||
var function_call_options = tf.Context.FunctionCallOptions; | |||
string config; | |||
if (function_call_options.config_proto_serialized().Length == 0) | |||
{ | |||
config = function_utils.get_disabled_rewriter_config().ToString(); | |||
} | |||
else | |||
{ | |||
config = function_call_options.config_proto_serialized().ToString(); | |||
} | |||
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons. | |||
config = ""; // TODO(Rinne): revise it. | |||
//if (function_call_options.config_proto_serialized().Length == 0) | |||
//{ | |||
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
//} | |||
//else | |||
//{ | |||
// config = function_call_options.config_proto_serialized().ToStringUtf8(); | |||
//} | |||
string executor_type = function_call_options.ExecutorType ?? ""; | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
@@ -14,12 +14,11 @@ namespace Tensorflow.Functions | |||
} | |||
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
ForwardAndBackwardFunctions(Tensors inference_args) | |||
{ | |||
var outputs = _func_graph.Outputs; | |||
(_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
= BuildFunctionsForOutputs(outputs, inference_args); | |||
return _forward_function; | |||
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray(); | |||
return BuildFunctionsForOutputs(outputs, inference_args); | |||
} | |||
} | |||
} |
@@ -14,7 +14,6 @@ namespace Tensorflow | |||
protected ConcreteFunction _concrete_variable_creation_fn; | |||
protected bool _autograph; | |||
protected TracingCompiler _variable_creation_fn; | |||
protected bool _has_initialized; | |||
public string Name { get; set; } | |||
public Function(Func<Tensor[], Tensor[]> csharp_function, | |||
string name, bool auto_graph = true) | |||
@@ -22,7 +21,6 @@ namespace Tensorflow | |||
_csharp_function = csharp_function; | |||
Name = name; | |||
_autograph = auto_graph; | |||
_has_initialized = false; | |||
} | |||
public virtual Tensors Apply(Tensors inputs) | |||
@@ -38,10 +36,11 @@ namespace Tensorflow | |||
protected virtual Tensors _call(Tensors inputs) | |||
{ | |||
if (!_has_initialized) | |||
if(_variable_creation_fn is not null) | |||
{ | |||
_initialize(inputs); | |||
return _variable_creation_fn.Apply(inputs); | |||
} | |||
_initialize(inputs); | |||
return _concrete_variable_creation_fn.CallFlat(inputs, | |||
_concrete_variable_creation_fn.CapturedInputs); | |||
@@ -63,7 +62,6 @@ namespace Tensorflow | |||
_variable_creation_fn = _compiler(_csharp_function); | |||
_variable_creation_fn._name = this.Name; | |||
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); | |||
_has_initialized = true; | |||
} | |||
} | |||
} |
@@ -24,23 +24,40 @@ namespace Tensorflow.Functions | |||
protected string _INFERENCE_PREFIX = "__inference_"; | |||
protected FuncGraph _func_graph; | |||
protected EagerDefinedFunction _forward_function; | |||
protected EagerDefinedFunction _forward; | |||
protected FuncGraph _forward_graph; | |||
protected List<int> _forwardprop_input_indices; | |||
protected List<int> _forwardprop_output_indices; | |||
protected int _num_forwardprop_outputs; | |||
protected ConcreteFunction _backward_function; | |||
protected int _num_inference_outputs; | |||
protected int _num_outputs; | |||
protected int _num_trainable_inference_outputs; | |||
protected ConcreteFunction _backward; | |||
BackwardFunction _backward_function_wrapper; | |||
public TapeGradientFunctions(FuncGraph func_graph, | |||
bool need_gradients_for_jvps) | |||
{ | |||
_func_graph = func_graph; | |||
_forward_graph = null; | |||
_forward = null; | |||
_backward = null; | |||
_num_outputs = func_graph.Outputs.Length; | |||
_forwardprop_output_indices = null; | |||
_num_forwardprop_outputs = 0; | |||
_num_inference_outputs = func_graph.Outputs.Length; | |||
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count(); | |||
} | |||
public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) | |||
{ | |||
// TODO(Rinne): add input_tangents arg. | |||
return ForwardAndBackwardFunctions(inference_args); | |||
if(_forward is null) | |||
{ | |||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
= ForwardAndBackwardFunctions(inference_args); | |||
} | |||
return _forward; | |||
} | |||
/// <summary> | |||
@@ -51,9 +68,13 @@ namespace Tensorflow.Functions | |||
public virtual void Record(Tensors flat_outputs, Tensors inference_args) | |||
{ | |||
// TODO(Rinne): add arg `input_tagents`. | |||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs); | |||
tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record, | |||
getBackwardFunction: backward_function); | |||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
} | |||
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function); | |||
} | |||
/// <summary> | |||
@@ -65,66 +86,95 @@ namespace Tensorflow.Functions | |||
/// <returns></returns> | |||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
{ | |||
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs) | |||
.ToDictionary(x => x.Item1, x => x.Item2); | |||
var captured_inputs = backward.CapturedInputs; | |||
var remapped_captures = captured_inputs.Select(c => | |||
{ | |||
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value)) | |||
{ | |||
return value; | |||
} | |||
else | |||
{ | |||
return c; | |||
} | |||
}).ToArray(); | |||
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph)) | |||
{ | |||
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph); | |||
throw new RuntimeError($"Failed to map all backward graph captures to " + | |||
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}"); | |||
} | |||
Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>(); | |||
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | |||
var recorded_outputs = new Tensors(); | |||
var trainable_recorded_outputs = 0; | |||
foreach (var (output_index, output) in enumerate(outputs)) | |||
int trainable_recorded_outputs = 0; | |||
var skip_positions = new HashSet<int>(); | |||
var relevant_outputs = outputs; | |||
foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||
{ | |||
if (trainable_recorded_outputs < backward_function_inputs) | |||
recorded_outputs.Add(output); | |||
if (gradients_util.IsTrainable(output)) | |||
trainable_recorded_outputs += 1; | |||
if (backprop_util.IsTrainable(output)) | |||
trainable_recorded_outputs++; | |||
else | |||
skip_positions.Add(output_index); | |||
if (output.dtype == dtypes.variant) | |||
variant_zeros_like[output_index] = default_gradient.zeros_like(output); | |||
} | |||
if(_backward_function_wrapper == null) | |||
_backward_function_wrapper = (args, unneeded_gradients) => | |||
{ | |||
var capture_mapping = new Dictionary<long, Tensor>(); | |||
foreach (var (i, output) in enumerate(outputs)) | |||
capture_mapping[forward_graph.Outputs[i].Id] = output; | |||
var remapped_captures = new Tensors(); | |||
foreach (var capture in backward.CapturedInputs) | |||
{ | |||
if (capture_mapping.ContainsKey(capture.Id)) | |||
remapped_captures.Add(capture_mapping[capture.Id]); | |||
} | |||
var skip_positions = new List<int>(); | |||
foreach (var (output_index, output) in enumerate(outputs)) | |||
if(backward.Outputs is null || backward.Outputs.Length == 0) | |||
{ | |||
if (!gradients_util.IsTrainable(output)) | |||
skip_positions.Add(output_index); | |||
return backward.FlatStructuredOutputs; | |||
} | |||
_backward_function_wrapper = (args, unneeded_gradients) => | |||
var processed_args = new Tensors(); | |||
int input_index = 0; | |||
foreach (var (output_index, arg) in enumerate(args)) | |||
{ | |||
var processed_args = new Tensors(); | |||
var input_index = 0; | |||
foreach (var (output_index, arg) in enumerate(args)) | |||
if (skip_positions.Contains(output_index)) | |||
continue; | |||
if (arg is null) | |||
{ | |||
var input_placeholder = backward.Inputs[input_index]; | |||
Tensor variant_arg; | |||
if (input_placeholder.dtype == dtypes.variant) | |||
{ | |||
variant_arg = variant_zeros_like[output_index]; | |||
} | |||
else | |||
{ | |||
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder); | |||
variant_arg = array_ops.zeros(shape, type); | |||
} | |||
processed_args.Add(variant_arg); | |||
} | |||
else | |||
{ | |||
if (skip_positions.Contains(output_index)) | |||
continue; | |||
if (arg == null) | |||
throw new NotImplementedException(""); | |||
processed_args.Add(arg); | |||
input_index += 1; | |||
if (input_index >= backward_function_inputs) | |||
break; | |||
} | |||
input_index++; | |||
if (input_index >= backward_function_inputs) | |||
break; | |||
} | |||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
var gradients = backward.CallFlat(processed_args, remapped_captures); | |||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||
{ | |||
var index = Convert.ToInt32(unneeded_gradient_index); | |||
if (gradients.Length <= index) | |||
gradients.Insert(index, null); | |||
} | |||
foreach (var unneeded_gradient_index in unneeded_gradients) | |||
{ | |||
var index = Convert.ToInt32(unneeded_gradient_index); | |||
if (gradients.Length <= index) | |||
gradients.Insert(index, null); | |||
} | |||
return gradients; | |||
}; | |||
} | |||
return gradients; | |||
}; | |||
return (_backward_function_wrapper, recorded_outputs); | |||
} | |||
@@ -143,7 +193,7 @@ namespace Tensorflow.Functions | |||
} | |||
} | |||
var backwards_graph = new FuncGraph(_func_graph.Name); | |||
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name)); | |||
backwards_graph.as_default(); | |||
var gradients_wrt_outputs = new List<Tensor>(); | |||
foreach (var output in trainable_outputs) | |||
@@ -153,6 +203,7 @@ namespace Tensorflow.Functions | |||
gradients_wrt_outputs.Add(gradient_placeholder); | |||
handle_data_util.copy_handle_data(output, gradient_placeholder); | |||
} | |||
// TODO(Rinne): with ops.device(None) | |||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
_func_graph.Inputs, | |||
grad_ys: gradients_wrt_outputs.ToArray(), | |||
@@ -175,7 +226,8 @@ namespace Tensorflow.Functions | |||
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); | |||
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); | |||
var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
var (wrapped_forward_function, wrapped_backward_function) = | |||
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph); | |||
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; | |||
//var backward_function_attr = new Dictionary<string, string>(); | |||
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
@@ -189,10 +241,11 @@ namespace Tensorflow.Functions | |||
// _func_graph.Inputs, _func_graph.Outputs, | |||
// monomorphic_function_utils._parse_func_attrs(forward_function_attr)); | |||
return (forward_function, _func_graph, backward_function, null, 0); | |||
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0); | |||
} | |||
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
ForwardAndBackwardFunctions(Tensors inference_args) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -73,12 +73,12 @@ namespace Tensorflow.Functions | |||
private static string male_cache_key(Tensor[] inputs) | |||
{ | |||
string res = ""; | |||
foreach (var input in inputs) | |||
{ | |||
res += $"{input.name}_{input.Id}"; | |||
} | |||
return res; | |||
//string res = ""; | |||
//foreach (var input in inputs) | |||
//{ | |||
// res += $"{input.name}_{input.Id}"; | |||
//} | |||
return inputs.Length.ToString(); | |||
} | |||
} | |||
} |
@@ -153,7 +153,7 @@ namespace Tensorflow.Functions | |||
foreach(var tape in tf.GetTapeSet()) | |||
{ | |||
tape.RecordOperation(_inference_function.Signature.Name, to_record, | |||
inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function); | |||
inference_args, backward_function); | |||
} | |||
} | |||
@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients | |||
/// Map from tensor to how many references still exist for this tensor in | |||
/// the tape. | |||
/// </summary> | |||
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||
public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||
/// <summary> | |||
/// Maps from op ID to how many output tensors of this op still need to have | |||
/// their gradients computed. | |||
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients | |||
public BackpropInitialState() | |||
{ | |||
op_tape = new OpTape(); | |||
tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||
tensor_usage_counts = new UnorderedMap<long, long>(); | |||
op_missing_tensor = new UnorderedMap<long, long>(); | |||
} | |||
} | |||
@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients | |||
/// <param name="target"></param> | |||
/// <param name="source"></param> | |||
/// <returns></returns> | |||
public Tensor gradient(Tensor target, Tensor source) | |||
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
if(_tape is null) | |||
{ | |||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
"compute one set of gradients (or jacobians)."); | |||
} | |||
ITape tape = stop_recording(); | |||
var results = tf.Runner.TFE_TapeGradient(tape, | |||
new[] { target }, | |||
new[] { source }, | |||
null); | |||
output_gradients, | |||
new[] { source }, | |||
unconnected_gradients); | |||
return results[0]; | |||
} | |||
public Tensor gradient(Tensor target, ResourceVariable source) | |||
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
var results = gradient(target, new List<IVariableV1> { source }); | |||
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients); | |||
return results[0]; | |||
} | |||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) | |||
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }); | |||
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients); | |||
return (results[0], results[1]); | |||
} | |||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | |||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null, | |||
string unconnected_gradients = null) | |||
{ | |||
if (_tape is null) | |||
{ | |||
throw new RuntimeError("A non-persistent GradientTape can only be used to " + | |||
"compute one set of gradients (or jacobians)."); | |||
} | |||
var tape = stop_recording(); | |||
var results = tf.Runner.TFE_TapeGradient(tape, | |||
new[] { target }, | |||
sources.Select(x => x.Handle).ToArray(), | |||
null); | |||
output_gradients, | |||
sources.Select(x => x.Handle).ToArray(), | |||
unconnected_gradients); | |||
if (!tape.Persistent) | |||
{ | |||
@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients | |||
public interface ITape | |||
{ | |||
void SetTapeId(int id); | |||
bool ShouldRecord(Tensor[] tensors); | |||
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes); | |||
void StartRecord(); | |||
void StopRecord(); | |||
bool Persistent { get; } | |||
void RecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
long[] input_tensor_id, | |||
TF_DataType[] input_dtypes, | |||
BackwardFunction backward_function); | |||
void VariableAccessed(ResourceVariable variable); | |||
void RecordOperation(string op_type, | |||
Tensor[] outputs, | |||
Tensor[] inputs, | |||
BackwardFunction backward_function); | |||
void VariableAccessed(IVariableV1 variable); | |||
void Watch(Tensor x); | |||
ResourceVariable[] WatchedVariables(); | |||
IVariableV1[] WatchedVariables(); | |||
Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
Tensor[] source_tensor_ids, | |||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
Tensor[] output_gradients); | |||
Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
long[] source_tensor_ids, | |||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
List<Tensor> output_gradients, | |||
bool build_default_zeros_grads); | |||
} | |||
} |
@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients | |||
{ | |||
public string op_type { get; set; } | |||
public TapeTensor[] output_tensor_info { get; set; } | |||
public Tensor[] input_tensor_id { get; set; } | |||
public long[] input_tensor_id { get; set; } | |||
public BackwardFunction backward_function { get; set; } | |||
public override string ToString() | |||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||
} | |||
} |
@@ -2,235 +2,246 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
public partial class Tape | |||
{ | |||
// int kMinAggregateCount = 4; | |||
// int kMinAggregateBytes = 128 * 1024 * 1024; | |||
static readonly int kMinAggregateCount = 4; | |||
static readonly int kMinAggregateBytes = 128 * 1024 * 1024; | |||
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap; | |||
public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||
Tensor[] source_tensor_ids, | |||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
Tensor[] output_gradients) | |||
static Tape() | |||
{ | |||
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||
// var gradients_size = new UnorderedMap<Tensor, long>(); | |||
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap(); | |||
var state = PrepareBackprop( | |||
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | |||
output_gradients, | |||
tensor_tape_, | |||
state.op_tape); | |||
_functionsAcceptingNoneForIndicesMap = new(); | |||
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
} | |||
while (!op_stack.empty()) | |||
public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||
long[] source_tensor_ids, | |||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
List<Tensor> output_gradients, | |||
bool build_default_zeros_grads) | |||
{ | |||
UnorderedSet<long> sources_set = new(source_tensor_ids); | |||
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent); | |||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | |||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape); | |||
UnorderedMap<long, long> gradients_size = new(); | |||
while(op_stack.Count > 0) | |||
{ | |||
var op = op_stack.Dequeue(); | |||
if (!state.op_tape.find(op, out var trace)) | |||
long op = op_stack.Dequeue(); | |||
if(!state.op_tape.TryGetValue(op, out var op_it)) | |||
{ | |||
continue; | |||
// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}"); | |||
} | |||
var trace = op_it; | |||
state.op_tape.erase(op); | |||
var out_gradients = new List<Tensor>(trace.output_tensor_info.Length); | |||
var unneeded_gradients = new List<long>(); | |||
for (int i = 0; i < trace.input_tensor_id.Length; i++) | |||
List<Tensor> out_gradients = new(); | |||
List<long> unneeded_gradients = new(); | |||
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++) | |||
{ | |||
var in_tensor_id = trace.input_tensor_id[i]; | |||
if (!tensor_tape_.find(in_tensor_id) && | |||
!sources_set.find(in_tensor_id)) | |||
long in_tensor_id = trace.input_tensor_id[i]; | |||
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id)) | |||
{ | |||
unneeded_gradients.Add(i); | |||
} | |||
} | |||
bool any_gradient_nonzero = false; | |||
var zero_indices = new List<int>(); | |||
for (int i = 0; i < trace.output_tensor_info.Length; ++i) | |||
List<int> zero_indices = new(); | |||
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++) | |||
{ | |||
var id = trace.output_tensor_info[i].GetTensor(); | |||
if (!gradients.find(id, out var grad_it)) | |||
long id = trace.output_tensor_info[i].GetID(); | |||
if(!gradients.TryGetValue(id, out var grad_it)) | |||
{ | |||
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) && | |||
func_name_it.find(i)) | |||
out_gradients.Add(null); | |||
if (build_default_zeros_grads) | |||
{ | |||
out_gradients.Add(null); | |||
} | |||
else | |||
{ | |||
out_gradients.Add(null); | |||
zero_indices.Add(i); | |||
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) || | |||
!func_name_it.find(i)) | |||
{ | |||
zero_indices.Add(i); | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
any_gradient_nonzero = true; | |||
var new_gradients = grad_it.Count == 1 ? | |||
grad_it[0] : | |||
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients | |||
Tensor new_gradients; | |||
if (grad_it.Count == 1) | |||
{ | |||
new_gradients = grad_it[0]; | |||
} | |||
else | |||
{ | |||
new_gradients = AggregateGradients(grad_it); | |||
} | |||
if (!sources_set.find(id)) | |||
{ | |||
gradients.Remove(id); | |||
} | |||
else | |||
{ | |||
// grad_it.Clear(); | |||
// grad_it.Add(new_gradients); | |||
// vspace.MarkAsResult(new_gradients); | |||
grad_it.Clear(); | |||
grad_it.Add(new_gradients); | |||
// MarkAsResult | |||
} | |||
out_gradients.Add(new_gradients); | |||
} | |||
} | |||
Tensor[] in_gradients; | |||
Tensor[] in_gradients = new Tensor[0]; | |||
if (any_gradient_nonzero) | |||
{ | |||
// foreach (var i in zero_indices) | |||
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray()); | |||
if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length) | |||
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | |||
if (!_persistent) | |||
foreach(var i in zero_indices) | |||
{ | |||
// trace.backward_function_deleter(trace.backward_function); | |||
trace.backward_function = null; | |||
out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); | |||
} | |||
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients); | |||
} | |||
else | |||
{ | |||
in_gradients = new Tensor[trace.input_tensor_id.Length]; | |||
out_gradients.Clear(); | |||
} | |||
bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length; | |||
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k) | |||
for(int i = 0, end = in_gradients.Length; i < end; i++) | |||
{ | |||
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k; | |||
var id = trace.input_tensor_id[k]; | |||
if (in_gradients[i] != null) | |||
long id = trace.input_tensor_id[i]; | |||
if (in_gradients[i] is not null) | |||
{ | |||
var unaggregated_grads = gradients[id]; | |||
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>()); | |||
unaggregated_grads.Add(in_gradients[i]); | |||
/*if (unaggregated_grads.Count > kMinAggregateCount) | |||
if(unaggregated_grads.Count > kMinAggregateCount) | |||
{ | |||
if (!gradients_size.find(id, out var size)) | |||
if(!gradients_size.TryGetValue(id, out var size)) | |||
{ | |||
size = (long)unaggregated_grads[0].size; | |||
size = NumElements(unaggregated_grads[0]); | |||
gradients_size.emplace(id, size); | |||
} | |||
if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes) | |||
{ | |||
throw new NotImplementedException(""); | |||
Tensor grad = AggregateGradients(unaggregated_grads); | |||
unaggregated_grads.Clear(); | |||
unaggregated_grads.Add(grad); | |||
} | |||
}*/ | |||
} | |||
} | |||
if (!state.tensor_usage_counts.find(id)) | |||
if(!state.tensor_usage_counts.find(id)) | |||
{ | |||
continue; | |||
} | |||
state.tensor_usage_counts[id]--; | |||
if (state.tensor_usage_counts[id] > 0) | |||
if(state.tensor_usage_counts[id] > 0) | |||
{ | |||
continue; | |||
if (!tensor_tape_.find(id, out var tape_it)) | |||
} | |||
if (!tensor_tape_.TryGetValue(id, out var tape_it)) | |||
{ | |||
if (gradients.find(id, out var grad_it)) | |||
if (gradients.find(id)) | |||
{ | |||
// foreach (var g in grad_it) | |||
// DeleteGradient(g); | |||
gradients.erase(id); | |||
} | |||
continue; | |||
} | |||
var op_id = tape_it; | |||
if (op_id == -1) | |||
long op_id = tape_it; | |||
if(op_id == -1) | |||
{ | |||
continue; | |||
if (state.op_missing_tensor.find(op_id, out var missing_it)) | |||
} | |||
if(state.op_missing_tensor.find(op_id)) | |||
{ | |||
state.op_missing_tensor[op_id]--; | |||
if (state.op_missing_tensor[op_id] == 0) | |||
if(state.op_missing_tensor[op_id] == 0) | |||
{ | |||
op_stack.Enqueue(op_id); | |||
} | |||
} | |||
} | |||
} | |||
if (state.op_tape.Count > 0) | |||
if(state.op_tape.Count > 0) | |||
{ | |||
throw new RuntimeError("Invalid tape state."); | |||
var result = new Tensor[source_tensor_ids.Length]; | |||
var j = 0; | |||
foreach (var id in source_tensor_ids) | |||
} | |||
Tensor[] result = new Tensor[source_tensor_ids.Length]; | |||
for(int i = 0; i < source_tensor_ids.Length; i++) | |||
{ | |||
if (gradients.find(id, out var grad_it)) | |||
long tensor_id = source_tensor_ids[i]; | |||
if(!gradients.TryGetValue(tensor_id, out var grad_it)) | |||
{ | |||
if (grad_it.Count > 1) | |||
result[j] = gen_math_ops.add_n(grad_it.ToArray()); | |||
else | |||
result[j] = grad_it[0]; | |||
result[i] = null; | |||
} | |||
else | |||
{ | |||
if(grad_it.Count > 1) | |||
{ | |||
Tensor grad = AggregateGradients(grad_it); | |||
grad_it.Clear(); | |||
grad_it.Add(grad); | |||
} | |||
result[i] = grad_it[0]; | |||
} | |||
j++; | |||
} | |||
return result; | |||
} | |||
UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() | |||
{ | |||
var m = new UnorderedMap<string, UnorderedSet<int>>(); | |||
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 })); | |||
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 })); | |||
return m; | |||
return _functionsAcceptingNoneForIndicesMap; | |||
} | |||
UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||
Tensor[] output_gradients, | |||
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||
List<Tensor> output_gradients, | |||
TensorTape tensor_tape, | |||
OpTape op_tape) | |||
{ | |||
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||
for (int i = 0; i < target_tensor_ids.Length; ++i) | |||
var result = new UnorderedMap<long, List<Tensor>>(); | |||
for(int i = 0, end = target_tensor_ids.Length; i < end; i++) | |||
{ | |||
var id = target_tensor_ids[i]; | |||
if (output_gradients.Length == 0 || output_gradients[i] == null) | |||
long id = target_tensor_ids[i]; | |||
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null) | |||
{ | |||
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1) | |||
{ | |||
if (!op_tape.find(tensor_tape[id], out var op_it)) | |||
if(!op_tape.TryGetValue(tensor_it, out var op_it)) | |||
{ | |||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
"failed to find operation producing a tensor"); | |||
"failed to find operation producing a tensor."); | |||
} | |||
bool found = false; | |||
for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | |||
for(int j = 0; j < op_it.output_tensor_info.Length; j++) | |||
{ | |||
if (op_it.output_tensor_info[j].GetTensor() == id) | |||
if (op_it.output_tensor_info[j].GetID() == id) | |||
{ | |||
found = true; | |||
var ones = op_it.output_tensor_info[j].OnesLike(); | |||
result[id].Add(ones); | |||
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]); | |||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
break; | |||
} | |||
} | |||
if (!found) | |||
{ | |||
throw new ValueError("Internal state of the gradient tape is invalid: " + | |||
"none of operations outputs match expected tensor"); | |||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | |||
"none of operations outputs match expected tensor."); | |||
} | |||
} | |||
else | |||
{ | |||
if (sources_that_are_targets.find(id, out var source_tensor)) | |||
result[id].Add(source_tensor.OnesLike()); | |||
if(sources_that_are_targets.TryGetValue(id, out var source_tensor)) | |||
{ | |||
Tensor ones_like = BuildOnesLike(source_tensor); | |||
result.SetDefault(id, new List<Tensor>()).Add(ones_like); | |||
} | |||
} | |||
} | |||
else | |||
{ | |||
result[id].Add(output_gradients[i]); | |||
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]); | |||
} | |||
} | |||
@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients | |||
} | |||
return result; | |||
} | |||
Tensor BuildOnesLike(TapeTensor t) | |||
{ | |||
return t.OnesLike(); | |||
} | |||
Tensor AggregateGradients(List<Tensor> gradient_tensors) | |||
{ | |||
if(gradient_tensors.Count == 0) | |||
{ | |||
return gradient_tensors[0]; | |||
} | |||
return tf.add_n(gradient_tensors.ToArray()); | |||
} | |||
void DeleteGradient(Tensor gradient) | |||
{ | |||
// Do not do anything here. Because GC will collect it when it has no reference. | |||
} | |||
long NumElements(Tensor tensor) => 1; | |||
} | |||
} |
@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients | |||
{ | |||
public partial class Tape | |||
{ | |||
public BackpropInitialState PrepareBackprop(Tensor[] target, | |||
public BackpropInitialState PrepareBackprop(long[] target, | |||
TensorTape tensor_tape, | |||
OpTape op_tape, | |||
UnorderedSet<Tensor> sources_set, | |||
UnorderedSet<long> sources_set, | |||
bool persistent_tape) | |||
{ | |||
Stack<long> tensor_stack = new Stack<long>(); | |||
foreach(var t in target) | |||
{ | |||
tensor_stack.Push(t); | |||
} | |||
BackpropInitialState result = new BackpropInitialState(); | |||
var tensor_stack = new Queue<Tensor>(target); | |||
while (tensor_stack.Count > 0) | |||
while(tensor_stack.Count > 0) | |||
{ | |||
var tensor_id = tensor_stack.Dequeue(); | |||
if (!tensor_tape.find(tensor_id, out var op_id)) | |||
long tensor_id = tensor_stack.Pop(); | |||
if(!tensor_tape.TryGetValue(tensor_id, out var op_id)) | |||
{ | |||
continue; | |||
if (op_id == -1 || | |||
!op_tape.find(op_id, out var op_it) || | |||
result.op_tape.find(op_id, out var result_op_it)) | |||
} | |||
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it) | |||
|| result.op_tape.find(op_id)) | |||
{ | |||
continue; | |||
} | |||
result.op_tape.emplace(op_id, op_it); | |||
foreach (var it in op_it.input_tensor_id) | |||
foreach(var it in op_it.input_tensor_id) | |||
{ | |||
if (result.tensor_usage_counts.find(it)) | |||
if(result.tensor_usage_counts.find(it)) | |||
{ | |||
result.tensor_usage_counts[it]++; | |||
} | |||
else | |||
{ | |||
result.tensor_usage_counts[it] = 1; | |||
if (tensor_tape.find(it)) | |||
tensor_stack.Enqueue(it); | |||
{ | |||
tensor_stack.Push(it); | |||
} | |||
} | |||
} | |||
if (!persistent_tape) | |||
op_tape.Remove(op_id); | |||
{ | |||
op_tape.erase(op_id); | |||
} | |||
} | |||
foreach (var pair in result.tensor_usage_counts) | |||
foreach(var pair in result.tensor_usage_counts) | |||
{ | |||
if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||
result.op_missing_tensor[it] += 1; | |||
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1) | |||
{ | |||
result.op_missing_tensor[it]++; | |||
} | |||
} | |||
if (!persistent_tape) | |||
{ | |||
// Call destructors for all unneeded gradient functions and | |||
// clear the op_tape. We can clear the tape because ownership of | |||
// backward functions that will be used for gradient computation | |||
// has been transferred to `result`. | |||
/*for (const auto&op_pair : *op_tape) { | |||
op_pair.second.backward_function_deleter( | |||
op_pair.second.backward_function); | |||
}*/ | |||
op_tape.Clear(); | |||
} | |||
return result; | |||
} | |||
} | |||
@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients | |||
public partial class Tape | |||
{ | |||
long next_op_id_ = 0; | |||
UnorderedMap<Tensor, long> tensor_usage_; | |||
UnorderedMap<long, long> tensor_usage_; | |||
public void RecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
long[] input_tensor_id, | |||
TF_DataType[] input_dtypes, | |||
BackwardFunction backward_function) | |||
{ | |||
if (!ShouldRecord(input_tensors)) | |||
if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||
return; | |||
var op_id = next_op_id_++; | |||
foreach (var i in input_tensors) | |||
foreach (var i in input_tensor_id) | |||
{ | |||
tensor_usage_[i]++; | |||
} | |||
long op_id = next_op_id_++; | |||
foreach (var o in output_tensors) | |||
{ | |||
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | |||
tensor_tape_[o.GetTensor()] = op_id; | |||
tensor_usage_[o.GetTensor()] = 1; | |||
tensor_tape_[o.GetID()] = op_id; | |||
tensor_usage_[o.GetID()] = 1; | |||
} | |||
op_tape_[op_id] = new OpTapeEntry | |||
{ | |||
op_type = op_type, | |||
output_tensor_info = output_tensors, | |||
input_tensor_id = input_tensors, | |||
output_tensor_info = output_tensors.ToArray(), | |||
input_tensor_id = input_tensor_id.ToArray(), | |||
backward_function = backward_function | |||
}; | |||
} | |||
public void RecordOperation(string op_type, | |||
Tensor[] outputs, | |||
Tensor[] inputs, | |||
BackwardFunction backward_function) | |||
{ | |||
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function); | |||
} | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients | |||
_created_eagerly = tf.Context.executing_eagerly(); | |||
tensor_tape_ = new TensorTape(); | |||
op_tape_ = new OpTape(); | |||
tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||
tensor_usage_ = new UnorderedMap<long, long>(); | |||
if(_created_eagerly) | |||
tf.Context.start_step(); | |||
// nesting_id = ++tape_nesting_id_counter; | |||
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients | |||
public void Watch(Tensor x) | |||
{ | |||
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | |||
tensor_tape_.emplace(x, -1); | |||
tensor_tape_.emplace(x.Id, -1); | |||
} | |||
public bool ShouldRecord(Tensor[] tensors) | |||
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes) | |||
{ | |||
var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||
for (int i = 0; i < tensors.Length; ++i) | |||
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length); | |||
for (int i = 0; i < tensor_ids.Length; ++i) | |||
{ | |||
if (tensor_tape_.find(tensors[i])) | |||
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i])) | |||
{ | |||
if (IsDtypeTrainable(dtypes[i])) | |||
return true; | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
public void VariableAccessed(ResourceVariable variable) | |||
public void VariableAccessed(IVariableV1 variable) | |||
{ | |||
Watch(variable.Handle); | |||
} | |||
public ResourceVariable[] WatchedVariables() | |||
public IVariableV1[] WatchedVariables() | |||
{ | |||
return null; | |||
} | |||
@@ -1,27 +1,63 @@ | |||
using static Tensorflow.Binding; | |||
using OneOf; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Gradients | |||
{ | |||
public class TapeTensor | |||
{ | |||
Tensor tensor; | |||
long id => tensor.Id; | |||
TF_DataType dtype => tensor.dtype; | |||
Shape shape => tensor.shape; | |||
internal Tensor tensor; | |||
internal long id; | |||
internal TF_DataType dtype; | |||
internal OneOf<Shape, Tensor> shape; | |||
public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||
{ | |||
this.id = id; | |||
this.dtype = dtype; | |||
this.shape = shape; | |||
} | |||
public TapeTensor(long id, TF_DataType dtype, Tensor shape) | |||
{ | |||
this.id = id; | |||
this.dtype = dtype; | |||
this.shape = shape; | |||
} | |||
public TapeTensor(Tensor tensor) | |||
{ | |||
this.id = tensor.Id; | |||
this.dtype = tensor.dtype; | |||
this.shape = tensor.shape; | |||
this.tensor = tensor; | |||
} | |||
public long GetID() => tensor.Id; | |||
public Tensor GetTensor() => tensor; | |||
public long GetID() => id; | |||
public Tensor ZerosLike() | |||
=> tf.zeros(shape: shape, dtype: dtype); | |||
{ | |||
if(dtype == dtypes.resource) | |||
{ | |||
return null; | |||
} | |||
if(shape.Index == 1) | |||
{ | |||
return tf.zeros_like(shape.AsT1); | |||
} | |||
return tf.zeros(shape.AsT0, dtype); | |||
} | |||
public Tensor OnesLike() | |||
=> tf.ones(shape: shape, dtype: dtype); | |||
{ | |||
if (shape.Index == 1) | |||
{ | |||
return tf.ones_like(shape.AsT1); | |||
} | |||
return tf.ones(shape.AsT0, dtype); | |||
} | |||
//public Tensor OnesLike() | |||
// => tf.ones(shape: shape, dtype: dtype); | |||
public override string ToString() | |||
=> $"{id}, {shape}, {dtype.as_numpy_name()}"; | |||
@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients | |||
/// produced this tensor. A value of -1 means that the tensor was directly | |||
/// watched and not the result of any operation in the tape. | |||
/// </summary> | |||
public class TensorTape : UnorderedMap<Tensor, long> | |||
public class TensorTape : UnorderedMap<long, long> | |||
{ | |||
} | |||
@@ -704,32 +704,7 @@ namespace Tensorflow | |||
public static int PossibleTapeGradientTypes(Tensor[] tensors) | |||
{ | |||
var tape_set = tf.GetTapeSet(); | |||
bool some_tape_watching = false; | |||
if(tape_set is not null && tape_set.Count > 0) | |||
{ | |||
foreach(var tape in tape_set) | |||
{ | |||
if (tape.ShouldRecord(tensors)) | |||
{ | |||
if(tape.Persistent || some_tape_watching) | |||
{ | |||
return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER; | |||
} | |||
some_tape_watching = true; | |||
} | |||
} | |||
} | |||
// skip the forward_accumulators. | |||
if (some_tape_watching) | |||
{ | |||
return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER; | |||
} | |||
else | |||
{ | |||
return POSSIBLE_GRADIENT_TYPES_NONE; | |||
} | |||
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors); | |||
} | |||
/// <summary> | |||
@@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable | |||
return tensor; | |||
} | |||
public void watch_variable(IVariableV1 v) | |||
{ | |||
if (_resource_tensor_inputs.Contains(v.Handle)) | |||
{ | |||
return; | |||
} | |||
_watched_variables.Add(new WeakReference<IVariableV1>(v)); | |||
//this = this.outer_graph; | |||
} | |||
Tensor capture_eager_tensor(Tensor tensor, string name) | |||
{ | |||
Tensor graph_const = null; | |||
@@ -4,10 +4,10 @@ public interface IOptimizer | |||
{ | |||
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); | |||
Tensor[] clip_gradients(Tensor[] grads); | |||
void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true); | |||
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true); | |||
} |
@@ -208,9 +208,9 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); | |||
//[DllImport(TensorFlowLibName)] | |||
//public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
//[DllImport(TensorFlowLibName)] | |||
//public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); | |||
} | |||
} |
@@ -39,7 +39,7 @@ namespace Tensorflow | |||
if (config is null) | |||
{ | |||
config = function_utils.get_disabled_rewriter_config().ToString(); | |||
config = function_utils.get_disabled_rewriter_config().ToStringUtf8(); | |||
} | |||
if (executor_type is null) | |||
@@ -49,6 +49,8 @@ namespace Tensorflow | |||
if (executing_eagerly) | |||
{ | |||
// TODO(Rinne): implement it. | |||
throw new NotImplementedException(); | |||
} | |||
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -210,7 +211,51 @@ namespace Tensorflow | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor`. Has the same type as `value`.</returns> | |||
public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
{ | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
try | |||
{ | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value)); | |||
return _result[0]; | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
try | |||
{ | |||
return fill_eager_fallback(dims, value as Tensor, name, ctx); | |||
} | |||
catch (Exception) | |||
{ | |||
} | |||
} | |||
Dictionary<string, object> attrs = new Dictionary<string, object>(); | |||
attrs["dims"] = dims; | |||
attrs["value"] = value; | |||
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return result.output; | |||
} | |||
public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() }; | |||
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return _result[0]; | |||
} | |||
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); | |||
/// <summary> | |||
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast. | |||
@@ -49,8 +49,10 @@ namespace Tensorflow.Operations | |||
target_t.HandleData = handle_data; | |||
return; | |||
} | |||
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.) | |||
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray()); | |||
Status status = new(); | |||
var proto = handle_data.ToByteArray(); | |||
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status); | |||
status.Check(true); | |||
} | |||
public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); | |||
@@ -25,6 +25,7 @@ using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
using System.Buffers; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow | |||
{ | |||
@@ -302,5 +303,18 @@ namespace Tensorflow | |||
// return handle_data_util.get_resource_handle_data(handle); | |||
//} | |||
} | |||
public static void variable_accessed(IVariableV1 variable) | |||
{ | |||
if (ops.get_default_graph() is FuncGraph func_graph) | |||
{ | |||
func_graph.watch_variable(variable); | |||
} | |||
if (variable.Trainable) | |||
{ | |||
foreach (var tape in tf.GetTapeSet()) | |||
tape.VariableAccessed(variable); | |||
} | |||
} | |||
} | |||
} |
@@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||
<PackageReference Include="OneOf" Version="3.0.223" /> | |||
<PackageReference Include="Protobuf.Text" Version="0.6.2" /> | |||
<PackageReference Include="Protobuf.Text" Version="0.7.0" /> | |||
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | |||
</ItemGroup> | |||
@@ -30,7 +30,7 @@ namespace Tensorflow | |||
{ | |||
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); | |||
public Tensor() | |||
protected Tensor() | |||
{ | |||
} | |||
@@ -108,6 +108,7 @@ namespace Tensorflow | |||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
{ | |||
_handle = TF_NewTensor(shape, dtype, null); | |||
_id = ops.uid(); | |||
} | |||
protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | |||
@@ -116,6 +117,7 @@ namespace Tensorflow | |||
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | |||
else | |||
_handle = TF_NewTensor(bytes, shape, dtype); | |||
_id = ops.uid(); | |||
} | |||
protected unsafe void InitTensor(Array array, Shape? shape = null) | |||
@@ -166,6 +168,8 @@ namespace Tensorflow | |||
string[] val => StringTensor(val, shape), | |||
_ => throw new NotImplementedException("") | |||
}; | |||
_id = ops.uid(); | |||
} | |||
unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged | |||
@@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
{ | |||
IEnumerable<ConcreteFunction> _concrete_functions; | |||
FunctionSpec _function_spec; | |||
public IEnumerable<ConcreteFunction> ConcreteFunctions => _concrete_functions; | |||
public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec, | |||
IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false) | |||
{ | |||
@@ -25,6 +25,19 @@ namespace Tensorflow.Util | |||
} | |||
} | |||
public Tv SetDefault(Tk key, Tv default_value) | |||
{ | |||
if(TryGetValue(key, out var res)) | |||
{ | |||
return res; | |||
} | |||
else | |||
{ | |||
base[key] = default_value; | |||
return base[key]; | |||
} | |||
} | |||
public void push_back(Tk key, Tv value) | |||
=> this[key] = value; | |||
@@ -9,6 +9,7 @@ using System.Diagnostics; | |||
using Tensorflow.Checkpoint; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using OneOf; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow | |||
{ | |||
@@ -193,6 +194,10 @@ namespace Tensorflow | |||
/// </summary> | |||
void variable_accessed(BaseResourceVariable variable) | |||
{ | |||
if(ops.get_default_graph() is FuncGraph func_graph) | |||
{ | |||
func_graph.watch_variable(variable as IVariableV1); | |||
} | |||
if (variable.Trainable) | |||
{ | |||
foreach (var tape in tf.GetTapeSet()) | |||
@@ -575,12 +575,8 @@ namespace Tensorflow | |||
public static HandleData get_resource_handle_data(Tensor graph_op) | |||
{ | |||
throw new NotImplementedException(); | |||
// This implementation hasn't been checked for some reasons. | |||
// If it throws an exception in the future, please check it. | |||
//var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
//return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
} | |||
public static void dismantle_graph(Graph graph) | |||
@@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Train; | |||
using Tensorflow.Training; | |||
using Tensorflow.Training.Saving.SavedModel; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
@@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine | |||
/// the layer's weights. | |||
/// </summary> | |||
protected bool built; | |||
public bool Built => built; | |||
public bool Built | |||
{ | |||
get | |||
{ | |||
return built; | |||
} | |||
internal set | |||
{ | |||
built = value; | |||
} | |||
} | |||
public bool Trainable => args.Trainable; | |||
public TF_DataType DType => args.DType; | |||
public bool AutoCast => args.Autocast; | |||
@@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
protected List<ILayer> _self_tracked_trackables; | |||
/// <summary> | |||
/// If this value is set, the behavior of layer call will be changed to directly calling this function. | |||
/// </summary> | |||
public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null; | |||
public Layer(LayerArgs args) | |||
{ | |||
Initialize(args); | |||
@@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine | |||
/// <returns></returns> | |||
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
if(ReplacedCall is not null) | |||
{ | |||
return ReplacedCall(inputs); | |||
} | |||
return inputs; | |||
} | |||
@@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
using var tape = tf.GradientTape(); | |||
//foreach (var variable in TrainableVariables) | |||
//{ | |||
// tape.watch(variable.Handle); | |||
//} | |||
var y_pred = Apply(x, training: true); | |||
var loss = compiled_loss.Call(y, y_pred); | |||
@@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine | |||
gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); | |||
gradients = optimizer.clip_gradients(gradients); | |||
optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), | |||
optimizer.apply_gradients(zip(gradients, trainable_variables), | |||
experimental_aggregate_gradients: false); | |||
} | |||
} | |||
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers | |||
_set_hyper("decay", args.InitialDecay); | |||
} | |||
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
public void apply_gradients((Tensor, IVariableV1) grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true) | |||
=> apply_gradients(new[] { grads_and_vars }, | |||
@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers | |||
/// <param name="grads_and_vars"></param> | |||
/// <param name="name"></param> | |||
/// <param name="experimental_aggregate_gradients"></param> | |||
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
string name = null, | |||
bool experimental_aggregate_gradients = true) | |||
{ | |||
@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers | |||
}); | |||
} | |||
void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state) | |||
{ | |||
_resource_apply_dense(var, grad, apply_state); | |||
// if var.constraint is not None: | |||
@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers | |||
throw new NotImplementedException("_resource_apply_dense"); | |||
} | |||
void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars, | |||
void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
string name, | |||
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) | |||
{ | |||
@@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving | |||
/// <param name="layers"></param> | |||
private void _finalize_saved_model_layers(List<Layer> layers) | |||
{ | |||
foreach(var layer in layers) | |||
{ | |||
layer.Built = true; | |||
var keras_attr = _get_keras_attr(layer); | |||
if(keras_attr is not Trackable trackable) | |||
{ | |||
continue; | |||
} | |||
if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call)) | |||
{ | |||
Debug.Assert(layer_call is RestoredFunction); | |||
var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions; | |||
if (concrete_functions is not null && concrete_functions.Count() > 0) | |||
{ | |||
layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply); | |||
} | |||
} | |||
} | |||
foreach(var layer in layers) | |||
{ | |||
// TODO(Rinne): deal with `RevivedNetwork`. | |||
@@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving | |||
} | |||
} | |||
private Func<Tensors, Tensors> use_wrapped_call(Layer layer, Func<Tensors, Tensors> call) | |||
{ | |||
// TODO(Rinne): revise it. | |||
return call; | |||
} | |||
private void _restore_layer_unconditional_losses(Layer layer) | |||
{ | |||
// TODO(Rinne): implement it. | |||
@@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
return _config; | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
{ | |||
return base.Call(inputs, state, training); | |||
} | |||
else | |||
{ | |||
return (func as Function).Apply(inputs); | |||
} | |||
} | |||
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
//{ | |||
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function) | |||
// { | |||
// return base.Call(inputs, state, training); | |||
// } | |||
// else | |||
// { | |||
// return (func as Function).Apply(inputs); | |||
// } | |||
//} | |||
} | |||
} |
@@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
//base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), | |||
// functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) | |||
base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), | |||
functions.Concat(new string[] { })) | |||
functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })) | |||
{ | |||
} | |||
@@ -64,23 +64,19 @@ public class SequentialModelLoad | |||
var model = tf.keras.models.load_model(@"Assets/python_func_model"); | |||
model.summary(); | |||
var x = tf.random.uniform((8, 784), -1, 1); | |||
var y = model.Apply(x); | |||
Console.WriteLine(y); | |||
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
//model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
//var data_loader = new MnistModelLoader(); | |||
//var num_epochs = 1; | |||
//var batch_size = 8; | |||
var data_loader = new MnistModelLoader(); | |||
var num_epochs = 1; | |||
var batch_size = 8; | |||
//var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
//{ | |||
// TrainDir = "mnist", | |||
// OneHot = false, | |||
// ValidationSize = 58000, | |||
//}).Result; | |||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
{ | |||
TrainDir = "mnist", | |||
OneHot = false, | |||
ValidationSize = 55000, | |||
}).Result; | |||
//model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
} | |||
} |