diff --git a/docs/assets/performance-comparison.jpg b/docs/assets/performance-comparison.jpg new file mode 100644 index 00000000..382f7ab6 Binary files /dev/null and b/docs/assets/performance-comparison.jpg differ diff --git a/src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs b/src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs new file mode 100644 index 00000000..f4908c71 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + public Tensor[] CallBackwardFunction(BackwardFunction backward_function, + List unneeded_gradients, + List output_gradients) + { + var grads = new Tensor[output_gradients.Count]; + var result = backward_function(output_gradients.ToArray(), + unneeded_gradients.ToArray()); + + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs new file mode 100644 index 00000000..94e0d3ee --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs @@ -0,0 +1,249 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Util; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + int kMinAggregateCount = 4; + int kMinAggregateBytes = 128 * 1024 * 1024; + + public Tensor[] ComputeGradient(long[] target_tensor_ids, + long[] source_tensor_ids, + UnorderedMap sources_that_are_targets, + Tensor[] output_gradients) + { + var result = new List(source_tensor_ids.Length); + var sources_set = new UnorderedSet(source_tensor_ids); + var gradients_size = new UnorderedMap(); + + var state = PrepareBackprop( + target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_); + var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); + var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, + output_gradients, + tensor_tape_, + state.op_tape); + + while (op_stack.Count > 0) + { + var op = op_stack.Dequeue(); + if (!state.op_tape.find(op, out var trace)) + continue; + + state.op_tape.erase(op); + + var out_gradients = new List(trace.output_tensor_info.Length); + var unneeded_gradients = new List(); + for (int i = 0; i < trace.input_tensor_id.Length; i++) + { + var in_tensor_id = trace.input_tensor_id[i]; + if (!tensor_tape_.find(in_tensor_id) && + !sources_set.find(in_tensor_id)) + unneeded_gradients.Add(i); + } + + bool any_gradient_nonzero = false; + var zero_indices = new List(); + for (int i = 0; i < trace.output_tensor_info.Length; ++i) + { + var id = trace.output_tensor_info[i].GetID(); + if (!gradients.find(id, out var grad_it)) + { + throw new NotImplementedException("FunctionsAcceptingNoneForIndicesMap"); + } + else + { + any_gradient_nonzero = true; + var new_gradients = grad_it.Count == 1 ? + grad_it[0] : + gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients + + if (!sources_set.find(id)) + gradients.Remove(id); + else + { + grad_it.Clear(); + grad_it.Add(new_gradients); + // vspace.MarkAsResult(new_gradients); + } + out_gradients.Add(new_gradients); + } + } + + Tensor[] in_gradients; + if (any_gradient_nonzero) + { + foreach (var i in zero_indices) + out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); + + in_gradients = CallBackwardFunction(trace.backward_function, + unneeded_gradients, + out_gradients); + + if (in_gradients.Count() != trace.input_tensor_id.Count()) + throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); + if (!persistent_) + { + // trace.backward_function_deleter(trace.backward_function); + } + } + else + { + throw new NotImplementedException(""); + } + + for (int i = 0; i < in_gradients.Length; ++i) + { + var id = trace.input_tensor_id[i]; + if (in_gradients[i] != null) + { + var unaggregated_grads = gradients[id]; + unaggregated_grads.Add(in_gradients[i]); + if(unaggregated_grads.Count > kMinAggregateCount) + { + if(!gradients_size.ContainsKey(id)) + { + } + else + { + + } + + throw new NotImplementedException(""); + } + } + + if (!state.tensor_usage_counts.find(id)) + continue; + + state.tensor_usage_counts[id]--; + if (state.tensor_usage_counts[id] > 0) + continue; + + if (!tensor_tape_.find(id, out var tape_it)) + { + if (gradients.find(id, out var grad_it)) + { + // foreach (var g in grad_it) + // DeleteGradient(g); + gradients.erase(id); + } + continue; + } + + var op_id = tape_it; + if (op_id == -1) + continue; + + if(state.op_missing_tensor.find(op_id, out var missing_it)) + { + state.op_missing_tensor[op_id]--; + if (state.op_missing_tensor[op_id] == 0) + op_stack.Enqueue(op_id); + } + } + } + + if (state.op_tape.Count > 0) + throw new RuntimeError("Invalid tape state."); + + var used_gradient_ids = new List(source_tensor_ids.Length); + foreach (var id in source_tensor_ids) + { + if (!gradients.find(id, out var grad_it)) + result.Add(null); + else + { + if(grad_it.Count > 1) + { + var grad = gen_math_ops.add_n(grad_it.ToArray()); + grad_it.Clear(); + grad_it.Add(grad); + } + result.Add(grad_it[0]); + used_gradient_ids.Add(id); + } + } + + /*foreach(var grad_pair in gradients) + { + if(!used_gradient_ids.Contains(grad_pair.Key)) + { + foreach(var g in grad_pair.Value) + { + vspace.DeleteGradient(g); + } + } + }*/ + + return result.ToArray(); + } + + UnorderedMapEnumerable> InitialGradients(long[] target_tensor_ids, + UnorderedMap sources_that_are_targets, + Tensor[] output_gradients, + TensorTape tensor_tape, + OpTape op_tape) + { + var result = new UnorderedMapEnumerable>(); + for (int i = 0; i < target_tensor_ids.Length; ++i) + { + var id = target_tensor_ids[i]; + if (output_gradients.Length == 0 || output_gradients[i] == null) + { + if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1) + { + if (!op_tape.find(tensor_tape[id], out var op_it)) + throw new RuntimeError("Internal state of the gradient tape is invalid: " + + "failed to find operation producing a tensor"); + bool found = false; + for (int j = 0; j < op_it.output_tensor_info.Length; ++j) + { + if (op_it.output_tensor_info[j].GetID() == id) + { + found = true; + var ones = op_it.output_tensor_info[j].OnesLike(); + result[id].Add(ones); + break; + } + } + + if (!found) + { + throw new ValueError("Internal state of the gradient tape is invalid: " + + "none of operations outputs match expected tensor"); + } + } + else + { + if (sources_that_are_targets.find(id, out var source_tensor)) + result[id].Add(source_tensor.OnesLike()); + } + } + else + { + result[id].Add(output_gradients[i]); + } + } + + return result; + } + + Queue InitialStack(OpTape op_tape, + UnorderedMap op_missing_tensor) + { + var result = new Queue(); + foreach(var op_entry in op_tape) + { + if (!op_missing_tensor.find(op_entry.Key)) + result.Enqueue(op_entry.Key); + } + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs new file mode 100644 index 00000000..55e6d7e4 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs @@ -0,0 +1,72 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Util; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + public BackpropInitialState PrepareBackprop(long[] target, + TensorTape tensor_tape, + OpTape op_tape, + UnorderedSet sources_set, + bool persistent_tape) + { + BackpropInitialState result = new BackpropInitialState(); + var tensor_stack = new Queue(target); + while (tensor_stack.Count > 0) + { + var tensor_id = tensor_stack.Dequeue(); + + if (!tensor_tape.find(tensor_id, out var op_id)) + continue; + + if (op_id == -1 || + !op_tape.find(op_id, out var op_it) || + result.op_tape.find(op_id, out var result_op_it)) + continue; + + result.op_tape.emplace(op_id, op_it); + + foreach (var it in op_it.input_tensor_id) + { + if(result.tensor_usage_counts.find(it)) + result.tensor_usage_counts[it]++; + else + { + result.tensor_usage_counts[it] = 1; + if (tensor_tape.find(it)) + tensor_stack.Enqueue(it); + } + } + + if (!persistent_tape) + op_tape.Remove(op_id); + } + + foreach (var pair in result.tensor_usage_counts) + { + if (tensor_tape.find(pair.Key, out var it) && it != -1) + result.op_missing_tensor[it] += 1; + } + + if (!persistent_tape) + { + // Call destructors for all unneeded gradient functions and + // clear the op_tape. We can clear the tape because ownership of + // backward functions that will be used for gradient computation + // has been transferred to `result`. + /*for (const auto&op_pair : *op_tape) { + op_pair.second.backward_function_deleter( + op_pair.second.backward_function); + }*/ + op_tape.Clear(); + } + + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs new file mode 100644 index 00000000..79595112 --- /dev/null +++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Util; +using static Tensorflow.tensorflow; + +namespace Tensorflow.Gradients +{ + public partial class Tape + { + long next_op_id_ = 0; + UnorderedMap tensor_usage_; + + public void RecordOperation(string op_type, + Tensor[] input_tensors, + TapeTensor[] output_tensors, + long[] input_tensor_id, + TF_DataType[] input_dtypes, + Func backward_function_getter) + { + if (!ShouldRecord(input_tensor_id, input_dtypes)) + { + return; + } + + long op_id = next_op_id_++; + var ids = new List(input_tensor_id.Length); + foreach (var i in input_tensor_id) + { + tensor_usage_[i]++; + ids.Add(i); + } + + var tensors = new List(output_tensors.Length); + foreach (var o in output_tensors) + { + tensor_tape_[o.GetID()] = op_id; + tensor_usage_[o.GetID()] = 1; + tensors.Add(o); + } + + op_tape_[op_id] = new OpTapeEntry + { + op_type = op_type, + output_tensor_info = tensors.ToArray(), + input_tensor_id = ids.ToArray(), + backward_function = backward_function_getter() + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index 663b3ef2..af6134c7 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -1,50 +1,112 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Util; +using static Tensorflow.Binding; +using static Tensorflow.tensorflow; namespace Tensorflow.Gradients { - public class Tape : ITape + public partial class Tape : ITape { + int nesting_id; + static int tape_nesting_id_counter = 0; + bool persistent_; + bool watch_accessed_variables; + TensorTape tensor_tape_; + OpTape op_tape_; + + /// + /// A deque-backed stack, whose element references are not invalidated by + /// pushes and pops at the back. + /// + Stack call_state_; + public Tape(bool persistent, bool watch_accessed_variables) { + this.persistent_ = persistent; + this.watch_accessed_variables = watch_accessed_variables; - } + tensor_tape_ = new TensorTape(); + op_tape_ = new OpTape(); + tensor_usage_ = new UnorderedMap(); - public Tensor[] ComputeGradient(long[] target_tensor_ids, long[] source_tensor_ids, UnorderedMap sources_that_are_targets, Tensor[] output_gradients) - { - throw new NotImplementedException(); + nesting_id = ++tape_nesting_id_counter; + tf.GetTapeSet().Add(this); } - public void PopTape(ITape tape) + /// + /// Marks this tensor to be watched by the given tape. + /// + /// + public void Watch(long tensor_id) { - throw new NotImplementedException(); + if (!CouldBackprop()) + return; + + tensor_tape_.emplace(tensor_id, -1); } - public void RecordOperation(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, long[] input_tensor_id, TF_DataType[] input_dtypes, Func backward_function_getter) + public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes) { - throw new NotImplementedException(); + for (int i = 0; i < tensor_ids.Length; ++i) + { + if (tensor_tape_.find(tensor_ids[i])) + if (IsDtypeTrainable(dtypes[i])) + return true; + } + return false; } - public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes) + /// + /// Pops the given tape in the stack. + /// + /// + public void PopTape(ITape tape) { - throw new NotImplementedException(); + tf.GetTapeSet().Remove(tape); } public void VariableAccessed(ResourceVariable variable) { - throw new NotImplementedException(); + Watch(variable.Handle.Id); } - public void Watch(long tensor_id) + public ResourceVariable[] WatchedVariables() { - throw new NotImplementedException(); + return null; } - public ResourceVariable[] WatchedVariables() + public bool IsDtypeTrainable(TF_DataType dtype) { - throw new NotImplementedException(); + switch (dtype) + { + case TF_DataType.TF_HALF: + case TF_DataType.TF_BFLOAT16: + case TF_DataType.TF_FLOAT: + case TF_DataType.TF_DOUBLE: + case TF_DataType.TF_COMPLEX64: + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_RESOURCE: + case TF_DataType.TF_VARIANT: + return true; + default: + return false; + } } + + bool CouldForwardprop() + => HasAccumulator(); + + bool CouldBackprop() + => HasGradientTape(); + + bool HasAccumulator() + //return !GetAccumulatorSet()->empty(); + => false; + + bool HasGradientTape() + => tf.GetTapeSet().Count > 0; } }