Browse Source

Fix the error of loaded function model backward.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
9420ba3243
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
45 changed files with 870 additions and 409 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  2. +25
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
  3. +12
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  4. +162
    -17
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  5. +4
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
  6. +13
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
  7. +8
    -1
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  8. +21
    -4
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  9. +9
    -10
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  10. +4
    -5
      src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
  11. +3
    -5
      src/TensorFlowNET.Core/Functions/Function.cs
  12. +105
    -52
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  13. +6
    -6
      src/TensorFlowNET.Core/Functions/TracingCompiler.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Functions/monomorphic_function.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
  16. +27
    -8
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  17. +15
    -8
      src/TensorFlowNET.Core/Gradients/ITape.cs
  18. +2
    -2
      src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
  19. +157
    -125
      src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
  20. +31
    -32
      src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
  21. +21
    -10
      src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
  22. +10
    -10
      src/TensorFlowNET.Core/Gradients/Tape.cs
  23. +45
    -9
      src/TensorFlowNET.Core/Gradients/TapeTensor.cs
  24. +1
    -1
      src/TensorFlowNET.Core/Gradients/TensorTape.cs
  25. +1
    -26
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  26. +10
    -0
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  27. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
  28. +4
    -4
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  29. +3
    -1
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  30. +46
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  31. +4
    -2
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  32. +14
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  33. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  34. +5
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  35. +1
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  36. +13
    -0
      src/TensorFlowNET.Core/Util/UnorderedMap.cs
  37. +5
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  38. +2
    -6
      src/TensorFlowNET.Core/ops.cs
  39. +21
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  40. +1
    -5
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  41. +4
    -4
      src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
  42. +25
    -0
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  43. +11
    -11
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
  44. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  45. +11
    -15
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 1
- 0
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Google.Protobuf; using Google.Protobuf;
using Protobuf.Text;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Contexts namespace Tensorflow.Contexts


+ 25
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs View File

@@ -12,18 +12,36 @@ namespace Tensorflow.Eager
return HasGradientTape(); return HasGradientTape();
} }


private bool ShouldRecord(Tensor[] inputs)
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors)
{ {
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
var tape_set = tf.GetTapeSet();
var input_ids = MakeTensorIDList(tensors);
var input_dtypes = MakeTensorDtypeList(tensors);
bool some_tape_watching = false;
if (tape_set is not null && tape_set.Count > 0)
{ {
if (tape.ShouldRecord(inputs))
foreach (var tape in tape_set)
{ {
should_record = true;
break;
if (tape.ShouldRecord(input_ids, input_dtypes))
{
if (tape.Persistent || some_tape_watching)
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
}
some_tape_watching = true;
}
} }
} }
return should_record;
// skip the forward_accumulators.

if (some_tape_watching)
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
}
else
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE;
}
} }
} }
} }

+ 12
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -13,7 +13,17 @@ namespace Tensorflow.Eager
Tensor[] results, Tensor[] results,
BackwardFunction backwardFunction = null) BackwardFunction backwardFunction = null)
{ {
bool should_record = ShouldRecord(inputs);
var input_ids = MakeTensorIDList(inputs);
var input_dtypes = MakeTensorDtypeList(inputs);
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
{
if (tape.ShouldRecord(input_ids, input_dtypes))
{
should_record = true;
break;
}
}


if (!should_record) if (!should_record)
{ {
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager
op_inputs = inputs;*/ op_inputs = inputs;*/


backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results); backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results);
TapeSetRecordOperation(op_name, inputs, results, backwardFunction);
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction);


return true; return true;
} }
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager
{ {
return HasGradientTape(); return HasGradientTape();
} }

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
return tensors.Select(x => x.dtype).ToArray();
}
} }
} }

+ 162
- 17
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -1,6 +1,8 @@
using System;
using OneOf.Types;
using System;
using Tensorflow.Gradients; using Tensorflow.Gradients;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager
/// </summary> /// </summary>
public partial class EagerRunner public partial class EagerRunner
{ {
/// <summary>
///
/// </summary>
/// <param name="tape"></param>
/// <param name="target"></param>
/// <param name="sources"></param>
/// <param name="output_gradients"></param>
/// <param name="unconnected_gradients">determines the value returned if the target and
/// sources are unconnected.When 'none' the value returned is None wheras when
/// 'zero' a zero tensor in the same shape as the sources is returned.</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public Tensor[] TFE_TapeGradient(ITape tape, public Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target, Tensor[] target,
Tensor[] sources, Tensor[] sources,
Tensor[] output_gradients)
List<Tensor> output_gradients,
Tensor[] sources_raw,
string unconnected_gradients)
{ {
var target_vec = target;
var sources_vec = sources;
var sources_set = sources_vec;
if (!tape.Persistent)
{
var tape_set = tf.GetTapeSet();
if (tape_set.Contains(tape))
{
throw new RuntimeError("gradient() cannot be invoked within the " +
"GradientTape context (i.e., while operations are being " +
"recorded). Either move the call to gradient() to be " +
"outside the 'with tf.GradientTape' block, or " +
"use a persistent tape: " +
"'with tf.GradientTape(persistent=true)'");
}
}

var target_vec = MakeTensorIDList(target);
var sources_vec = MakeTensorIDList(sources);
HashSet<long> sources_set = new HashSet<long>(sources_vec);
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>();

int len = target.Length;
for(int i = 0; i < len; i++)
{
var target_id = target_vec[i];
if (sources_set.Contains(target_id))
{
var tensor = target[i];
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor);
}
}

List<Tensor> outgrad_vec = new();
if(output_gradients is not null)
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);


var seq_array = target;
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>();


for (int i = 0; i < target.Length; ++i)
bool unconnected_gradients_zero = unconnected_gradients == "zero";
Tensor[] sources_obj = null;
if (unconnected_gradients_zero)
{ {
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
sources_obj = MakeTensorList(sources_raw);
} }


if (output_gradients != null)
if (result.Length > 0)
{ {
throw new NotImplementedException("");
for(int i = 0; i < result.Length; i++)
{
if (result[i] is null && unconnected_gradients_zero)
{
var dtype = sources_obj[i].dtype;
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike();
}
}
} }
else
return result;
}

Tensor[] MakeTensorList(IEnumerable<Tensor> tensors)
{
return tensors.ToArray();
}

long[] MakeTensorIDList(Tensor[] tensors)
{
int len = tensors.Length;
long[] ids = new long[len];
for(int i = 0; i < len; i++)
{
var tensor = tensors[i];
ids[i] = tensor.Id;
}
return ids;
}

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
int len = tensors.Length;
TF_DataType[] dtypes = new TF_DataType[len];
for (int i = 0; i < len; i++)
{ {
output_gradients = new Tensor[0];
var tensor = tensors[i];
dtypes[i] = tensor.dtype;
} }
return dtypes;
}


var outgrad_vec = MakeTensorList(output_gradients);
TapeTensor TapeTensorFromTensor(Tensor tensor)
{
long id = tensor.Id;
var dtype = tensor.dtype;
if (tensor is EagerTensor)
{
var handle = tensor.EagerTensorHandle;
if (DTypeNeedsHandleData(dtype))
{
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor);
}

Status status = new();
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status);
long[] dims = new long[num_dims];
for(int i = 0; i < num_dims; i++)
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
Shape tensor_shape = new(dims);

if(status.Code != TF_Code.TF_OK)
{
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null);
}
else
{
return new TapeTensor(id, dtype, tensor_shape);
}
}
var shape_tuple = tensor.shape.dims;
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype))
{
return new TapeTensor(id, dtype, tensor);
}
long[] l = new long[shape_tuple.Length];
for(int i = 0; i < shape_tuple.Length; i++)
{
if (shape_tuple[i] < 0)
{
l[i] = 0;
}
else
{
l[i] = shape_tuple[i];
}
}
return new TapeTensor(id, dtype, new Shape(l));
}


return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec);
bool DTypeNeedsHandleData(TF_DataType dtype)
{
return dtype == dtypes.variant || dtype == dtypes.resource;
} }


Tensor[] MakeTensorList(Tensor[] tensors)
bool ListContainNone(long[] list)
{ {
return tensors;
int len = list.Length;
if(len == 0)
{
return true;
}
for(int i = 0; i < len; i++)
{
if (list[i] == -1)
{
return true;
}
}
return false;
} }
} }
} }

+ 4
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs View File

@@ -7,8 +7,9 @@ namespace Tensorflow.Eager
public partial class EagerRunner public partial class EagerRunner
{ {
void TapeSetRecordBackprop(string op_type, void TapeSetRecordBackprop(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
TapeTensor[] output_info,
long[] input_ids,
TF_DataType[] input_detyps,
BackwardFunction backward_function) BackwardFunction backward_function)
{ {
if (!CouldBackprop()) if (!CouldBackprop())
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager


foreach (var tape in tf.GetTapeSet()) foreach (var tape in tf.GetTapeSet())
{ {
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function);
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function);
} }
} }
} }


+ 13
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs View File

@@ -10,18 +10,28 @@ namespace Tensorflow.Eager
public bool TapeSetRecordOperation(string op_type, public bool TapeSetRecordOperation(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
Tensor[] output_tensors, Tensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
BackwardFunction backward_function) BackwardFunction backward_function)
{ {
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();

var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray();
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
backward_function)) backward_function))
return false; return false;


TapeSetRecordBackprop(op_type, input_tensors, output_info,
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes,
backward_function); backward_function);


return true; return true;
} }

public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
Tensor[] input_tensors, BackwardFunction backward_function)
{
var input_ids = MakeTensorIDList(input_tensors);
var input_dtypes = MakeTensorDtypeList(input_tensors);
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes,
backward_function);
}
} }
} }

+ 8
- 1
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -29,7 +29,14 @@ namespace Tensorflow.Eager
Tensor[] TFE_TapeGradient(ITape tape, Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target, Tensor[] target,
Tensor[] sources, Tensor[] sources,
Tensor[] output_gradients);
List<Tensor> output_gradients,
Tensor[] sources_raw,
string unconnected_gradients);

void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
Tensor[] input_tensors, BackwardFunction backward_function);

int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors);


bool RecordGradient(string op_name, bool RecordGradient(string op_name,
Tensor[] inputs, Tensor[] inputs,


+ 21
- 4
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -18,12 +18,13 @@ namespace Tensorflow.Functions
public class ConcreteFunction: Trackable public class ConcreteFunction: Trackable
{ {
protected IEnumerable<Tensor> _captured_inputs; protected IEnumerable<Tensor> _captured_inputs;
internal FuncGraph func_graph;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary<string, AttrValue> _attrs; protected Dictionary<string, AttrValue> _attrs;
protected FunctionSpec _function_spec; protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null; protected FunctionSpec _pre_initialized_function_spec = null;
protected EagerDefinedFunction _inference_function; protected EagerDefinedFunction _inference_function;
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new();
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward; internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs; public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures; public Tensor[] CapturedInputs => func_graph.external_captures;
@@ -156,6 +157,17 @@ namespace Tensorflow.Functions
{ {
var executing_eagerly = tf.Context.executing_eagerly(); var executing_eagerly = tf.Context.executing_eagerly();
var default_graph = ops.get_default_graph(); var default_graph = ops.get_default_graph();
// TODO(Rinne): deal with `default_graph.building_function`

var tempvv = func_graph.Variables;
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph)
{
foreach(var v in this.func_graph.Variables)
{
resource_variable_ops.variable_accessed(v);
}
}

var tensor_inputs = new Tensors(); var tensor_inputs = new Tensors();
foreach (var (i, arg) in enumerate(args)) foreach (var (i, arg) in enumerate(args))
{ {
@@ -223,11 +235,16 @@ namespace Tensorflow.Functions
{ {
input_tangents = new TangentInfo(); input_tangents = new TangentInfo();
} }
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient())
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
{ {
if(input_tangents.Indices is not null || executing_eagerly) if(input_tangents.Indices is not null || executing_eagerly)
{ {
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
string cache_key = "first_order";
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions))
{
functions = new FirstOrderTapeGradientFunctions(func_graph, false);
_tape_functions_cache[cache_key] = functions;
}
return new ForwardBackwardCall(functions, args, tape_watching: true); return new ForwardBackwardCall(functions, args, tape_watching: true);
} }
else else
@@ -241,7 +258,7 @@ namespace Tensorflow.Functions
} }


// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall. // TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall.
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient());
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false);
} }


internal void set_variables(IEnumerable<IVariableV1> variables) internal void set_variables(IEnumerable<IVariableV1> variables)


+ 9
- 10
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -124,17 +124,16 @@ namespace Tensorflow.Functions
// TODO(Rinne): Add arg `CancellationManager`. // TODO(Rinne): Add arg `CancellationManager`.
// TODO(Rinne): Check the arg length. // TODO(Rinne): Check the arg length.
var function_call_options = tf.Context.FunctionCallOptions; var function_call_options = tf.Context.FunctionCallOptions;
string config;
if (function_call_options.config_proto_serialized().Length == 0)
{
config = function_utils.get_disabled_rewriter_config().ToString();
}
else
{
config = function_call_options.config_proto_serialized().ToString();
}
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons.


config = ""; // TODO(Rinne): revise it.
//if (function_call_options.config_proto_serialized().Length == 0)
//{
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
//}
//else
//{
// config = function_call_options.config_proto_serialized().ToStringUtf8();
//}


string executor_type = function_call_options.ExecutorType ?? ""; string executor_type = function_call_options.ExecutorType ?? "";
var executing_eagerly = tf.Context.executing_eagerly(); var executing_eagerly = tf.Context.executing_eagerly();


+ 4
- 5
src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs View File

@@ -14,12 +14,11 @@ namespace Tensorflow.Functions


} }


public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
ForwardAndBackwardFunctions(Tensors inference_args)
{ {
var outputs = _func_graph.Outputs;
(_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs)
= BuildFunctionsForOutputs(outputs, inference_args);
return _forward_function;
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray();
return BuildFunctionsForOutputs(outputs, inference_args);
} }
} }
} }

+ 3
- 5
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -14,7 +14,6 @@ namespace Tensorflow
protected ConcreteFunction _concrete_variable_creation_fn; protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _autograph; protected bool _autograph;
protected TracingCompiler _variable_creation_fn; protected TracingCompiler _variable_creation_fn;
protected bool _has_initialized;
public string Name { get; set; } public string Name { get; set; }
public Function(Func<Tensor[], Tensor[]> csharp_function, public Function(Func<Tensor[], Tensor[]> csharp_function,
string name, bool auto_graph = true) string name, bool auto_graph = true)
@@ -22,7 +21,6 @@ namespace Tensorflow
_csharp_function = csharp_function; _csharp_function = csharp_function;
Name = name; Name = name;
_autograph = auto_graph; _autograph = auto_graph;
_has_initialized = false;
} }


public virtual Tensors Apply(Tensors inputs) public virtual Tensors Apply(Tensors inputs)
@@ -38,10 +36,11 @@ namespace Tensorflow


protected virtual Tensors _call(Tensors inputs) protected virtual Tensors _call(Tensors inputs)
{ {
if (!_has_initialized)
if(_variable_creation_fn is not null)
{ {
_initialize(inputs);
return _variable_creation_fn.Apply(inputs);
} }
_initialize(inputs);


return _concrete_variable_creation_fn.CallFlat(inputs, return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs); _concrete_variable_creation_fn.CapturedInputs);
@@ -63,7 +62,6 @@ namespace Tensorflow
_variable_creation_fn = _compiler(_csharp_function); _variable_creation_fn = _compiler(_csharp_function);
_variable_creation_fn._name = this.Name; _variable_creation_fn._name = this.Name;
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); _concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
_has_initialized = true;
} }
} }
} }

+ 105
- 52
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -24,23 +24,40 @@ namespace Tensorflow.Functions
protected string _INFERENCE_PREFIX = "__inference_"; protected string _INFERENCE_PREFIX = "__inference_";


protected FuncGraph _func_graph; protected FuncGraph _func_graph;
protected EagerDefinedFunction _forward_function;
protected EagerDefinedFunction _forward;
protected FuncGraph _forward_graph; protected FuncGraph _forward_graph;
protected List<int> _forwardprop_input_indices;
protected List<int> _forwardprop_output_indices; protected List<int> _forwardprop_output_indices;
protected int _num_forwardprop_outputs; protected int _num_forwardprop_outputs;
protected ConcreteFunction _backward_function;
protected int _num_inference_outputs;
protected int _num_outputs;
protected int _num_trainable_inference_outputs;
protected ConcreteFunction _backward;
BackwardFunction _backward_function_wrapper; BackwardFunction _backward_function_wrapper;


public TapeGradientFunctions(FuncGraph func_graph, public TapeGradientFunctions(FuncGraph func_graph,
bool need_gradients_for_jvps) bool need_gradients_for_jvps)
{ {
_func_graph = func_graph; _func_graph = func_graph;
_forward_graph = null;
_forward = null;
_backward = null;
_num_outputs = func_graph.Outputs.Length;
_forwardprop_output_indices = null;
_num_forwardprop_outputs = 0;
_num_inference_outputs = func_graph.Outputs.Length;
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count();
} }


public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null) public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null)
{ {
// TODO(Rinne): add input_tangents arg. // TODO(Rinne): add input_tangents arg.
return ForwardAndBackwardFunctions(inference_args);
if(_forward is null)
{
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
= ForwardAndBackwardFunctions(inference_args);
}
return _forward;
} }


/// <summary> /// <summary>
@@ -51,9 +68,13 @@ namespace Tensorflow.Functions
public virtual void Record(Tensors flat_outputs, Tensors inference_args) public virtual void Record(Tensors flat_outputs, Tensors inference_args)
{ {
// TODO(Rinne): add arg `input_tagents`. // TODO(Rinne): add arg `input_tagents`.
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs);
tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record,
getBackwardFunction: backward_function);
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
}
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function);
} }


/// <summary> /// <summary>
@@ -65,66 +86,95 @@ namespace Tensorflow.Functions
/// <returns></returns> /// <returns></returns>
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{ {
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs)
.ToDictionary(x => x.Item1, x => x.Item2);
var captured_inputs = backward.CapturedInputs;
var remapped_captures = captured_inputs.Select(c =>
{
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value))
{
return value;
}
else
{
return c;
}
}).ToArray();
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph))
{
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph);
throw new RuntimeError($"Failed to map all backward graph captures to " +
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}");
}

Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>();
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
var recorded_outputs = new Tensors(); var recorded_outputs = new Tensors();
var trainable_recorded_outputs = 0;
foreach (var (output_index, output) in enumerate(outputs))
int trainable_recorded_outputs = 0;
var skip_positions = new HashSet<int>();
var relevant_outputs = outputs;
foreach (var (output_index, output) in enumerate(relevant_outputs))
{ {
if (trainable_recorded_outputs < backward_function_inputs) if (trainable_recorded_outputs < backward_function_inputs)
recorded_outputs.Add(output); recorded_outputs.Add(output);
if (gradients_util.IsTrainable(output))
trainable_recorded_outputs += 1;
if (backprop_util.IsTrainable(output))
trainable_recorded_outputs++;
else
skip_positions.Add(output_index);
if (output.dtype == dtypes.variant)
variant_zeros_like[output_index] = default_gradient.zeros_like(output);
} }


if(_backward_function_wrapper == null)
_backward_function_wrapper = (args, unneeded_gradients) =>
{ {
var capture_mapping = new Dictionary<long, Tensor>();
foreach (var (i, output) in enumerate(outputs))
capture_mapping[forward_graph.Outputs[i].Id] = output;

var remapped_captures = new Tensors();
foreach (var capture in backward.CapturedInputs)
{
if (capture_mapping.ContainsKey(capture.Id))
remapped_captures.Add(capture_mapping[capture.Id]);
}

var skip_positions = new List<int>();
foreach (var (output_index, output) in enumerate(outputs))
if(backward.Outputs is null || backward.Outputs.Length == 0)
{ {
if (!gradients_util.IsTrainable(output))
skip_positions.Add(output_index);
return backward.FlatStructuredOutputs;
} }


_backward_function_wrapper = (args, unneeded_gradients) =>
var processed_args = new Tensors();
int input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
{ {
var processed_args = new Tensors();
var input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
if (skip_positions.Contains(output_index))
continue;
if (arg is null)
{
var input_placeholder = backward.Inputs[input_index];
Tensor variant_arg;
if (input_placeholder.dtype == dtypes.variant)
{
variant_arg = variant_zeros_like[output_index];
}
else
{
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder);

variant_arg = array_ops.zeros(shape, type);
}
processed_args.Add(variant_arg);
}
else
{ {
if (skip_positions.Contains(output_index))
continue;
if (arg == null)
throw new NotImplementedException("");
processed_args.Add(arg); processed_args.Add(arg);
input_index += 1;
if (input_index >= backward_function_inputs)
break;
} }
input_index++;
if (input_index >= backward_function_inputs)
break;
}


tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);


foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}
foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}


return gradients;
};
}
return gradients;
};


return (_backward_function_wrapper, recorded_outputs); return (_backward_function_wrapper, recorded_outputs);
} }
@@ -143,7 +193,7 @@ namespace Tensorflow.Functions
} }
} }


var backwards_graph = new FuncGraph(_func_graph.Name);
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name));
backwards_graph.as_default(); backwards_graph.as_default();
var gradients_wrt_outputs = new List<Tensor>(); var gradients_wrt_outputs = new List<Tensor>();
foreach (var output in trainable_outputs) foreach (var output in trainable_outputs)
@@ -153,6 +203,7 @@ namespace Tensorflow.Functions
gradients_wrt_outputs.Add(gradient_placeholder); gradients_wrt_outputs.Add(gradient_placeholder);
handle_data_util.copy_handle_data(output, gradient_placeholder); handle_data_util.copy_handle_data(output, gradient_placeholder);
} }
// TODO(Rinne): with ops.device(None)
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
_func_graph.Inputs, _func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(), grad_ys: gradients_wrt_outputs.ToArray(),
@@ -175,7 +226,8 @@ namespace Tensorflow.Functions
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray(); backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray();
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null)); backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null));


var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
var (wrapped_forward_function, wrapped_backward_function) =
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}"; //var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
//var backward_function_attr = new Dictionary<string, string>(); //var backward_function_attr = new Dictionary<string, string>();
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; //backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
@@ -189,10 +241,11 @@ namespace Tensorflow.Functions
// _func_graph.Inputs, _func_graph.Outputs, // _func_graph.Inputs, _func_graph.Outputs,
// monomorphic_function_utils._parse_func_attrs(forward_function_attr)); // monomorphic_function_utils._parse_func_attrs(forward_function_attr));
return (forward_function, _func_graph, backward_function, null, 0);
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0);
} }


public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
ForwardAndBackwardFunctions(Tensors inference_args)
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


+ 6
- 6
src/TensorFlowNET.Core/Functions/TracingCompiler.cs View File

@@ -73,12 +73,12 @@ namespace Tensorflow.Functions


private static string male_cache_key(Tensor[] inputs) private static string male_cache_key(Tensor[] inputs)
{ {
string res = "";
foreach (var input in inputs)
{
res += $"{input.name}_{input.Id}";
}
return res;
//string res = "";
//foreach (var input in inputs)
//{
// res += $"{input.name}_{input.Id}";
//}
return inputs.Length.ToString();
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Functions/monomorphic_function.cs View File

@@ -153,7 +153,7 @@ namespace Tensorflow.Functions
foreach(var tape in tf.GetTapeSet()) foreach(var tape in tf.GetTapeSet())
{ {
tape.RecordOperation(_inference_function.Signature.Name, to_record, tape.RecordOperation(_inference_function.Signature.Name, to_record,
inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function);
inference_args, backward_function);
} }
} }




+ 2
- 2
src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients
/// Map from tensor to how many references still exist for this tensor in /// Map from tensor to how many references still exist for this tensor in
/// the tape. /// the tape.
/// </summary> /// </summary>
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; }
public UnorderedMap<long, long> tensor_usage_counts { get; set; }
/// <summary> /// <summary>
/// Maps from op ID to how many output tensors of this op still need to have /// Maps from op ID to how many output tensors of this op still need to have
/// their gradients computed. /// their gradients computed.
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients
public BackpropInitialState() public BackpropInitialState()
{ {
op_tape = new OpTape(); op_tape = new OpTape();
tensor_usage_counts = new UnorderedMap<Tensor, long>();
tensor_usage_counts = new UnorderedMap<long, long>();
op_missing_tensor = new UnorderedMap<long, long>(); op_missing_tensor = new UnorderedMap<long, long>();
} }
} }


+ 27
- 8
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients
/// <param name="target"></param> /// <param name="target"></param>
/// <param name="source"></param> /// <param name="source"></param>
/// <returns></returns> /// <returns></returns>
public Tensor gradient(Tensor target, Tensor source)
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{ {
if(_tape is null)
{
throw new RuntimeError("A non-persistent GradientTape can only be used to " +
"compute one set of gradients (or jacobians).");
}
ITape tape = stop_recording(); ITape tape = stop_recording();


var results = tf.Runner.TFE_TapeGradient(tape, var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target }, new[] { target },
new[] { source }, new[] { source },
null);
output_gradients,
new[] { source },
unconnected_gradients);


return results[0]; return results[0];
} }


public Tensor gradient(Tensor target, ResourceVariable source)
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{ {
var results = gradient(target, new List<IVariableV1> { source });
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients);


return results[0]; return results[0];
} }


public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{ {
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 });
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients);


return (results[0], results[1]); return (results[0], results[1]);
} }


public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources)
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{ {
if (_tape is null)
{
throw new RuntimeError("A non-persistent GradientTape can only be used to " +
"compute one set of gradients (or jacobians).");
}
var tape = stop_recording(); var tape = stop_recording();


var results = tf.Runner.TFE_TapeGradient(tape, var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target }, new[] { target },
sources.Select(x => x.Handle).ToArray(), sources.Select(x => x.Handle).ToArray(),
null);
output_gradients,
sources.Select(x => x.Handle).ToArray(),
unconnected_gradients);


if (!tape.Persistent) if (!tape.Persistent)
{ {


+ 15
- 8
src/TensorFlowNET.Core/Gradients/ITape.cs View File

@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients
public interface ITape public interface ITape
{ {
void SetTapeId(int id); void SetTapeId(int id);
bool ShouldRecord(Tensor[] tensors);
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes);
void StartRecord(); void StartRecord();
void StopRecord(); void StopRecord();
bool Persistent { get; } bool Persistent { get; }
void RecordOperation(string op_type, void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
BackwardFunction backward_function); BackwardFunction backward_function);


void VariableAccessed(ResourceVariable variable);
void RecordOperation(string op_type,
Tensor[] outputs,
Tensor[] inputs,
BackwardFunction backward_function);

void VariableAccessed(IVariableV1 variable);


void Watch(Tensor x); void Watch(Tensor x);


ResourceVariable[] WatchedVariables();
IVariableV1[] WatchedVariables();


Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients);
Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
bool build_default_zeros_grads);
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs View File

@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients
{ {
public string op_type { get; set; } public string op_type { get; set; }
public TapeTensor[] output_tensor_info { get; set; } public TapeTensor[] output_tensor_info { get; set; }
public Tensor[] input_tensor_id { get; set; }
public long[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; } public BackwardFunction backward_function { get; set; }
public override string ToString() public override string ToString()
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}";
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
} }
} }

+ 157
- 125
src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs View File

@@ -2,235 +2,246 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding;


namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
public partial class Tape public partial class Tape
{ {
// int kMinAggregateCount = 4;
// int kMinAggregateBytes = 128 * 1024 * 1024;
static readonly int kMinAggregateCount = 4;
static readonly int kMinAggregateBytes = 128 * 1024 * 1024;
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap;


public Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients)
static Tape()
{ {
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids);
// var gradients_size = new UnorderedMap<Tensor, long>();
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap();
var state = PrepareBackprop(
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
output_gradients,
tensor_tape_,
state.op_tape);
_functionsAcceptingNoneForIndicesMap = new();
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
}


while (!op_stack.empty())
public Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
bool build_default_zeros_grads)
{
UnorderedSet<long> sources_set = new(source_tensor_ids);
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape);
UnorderedMap<long, long> gradients_size = new();
while(op_stack.Count > 0)
{ {
var op = op_stack.Dequeue();
if (!state.op_tape.find(op, out var trace))
long op = op_stack.Dequeue();
if(!state.op_tape.TryGetValue(op, out var op_it))
{
continue; continue;

// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}");
}
var trace = op_it;
state.op_tape.erase(op); state.op_tape.erase(op);

var out_gradients = new List<Tensor>(trace.output_tensor_info.Length);
var unneeded_gradients = new List<long>();
for (int i = 0; i < trace.input_tensor_id.Length; i++)
List<Tensor> out_gradients = new();
List<long> unneeded_gradients = new();
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++)
{ {
var in_tensor_id = trace.input_tensor_id[i];
if (!tensor_tape_.find(in_tensor_id) &&
!sources_set.find(in_tensor_id))
long in_tensor_id = trace.input_tensor_id[i];
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id))
{
unneeded_gradients.Add(i); unneeded_gradients.Add(i);
}
} }


bool any_gradient_nonzero = false; bool any_gradient_nonzero = false;
var zero_indices = new List<int>();
for (int i = 0; i < trace.output_tensor_info.Length; ++i)
List<int> zero_indices = new();
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++)
{ {
var id = trace.output_tensor_info[i].GetTensor();
if (!gradients.find(id, out var grad_it))
long id = trace.output_tensor_info[i].GetID();
if(!gradients.TryGetValue(id, out var grad_it))
{ {
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) &&
func_name_it.find(i))
out_gradients.Add(null);
if (build_default_zeros_grads)
{ {
out_gradients.Add(null);
}
else
{
out_gradients.Add(null);
zero_indices.Add(i);
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) ||
!func_name_it.find(i))
{
zero_indices.Add(i);
}
} }
} }
else else
{ {
any_gradient_nonzero = true; any_gradient_nonzero = true;
var new_gradients = grad_it.Count == 1 ?
grad_it[0] :
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients

Tensor new_gradients;
if (grad_it.Count == 1)
{
new_gradients = grad_it[0];
}
else
{
new_gradients = AggregateGradients(grad_it);
}
if (!sources_set.find(id)) if (!sources_set.find(id))
{
gradients.Remove(id); gradients.Remove(id);
}
else else
{ {
// grad_it.Clear();
// grad_it.Add(new_gradients);
// vspace.MarkAsResult(new_gradients);
grad_it.Clear();
grad_it.Add(new_gradients);
// MarkAsResult
} }
out_gradients.Add(new_gradients); out_gradients.Add(new_gradients);
} }
} }


Tensor[] in_gradients;
Tensor[] in_gradients = new Tensor[0];
if (any_gradient_nonzero) if (any_gradient_nonzero)
{ {
// foreach (var i in zero_indices)
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray());

if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length)
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
if (!_persistent)
foreach(var i in zero_indices)
{ {
// trace.backward_function_deleter(trace.backward_function);
trace.backward_function = null;
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
} }
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients);
} }
else else
{ {
in_gradients = new Tensor[trace.input_tensor_id.Length];
out_gradients.Clear();
} }

bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length;
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k)
for(int i = 0, end = in_gradients.Length; i < end; i++)
{ {
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k;
var id = trace.input_tensor_id[k];
if (in_gradients[i] != null)
long id = trace.input_tensor_id[i];
if (in_gradients[i] is not null)
{ {
var unaggregated_grads = gradients[id];
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>());
unaggregated_grads.Add(in_gradients[i]); unaggregated_grads.Add(in_gradients[i]);
/*if (unaggregated_grads.Count > kMinAggregateCount)
if(unaggregated_grads.Count > kMinAggregateCount)
{ {
if (!gradients_size.find(id, out var size))
if(!gradients_size.TryGetValue(id, out var size))
{ {
size = (long)unaggregated_grads[0].size;
size = NumElements(unaggregated_grads[0]);
gradients_size.emplace(id, size); gradients_size.emplace(id, size);
} }

if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
{ {
throw new NotImplementedException("");
Tensor grad = AggregateGradients(unaggregated_grads);
unaggregated_grads.Clear();
unaggregated_grads.Add(grad);
} }
}*/
}
} }
if (!state.tensor_usage_counts.find(id))
if(!state.tensor_usage_counts.find(id))
{
continue; continue;
}
state.tensor_usage_counts[id]--; state.tensor_usage_counts[id]--;
if (state.tensor_usage_counts[id] > 0)
if(state.tensor_usage_counts[id] > 0)
{
continue; continue;
if (!tensor_tape_.find(id, out var tape_it))
}
if (!tensor_tape_.TryGetValue(id, out var tape_it))
{ {
if (gradients.find(id, out var grad_it))
if (gradients.find(id))
{ {
// foreach (var g in grad_it)
// DeleteGradient(g);
gradients.erase(id); gradients.erase(id);
} }
continue; continue;
} }
var op_id = tape_it;
if (op_id == -1)
long op_id = tape_it;
if(op_id == -1)
{
continue; continue;
if (state.op_missing_tensor.find(op_id, out var missing_it))
}
if(state.op_missing_tensor.find(op_id))
{ {
state.op_missing_tensor[op_id]--; state.op_missing_tensor[op_id]--;
if (state.op_missing_tensor[op_id] == 0)
if(state.op_missing_tensor[op_id] == 0)
{
op_stack.Enqueue(op_id); op_stack.Enqueue(op_id);
}
} }
} }
} }


if (state.op_tape.Count > 0)
if(state.op_tape.Count > 0)
{
throw new RuntimeError("Invalid tape state."); throw new RuntimeError("Invalid tape state.");

var result = new Tensor[source_tensor_ids.Length];
var j = 0;
foreach (var id in source_tensor_ids)
}
Tensor[] result = new Tensor[source_tensor_ids.Length];
for(int i = 0; i < source_tensor_ids.Length; i++)
{ {
if (gradients.find(id, out var grad_it))
long tensor_id = source_tensor_ids[i];
if(!gradients.TryGetValue(tensor_id, out var grad_it))
{ {
if (grad_it.Count > 1)
result[j] = gen_math_ops.add_n(grad_it.ToArray());
else
result[j] = grad_it[0];
result[i] = null;
}
else
{
if(grad_it.Count > 1)
{
Tensor grad = AggregateGradients(grad_it);
grad_it.Clear();
grad_it.Add(grad);
}
result[i] = grad_it[0];
} }
j++;
} }

return result; return result;
} }


UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap() UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap()
{ {
var m = new UnorderedMap<string, UnorderedSet<int>>();
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
return m;
return _functionsAcceptingNoneForIndicesMap;
} }


UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients,
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
TensorTape tensor_tape, TensorTape tensor_tape,
OpTape op_tape) OpTape op_tape)
{ {
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>();
for (int i = 0; i < target_tensor_ids.Length; ++i)
var result = new UnorderedMap<long, List<Tensor>>();
for(int i = 0, end = target_tensor_ids.Length; i < end; i++)
{ {
var id = target_tensor_ids[i];
if (output_gradients.Length == 0 || output_gradients[i] == null)
long id = target_tensor_ids[i];
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null)
{ {
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null)
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1)
{ {
if (!op_tape.find(tensor_tape[id], out var op_it))
if(!op_tape.TryGetValue(tensor_it, out var op_it))
{
throw new RuntimeError("Internal state of the gradient tape is invalid: " + throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"failed to find operation producing a tensor");
"failed to find operation producing a tensor.");
}
bool found = false; bool found = false;
for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
for(int j = 0; j < op_it.output_tensor_info.Length; j++)
{ {
if (op_it.output_tensor_info[j].GetTensor() == id)
if (op_it.output_tensor_info[j].GetID() == id)
{ {
found = true; found = true;
var ones = op_it.output_tensor_info[j].OnesLike();
result[id].Add(ones);
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]);
result.SetDefault(id, new List<Tensor>()).Add(ones_like);
break; break;
} }
} }

if (!found) if (!found)
{ {
throw new ValueError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor");
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor.");
} }
} }
else else
{ {
if (sources_that_are_targets.find(id, out var source_tensor))
result[id].Add(source_tensor.OnesLike());
if(sources_that_are_targets.TryGetValue(id, out var source_tensor))
{
Tensor ones_like = BuildOnesLike(source_tensor);
result.SetDefault(id, new List<Tensor>()).Add(ones_like);
}
} }
} }
else else
{ {
result[id].Add(output_gradients[i]);
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]);
} }
} }


@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients
} }
return result; return result;
} }

Tensor BuildOnesLike(TapeTensor t)
{
return t.OnesLike();
}

Tensor AggregateGradients(List<Tensor> gradient_tensors)
{
if(gradient_tensors.Count == 0)
{
return gradient_tensors[0];
}
return tf.add_n(gradient_tensors.ToArray());
}

void DeleteGradient(Tensor gradient)
{
// Do not do anything here. Because GC will collect it when it has no reference.
}

long NumElements(Tensor tensor) => 1;
} }
} }

+ 31
- 32
src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs View File

@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients
{ {
public partial class Tape public partial class Tape
{ {
public BackpropInitialState PrepareBackprop(Tensor[] target,
public BackpropInitialState PrepareBackprop(long[] target,
TensorTape tensor_tape, TensorTape tensor_tape,
OpTape op_tape, OpTape op_tape,
UnorderedSet<Tensor> sources_set,
UnorderedSet<long> sources_set,
bool persistent_tape) bool persistent_tape)
{ {
Stack<long> tensor_stack = new Stack<long>();
foreach(var t in target)
{
tensor_stack.Push(t);
}
BackpropInitialState result = new BackpropInitialState(); BackpropInitialState result = new BackpropInitialState();
var tensor_stack = new Queue<Tensor>(target);
while (tensor_stack.Count > 0)
while(tensor_stack.Count > 0)
{ {
var tensor_id = tensor_stack.Dequeue();
if (!tensor_tape.find(tensor_id, out var op_id))
long tensor_id = tensor_stack.Pop();
if(!tensor_tape.TryGetValue(tensor_id, out var op_id))
{
continue; continue;
if (op_id == -1 ||
!op_tape.find(op_id, out var op_it) ||
result.op_tape.find(op_id, out var result_op_it))
}
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it)
|| result.op_tape.find(op_id))
{
continue; continue;
}
result.op_tape.emplace(op_id, op_it); result.op_tape.emplace(op_id, op_it);

foreach (var it in op_it.input_tensor_id)
foreach(var it in op_it.input_tensor_id)
{ {
if (result.tensor_usage_counts.find(it))
if(result.tensor_usage_counts.find(it))
{
result.tensor_usage_counts[it]++; result.tensor_usage_counts[it]++;
}
else else
{ {
result.tensor_usage_counts[it] = 1; result.tensor_usage_counts[it] = 1;
if (tensor_tape.find(it)) if (tensor_tape.find(it))
tensor_stack.Enqueue(it);
{
tensor_stack.Push(it);
}
} }
} }

if (!persistent_tape) if (!persistent_tape)
op_tape.Remove(op_id);
{
op_tape.erase(op_id);
}
} }

foreach (var pair in result.tensor_usage_counts)
foreach(var pair in result.tensor_usage_counts)
{ {
if (tensor_tape.find(pair.Key, out var it) && it != -1)
result.op_missing_tensor[it] += 1;
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1)
{
result.op_missing_tensor[it]++;
}
} }

if (!persistent_tape) if (!persistent_tape)
{ {
// Call destructors for all unneeded gradient functions and
// clear the op_tape. We can clear the tape because ownership of
// backward functions that will be used for gradient computation
// has been transferred to `result`.
/*for (const auto&op_pair : *op_tape) {
op_pair.second.backward_function_deleter(
op_pair.second.backward_function);
}*/
op_tape.Clear(); op_tape.Clear();
} }

return result; return result;
} }
} }


+ 21
- 10
src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs View File

@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients
public partial class Tape public partial class Tape
{ {
long next_op_id_ = 0; long next_op_id_ = 0;
UnorderedMap<Tensor, long> tensor_usage_;
UnorderedMap<long, long> tensor_usage_;


public void RecordOperation(string op_type, public void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
BackwardFunction backward_function) BackwardFunction backward_function)
{ {
if (!ShouldRecord(input_tensors))
if (!ShouldRecord(input_tensor_id, input_dtypes))
return; return;


var op_id = next_op_id_++;
foreach (var i in input_tensors)
foreach (var i in input_tensor_id)
{
tensor_usage_[i]++; tensor_usage_[i]++;

}
long op_id = next_op_id_++;
foreach (var o in output_tensors) foreach (var o in output_tensors)
{ {
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
tensor_tape_[o.GetTensor()] = op_id;
tensor_usage_[o.GetTensor()] = 1;
tensor_tape_[o.GetID()] = op_id;
tensor_usage_[o.GetID()] = 1;
} }


op_tape_[op_id] = new OpTapeEntry op_tape_[op_id] = new OpTapeEntry
{ {
op_type = op_type, op_type = op_type,
output_tensor_info = output_tensors,
input_tensor_id = input_tensors,
output_tensor_info = output_tensors.ToArray(),
input_tensor_id = input_tensor_id.ToArray(),
backward_function = backward_function backward_function = backward_function
}; };
} }

public void RecordOperation(string op_type,
Tensor[] outputs,
Tensor[] inputs,
BackwardFunction backward_function)
{
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function);
}
} }
} }

+ 10
- 10
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -1,5 +1,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients
_created_eagerly = tf.Context.executing_eagerly(); _created_eagerly = tf.Context.executing_eagerly();
tensor_tape_ = new TensorTape(); tensor_tape_ = new TensorTape();
op_tape_ = new OpTape(); op_tape_ = new OpTape();
tensor_usage_ = new UnorderedMap<Tensor, long>();
tensor_usage_ = new UnorderedMap<long, long>();
if(_created_eagerly) if(_created_eagerly)
tf.Context.start_step(); tf.Context.start_step();
// nesting_id = ++tape_nesting_id_counter; // nesting_id = ++tape_nesting_id_counter;
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients
public void Watch(Tensor x) public void Watch(Tensor x)
{ {
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}");
tensor_tape_.emplace(x, -1);
tensor_tape_.emplace(x.Id, -1);
} }


public bool ShouldRecord(Tensor[] tensors)
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes)
{ {
var dtypes = tensors.Select(x => x.dtype).ToArray();
for (int i = 0; i < tensors.Length; ++i)
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length);
for (int i = 0; i < tensor_ids.Length; ++i)
{ {
if (tensor_tape_.find(tensors[i]))
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i]))
{ {
if (IsDtypeTrainable(dtypes[i]))
return true;
return true;
} }
} }
return false; return false;
} }


public void VariableAccessed(ResourceVariable variable)
public void VariableAccessed(IVariableV1 variable)
{ {
Watch(variable.Handle); Watch(variable.Handle);
} }


public ResourceVariable[] WatchedVariables()
public IVariableV1[] WatchedVariables()
{ {
return null; return null;
} }


+ 45
- 9
src/TensorFlowNET.Core/Gradients/TapeTensor.cs View File

@@ -1,27 +1,63 @@
using static Tensorflow.Binding;
using OneOf;
using static Tensorflow.Binding;


namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
public class TapeTensor public class TapeTensor
{ {
Tensor tensor;
long id => tensor.Id;
TF_DataType dtype => tensor.dtype;
Shape shape => tensor.shape;
internal Tensor tensor;
internal long id;
internal TF_DataType dtype;
internal OneOf<Shape, Tensor> shape;

public TapeTensor(long id, TF_DataType dtype, Shape shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}

public TapeTensor(long id, TF_DataType dtype, Tensor shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}


public TapeTensor(Tensor tensor) public TapeTensor(Tensor tensor)
{ {
this.id = tensor.Id;
this.dtype = tensor.dtype;
this.shape = tensor.shape;
this.tensor = tensor; this.tensor = tensor;
} }


public long GetID() => tensor.Id;
public Tensor GetTensor() => tensor;
public long GetID() => id;


public Tensor ZerosLike() public Tensor ZerosLike()
=> tf.zeros(shape: shape, dtype: dtype);
{
if(dtype == dtypes.resource)
{
return null;
}
if(shape.Index == 1)
{
return tf.zeros_like(shape.AsT1);
}
return tf.zeros(shape.AsT0, dtype);
}


public Tensor OnesLike() public Tensor OnesLike()
=> tf.ones(shape: shape, dtype: dtype);
{
if (shape.Index == 1)
{
return tf.ones_like(shape.AsT1);
}
return tf.ones(shape.AsT0, dtype);
}

//public Tensor OnesLike()
// => tf.ones(shape: shape, dtype: dtype);


public override string ToString() public override string ToString()
=> $"{id}, {shape}, {dtype.as_numpy_name()}"; => $"{id}, {shape}, {dtype.as_numpy_name()}";


+ 1
- 1
src/TensorFlowNET.Core/Gradients/TensorTape.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients
/// produced this tensor. A value of -1 means that the tensor was directly /// produced this tensor. A value of -1 means that the tensor was directly
/// watched and not the result of any operation in the tape. /// watched and not the result of any operation in the tape.
/// </summary> /// </summary>
public class TensorTape : UnorderedMap<Tensor, long>
public class TensorTape : UnorderedMap<long, long>
{ {


} }


+ 1
- 26
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -704,32 +704,7 @@ namespace Tensorflow


public static int PossibleTapeGradientTypes(Tensor[] tensors) public static int PossibleTapeGradientTypes(Tensor[] tensors)
{ {
var tape_set = tf.GetTapeSet();
bool some_tape_watching = false;
if(tape_set is not null && tape_set.Count > 0)
{
foreach(var tape in tape_set)
{
if (tape.ShouldRecord(tensors))
{
if(tape.Persistent || some_tape_watching)
{
return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
}
some_tape_watching = true;
}
}
}
// skip the forward_accumulators.

if (some_tape_watching)
{
return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
}
else
{
return POSSIBLE_GRADIENT_TYPES_NONE;
}
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors);
} }


/// <summary> /// <summary>


+ 10
- 0
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable
return tensor; return tensor;
} }


public void watch_variable(IVariableV1 v)
{
if (_resource_tensor_inputs.Contains(v.Handle))
{
return;
}
_watched_variables.Add(new WeakReference<IVariableV1>(v));
//this = this.outer_graph;
}

Tensor capture_eager_tensor(Tensor tensor, string name) Tensor capture_eager_tensor(Tensor tensor, string name)
{ {
Tensor graph_const = null; Tensor graph_const = null;


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs View File

@@ -4,10 +4,10 @@ public interface IOptimizer
{ {
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars); Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars);
Tensor[] clip_gradients(Tensor[] grads); Tensor[] clip_gradients(Tensor[] grads);
void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null, string name = null,
bool experimental_aggregate_gradients = true); bool experimental_aggregate_gradients = true);
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null, string name = null,
bool experimental_aggregate_gradients = true); bool experimental_aggregate_gradients = true);
} }

+ 4
- 4
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -208,9 +208,9 @@ namespace Tensorflow


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);
//[DllImport(TensorFlowLibName)]
//public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
//[DllImport(TensorFlowLibName)]
//public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
} }
} }

+ 3
- 1
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow


if (config is null) if (config is null)
{ {
config = function_utils.get_disabled_rewriter_config().ToString();
config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
} }


if (executor_type is null) if (executor_type is null)
@@ -49,6 +49,8 @@ namespace Tensorflow


if (executing_eagerly) if (executing_eagerly)
{ {
// TODO(Rinne): implement it.
throw new NotImplementedException(); throw new NotImplementedException();
} }




+ 46
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Linq; using System.Linq;
using Tensorflow.Contexts; using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
@@ -210,7 +211,51 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional).</param> /// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `value`.</returns> /// <returns>A `Tensor`. Has the same type as `value`.</returns>
public static Tensor fill<T>(Tensor dims, T value, string name = null) public static Tensor fill<T>(Tensor dims, T value, string name = null)
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value));
return _result[0];
}
catch (Exception)
{
}
try
{
return fill_eager_fallback(dims, value as Tensor, name, ctx);
}
catch (Exception)
{

}
}
Dictionary<string, object> attrs = new Dictionary<string, object>();
attrs["dims"] = dims;
attrs["value"] = value;
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result.output;
}

public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx)
{
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() };
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name);

if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return _result[0];
}
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));


/// <summary> /// <summary>
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast. /// Return the reduction indices for computing gradients of s0 op s1 with broadcast.


+ 4
- 2
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -49,8 +49,10 @@ namespace Tensorflow.Operations
target_t.HandleData = handle_data; target_t.HandleData = handle_data;
return; return;
} }
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.)
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
Status status = new();
var proto = handle_data.ToByteArray();
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status);
status.Check(true);
} }


public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op); public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op);


+ 14
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -25,6 +25,7 @@ using static Tensorflow.Binding;
using Tensorflow.Operations; using Tensorflow.Operations;
using System.Buffers; using System.Buffers;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Graphs;


namespace Tensorflow namespace Tensorflow
{ {
@@ -302,5 +303,18 @@ namespace Tensorflow
// return handle_data_util.get_resource_handle_data(handle); // return handle_data_util.get_resource_handle_data(handle);
//} //}
} }

public static void variable_accessed(IVariableV1 variable)
{
if (ops.get_default_graph() is FuncGraph func_graph)
{
func_graph.watch_variable(variable);
}
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())
tape.VariableAccessed(variable);
}
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> <PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="OneOf" Version="3.0.223" /> <PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.6.2" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup> </ItemGroup>




+ 5
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
{ {
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle); public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);


public Tensor()
protected Tensor()
{ {
} }


@@ -108,6 +108,7 @@ namespace Tensorflow
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{ {
_handle = TF_NewTensor(shape, dtype, null); _handle = TF_NewTensor(shape, dtype, null);
_id = ops.uid();
} }


protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
@@ -116,6 +117,7 @@ namespace Tensorflow
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar);
else else
_handle = TF_NewTensor(bytes, shape, dtype); _handle = TF_NewTensor(bytes, shape, dtype);
_id = ops.uid();
} }


protected unsafe void InitTensor(Array array, Shape? shape = null) protected unsafe void InitTensor(Array array, Shape? shape = null)
@@ -166,6 +168,8 @@ namespace Tensorflow
string[] val => StringTensor(val, shape), string[] val => StringTensor(val, shape),
_ => throw new NotImplementedException("") _ => throw new NotImplementedException("")
}; };

_id = ops.uid();
} }


unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged


+ 1
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel
{ {
IEnumerable<ConcreteFunction> _concrete_functions; IEnumerable<ConcreteFunction> _concrete_functions;
FunctionSpec _function_spec; FunctionSpec _function_spec;
public IEnumerable<ConcreteFunction> ConcreteFunctions => _concrete_functions;
public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec, public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec,
IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false) IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false)
{ {


+ 13
- 0
src/TensorFlowNET.Core/Util/UnorderedMap.cs View File

@@ -25,6 +25,19 @@ namespace Tensorflow.Util
} }
} }


public Tv SetDefault(Tk key, Tv default_value)
{
if(TryGetValue(key, out var res))
{
return res;
}
else
{
base[key] = default_value;
return base[key];
}
}

public void push_back(Tk key, Tv value) public void push_back(Tk key, Tv value)
=> this[key] = value; => this[key] = value;




+ 5
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -9,6 +9,7 @@ using System.Diagnostics;
using Tensorflow.Checkpoint; using Tensorflow.Checkpoint;
using Tensorflow.Training.Saving.SavedModel; using Tensorflow.Training.Saving.SavedModel;
using OneOf; using OneOf;
using Tensorflow.Graphs;


namespace Tensorflow namespace Tensorflow
{ {
@@ -193,6 +194,10 @@ namespace Tensorflow
/// </summary> /// </summary>
void variable_accessed(BaseResourceVariable variable) void variable_accessed(BaseResourceVariable variable)
{ {
if(ops.get_default_graph() is FuncGraph func_graph)
{
func_graph.watch_variable(variable as IVariableV1);
}
if (variable.Trainable) if (variable.Trainable)
{ {
foreach (var tape in tf.GetTapeSet()) foreach (var tape in tf.GetTapeSet())


+ 2
- 6
src/TensorFlowNET.Core/ops.cs View File

@@ -575,12 +575,8 @@ namespace Tensorflow


public static HandleData get_resource_handle_data(Tensor graph_op) public static HandleData get_resource_handle_data(Tensor graph_op)
{ {
throw new NotImplementedException();
// This implementation hasn't been checked for some reasons.
// If it throws an exception in the future, please check it.

//var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
//return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
} }


public static void dismantle_graph(Graph graph) public static void dismantle_graph(Graph graph)


+ 21
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Training; using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine
/// the layer's weights. /// the layer's weights.
/// </summary> /// </summary>
protected bool built; protected bool built;
public bool Built => built;
public bool Built
{
get
{
return built;
}
internal set
{
built = value;
}
}
public bool Trainable => args.Trainable; public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType; public TF_DataType DType => args.DType;
public bool AutoCast => args.Autocast; public bool AutoCast => args.Autocast;
@@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine
} }
protected List<ILayer> _self_tracked_trackables; protected List<ILayer> _self_tracked_trackables;


/// <summary>
/// If this value is set, the behavior of layer call will be changed to directly calling this function.
/// </summary>
public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null;

public Layer(LayerArgs args) public Layer(LayerArgs args)
{ {
Initialize(args); Initialize(args);
@@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine
/// <returns></returns> /// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {
if(ReplacedCall is not null)
{
return ReplacedCall(inputs);
}
return inputs; return inputs;
} }




+ 1
- 5
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine
{ {
(x, y) = data_handler.DataAdapter.Expand1d(x, y); (x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape(); using var tape = tf.GradientTape();
//foreach (var variable in TrainableVariables)
//{
// tape.watch(variable.Handle);
//}
var y_pred = Apply(x, training: true); var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred); var loss = compiled_loss.Call(y, y_pred);


@@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine
gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables)); gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables));
gradients = optimizer.clip_gradients(gradients); gradients = optimizer.clip_gradients(gradients);


optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
optimizer.apply_gradients(zip(gradients, trainable_variables),
experimental_aggregate_gradients: false); experimental_aggregate_gradients: false);
} }
} }


+ 4
- 4
src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers
_set_hyper("decay", args.InitialDecay); _set_hyper("decay", args.InitialDecay);
} }


public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
public void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null, string name = null,
bool experimental_aggregate_gradients = true) bool experimental_aggregate_gradients = true)
=> apply_gradients(new[] { grads_and_vars }, => apply_gradients(new[] { grads_and_vars },
@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers
/// <param name="grads_and_vars"></param> /// <param name="grads_and_vars"></param>
/// <param name="name"></param> /// <param name="name"></param>
/// <param name="experimental_aggregate_gradients"></param> /// <param name="experimental_aggregate_gradients"></param>
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null, string name = null,
bool experimental_aggregate_gradients = true) bool experimental_aggregate_gradients = true)
{ {
@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers
}); });
} }


void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{ {
_resource_apply_dense(var, grad, apply_state); _resource_apply_dense(var, grad, apply_state);
// if var.constraint is not None: // if var.constraint is not None:
@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers
throw new NotImplementedException("_resource_apply_dense"); throw new NotImplementedException("_resource_apply_dense");
} }


void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name, string name,
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state) Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
{ {


+ 25
- 0
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving
/// <param name="layers"></param> /// <param name="layers"></param>
private void _finalize_saved_model_layers(List<Layer> layers) private void _finalize_saved_model_layers(List<Layer> layers)
{ {
foreach(var layer in layers)
{
layer.Built = true;
var keras_attr = _get_keras_attr(layer);
if(keras_attr is not Trackable trackable)
{
continue;
}
if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call))
{
Debug.Assert(layer_call is RestoredFunction);
var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions;
if (concrete_functions is not null && concrete_functions.Count() > 0)
{
layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply);
}
}
}

foreach(var layer in layers) foreach(var layer in layers)
{ {
// TODO(Rinne): deal with `RevivedNetwork`. // TODO(Rinne): deal with `RevivedNetwork`.
@@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving
} }
} }


private Func<Tensors, Tensors> use_wrapped_call(Layer layer, Func<Tensors, Tensors> call)
{
// TODO(Rinne): revise it.
return call;
}

private void _restore_layer_unconditional_losses(Layer layer) private void _restore_layer_unconditional_losses(Layer layer)
{ {
// TODO(Rinne): implement it. // TODO(Rinne): implement it.


+ 11
- 11
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs View File

@@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel
return _config; return _config;
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
{
return base.Call(inputs, state, training);
}
else
{
return (func as Function).Apply(inputs);
}
}
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
//{
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
// {
// return base.Call(inputs, state, training);
// }
// else
// {
// return (func as Function).Apply(inputs);
// }
//}
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
//base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }), //base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }),
// functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }) // functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}), base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}),
functions.Concat(new string[] { }))
functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }))
{ {


} }


+ 11
- 15
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -64,23 +64,19 @@ public class SequentialModelLoad
var model = tf.keras.models.load_model(@"Assets/python_func_model"); var model = tf.keras.models.load_model(@"Assets/python_func_model");
model.summary(); model.summary();


var x = tf.random.uniform((8, 784), -1, 1);
var y = model.Apply(x);
Console.WriteLine(y);
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });


//model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

//var data_loader = new MnistModelLoader();
//var num_epochs = 1;
//var batch_size = 8;
var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 8;


//var dataset = data_loader.LoadAsync(new ModelLoadSetting
//{
// TrainDir = "mnist",
// OneHot = false,
// ValidationSize = 58000,
//}).Result;
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 55000,
}).Result;


//model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
} }
} }

Loading…
Cancel
Save