Browse Source

Fix the error of loaded function model backward.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
9420ba3243
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
45 changed files with 870 additions and 409 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  2. +25
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
  3. +12
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  4. +162
    -17
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  5. +4
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs
  6. +13
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
  7. +8
    -1
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  8. +21
    -4
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  9. +9
    -10
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  10. +4
    -5
      src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
  11. +3
    -5
      src/TensorFlowNET.Core/Functions/Function.cs
  12. +105
    -52
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  13. +6
    -6
      src/TensorFlowNET.Core/Functions/TracingCompiler.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Functions/monomorphic_function.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
  16. +27
    -8
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  17. +15
    -8
      src/TensorFlowNET.Core/Gradients/ITape.cs
  18. +2
    -2
      src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
  19. +157
    -125
      src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
  20. +31
    -32
      src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
  21. +21
    -10
      src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
  22. +10
    -10
      src/TensorFlowNET.Core/Gradients/Tape.cs
  23. +45
    -9
      src/TensorFlowNET.Core/Gradients/TapeTensor.cs
  24. +1
    -1
      src/TensorFlowNET.Core/Gradients/TensorTape.cs
  25. +1
    -26
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  26. +10
    -0
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  27. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs
  28. +4
    -4
      src/TensorFlowNET.Core/Operations/c_api.ops.cs
  29. +3
    -1
      src/TensorFlowNET.Core/Operations/functional_ops.cs
  30. +46
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  31. +4
    -2
      src/TensorFlowNET.Core/Operations/handle_data_util.cs
  32. +14
    -0
      src/TensorFlowNET.Core/Operations/resource_variable_ops.cs
  33. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  34. +5
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  35. +1
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  36. +13
    -0
      src/TensorFlowNET.Core/Util/UnorderedMap.cs
  37. +5
    -0
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  38. +2
    -6
      src/TensorFlowNET.Core/ops.cs
  39. +21
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  40. +1
    -5
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  41. +4
    -4
      src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs
  42. +25
    -0
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  43. +11
    -11
      src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs
  44. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  45. +11
    -15
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 1
- 0
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Google.Protobuf;
using Protobuf.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Contexts


+ 25
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs View File

@@ -12,18 +12,36 @@ namespace Tensorflow.Eager
return HasGradientTape();
}

private bool ShouldRecord(Tensor[] inputs)
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors)
{
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
var tape_set = tf.GetTapeSet();
var input_ids = MakeTensorIDList(tensors);
var input_dtypes = MakeTensorDtypeList(tensors);
bool some_tape_watching = false;
if (tape_set is not null && tape_set.Count > 0)
{
if (tape.ShouldRecord(inputs))
foreach (var tape in tape_set)
{
should_record = true;
break;
if (tape.ShouldRecord(input_ids, input_dtypes))
{
if (tape.Persistent || some_tape_watching)
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
}
some_tape_watching = true;
}
}
}
return should_record;
// skip the forward_accumulators.

if (some_tape_watching)
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
}
else
{
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE;
}
}
}
}

+ 12
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -13,7 +13,17 @@ namespace Tensorflow.Eager
Tensor[] results,
BackwardFunction backwardFunction = null)
{
bool should_record = ShouldRecord(inputs);
var input_ids = MakeTensorIDList(inputs);
var input_dtypes = MakeTensorDtypeList(inputs);
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
{
if (tape.ShouldRecord(input_ids, input_dtypes))
{
should_record = true;
break;
}
}

if (!should_record)
{
@@ -59,7 +69,7 @@ namespace Tensorflow.Eager
op_inputs = inputs;*/

backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results);
TapeSetRecordOperation(op_name, inputs, results, backwardFunction);
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction);

return true;
}
@@ -129,10 +139,5 @@ namespace Tensorflow.Eager
{
return HasGradientTape();
}

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
return tensors.Select(x => x.dtype).ToArray();
}
}
}

+ 162
- 17
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -1,6 +1,8 @@
using System;
using OneOf.Types;
using System;
using Tensorflow.Gradients;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager
/// </summary>
public partial class EagerRunner
{
/// <summary>
///
/// </summary>
/// <param name="tape"></param>
/// <param name="target"></param>
/// <param name="sources"></param>
/// <param name="output_gradients"></param>
/// <param name="unconnected_gradients">determines the value returned if the target and
/// sources are unconnected.When 'none' the value returned is None wheras when
/// 'zero' a zero tensor in the same shape as the sources is returned.</param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target,
Tensor[] sources,
Tensor[] output_gradients)
List<Tensor> output_gradients,
Tensor[] sources_raw,
string unconnected_gradients)
{
var target_vec = target;
var sources_vec = sources;
var sources_set = sources_vec;
if (!tape.Persistent)
{
var tape_set = tf.GetTapeSet();
if (tape_set.Contains(tape))
{
throw new RuntimeError("gradient() cannot be invoked within the " +
"GradientTape context (i.e., while operations are being " +
"recorded). Either move the call to gradient() to be " +
"outside the 'with tf.GradientTape' block, or " +
"use a persistent tape: " +
"'with tf.GradientTape(persistent=true)'");
}
}

var target_vec = MakeTensorIDList(target);
var sources_vec = MakeTensorIDList(sources);
HashSet<long> sources_set = new HashSet<long>(sources_vec);
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>();

int len = target.Length;
for(int i = 0; i < len; i++)
{
var target_id = target_vec[i];
if (sources_set.Contains(target_id))
{
var tensor = target[i];
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor);
}
}

List<Tensor> outgrad_vec = new();
if(output_gradients is not null)
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);

var seq_array = target;
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>();

for (int i = 0; i < target.Length; ++i)
bool unconnected_gradients_zero = unconnected_gradients == "zero";
Tensor[] sources_obj = null;
if (unconnected_gradients_zero)
{
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
sources_obj = MakeTensorList(sources_raw);
}

if (output_gradients != null)
if (result.Length > 0)
{
throw new NotImplementedException("");
for(int i = 0; i < result.Length; i++)
{
if (result[i] is null && unconnected_gradients_zero)
{
var dtype = sources_obj[i].dtype;
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike();
}
}
}
else
return result;
}

Tensor[] MakeTensorList(IEnumerable<Tensor> tensors)
{
return tensors.ToArray();
}

long[] MakeTensorIDList(Tensor[] tensors)
{
int len = tensors.Length;
long[] ids = new long[len];
for(int i = 0; i < len; i++)
{
var tensor = tensors[i];
ids[i] = tensor.Id;
}
return ids;
}

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
int len = tensors.Length;
TF_DataType[] dtypes = new TF_DataType[len];
for (int i = 0; i < len; i++)
{
output_gradients = new Tensor[0];
var tensor = tensors[i];
dtypes[i] = tensor.dtype;
}
return dtypes;
}

var outgrad_vec = MakeTensorList(output_gradients);
TapeTensor TapeTensorFromTensor(Tensor tensor)
{
long id = tensor.Id;
var dtype = tensor.dtype;
if (tensor is EagerTensor)
{
var handle = tensor.EagerTensorHandle;
if (DTypeNeedsHandleData(dtype))
{
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor);
}

Status status = new();
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status);
long[] dims = new long[num_dims];
for(int i = 0; i < num_dims; i++)
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
Shape tensor_shape = new(dims);

if(status.Code != TF_Code.TF_OK)
{
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null);
}
else
{
return new TapeTensor(id, dtype, tensor_shape);
}
}
var shape_tuple = tensor.shape.dims;
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype))
{
return new TapeTensor(id, dtype, tensor);
}
long[] l = new long[shape_tuple.Length];
for(int i = 0; i < shape_tuple.Length; i++)
{
if (shape_tuple[i] < 0)
{
l[i] = 0;
}
else
{
l[i] = shape_tuple[i];
}
}
return new TapeTensor(id, dtype, new Shape(l));
}

return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec);
bool DTypeNeedsHandleData(TF_DataType dtype)
{
return dtype == dtypes.variant || dtype == dtypes.resource;
}

Tensor[] MakeTensorList(Tensor[] tensors)
bool ListContainNone(long[] list)
{
return tensors;
int len = list.Length;
if(len == 0)
{
return true;
}
for(int i = 0; i < len; i++)
{
if (list[i] == -1)
{
return true;
}
}
return false;
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs View File

@@ -7,8 +7,9 @@ namespace Tensorflow.Eager
public partial class EagerRunner
{
void TapeSetRecordBackprop(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
TapeTensor[] output_info,
long[] input_ids,
TF_DataType[] input_detyps,
BackwardFunction backward_function)
{
if (!CouldBackprop())
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager

foreach (var tape in tf.GetTapeSet())
{
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function);
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function);
}
}
}


+ 13
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs View File

@@ -10,18 +10,28 @@ namespace Tensorflow.Eager
public bool TapeSetRecordOperation(string op_type,
Tensor[] input_tensors,
Tensor[] output_tensors,
long[] input_ids,
TF_DataType[] input_dtypes,
BackwardFunction backward_function)
{
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();

var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray();
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
backward_function))
return false;

TapeSetRecordBackprop(op_type, input_tensors, output_info,
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes,
backward_function);

return true;
}

public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
Tensor[] input_tensors, BackwardFunction backward_function)
{
var input_ids = MakeTensorIDList(input_tensors);
var input_dtypes = MakeTensorDtypeList(input_tensors);
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes,
backward_function);
}
}
}

+ 8
- 1
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -29,7 +29,14 @@ namespace Tensorflow.Eager
Tensor[] TFE_TapeGradient(ITape tape,
Tensor[] target,
Tensor[] sources,
Tensor[] output_gradients);
List<Tensor> output_gradients,
Tensor[] sources_raw,
string unconnected_gradients);

void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
Tensor[] input_tensors, BackwardFunction backward_function);

int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors);

bool RecordGradient(string op_name,
Tensor[] inputs,


+ 21
- 4
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -18,12 +18,13 @@ namespace Tensorflow.Functions
public class ConcreteFunction: Trackable
{
protected IEnumerable<Tensor> _captured_inputs;
internal FuncGraph func_graph;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary<string, AttrValue> _attrs;
protected FunctionSpec _function_spec;
protected FunctionSpec _pre_initialized_function_spec = null;
protected EagerDefinedFunction _inference_function;
protected Dictionary<string, TapeGradientFunctions> _tape_functions_cache = new();
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;
@@ -156,6 +157,17 @@ namespace Tensorflow.Functions
{
var executing_eagerly = tf.Context.executing_eagerly();
var default_graph = ops.get_default_graph();
// TODO(Rinne): deal with `default_graph.building_function`

var tempvv = func_graph.Variables;
if(tf.GetTapeSet().Count > 0 || default_graph is FuncGraph)
{
foreach(var v in this.func_graph.Variables)
{
resource_variable_ops.variable_accessed(v);
}
}

var tensor_inputs = new Tensors();
foreach (var (i, arg) in enumerate(args))
{
@@ -223,11 +235,16 @@ namespace Tensorflow.Functions
{
input_tangents = new TangentInfo();
}
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER || tf.Runner.MustRecordGradient())
if(possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER)
{
if(input_tangents.Indices is not null || executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
string cache_key = "first_order";
if(!_tape_functions_cache.TryGetValue(cache_key, out var functions))
{
functions = new FirstOrderTapeGradientFunctions(func_graph, false);
_tape_functions_cache[cache_key] = functions;
}
return new ForwardBackwardCall(functions, args, tape_watching: true);
}
else
@@ -241,7 +258,7 @@ namespace Tensorflow.Functions
}

// TODO(Rinne): add arg "input_tagents" for ForwardBackwardCall.
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: tf.Runner.MustRecordGradient());
return new ForwardBackwardCall(_delayed_rewrite_functions, args, tape_watching: false);
}

internal void set_variables(IEnumerable<IVariableV1> variables)


+ 9
- 10
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -124,17 +124,16 @@ namespace Tensorflow.Functions
// TODO(Rinne): Add arg `CancellationManager`.
// TODO(Rinne): Check the arg length.
var function_call_options = tf.Context.FunctionCallOptions;
string config;
if (function_call_options.config_proto_serialized().Length == 0)
{
config = function_utils.get_disabled_rewriter_config().ToString();
}
else
{
config = function_call_options.config_proto_serialized().ToString();
}
string config = ""; // TODO(Rinne): revise it. The following code should work but not, for unclear reasons.

config = ""; // TODO(Rinne): revise it.
//if (function_call_options.config_proto_serialized().Length == 0)
//{
// config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
//}
//else
//{
// config = function_call_options.config_proto_serialized().ToStringUtf8();
//}

string executor_type = function_call_options.ExecutorType ?? "";
var executing_eagerly = tf.Context.executing_eagerly();


+ 4
- 5
src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs View File

@@ -14,12 +14,11 @@ namespace Tensorflow.Functions

}

public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
public override (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
ForwardAndBackwardFunctions(Tensors inference_args)
{
var outputs = _func_graph.Outputs;
(_forward_function, _forward_graph, _backward_function, _forwardprop_output_indices, _num_forwardprop_outputs)
= BuildFunctionsForOutputs(outputs, inference_args);
return _forward_function;
var outputs = _func_graph.Outputs.Take(_num_inference_outputs).ToArray();
return BuildFunctionsForOutputs(outputs, inference_args);
}
}
}

+ 3
- 5
src/TensorFlowNET.Core/Functions/Function.cs View File

@@ -14,7 +14,6 @@ namespace Tensorflow
protected ConcreteFunction _concrete_variable_creation_fn;
protected bool _autograph;
protected TracingCompiler _variable_creation_fn;
protected bool _has_initialized;
public string Name { get; set; }
public Function(Func<Tensor[], Tensor[]> csharp_function,
string name, bool auto_graph = true)
@@ -22,7 +21,6 @@ namespace Tensorflow
_csharp_function = csharp_function;
Name = name;
_autograph = auto_graph;
_has_initialized = false;
}

public virtual Tensors Apply(Tensors inputs)
@@ -38,10 +36,11 @@ namespace Tensorflow

protected virtual Tensors _call(Tensors inputs)
{
if (!_has_initialized)
if(_variable_creation_fn is not null)
{
_initialize(inputs);
return _variable_creation_fn.Apply(inputs);
}
_initialize(inputs);

return _concrete_variable_creation_fn.CallFlat(inputs,
_concrete_variable_creation_fn.CapturedInputs);
@@ -63,7 +62,6 @@ namespace Tensorflow
_variable_creation_fn = _compiler(_csharp_function);
_variable_creation_fn._name = this.Name;
_concrete_variable_creation_fn = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args);
_has_initialized = true;
}
}
}

+ 105
- 52
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -24,23 +24,40 @@ namespace Tensorflow.Functions
protected string _INFERENCE_PREFIX = "__inference_";

protected FuncGraph _func_graph;
protected EagerDefinedFunction _forward_function;
protected EagerDefinedFunction _forward;
protected FuncGraph _forward_graph;
protected List<int> _forwardprop_input_indices;
protected List<int> _forwardprop_output_indices;
protected int _num_forwardprop_outputs;
protected ConcreteFunction _backward_function;
protected int _num_inference_outputs;
protected int _num_outputs;
protected int _num_trainable_inference_outputs;
protected ConcreteFunction _backward;
BackwardFunction _backward_function_wrapper;

public TapeGradientFunctions(FuncGraph func_graph,
bool need_gradients_for_jvps)
{
_func_graph = func_graph;
_forward_graph = null;
_forward = null;
_backward = null;
_num_outputs = func_graph.Outputs.Length;
_forwardprop_output_indices = null;
_num_forwardprop_outputs = 0;
_num_inference_outputs = func_graph.Outputs.Length;
_num_trainable_inference_outputs = func_graph.Outputs.Where(t => backprop_util.IsTrainable(t)).Count();
}

public virtual EagerDefinedFunction Forward(Tensors inference_args, Tensors input_tangents = null)
{
// TODO(Rinne): add input_tangents arg.
return ForwardAndBackwardFunctions(inference_args);
if(_forward is null)
{
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
= ForwardAndBackwardFunctions(inference_args);
}
return _forward;
}

/// <summary>
@@ -51,9 +68,13 @@ namespace Tensorflow.Functions
public virtual void Record(Tensors flat_outputs, Tensors inference_args)
{
// TODO(Rinne): add arg `input_tagents`.
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward_function, flat_outputs);
tf.Runner.RecordGradient(_forward_function.Name, inference_args, new object[0], to_record,
getBackwardFunction: backward_function);
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
if(_forwardprop_output_indices is not null && _forwardprop_output_indices.Count > 0)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
}
tf.Runner.TFE_TapeSetRecordOperation(_forward.Signature.Name, to_record, inference_args, backward_function);
}

/// <summary>
@@ -65,66 +86,95 @@ namespace Tensorflow.Functions
/// <returns></returns>
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{
var capture_mapping = zip(forward_graph.Outputs.Select(t => ops.tensor_id(t)), outputs)
.ToDictionary(x => x.Item1, x => x.Item2);
var captured_inputs = backward.CapturedInputs;
var remapped_captures = captured_inputs.Select(c =>
{
if (capture_mapping.TryGetValue(ops.tensor_id(c), out var value))
{
return value;
}
else
{
return c;
}
}).ToArray();
if(remapped_captures.Where(t => t is not EagerTensor).Any(t => t.graph == forward_graph))
{
var incorrect_mapping = remapped_captures.Where(t => t is not EagerTensor && t.graph != forward_graph);
throw new RuntimeError($"Failed to map all backward graph captures to " +
$"the forward graph. Incorrectly mapped: {string.Join(", ", incorrect_mapping)}");
}

Dictionary<int, Tensor> variant_zeros_like = new Dictionary<int, Tensor>();
var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
var recorded_outputs = new Tensors();
var trainable_recorded_outputs = 0;
foreach (var (output_index, output) in enumerate(outputs))
int trainable_recorded_outputs = 0;
var skip_positions = new HashSet<int>();
var relevant_outputs = outputs;
foreach (var (output_index, output) in enumerate(relevant_outputs))
{
if (trainable_recorded_outputs < backward_function_inputs)
recorded_outputs.Add(output);
if (gradients_util.IsTrainable(output))
trainable_recorded_outputs += 1;
if (backprop_util.IsTrainable(output))
trainable_recorded_outputs++;
else
skip_positions.Add(output_index);
if (output.dtype == dtypes.variant)
variant_zeros_like[output_index] = default_gradient.zeros_like(output);
}

if(_backward_function_wrapper == null)
_backward_function_wrapper = (args, unneeded_gradients) =>
{
var capture_mapping = new Dictionary<long, Tensor>();
foreach (var (i, output) in enumerate(outputs))
capture_mapping[forward_graph.Outputs[i].Id] = output;

var remapped_captures = new Tensors();
foreach (var capture in backward.CapturedInputs)
{
if (capture_mapping.ContainsKey(capture.Id))
remapped_captures.Add(capture_mapping[capture.Id]);
}

var skip_positions = new List<int>();
foreach (var (output_index, output) in enumerate(outputs))
if(backward.Outputs is null || backward.Outputs.Length == 0)
{
if (!gradients_util.IsTrainable(output))
skip_positions.Add(output_index);
return backward.FlatStructuredOutputs;
}

_backward_function_wrapper = (args, unneeded_gradients) =>
var processed_args = new Tensors();
int input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
{
var processed_args = new Tensors();
var input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
if (skip_positions.Contains(output_index))
continue;
if (arg is null)
{
var input_placeholder = backward.Inputs[input_index];
Tensor variant_arg;
if (input_placeholder.dtype == dtypes.variant)
{
variant_arg = variant_zeros_like[output_index];
}
else
{
var (shape, type) = default_gradient.shape_and_dtype(input_placeholder);

variant_arg = array_ops.zeros(shape, type);
}
processed_args.Add(variant_arg);
}
else
{
if (skip_positions.Contains(output_index))
continue;
if (arg == null)
throw new NotImplementedException("");
processed_args.Add(arg);
input_index += 1;
if (input_index >= backward_function_inputs)
break;
}
input_index++;
if (input_index >= backward_function_inputs)
break;
}

tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);
tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);

foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}
foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}

return gradients;
};
}
return gradients;
};

return (_backward_function_wrapper, recorded_outputs);
}
@@ -143,7 +193,7 @@ namespace Tensorflow.Functions
}
}

var backwards_graph = new FuncGraph(_func_graph.Name);
var backwards_graph = new FuncGraph(monomorphic_function_utils._backward_name(_func_graph.Name));
backwards_graph.as_default();
var gradients_wrt_outputs = new List<Tensor>();
foreach (var output in trainable_outputs)
@@ -153,6 +203,7 @@ namespace Tensorflow.Functions
gradients_wrt_outputs.Add(gradient_placeholder);
handle_data_util.copy_handle_data(output, gradient_placeholder);
}
// TODO(Rinne): with ops.device(None)
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
@@ -175,7 +226,8 @@ namespace Tensorflow.Functions
backwards_graph.Inputs = gradients_wrt_outputs.Concat(backwards_graph.internal_captures).ToArray();
backwards_graph.Outputs.AddRange(gradients_wrt_inputs.Where(x => x is not null));

var (forward_function, backward_function) = monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
var (wrapped_forward_function, wrapped_backward_function) =
monomorphic_function_utils._create_forward_backward_with_graph(null, _func_graph, backwards_graph);
//var forward_function_name = $"{_FORWARD_PREFIX}_{_func_graph.FuncName}_{ops.uid()}";
//var backward_function_attr = new Dictionary<string, string>();
//backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
@@ -189,10 +241,11 @@ namespace Tensorflow.Functions
// _func_graph.Inputs, _func_graph.Outputs,
// monomorphic_function_utils._parse_func_attrs(forward_function_attr));
return (forward_function, _func_graph, backward_function, null, 0);
return (wrapped_forward_function, _func_graph, wrapped_backward_function, null, 0);
}

public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
public virtual (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
ForwardAndBackwardFunctions(Tensors inference_args)
{
throw new NotImplementedException("");
}


+ 6
- 6
src/TensorFlowNET.Core/Functions/TracingCompiler.cs View File

@@ -73,12 +73,12 @@ namespace Tensorflow.Functions

private static string male_cache_key(Tensor[] inputs)
{
string res = "";
foreach (var input in inputs)
{
res += $"{input.name}_{input.Id}";
}
return res;
//string res = "";
//foreach (var input in inputs)
//{
// res += $"{input.name}_{input.Id}";
//}
return inputs.Length.ToString();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Functions/monomorphic_function.cs View File

@@ -153,7 +153,7 @@ namespace Tensorflow.Functions
foreach(var tape in tf.GetTapeSet())
{
tape.RecordOperation(_inference_function.Signature.Name, to_record,
inference_args.Select(t => new TapeTensor(t)).ToArray(), backward_function);
inference_args, backward_function);
}
}



+ 2
- 2
src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.Gradients
/// Map from tensor to how many references still exist for this tensor in
/// the tape.
/// </summary>
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; }
public UnorderedMap<long, long> tensor_usage_counts { get; set; }
/// <summary>
/// Maps from op ID to how many output tensors of this op still need to have
/// their gradients computed.
@@ -19,7 +19,7 @@ namespace Tensorflow.Gradients
public BackpropInitialState()
{
op_tape = new OpTape();
tensor_usage_counts = new UnorderedMap<Tensor, long>();
tensor_usage_counts = new UnorderedMap<long, long>();
op_missing_tensor = new UnorderedMap<long, long>();
}
}


+ 27
- 8
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -67,40 +67,59 @@ namespace Tensorflow.Gradients
/// <param name="target"></param>
/// <param name="source"></param>
/// <returns></returns>
public Tensor gradient(Tensor target, Tensor source)
public Tensor gradient(Tensor target, Tensor source, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
if(_tape is null)
{
throw new RuntimeError("A non-persistent GradientTape can only be used to " +
"compute one set of gradients (or jacobians).");
}
ITape tape = stop_recording();

var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
new[] { source },
null);
output_gradients,
new[] { source },
unconnected_gradients);

return results[0];
}

public Tensor gradient(Tensor target, ResourceVariable source)
public Tensor gradient(Tensor target, ResourceVariable source, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
var results = gradient(target, new List<IVariableV1> { source });
var results = gradient(target, new List<IVariableV1> { source }, output_gradients, unconnected_gradients);

return results[0];
}

public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 });
var results = gradient(target, new List<IVariableV1> { sources.Item1, sources.Item2 }, output_gradients, unconnected_gradients);

return (results[0], results[1]);
}

public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources)
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources, List<Tensor> output_gradients = null,
string unconnected_gradients = null)
{
if (_tape is null)
{
throw new RuntimeError("A non-persistent GradientTape can only be used to " +
"compute one set of gradients (or jacobians).");
}
var tape = stop_recording();

var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
sources.Select(x => x.Handle).ToArray(),
null);
output_gradients,
sources.Select(x => x.Handle).ToArray(),
unconnected_gradients);

if (!tape.Persistent)
{


+ 15
- 8
src/TensorFlowNET.Core/Gradients/ITape.cs View File

@@ -6,24 +6,31 @@ namespace Tensorflow.Gradients
public interface ITape
{
void SetTapeId(int id);
bool ShouldRecord(Tensor[] tensors);
bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes);
void StartRecord();
void StopRecord();
bool Persistent { get; }
void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
BackwardFunction backward_function);

void VariableAccessed(ResourceVariable variable);
void RecordOperation(string op_type,
Tensor[] outputs,
Tensor[] inputs,
BackwardFunction backward_function);

void VariableAccessed(IVariableV1 variable);

void Watch(Tensor x);

ResourceVariable[] WatchedVariables();
IVariableV1[] WatchedVariables();

Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients);
Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
bool build_default_zeros_grads);
}
}

+ 2
- 2
src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs View File

@@ -9,9 +9,9 @@ namespace Tensorflow.Gradients
{
public string op_type { get; set; }
public TapeTensor[] output_tensor_info { get; set; }
public Tensor[] input_tensor_id { get; set; }
public long[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; }
public override string ToString()
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}";
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
}
}

+ 157
- 125
src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs View File

@@ -2,235 +2,246 @@
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
public partial class Tape
{
// int kMinAggregateCount = 4;
// int kMinAggregateBytes = 128 * 1024 * 1024;
static readonly int kMinAggregateCount = 4;
static readonly int kMinAggregateBytes = 128 * 1024 * 1024;
private static UnorderedMap<string, UnorderedSet<int>> _functionsAcceptingNoneForIndicesMap;

public Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients)
static Tape()
{
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids);
// var gradients_size = new UnorderedMap<Tensor, long>();
var functionsAcceptingNoneForIndicesMap = FunctionsAcceptingNoneForIndicesMap();
var state = PrepareBackprop(
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
output_gradients,
tensor_tape_,
state.op_tape);
_functionsAcceptingNoneForIndicesMap = new();
_functionsAcceptingNoneForIndicesMap.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
_functionsAcceptingNoneForIndicesMap.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
_functionsAcceptingNoneForIndicesMap.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
}

while (!op_stack.empty())
public Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
bool build_default_zeros_grads)
{
UnorderedSet<long> sources_set = new(source_tensor_ids);
BackpropInitialState state = PrepareBackprop(target_tensor_ids, tensor_tape_, op_tape_, sources_set, Persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, output_gradients, tensor_tape_, state.op_tape);
UnorderedMap<long, long> gradients_size = new();
while(op_stack.Count > 0)
{
var op = op_stack.Dequeue();
if (!state.op_tape.find(op, out var trace))
long op = op_stack.Dequeue();
if(!state.op_tape.TryGetValue(op, out var op_it))
{
continue;

// Console.WriteLine($"ComputeGradient: {state.op_tape[op].op_type}");
}
var trace = op_it;
state.op_tape.erase(op);

var out_gradients = new List<Tensor>(trace.output_tensor_info.Length);
var unneeded_gradients = new List<long>();
for (int i = 0; i < trace.input_tensor_id.Length; i++)
List<Tensor> out_gradients = new();
List<long> unneeded_gradients = new();
for(int i = 0, end = trace.input_tensor_id.Length; i < end; i++)
{
var in_tensor_id = trace.input_tensor_id[i];
if (!tensor_tape_.find(in_tensor_id) &&
!sources_set.find(in_tensor_id))
long in_tensor_id = trace.input_tensor_id[i];
if(!tensor_tape_.find(in_tensor_id) && !sources_set.find(in_tensor_id))
{
unneeded_gradients.Add(i);
}
}

bool any_gradient_nonzero = false;
var zero_indices = new List<int>();
for (int i = 0; i < trace.output_tensor_info.Length; ++i)
List<int> zero_indices = new();
for(int i = 0, end = trace.output_tensor_info.Length; i < end; i++)
{
var id = trace.output_tensor_info[i].GetTensor();
if (!gradients.find(id, out var grad_it))
long id = trace.output_tensor_info[i].GetID();
if(!gradients.TryGetValue(id, out var grad_it))
{
if (functionsAcceptingNoneForIndicesMap.find(trace.op_type, out var func_name_it) &&
func_name_it.find(i))
out_gradients.Add(null);
if (build_default_zeros_grads)
{
out_gradients.Add(null);
}
else
{
out_gradients.Add(null);
zero_indices.Add(i);
if(!_functionsAcceptingNoneForIndicesMap.TryGetValue(trace.op_type, out var func_name_it) ||
!func_name_it.find(i))
{
zero_indices.Add(i);
}
}
}
else
{
any_gradient_nonzero = true;
var new_gradients = grad_it.Count == 1 ?
grad_it[0] :
gen_math_ops.add_n(grad_it.ToArray()); // vspace.AggregateGradients

Tensor new_gradients;
if (grad_it.Count == 1)
{
new_gradients = grad_it[0];
}
else
{
new_gradients = AggregateGradients(grad_it);
}
if (!sources_set.find(id))
{
gradients.Remove(id);
}
else
{
// grad_it.Clear();
// grad_it.Add(new_gradients);
// vspace.MarkAsResult(new_gradients);
grad_it.Clear();
grad_it.Add(new_gradients);
// MarkAsResult
}
out_gradients.Add(new_gradients);
}
}

Tensor[] in_gradients;
Tensor[] in_gradients = new Tensor[0];
if (any_gradient_nonzero)
{
// foreach (var i in zero_indices)
// out_gradients[i] = trace.output_tensor_info[i].ZerosLike();

in_gradients = trace.backward_function(out_gradients.ToArray(), unneeded_gradients.ToArray());

if (in_gradients.Length != trace.input_tensor_id.Length && in_gradients.Length + unneeded_gradients.Count != trace.input_tensor_id.Length)
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
if (!_persistent)
foreach(var i in zero_indices)
{
// trace.backward_function_deleter(trace.backward_function);
trace.backward_function = null;
out_gradients[i] = trace.output_tensor_info[i].ZerosLike();
}
in_gradients = CallBackwardFunction(trace.backward_function, unneeded_gradients, out_gradients);
}
else
{
in_gradients = new Tensor[trace.input_tensor_id.Length];
out_gradients.Clear();
}

bool skip_unneeded_id = trace.input_tensor_id.Length > in_gradients.Length;
for (int i = 0, k = 0; i < in_gradients.Length && k < trace.input_tensor_id.Count(); ++i, ++k)
for(int i = 0, end = in_gradients.Length; i < end; i++)
{
if (skip_unneeded_id && unneeded_gradients.Contains(k)) ++k;
var id = trace.input_tensor_id[k];
if (in_gradients[i] != null)
long id = trace.input_tensor_id[i];
if (in_gradients[i] is not null)
{
var unaggregated_grads = gradients[id];
var unaggregated_grads = gradients.SetDefault(id, new List<Tensor>());
unaggregated_grads.Add(in_gradients[i]);
/*if (unaggregated_grads.Count > kMinAggregateCount)
if(unaggregated_grads.Count > kMinAggregateCount)
{
if (!gradients_size.find(id, out var size))
if(!gradients_size.TryGetValue(id, out var size))
{
size = (long)unaggregated_grads[0].size;
size = NumElements(unaggregated_grads[0]);
gradients_size.emplace(id, size);
}

if (unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
if(unaggregated_grads.Count * size * 4 > kMinAggregateBytes)
{
throw new NotImplementedException("");
Tensor grad = AggregateGradients(unaggregated_grads);
unaggregated_grads.Clear();
unaggregated_grads.Add(grad);
}
}*/
}
}
if (!state.tensor_usage_counts.find(id))
if(!state.tensor_usage_counts.find(id))
{
continue;
}
state.tensor_usage_counts[id]--;
if (state.tensor_usage_counts[id] > 0)
if(state.tensor_usage_counts[id] > 0)
{
continue;
if (!tensor_tape_.find(id, out var tape_it))
}
if (!tensor_tape_.TryGetValue(id, out var tape_it))
{
if (gradients.find(id, out var grad_it))
if (gradients.find(id))
{
// foreach (var g in grad_it)
// DeleteGradient(g);
gradients.erase(id);
}
continue;
}
var op_id = tape_it;
if (op_id == -1)
long op_id = tape_it;
if(op_id == -1)
{
continue;
if (state.op_missing_tensor.find(op_id, out var missing_it))
}
if(state.op_missing_tensor.find(op_id))
{
state.op_missing_tensor[op_id]--;
if (state.op_missing_tensor[op_id] == 0)
if(state.op_missing_tensor[op_id] == 0)
{
op_stack.Enqueue(op_id);
}
}
}
}

if (state.op_tape.Count > 0)
if(state.op_tape.Count > 0)
{
throw new RuntimeError("Invalid tape state.");

var result = new Tensor[source_tensor_ids.Length];
var j = 0;
foreach (var id in source_tensor_ids)
}
Tensor[] result = new Tensor[source_tensor_ids.Length];
for(int i = 0; i < source_tensor_ids.Length; i++)
{
if (gradients.find(id, out var grad_it))
long tensor_id = source_tensor_ids[i];
if(!gradients.TryGetValue(tensor_id, out var grad_it))
{
if (grad_it.Count > 1)
result[j] = gen_math_ops.add_n(grad_it.ToArray());
else
result[j] = grad_it[0];
result[i] = null;
}
else
{
if(grad_it.Count > 1)
{
Tensor grad = AggregateGradients(grad_it);
grad_it.Clear();
grad_it.Add(grad);
}
result[i] = grad_it[0];
}
j++;
}

return result;
}

UnorderedMap<string, UnorderedSet<int>> FunctionsAcceptingNoneForIndicesMap()
{
var m = new UnorderedMap<string, UnorderedSet<int>>();
m.Add("SoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
m.Add("SparseSoftmaxCrossEntropyWithLogits", new UnorderedSet<int>(new[] { 1 }));
m.Add("FusedBatchNorm", new UnorderedSet<int>(new[] { 1, 2, 3, 4 }));
return m;
return _functionsAcceptingNoneForIndicesMap;
}

UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients,
UnorderedMap<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
List<Tensor> output_gradients,
TensorTape tensor_tape,
OpTape op_tape)
{
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>();
for (int i = 0; i < target_tensor_ids.Length; ++i)
var result = new UnorderedMap<long, List<Tensor>>();
for(int i = 0, end = target_tensor_ids.Length; i < end; i++)
{
var id = target_tensor_ids[i];
if (output_gradients.Length == 0 || output_gradients[i] == null)
long id = target_tensor_ids[i];
if( output_gradients is null ||output_gradients.Count == 0 || output_gradients[i] is null)
{
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null)
if(tensor_tape.TryGetValue(id, out var tensor_it) && tensor_it != -1)
{
if (!op_tape.find(tensor_tape[id], out var op_it))
if(!op_tape.TryGetValue(tensor_it, out var op_it))
{
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"failed to find operation producing a tensor");
"failed to find operation producing a tensor.");
}
bool found = false;
for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
for(int j = 0; j < op_it.output_tensor_info.Length; j++)
{
if (op_it.output_tensor_info[j].GetTensor() == id)
if (op_it.output_tensor_info[j].GetID() == id)
{
found = true;
var ones = op_it.output_tensor_info[j].OnesLike();
result[id].Add(ones);
Tensor ones_like = BuildOnesLike(op_it.output_tensor_info[j]);
result.SetDefault(id, new List<Tensor>()).Add(ones_like);
break;
}
}

if (!found)
{
throw new ValueError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor");
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
"none of operations outputs match expected tensor.");
}
}
else
{
if (sources_that_are_targets.find(id, out var source_tensor))
result[id].Add(source_tensor.OnesLike());
if(sources_that_are_targets.TryGetValue(id, out var source_tensor))
{
Tensor ones_like = BuildOnesLike(source_tensor);
result.SetDefault(id, new List<Tensor>()).Add(ones_like);
}
}
}
else
{
result[id].Add(output_gradients[i]);
result.SetDefault(id, new List<Tensor>()).Add(output_gradients[i]);
}
}

@@ -248,5 +259,26 @@ namespace Tensorflow.Gradients
}
return result;
}

Tensor BuildOnesLike(TapeTensor t)
{
return t.OnesLike();
}

Tensor AggregateGradients(List<Tensor> gradient_tensors)
{
if(gradient_tensors.Count == 0)
{
return gradient_tensors[0];
}
return tf.add_n(gradient_tensors.ToArray());
}

void DeleteGradient(Tensor gradient)
{
// Do not do anything here. Because GC will collect it when it has no reference.
}

long NumElements(Tensor tensor) => 1;
}
}

+ 31
- 32
src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs View File

@@ -5,63 +5,62 @@ namespace Tensorflow.Gradients
{
public partial class Tape
{
public BackpropInitialState PrepareBackprop(Tensor[] target,
public BackpropInitialState PrepareBackprop(long[] target,
TensorTape tensor_tape,
OpTape op_tape,
UnorderedSet<Tensor> sources_set,
UnorderedSet<long> sources_set,
bool persistent_tape)
{
Stack<long> tensor_stack = new Stack<long>();
foreach(var t in target)
{
tensor_stack.Push(t);
}
BackpropInitialState result = new BackpropInitialState();
var tensor_stack = new Queue<Tensor>(target);
while (tensor_stack.Count > 0)
while(tensor_stack.Count > 0)
{
var tensor_id = tensor_stack.Dequeue();
if (!tensor_tape.find(tensor_id, out var op_id))
long tensor_id = tensor_stack.Pop();
if(!tensor_tape.TryGetValue(tensor_id, out var op_id))
{
continue;
if (op_id == -1 ||
!op_tape.find(op_id, out var op_it) ||
result.op_tape.find(op_id, out var result_op_it))
}
if(op_id == -1 || !op_tape.TryGetValue(op_id, out var op_it)
|| result.op_tape.find(op_id))
{
continue;
}
result.op_tape.emplace(op_id, op_it);

foreach (var it in op_it.input_tensor_id)
foreach(var it in op_it.input_tensor_id)
{
if (result.tensor_usage_counts.find(it))
if(result.tensor_usage_counts.find(it))
{
result.tensor_usage_counts[it]++;
}
else
{
result.tensor_usage_counts[it] = 1;
if (tensor_tape.find(it))
tensor_stack.Enqueue(it);
{
tensor_stack.Push(it);
}
}
}

if (!persistent_tape)
op_tape.Remove(op_id);
{
op_tape.erase(op_id);
}
}

foreach (var pair in result.tensor_usage_counts)
foreach(var pair in result.tensor_usage_counts)
{
if (tensor_tape.find(pair.Key, out var it) && it != -1)
result.op_missing_tensor[it] += 1;
if(tensor_tape.TryGetValue(pair.Key, out var it) && it != -1)
{
result.op_missing_tensor[it]++;
}
}

if (!persistent_tape)
{
// Call destructors for all unneeded gradient functions and
// clear the op_tape. We can clear the tape because ownership of
// backward functions that will be used for gradient computation
// has been transferred to `result`.
/*for (const auto&op_pair : *op_tape) {
op_pair.second.backward_function_deleter(
op_pair.second.backward_function);
}*/
op_tape.Clear();
}

return result;
}
}


+ 21
- 10
src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs View File

@@ -8,34 +8,45 @@ namespace Tensorflow.Gradients
public partial class Tape
{
long next_op_id_ = 0;
UnorderedMap<Tensor, long> tensor_usage_;
UnorderedMap<long, long> tensor_usage_;

public void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
BackwardFunction backward_function)
{
if (!ShouldRecord(input_tensors))
if (!ShouldRecord(input_tensor_id, input_dtypes))
return;

var op_id = next_op_id_++;
foreach (var i in input_tensors)
foreach (var i in input_tensor_id)
{
tensor_usage_[i]++;

}
long op_id = next_op_id_++;
foreach (var o in output_tensors)
{
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
tensor_tape_[o.GetTensor()] = op_id;
tensor_usage_[o.GetTensor()] = 1;
tensor_tape_[o.GetID()] = op_id;
tensor_usage_[o.GetID()] = 1;
}

op_tape_[op_id] = new OpTapeEntry
{
op_type = op_type,
output_tensor_info = output_tensors,
input_tensor_id = input_tensors,
output_tensor_info = output_tensors.ToArray(),
input_tensor_id = input_tensor_id.ToArray(),
backward_function = backward_function
};
}

public void RecordOperation(string op_type,
Tensor[] outputs,
Tensor[] inputs,
BackwardFunction backward_function)
{
tf.Runner.TFE_TapeSetRecordOperation(op_type, outputs, inputs, backward_function);
}
}
}

+ 10
- 10
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;
@@ -29,7 +30,7 @@ namespace Tensorflow.Gradients
_created_eagerly = tf.Context.executing_eagerly();
tensor_tape_ = new TensorTape();
op_tape_ = new OpTape();
tensor_usage_ = new UnorderedMap<Tensor, long>();
tensor_usage_ = new UnorderedMap<long, long>();
if(_created_eagerly)
tf.Context.start_step();
// nesting_id = ++tape_nesting_id_counter;
@@ -42,29 +43,28 @@ namespace Tensorflow.Gradients
public void Watch(Tensor x)
{
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}");
tensor_tape_.emplace(x, -1);
tensor_tape_.emplace(x.Id, -1);
}

public bool ShouldRecord(Tensor[] tensors)
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] tensor_dtypes)
{
var dtypes = tensors.Select(x => x.dtype).ToArray();
for (int i = 0; i < tensors.Length; ++i)
Debug.Assert(tensor_ids.Length == tensor_dtypes.Length);
for (int i = 0; i < tensor_ids.Length; ++i)
{
if (tensor_tape_.find(tensors[i]))
if (tensor_tape_.find(tensor_ids[i]) && IsDtypeTrainable(tensor_dtypes[i]))
{
if (IsDtypeTrainable(dtypes[i]))
return true;
return true;
}
}
return false;
}

public void VariableAccessed(ResourceVariable variable)
public void VariableAccessed(IVariableV1 variable)
{
Watch(variable.Handle);
}

public ResourceVariable[] WatchedVariables()
public IVariableV1[] WatchedVariables()
{
return null;
}


+ 45
- 9
src/TensorFlowNET.Core/Gradients/TapeTensor.cs View File

@@ -1,27 +1,63 @@
using static Tensorflow.Binding;
using OneOf;
using static Tensorflow.Binding;

namespace Tensorflow.Gradients
{
public class TapeTensor
{
Tensor tensor;
long id => tensor.Id;
TF_DataType dtype => tensor.dtype;
Shape shape => tensor.shape;
internal Tensor tensor;
internal long id;
internal TF_DataType dtype;
internal OneOf<Shape, Tensor> shape;

public TapeTensor(long id, TF_DataType dtype, Shape shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}

public TapeTensor(long id, TF_DataType dtype, Tensor shape)
{
this.id = id;
this.dtype = dtype;
this.shape = shape;
}

public TapeTensor(Tensor tensor)
{
this.id = tensor.Id;
this.dtype = tensor.dtype;
this.shape = tensor.shape;
this.tensor = tensor;
}

public long GetID() => tensor.Id;
public Tensor GetTensor() => tensor;
public long GetID() => id;

public Tensor ZerosLike()
=> tf.zeros(shape: shape, dtype: dtype);
{
if(dtype == dtypes.resource)
{
return null;
}
if(shape.Index == 1)
{
return tf.zeros_like(shape.AsT1);
}
return tf.zeros(shape.AsT0, dtype);
}

public Tensor OnesLike()
=> tf.ones(shape: shape, dtype: dtype);
{
if (shape.Index == 1)
{
return tf.ones_like(shape.AsT1);
}
return tf.ones(shape.AsT0, dtype);
}

//public Tensor OnesLike()
// => tf.ones(shape: shape, dtype: dtype);

public override string ToString()
=> $"{id}, {shape}, {dtype.as_numpy_name()}";


+ 1
- 1
src/TensorFlowNET.Core/Gradients/TensorTape.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.Gradients
/// produced this tensor. A value of -1 means that the tensor was directly
/// watched and not the result of any operation in the tape.
/// </summary>
public class TensorTape : UnorderedMap<Tensor, long>
public class TensorTape : UnorderedMap<long, long>
{

}


+ 1
- 26
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -704,32 +704,7 @@ namespace Tensorflow

public static int PossibleTapeGradientTypes(Tensor[] tensors)
{
var tape_set = tf.GetTapeSet();
bool some_tape_watching = false;
if(tape_set is not null && tape_set.Count > 0)
{
foreach(var tape in tape_set)
{
if (tape.ShouldRecord(tensors))
{
if(tape.Persistent || some_tape_watching)
{
return POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
}
some_tape_watching = true;
}
}
}
// skip the forward_accumulators.

if (some_tape_watching)
{
return POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
}
else
{
return POSSIBLE_GRADIENT_TYPES_NONE;
}
return tf.Runner.TFE_TapeSetPossibleGradientTypes(tensors);
}

/// <summary>


+ 10
- 0
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -215,6 +215,16 @@ public class FuncGraph : Graph, IDisposable
return tensor;
}

public void watch_variable(IVariableV1 v)
{
if (_resource_tensor_inputs.Contains(v.Handle))
{
return;
}
_watched_variables.Add(new WeakReference<IVariableV1>(v));
//this = this.outer_graph;
}

Tensor capture_eager_tensor(Tensor tensor, string name)
{
Tensor graph_const = null;


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/IOptimizer.cs View File

@@ -4,10 +4,10 @@ public interface IOptimizer
{
Tensor[] aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars);
Tensor[] clip_gradients(Tensor[] grads);
void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true);
}

+ 4
- 4
src/TensorFlowNET.Core/Operations/c_api.ops.cs View File

@@ -208,9 +208,9 @@ namespace Tensorflow

[DllImport(TensorFlowLibName)]
public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status);
//[DllImport(TensorFlowLibName)]
//public static extern IntPtr GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
//[DllImport(TensorFlowLibName)]
//public static extern void SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}

+ 3
- 1
src/TensorFlowNET.Core/Operations/functional_ops.cs View File

@@ -39,7 +39,7 @@ namespace Tensorflow

if (config is null)
{
config = function_utils.get_disabled_rewriter_config().ToString();
config = function_utils.get_disabled_rewriter_config().ToStringUtf8();
}

if (executor_type is null)
@@ -49,6 +49,8 @@ namespace Tensorflow

if (executing_eagerly)
{
// TODO(Rinne): implement it.
throw new NotImplementedException();
}



+ 46
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Linq;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -210,7 +211,51 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `value`.</returns>
public static Tensor fill<T>(Tensor dims, T value, string name = null)
=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));
{
var ctx = tf.Context;
if (ctx.executing_eagerly())
{
try
{
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Fill", name, dims, value));
return _result[0];
}
catch (Exception)
{
}
try
{
return fill_eager_fallback(dims, value as Tensor, name, ctx);
}
catch (Exception)
{

}
}
Dictionary<string, object> attrs = new Dictionary<string, object>();
attrs["dims"] = dims;
attrs["value"] = value;
var result = tf.OpDefLib._apply_op_helper("Fill", name, attrs);
if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return result.output;
}

public static Tensor fill_eager_fallback(Tensor dims, Tensor value, string name, Context ctx)
{
object[] attrs = new object[] { "T", dims.dtype.as_datatype_enum(), "index_type", dims.dtype.as_datatype_enum() };
var _result = execute.executes("Fill", 1, new Tensor[] { dims, value }, attrs, ctx, name);

if (execute.must_record_gradient())
{
throw new NotImplementedException();
}
return _result[0];
}
//=> tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value));

/// <summary>
/// Return the reduction indices for computing gradients of s0 op s1 with broadcast.


+ 4
- 2
src/TensorFlowNET.Core/Operations/handle_data_util.cs View File

@@ -49,8 +49,10 @@ namespace Tensorflow.Operations
target_t.HandleData = handle_data;
return;
}
// TODO(Rinne): enable it. (currently the internal c api cannot be invoked.)
//c_api.SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), handle_data.ToByteArray());
Status status = new();
var proto = handle_data.ToByteArray();
c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status);
status.Check(true);
}

public static HandleData get_resource_handle_data(Tensor graph_op) => ops.get_resource_handle_data(graph_op);


+ 14
- 0
src/TensorFlowNET.Core/Operations/resource_variable_ops.cs View File

@@ -25,6 +25,7 @@ using static Tensorflow.Binding;
using Tensorflow.Operations;
using System.Buffers;
using Tensorflow.Eager;
using Tensorflow.Graphs;

namespace Tensorflow
{
@@ -302,5 +303,18 @@ namespace Tensorflow
// return handle_data_util.get_resource_handle_data(handle);
//}
}

public static void variable_accessed(IVariableV1 variable)
{
if (ops.get_default_graph() is FuncGraph func_graph)
{
func_graph.watch_variable(variable);
}
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())
tape.VariableAccessed(variable);
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -110,7 +110,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.6.2" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>



+ 5
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
{
public virtual IntPtr TensorDataPointer => _handle == null ? IntPtr.Zero : TF_TensorData(_handle);

public Tensor()
protected Tensor()
{
}

@@ -108,6 +108,7 @@ namespace Tensorflow
protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, null);
_id = ops.uid();
}

protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
@@ -116,6 +117,7 @@ namespace Tensorflow
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar);
else
_handle = TF_NewTensor(bytes, shape, dtype);
_id = ops.uid();
}

protected unsafe void InitTensor(Array array, Shape? shape = null)
@@ -166,6 +168,8 @@ namespace Tensorflow
string[] val => StringTensor(val, shape),
_ => throw new NotImplementedException("")
};

_id = ops.uid();
}

unsafe SafeTensorHandle InitTensor<T>(T[] array, Shape shape, TF_DataType dtype) where T : unmanaged


+ 1
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -462,6 +462,7 @@ namespace Tensorflow.Training.Saving.SavedModel
{
IEnumerable<ConcreteFunction> _concrete_functions;
FunctionSpec _function_spec;
public IEnumerable<ConcreteFunction> ConcreteFunctions => _concrete_functions;
public RestoredFunction(Func<Tensor[], Tensor[]> function, string name, FunctionSpec function_spec,
IEnumerable<ConcreteFunction> concrete_functions): base(function, name, auto_graph: false)
{


+ 13
- 0
src/TensorFlowNET.Core/Util/UnorderedMap.cs View File

@@ -25,6 +25,19 @@ namespace Tensorflow.Util
}
}

public Tv SetDefault(Tk key, Tv default_value)
{
if(TryGetValue(key, out var res))
{
return res;
}
else
{
base[key] = default_value;
return base[key];
}
}

public void push_back(Tk key, Tv value)
=> this[key] = value;



+ 5
- 0
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -9,6 +9,7 @@ using System.Diagnostics;
using Tensorflow.Checkpoint;
using Tensorflow.Training.Saving.SavedModel;
using OneOf;
using Tensorflow.Graphs;

namespace Tensorflow
{
@@ -193,6 +194,10 @@ namespace Tensorflow
/// </summary>
void variable_accessed(BaseResourceVariable variable)
{
if(ops.get_default_graph() is FuncGraph func_graph)
{
func_graph.watch_variable(variable as IVariableV1);
}
if (variable.Trainable)
{
foreach (var tape in tf.GetTapeSet())


+ 2
- 6
src/TensorFlowNET.Core/ops.cs View File

@@ -575,12 +575,8 @@ namespace Tensorflow

public static HandleData get_resource_handle_data(Tensor graph_op)
{
throw new NotImplementedException();
// This implementation hasn't been checked for some reasons.
// If it throws an exception in the future, please check it.

//var handle_data = c_api.GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
//return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
}

public static void dismantle_graph(Graph graph)


+ 21
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -27,6 +27,7 @@ using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using Tensorflow.Train;
using Tensorflow.Training;
using Tensorflow.Training.Saving.SavedModel;
using Tensorflow.Util;
using static Tensorflow.Binding;

@@ -50,7 +51,17 @@ namespace Tensorflow.Keras.Engine
/// the layer's weights.
/// </summary>
protected bool built;
public bool Built => built;
public bool Built
{
get
{
return built;
}
internal set
{
built = value;
}
}
public bool Trainable => args.Trainable;
public TF_DataType DType => args.DType;
public bool AutoCast => args.Autocast;
@@ -179,6 +190,11 @@ namespace Tensorflow.Keras.Engine
}
protected List<ILayer> _self_tracked_trackables;

/// <summary>
/// If this value is set, the behavior of layer call will be changed to directly calling this function.
/// </summary>
public Func<Tensors, Tensors>? ReplacedCall { get; set; } = null;

public Layer(LayerArgs args)
{
Initialize(args);
@@ -259,6 +275,10 @@ namespace Tensorflow.Keras.Engine
/// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if(ReplacedCall is not null)
{
return ReplacedCall(inputs);
}
return inputs;
}



+ 1
- 5
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -35,10 +35,6 @@ namespace Tensorflow.Keras.Engine
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
//foreach (var variable in TrainableVariables)
//{
// tape.watch(variable.Handle);
//}
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred);

@@ -70,7 +66,7 @@ namespace Tensorflow.Keras.Engine
gradients = optimizer.aggregate_gradients(zip(gradients, trainable_variables));
gradients = optimizer.clip_gradients(gradients);

optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)),
optimizer.apply_gradients(zip(gradients, trainable_variables),
experimental_aggregate_gradients: false);
}
}


+ 4
- 4
src/TensorFlowNET.Keras/Optimizers/OptimizerV2.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Optimizers
_set_hyper("decay", args.InitialDecay);
}

public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
public void apply_gradients((Tensor, IVariableV1) grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
=> apply_gradients(new[] { grads_and_vars },
@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Optimizers
/// <param name="grads_and_vars"></param>
/// <param name="name"></param>
/// <param name="experimental_aggregate_gradients"></param>
public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
public void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name = null,
bool experimental_aggregate_gradients = true)
{
@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Optimizers
});
}

void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
void apply_grad_to_update_var(IVariableV1 var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
{
_resource_apply_dense(var, grad, apply_state);
// if var.constraint is not None:
@@ -93,7 +93,7 @@ namespace Tensorflow.Keras.Optimizers
throw new NotImplementedException("_resource_apply_dense");
}

void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
void _distributed_apply(IEnumerable<(Tensor, IVariableV1)> grads_and_vars,
string name,
Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
{


+ 25
- 0
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -255,6 +255,25 @@ namespace Tensorflow.Keras.Saving
/// <param name="layers"></param>
private void _finalize_saved_model_layers(List<Layer> layers)
{
foreach(var layer in layers)
{
layer.Built = true;
var keras_attr = _get_keras_attr(layer);
if(keras_attr is not Trackable trackable)
{
continue;
}
if (trackable.CustomizedFields.TryGetValue("call_and_return_conditional_losses", out var layer_call))
{
Debug.Assert(layer_call is RestoredFunction);
var concrete_functions = ((RestoredFunction)layer_call).ConcreteFunctions;
if (concrete_functions is not null && concrete_functions.Count() > 0)
{
layer.ReplacedCall = use_wrapped_call(layer, ((RestoredFunction)layer_call).Apply);
}
}
}

foreach(var layer in layers)
{
// TODO(Rinne): deal with `RevivedNetwork`.
@@ -265,6 +284,12 @@ namespace Tensorflow.Keras.Saving
}
}

private Func<Tensors, Tensors> use_wrapped_call(Layer layer, Func<Tensors, Tensors> call)
{
// TODO(Rinne): revise it.
return call;
}

private void _restore_layer_unconditional_losses(Layer layer)
{
// TODO(Rinne): implement it.


+ 11
- 11
src/TensorFlowNET.Keras/Saving/SavedModel/RevivedLayer.cs View File

@@ -85,16 +85,16 @@ namespace Tensorflow.Keras.Saving.SavedModel
return _config;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
{
return base.Call(inputs, state, training);
}
else
{
return (func as Function).Apply(inputs);
}
}
//protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
//{
// if(SerializedAttributes is null || !SerializedAttributes.TryGetValue("__call__", out var func) || func is not Function)
// {
// return base.Call(inputs, state, training);
// }
// else
// {
// return (func as Function).Apply(inputs);
// }
//}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -223,7 +223,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
//base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers", "metrics", "layer_regularization_losses", "layer_metrics" }),
// functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" })
base(checkpointable_objects.Concat(new string[] { "non_trainable_variables", "layers"}),
functions.Concat(new string[] { }))
functions.Concat(new string[] { "call_and_return_conditional_losses", "activity_regularizer_fn" }))
{

}


+ 11
- 15
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -64,23 +64,19 @@ public class SequentialModelLoad
var model = tf.keras.models.load_model(@"Assets/python_func_model");
model.summary();

var x = tf.random.uniform((8, 784), -1, 1);
var y = model.Apply(x);
Console.WriteLine(y);
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

//model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

//var data_loader = new MnistModelLoader();
//var num_epochs = 1;
//var batch_size = 8;
var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 8;

//var dataset = data_loader.LoadAsync(new ModelLoadSetting
//{
// TrainDir = "mnist",
// OneHot = false,
// ValidationSize = 58000,
//}).Result;
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 55000,
}).Result;

//model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}
}

Loading…
Cancel
Save