Browse Source

Implement Tape.

tags/v0.20
Oceania2018 5 years ago
parent
commit
9dbd51bd26
6 changed files with 472 additions and 16 deletions
  1. BIN
      docs/assets/performance-comparison.jpg
  2. +22
    -0
      src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs
  3. +249
    -0
      src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
  4. +72
    -0
      src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
  5. +51
    -0
      src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
  6. +78
    -16
      src/TensorFlowNET.Core/Gradients/Tape.cs

BIN
docs/assets/performance-comparison.jpg View File

Before After
Width: 666  |  Height: 1525  |  Size: 213 kB

+ 22
- 0
src/TensorFlowNET.Core/Gradients/Tape.CallBackwardFunction.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.tensorflow;

namespace Tensorflow.Gradients
{
public partial class Tape
{
public Tensor[] CallBackwardFunction(BackwardFunction backward_function,
List<long> unneeded_gradients,
List<Tensor> output_gradients)
{
var grads = new Tensor[output_gradients.Count];
var result = backward_function(output_gradients.ToArray(),
unneeded_gradients.ToArray());

return result;
}
}
}

+ 249
- 0
src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs View File

@@ -0,0 +1,249 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.tensorflow;

namespace Tensorflow.Gradients
{
public partial class Tape
{
int kMinAggregateCount = 4;
int kMinAggregateBytes = 128 * 1024 * 1024;

public Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients)
{
var result = new List<Tensor>(source_tensor_ids.Length);
var sources_set = new UnorderedSet<long>(source_tensor_ids);
var gradients_size = new UnorderedMap<long, long>();

var state = PrepareBackprop(
target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
output_gradients,
tensor_tape_,
state.op_tape);

while (op_stack.Count > 0)
{
var op = op_stack.Dequeue();
if (!state.op_tape.find(op, out var trace))
continue;

state.op_tape.erase(op);

var out_gradients = new List<Tensor>(trace.output_tensor_info.Length);
var unneeded_gradients = new List<long>();
for (int i = 0; i < trace.input_tensor_id.Length; i++)
{
var in_tensor_id = trace.input_tensor_id[i];
if (!tensor_tape_.find(in_tensor_id) &&
!sources_set.find(in_tensor_id))
unneeded_gradients.Add(i);
}

bool any_gradient_nonzero = false;
var zero_indices = new List<int>();
for (int i = 0; i < trace.output_tensor_info.Length; ++i)
{
var id = trace.output_tensor_info[i].GetID();
if (!gradients.find(id, out var grad_it))
{
throw new NotImplementedException("FunctionsAcceptingNoneForIndicesMap");
}
else
{
any_gradient_nonzero = true;
var new_gradients = grad_it.Count == 1 ?
grad_it[0] :
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients

if (!sources_set.find(id))
gradients.Remove(id);
else
{
grad_it.Clear();
grad_it.Add(new_gradients);
// vspace.MarkAsResult(new_gradients);
}
out_gradients.Add(new_gradients);
}
}

Tensor[] in_gradients;
if (any_gradient_nonzero)
{
foreach (var i in zero_indices)
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

in_gradients = CallBackwardFunction(trace.backward_function,
unneeded_gradients,
out_gradients);

if (in_gradients.Count() != trace.input_tensor_id.Count())
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
if (!persistent_)
{
// trace.backward_function_deleter(trace.backward_function);
}
}
else
{
throw new NotImplementedException("");
}

for (int i = 0; i < in_gradients.Length; ++i)
{
var id = trace.input_tensor_id[i];
if (in_gradients[i] != null)
{
var unaggregated_grads = gradients[id];
unaggregated_grads.Add(in_gradients[i]);
if(unaggregated_grads.Count > kMinAggregateCount)
{
if(!gradients_size.ContainsKey(id))
{
}
else
{

}

throw new NotImplementedException("");
}
}

if (!state.tensor_usage_counts.find(id))
continue;

state.tensor_usage_counts[id]--;
if (state.tensor_usage_counts[id] > 0)
continue;

if (!tensor_tape_.find(id, out var tape_it))
{
if (gradients.find(id, out var grad_it))
{
// foreach (var g in grad_it)
// DeleteGradient(g);
gradients.erase(id);
}
continue;
}

var op_id = tape_it;
if (op_id == -1)
continue;

if(state.op_missing_tensor.find(op_id, out var missing_it))
{
state.op_missing_tensor[op_id]--;
if (state.op_missing_tensor[op_id] == 0)
op_stack.Enqueue(op_id);
}
}
}

if (state.op_tape.Count > 0)
throw new RuntimeError("Invalid tape state.");

var used_gradient_ids = new List<long>(source_tensor_ids.Length);
foreach (var id in source_tensor_ids)
{
if (!gradients.find(id, out var grad_it))
result.Add(null);
else
{
if(grad_it.Count > 1)
{
var grad = gen_math_ops.add_n(grad_it.ToArray());
grad_it.Clear();
grad_it.Add(grad);
}
result.Add(grad_it[0]);
used_gradient_ids.Add(id);
}
}

/*foreach(var grad_pair in gradients)
{
if(!used_gradient_ids.Contains(grad_pair.Key))
{
foreach(var g in grad_pair.Value)
{
vspace.DeleteGradient(g);
}
}
}*/

return result.ToArray();
}

UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients,
TensorTape tensor_tape,
OpTape<BackwardFunction, TapeTensor> op_tape)
{
var result = new UnorderedMapEnumerable<long, List<Tensor>>();
for (int i = 0; i < target_tensor_ids.Length; ++i)
{
var id = target_tensor_ids[i];
if (output_gradients.Length == 0 || output_gradients[i] == null)
{
if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1)
{
if (!op_tape.find(tensor_tape[id], out var op_it))
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"failed to find operation producing a tensor");
bool found = false;
for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
{
if (op_it.output_tensor_info[j].GetID() == id)
{
found = true;
var ones = op_it.output_tensor_info[j].OnesLike();
result[id].Add(ones);
break;
}
}

if (!found)
{
throw new ValueError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor");
}
}
else
{
if (sources_that_are_targets.find(id, out var source_tensor))
result[id].Add(source_tensor.OnesLike());
}
}
else
{
result[id].Add(output_gradients[i]);
}
}

return result;
}

Queue<long> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape,
UnorderedMap<long, long> op_missing_tensor)
{
var result = new Queue<long>();
foreach(var op_entry in op_tape)
{
if (!op_missing_tensor.find(op_entry.Key))
result.Enqueue(op_entry.Key);
}
return result;
}
}
}

+ 72
- 0
src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs View File

@@ -0,0 +1,72 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.tensorflow;

namespace Tensorflow.Gradients
{
public partial class Tape
{
public BackpropInitialState PrepareBackprop(long[] target,
TensorTape tensor_tape,
OpTape<BackwardFunction, TapeTensor> op_tape,
UnorderedSet<long> sources_set,
bool persistent_tape)
{
BackpropInitialState result = new BackpropInitialState();
var tensor_stack = new Queue<long>(target);
while (tensor_stack.Count > 0)
{
var tensor_id = tensor_stack.Dequeue();

if (!tensor_tape.find(tensor_id, out var op_id))
continue;

if (op_id == -1 ||
!op_tape.find(op_id, out var op_it) ||
result.op_tape.find(op_id, out var result_op_it))
continue;

result.op_tape.emplace(op_id, op_it);

foreach (var it in op_it.input_tensor_id)
{
if(result.tensor_usage_counts.find(it))
result.tensor_usage_counts[it]++;
else
{
result.tensor_usage_counts[it] = 1;
if (tensor_tape.find(it))
tensor_stack.Enqueue(it);
}
}

if (!persistent_tape)
op_tape.Remove(op_id);
}

foreach (var pair in result.tensor_usage_counts)
{
if (tensor_tape.find(pair.Key, out var it) && it != -1)
result.op_missing_tensor[it] += 1;
}

if (!persistent_tape)
{
// Call destructors for all unneeded gradient functions and
// clear the op_tape. We can clear the tape because ownership of
// backward functions that will be used for gradient computation
// has been transferred to `result`.
/*for (const auto&op_pair : *op_tape) {
op_pair.second.backward_function_deleter(
op_pair.second.backward_function);
}*/
op_tape.Clear();
}

return result;
}
}
}

+ 51
- 0
src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs View File

@@ -0,0 +1,51 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.tensorflow;

namespace Tensorflow.Gradients
{
public partial class Tape
{
long next_op_id_ = 0;
UnorderedMap<long, long> tensor_usage_;

public void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter)
{
if (!ShouldRecord(input_tensor_id, input_dtypes))
{
return;
}

long op_id = next_op_id_++;
var ids = new List<long>(input_tensor_id.Length);
foreach (var i in input_tensor_id)
{
tensor_usage_[i]++;
ids.Add(i);
}

var tensors = new List<TapeTensor>(output_tensors.Length);
foreach (var o in output_tensors)
{
tensor_tape_[o.GetID()] = op_id;
tensor_usage_[o.GetID()] = 1;
tensors.Add(o);
}
op_tape_[op_id] = new OpTapeEntry<BackwardFunction, TapeTensor>
{
op_type = op_type,
output_tensor_info = tensors.ToArray(),
input_tensor_id = ids.ToArray(),
backward_function = backward_function_getter()
};
}
}
}

+ 78
- 16
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -1,50 +1,112 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Util;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;

namespace Tensorflow.Gradients
{
public class Tape : ITape
public partial class Tape : ITape
{
int nesting_id;
static int tape_nesting_id_counter = 0;
bool persistent_;
bool watch_accessed_variables;
TensorTape tensor_tape_;
OpTape<BackwardFunction, TapeTensor> op_tape_;

/// <summary>
/// A deque-backed stack, whose element references are not invalidated by
/// pushes and pops at the back.
/// </summary>
Stack<AccumulatorCallState> call_state_;

public Tape(bool persistent, bool watch_accessed_variables)
{
this.persistent_ = persistent;
this.watch_accessed_variables = watch_accessed_variables;

}
tensor_tape_ = new TensorTape();
op_tape_ = new OpTape<BackwardFunction, TapeTensor>();
tensor_usage_ = new UnorderedMap<long, long>();

public Tensor[] ComputeGradient(long[] target_tensor_ids, long[] source_tensor_ids, UnorderedMap<long, TapeTensor> sources_that_are_targets, Tensor[] output_gradients)
{
throw new NotImplementedException();
nesting_id = ++tape_nesting_id_counter;
tf.GetTapeSet().Add(this);
}

public void PopTape(ITape tape)
/// <summary>
/// Marks this tensor to be watched by the given tape.
/// </summary>
/// <param name="x"></param>
public void Watch(long tensor_id)
{
throw new NotImplementedException();
if (!CouldBackprop())
return;

tensor_tape_.emplace(tensor_id, -1);
}

public void RecordOperation(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, long[] input_tensor_id, TF_DataType[] input_dtypes, Func<tensorflow.BackwardFunction> backward_function_getter)
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes)
{
throw new NotImplementedException();
for (int i = 0; i < tensor_ids.Length; ++i)
{
if (tensor_tape_.find(tensor_ids[i]))
if (IsDtypeTrainable(dtypes[i]))
return true;
}
return false;
}

public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes)
/// <summary>
/// Pops the given tape in the stack.
/// </summary>
/// <param name="tape"></param>
public void PopTape(ITape tape)
{
throw new NotImplementedException();
tf.GetTapeSet().Remove(tape);
}

public void VariableAccessed(ResourceVariable variable)
{
throw new NotImplementedException();
Watch(variable.Handle.Id);
}

public void Watch(long tensor_id)
public ResourceVariable[] WatchedVariables()
{
throw new NotImplementedException();
return null;
}

public ResourceVariable[] WatchedVariables()
public bool IsDtypeTrainable(TF_DataType dtype)
{
throw new NotImplementedException();
switch (dtype)
{
case TF_DataType.TF_HALF:
case TF_DataType.TF_BFLOAT16:
case TF_DataType.TF_FLOAT:
case TF_DataType.TF_DOUBLE:
case TF_DataType.TF_COMPLEX64:
case TF_DataType.TF_COMPLEX128:
case TF_DataType.TF_RESOURCE:
case TF_DataType.TF_VARIANT:
return true;
default:
return false;
}
}

bool CouldForwardprop()
=> HasAccumulator();

bool CouldBackprop()
=> HasGradientTape();

bool HasAccumulator()
//return !GetAccumulatorSet()->empty();
=> false;

bool HasGradientTape()
=> tf.GetTapeSet().Count > 0;
}
}

Loading…
Cancel
Save