Browse Source

refactor gradient tape.

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
05443ead2b
27 changed files with 203 additions and 306 deletions
  1. +12
    -2
      src/TensorFlowNET.Core/APIs/tf.gradients.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Binding.cs
  3. +14
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
  4. +1
    -26
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  6. +4
    -6
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  7. +4
    -7
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
  8. +0
    -12
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorFromTensor.cs
  9. +0
    -18
      src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorsFromTensorSequence.cs
  10. +5
    -5
      src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
  11. +38
    -61
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  12. +9
    -9
      src/TensorFlowNET.Core/Gradients/ITape.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Gradients/OpTape.cs
  14. +5
    -3
      src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
  15. +18
    -18
      src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
  16. +5
    -5
      src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
  17. +9
    -19
      src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
  18. +44
    -42
      src/TensorFlowNET.Core/Gradients/Tape.cs
  19. +8
    -8
      src/TensorFlowNET.Core/Gradients/TapeTensor.cs
  20. +2
    -2
      src/TensorFlowNET.Core/Gradients/TensorTape.cs
  21. +1
    -11
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  22. +10
    -0
      src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
  23. +1
    -0
      src/TensorFlowNET.Core/Operations/Operation.Input.cs
  24. +2
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  25. +1
    -1
      src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs
  26. +2
    -37
      src/TensorFlowNET.Core/tensorflow.cs
  27. +5
    -10
      test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

+ 12
- 2
src/TensorFlowNET.Core/APIs/tf.gradients.cs View File

@@ -14,22 +14,32 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;
using Tensorflow.Gradients; using Tensorflow.Gradients;


namespace Tensorflow namespace Tensorflow
{ {
public partial class tensorflow public partial class tensorflow
{ {
GradientTape _tapeSet;

/// <summary> /// <summary>
/// Record operations for automatic differentiation. /// Record operations for automatic differentiation.
/// </summary> /// </summary>
/// <param name="persistent"></param> /// <param name="persistent"></param>
/// <param name="watch_accessed_variables"></param> /// <param name="watch_accessed_variables"></param>
/// <returns></returns>
/// <returns>Tape set</returns>
public GradientTape GradientTape(bool persistent = false, public GradientTape GradientTape(bool persistent = false,
bool watch_accessed_variables = true) bool watch_accessed_variables = true)
=> new GradientTape(persistent: persistent,
{
var tape = _tapeSet.PushTape(persistent: persistent,
watch_accessed_variables: watch_accessed_variables); watch_accessed_variables: watch_accessed_variables);
tape.StartRecord();
return _tapeSet;
}

public Stack<ITape> GetTapeSet()
=> _tapeSet.GetTapeSet();


public Tensor[] gradients(Tensor[] ys, public Tensor[] gradients(Tensor[] ys,
Tensor[] xs, Tensor[] xs,


+ 1
- 1
src/TensorFlowNET.Core/Binding.cs View File

@@ -4,7 +4,7 @@ namespace Tensorflow
{ {
public static partial class Binding public static partial class Binding
{ {
public static tensorflow tf { get; } = New<tensorflow>();
public static tensorflow tf { get; } = new tensorflow();


/// <summary> /// <summary>
/// Alias to null, similar to python's None. /// Alias to null, similar to python's None.


+ 14
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs View File

@@ -11,5 +11,19 @@ namespace Tensorflow.Eager
{ {
return HasGradientTape(); return HasGradientTape();
} }

private bool ShouldRecord(Tensor[] inputs)
{
bool should_record = false;
foreach (var tape in tf.GetTapeSet())
{
if (tape.ShouldRecord(inputs))
{
should_record = true;
break;
}
}
return should_record;
}
} }
} }

+ 1
- 26
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -2,7 +2,6 @@
using System.Linq; using System.Linq;
using Tensorflow.Gradients; using Tensorflow.Gradients;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.tensorflow;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
@@ -14,18 +13,7 @@ namespace Tensorflow.Eager
Tensor[] results, Tensor[] results,
Func<BackwardFunction> getBackwardFunction = null) Func<BackwardFunction> getBackwardFunction = null)
{ {
var input_ids = MakeTensorIDList(inputs);
var input_dtypes = MakeTensorDtypeList(inputs);

bool should_record = false;
foreach (var tape in tf.GetTapeSet())
{
if (tape.ShouldRecord(input_ids, input_dtypes))
{
should_record = true;
break;
}
}
bool should_record = ShouldRecord(inputs);


if (!should_record) if (!should_record)
{ {
@@ -43,9 +31,6 @@ namespace Tensorflow.Eager
tf.Logger.Debug($"RecordGradient: op_name={op_name}"); tf.Logger.Debug($"RecordGradient: op_name={op_name}");


Tensor[] op_outputs; Tensor[] op_outputs;
#pragma warning disable CS0219 // Variable is assigned but its value is never used
bool op_outputs_tuple_created = false;
#pragma warning restore CS0219 // Variable is assigned but its value is never used
var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name); var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name);
if (unused_output_indices != null) if (unused_output_indices != null)
{ {
@@ -53,7 +38,6 @@ namespace Tensorflow.Eager
op_outputs = new Tensor[0]; op_outputs = new Tensor[0];
else else
{ {
op_outputs_tuple_created = true;
// op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices); // op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices);
} }
} }
@@ -61,9 +45,6 @@ namespace Tensorflow.Eager
op_outputs = results; op_outputs = results;


Tensor[] op_inputs; Tensor[] op_inputs;
#pragma warning disable CS0219 // Variable is assigned but its value is never used
bool op_inputs_tuple_created = false;
#pragma warning restore CS0219 // Variable is assigned but its value is never used
var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name); var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name);
if (unused_input_indices != null) if (unused_input_indices != null)
{ {
@@ -71,7 +52,6 @@ namespace Tensorflow.Eager
op_inputs = new Tensor[0]; op_inputs = new Tensor[0];
else else
{ {
op_inputs_tuple_created = true;
// op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); // op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
} }
} }
@@ -125,11 +105,6 @@ namespace Tensorflow.Eager
return HasGradientTape(); return HasGradientTape();
} }


long[] MakeTensorIDList(Tensor[] tensors)
{
return tensors.Select(x => x.Id).ToArray();
}

TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{ {
return tensors.Select(x => x.dtype).ToArray(); return tensors.Select(x => x.dtype).ToArray();


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -310,7 +310,7 @@ namespace Tensorflow.Eager
for (int i = 0; i < num_values; ++i) for (int i = 0; i < num_values; ++i)
{ {
dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim);
tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim * sizeof(long));
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long));
} }


c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);


+ 4
- 6
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -14,18 +14,16 @@ namespace Tensorflow.Eager
Tensor[] sources, Tensor[] sources,
Tensor[] output_gradients) Tensor[] output_gradients)
{ {
var target_vec = MakeTensorIDList(target);
var sources_vec = MakeTensorIDList(sources);
var target_vec = target;
var sources_vec = sources;
var sources_set = sources_vec; var sources_set = sources_vec;


var seq_array = target; var seq_array = target;
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>();
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>();


for (int i = 0; i < target.Length; ++i) for (int i = 0; i < target.Length; ++i)
{ {
var target_id = target_vec[i];
var tensor = seq_array[i];
source_tensors_that_are_targets.Add(target_id, TapeTensorFromTensor(tensor));
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
} }


if (output_gradients != null) if (output_gradients != null)


+ 4
- 7
src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs View File

@@ -1,7 +1,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using Tensorflow.Gradients; using Tensorflow.Gradients;
using static Tensorflow.tensorflow;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
@@ -12,16 +12,13 @@ namespace Tensorflow.Eager
Tensor[] output_tensors, Tensor[] output_tensors,
Func<BackwardFunction> backward_function_getter) Func<BackwardFunction> backward_function_getter)
{ {
var output_info = new List<TapeTensor>();
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();


if (!TapeTensorsFromTensorSequence(output_tensors, output_info))
return false;

if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(),
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
backward_function_getter)) backward_function_getter))
return false; return false;


TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(),
TapeSetRecordBackprop(op_type, input_tensors, output_info,
backward_function_getter); backward_function_getter);


return true; return true;


+ 0
- 12
src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorFromTensor.cs View File

@@ -1,12 +0,0 @@
using Tensorflow.Gradients;

namespace Tensorflow.Eager
{
public partial class EagerRunner
{
TapeTensor TapeTensorFromTensor(Tensor tensor)
{
return new TapeTensor(tensor.Id, tensor.dtype, tensor.shape);
}
}
}

+ 0
- 18
src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorsFromTensorSequence.cs View File

@@ -1,18 +0,0 @@
using System.Collections.Generic;
using Tensorflow.Gradients;

namespace Tensorflow.Eager
{
public partial class EagerRunner
{
bool TapeTensorsFromTensorSequence(Tensor[] output_seq,
List<TapeTensor> output_info)
{
for (var i = 0; i < output_seq.Length; ++i)
{
output_info.Add(TapeTensorFromTensor(output_seq[i]));
}
return true;
}
}
}

+ 5
- 5
src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs View File

@@ -7,21 +7,21 @@ namespace Tensorflow.Gradients
{ {
public OpTape<BackwardFunction, TapeTensor> op_tape { get; set; } public OpTape<BackwardFunction, TapeTensor> op_tape { get; set; }
/// <summary> /// <summary>
/// Map from tensor ID to how many references still exist for this tensor in
/// Map from tensor to how many references still exist for this tensor in
/// the tape. /// the tape.
/// </summary> /// </summary>
public UnorderedMap<long, long> tensor_usage_counts { get; set; }
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; }
/// <summary> /// <summary>
/// Maps from op ID to how many output tensors of this op still need to have /// Maps from op ID to how many output tensors of this op still need to have
/// their gradients computed. /// their gradients computed.
/// </summary> /// </summary>
public UnorderedMap<long, long> op_missing_tensor { get; set; }
public UnorderedMap<Tensor, long> op_missing_tensor { get; set; }


public BackpropInitialState() public BackpropInitialState()
{ {
op_tape = new OpTape<BackwardFunction, TapeTensor>(); op_tape = new OpTape<BackwardFunction, TapeTensor>();
tensor_usage_counts = new UnorderedMap<long, long>();
op_missing_tensor = new UnorderedMap<long, long>();
tensor_usage_counts = new UnorderedMap<Tensor, long>();
op_missing_tensor = new UnorderedMap<Tensor, long>();
} }
} }
} }

+ 38
- 61
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -6,6 +6,7 @@ using static Tensorflow.Binding;
namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
/// <summary> /// <summary>
/// Gradient Tape Set
/// Record operations for automatic differentiation. /// Record operations for automatic differentiation.
/// ///
/// Operations are recorded if they are executed within this context manager and /// Operations are recorded if they are executed within this context manager and
@@ -18,54 +19,35 @@ namespace Tensorflow.Gradients
/// </summary> /// </summary>
public class GradientTape : IDisposable public class GradientTape : IDisposable
{ {
bool _recording;
public bool Recording => _recording;
bool _persistent;
bool _watch_accessed_variables;
ResourceVariable[] _watched_variables;
bool _created_eagerly;
ITape _tape;

public GradientTape(bool persistent = false,
bool watch_accessed_variables = true)
int _nextTapeId;
ITape _tape => _tapeSet.Peek();
Stack<ITape> _tapeSet;

public GradientTape()
{ {
_persistent = persistent;
_watch_accessed_variables = watch_accessed_variables;
_created_eagerly = tf.Context.executing_eagerly();
_recording = false;
_created_eagerly = tf.Context.executing_eagerly();
// Enters a context inside which operations are recorded on this tape.
if (_created_eagerly)
{
tf.Context.ensure_initialized();
tf.Context.start_step();
}
_push_tape();
_tapeSet = new Stack<ITape>();
} }


/// <summary> /// <summary>
/// Pushes a new tape onto the tape stack.
/// New tape onto the tape stack.
/// </summary> /// </summary>
private void _push_tape()
public ITape PushTape(bool persistent = false,
bool watch_accessed_variables = true)
{ {
if (_recording)
throw new ValueError("Tape is still recording, This can happen if you try to " +
"re-enter an already-active tape.");

if (_tape == null)
_tape = new Tape(_persistent, _watch_accessed_variables);
else
tf.GetTapeSet().Add(_tape);
// Enters a context inside which operations are recorded on this tape.
if (tf.Context.executing_eagerly())
tf.Context.ensure_initialized();


_recording = true;
var tape = new Tape(persistent, watch_accessed_variables);
tape.SetTapeId(_nextTapeId++);
_tapeSet.Push(tape);
return tape;
} }


private void _pop_tape()
ITape PopTape()
{ {
if (!_recording)
throw new ValueError("Tape is not recording.");
_tape.PopTape(_tape);
_recording = false;
_tape.StopRecord();
return _tapeSet.Pop();
} }


/// <summary> /// <summary>
@@ -74,7 +56,9 @@ namespace Tensorflow.Gradients
/// <param name="x"></param> /// <param name="x"></param>
public void watch(Tensor x) public void watch(Tensor x)
{ {
_tape.Watch(x.Id);
if (!_tapeSet.Any())
return;
_tape.Watch(x);
} }


/// <summary> /// <summary>
@@ -85,13 +69,9 @@ namespace Tensorflow.Gradients
/// <returns></returns> /// <returns></returns>
public Tensor gradient(Tensor target, Tensor source) public Tensor gradient(Tensor target, Tensor source)
{ {
if (_recording)
{
if (!_persistent)
_pop_tape();
}
ITape tape = stop_recording();


var results = tf.Runner.TFE_TapeGradient(_tape,
var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target }, new[] { target },
new[] { source }, new[] { source },
null); null);
@@ -115,22 +95,17 @@ namespace Tensorflow.Gradients


public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources)
{ {
if (_recording)
{
if (!_persistent)
_pop_tape();
}
var tape = stop_recording();


var results = tf.Runner.TFE_TapeGradient(_tape,
var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target }, new[] { target },
sources.Select(x => x.Handle).ToArray(), sources.Select(x => x.Handle).ToArray(),
null); null);


if (!_persistent)
if (!tape.Persistent)
{ {
// Keep track of watched variables before setting tape to None // Keep track of watched variables before setting tape to None
_watched_variables = _tape.WatchedVariables();
_tape = null;
// _watched_variables = _tape.WatchedVariables();
} }


return results; return results;
@@ -139,18 +114,20 @@ namespace Tensorflow.Gradients
/// <summary> /// <summary>
/// Temporarily stops recording operations on this tape. /// Temporarily stops recording operations on this tape.
/// </summary> /// </summary>
public void stop_recording()
public ITape stop_recording()
{ {
_pop_tape();
var tape = _tape;
if (!tape.Persistent)
tape = PopTape();
return tape;
} }


public Stack<ITape> GetTapeSet()
=> _tapeSet;

public void Dispose() public void Dispose()
{ {
if (_recording)
_pop_tape();

if (_created_eagerly)
tf.Context.end_step();
_tapeSet.Clear();
} }
} }
} }

+ 9
- 9
src/TensorFlowNET.Core/Gradients/ITape.cs View File

@@ -1,15 +1,15 @@
using System; using System;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.tensorflow;


namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
public interface ITape public interface ITape
{ {
void PopTape(ITape tape);

bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes);

void SetTapeId(int id);
bool ShouldRecord(Tensor[] tensors);
void StartRecord();
void StopRecord();
bool Persistent { get; }
void RecordOperation(string op_type, void RecordOperation(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
@@ -17,13 +17,13 @@ namespace Tensorflow.Gradients


void VariableAccessed(ResourceVariable variable); void VariableAccessed(ResourceVariable variable);


void Watch(long tensor_id);
void Watch(Tensor x);


ResourceVariable[] WatchedVariables(); ResourceVariable[] WatchedVariables();


Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients); Tensor[] output_gradients);
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Gradients/OpTape.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow.Gradients
/// <typeparam name="BackwardFunction"></typeparam> /// <typeparam name="BackwardFunction"></typeparam>
/// <typeparam name="TapeTensor"></typeparam> /// <typeparam name="TapeTensor"></typeparam>
public class OpTape<BackwardFunction, TapeTensor> : public class OpTape<BackwardFunction, TapeTensor> :
UnorderedMap<long, OpTapeEntry<BackwardFunction, TapeTensor>>
UnorderedMap<Tensor, OpTapeEntry<BackwardFunction, TapeTensor>>
{ {


} }


+ 5
- 3
src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs View File

@@ -1,4 +1,6 @@
namespace Tensorflow.Gradients
using System.Linq;

namespace Tensorflow.Gradients
{ {
/// <summary> /// <summary>
/// Represents an entry in the tape. /// Represents an entry in the tape.
@@ -9,9 +11,9 @@
{ {
public string op_type { get; set; } public string op_type { get; set; }
public TapeTensor[] output_tensor_info { get; set; } public TapeTensor[] output_tensor_info { get; set; }
public long[] input_tensor_id { get; set; }
public Tensor[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; } public BackwardFunction backward_function { get; set; }
public override string ToString() public override string ToString()
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}";
} }
} }

+ 18
- 18
src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs View File

@@ -11,17 +11,17 @@ namespace Tensorflow.Gradients
int kMinAggregateCount = 4; int kMinAggregateCount = 4;
int kMinAggregateBytes = 128 * 1024 * 1024; int kMinAggregateBytes = 128 * 1024 * 1024;


public Tensor[] ComputeGradient(long[] target_tensor_ids,
long[] source_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
public Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
Tensor[] source_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients) Tensor[] output_gradients)
{ {
var result = new List<Tensor>(source_tensor_ids.Length); var result = new List<Tensor>(source_tensor_ids.Length);
var sources_set = new UnorderedSet<long>(source_tensor_ids);
var gradients_size = new UnorderedMap<long, long>();
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids);
var gradients_size = new UnorderedMap<Tensor, long>();


var state = PrepareBackprop( var state = PrepareBackprop(
target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_);
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
output_gradients, output_gradients,
@@ -51,7 +51,7 @@ namespace Tensorflow.Gradients
var zero_indices = new List<int>(); var zero_indices = new List<int>();
for (int i = 0; i < trace.output_tensor_info.Length; ++i) for (int i = 0; i < trace.output_tensor_info.Length; ++i)
{ {
var id = trace.output_tensor_info[i].GetID();
var id = trace.output_tensor_info[i].GetTensor();
if (!gradients.find(id, out var grad_it)) if (!gradients.find(id, out var grad_it))
{ {
if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) && if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) &&
@@ -96,7 +96,7 @@ namespace Tensorflow.Gradients


if (in_gradients.Count() != trace.input_tensor_id.Count()) if (in_gradients.Count() != trace.input_tensor_id.Count())
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
if (!persistent_)
if (!_persistent)
{ {
// trace.backward_function_deleter(trace.backward_function); // trace.backward_function_deleter(trace.backward_function);
} }
@@ -147,7 +147,7 @@ namespace Tensorflow.Gradients
} }


var op_id = tape_it; var op_id = tape_it;
if (op_id == -1)
if (op_id == null)
continue; continue;


if (state.op_missing_tensor.find(op_id, out var missing_it)) if (state.op_missing_tensor.find(op_id, out var missing_it))
@@ -162,7 +162,7 @@ namespace Tensorflow.Gradients
if (state.op_tape.Count > 0) if (state.op_tape.Count > 0)
throw new RuntimeError("Invalid tape state."); throw new RuntimeError("Invalid tape state.");


var used_gradient_ids = new List<long>(source_tensor_ids.Length);
var used_gradient_ids = new List<Tensor>(source_tensor_ids.Length);
foreach (var id in source_tensor_ids) foreach (var id in source_tensor_ids)
{ {
if (!gradients.find(id, out var grad_it)) if (!gradients.find(id, out var grad_it))
@@ -203,19 +203,19 @@ namespace Tensorflow.Gradients
return m; return m;
} }


UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids,
UnorderedMap<long, TapeTensor> sources_that_are_targets,
UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids,
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets,
Tensor[] output_gradients, Tensor[] output_gradients,
TensorTape tensor_tape, TensorTape tensor_tape,
OpTape<BackwardFunction, TapeTensor> op_tape) OpTape<BackwardFunction, TapeTensor> op_tape)
{ {
var result = new UnorderedMapEnumerable<long, List<Tensor>>();
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>();
for (int i = 0; i < target_tensor_ids.Length; ++i) for (int i = 0; i < target_tensor_ids.Length; ++i)
{ {
var id = target_tensor_ids[i]; var id = target_tensor_ids[i];
if (output_gradients.Length == 0 || output_gradients[i] == null) if (output_gradients.Length == 0 || output_gradients[i] == null)
{ {
if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1)
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null)
{ {
if (!op_tape.find(tensor_tape[id], out var op_it)) if (!op_tape.find(tensor_tape[id], out var op_it))
throw new RuntimeError("Internal state of the gradient tape is invalid: " + throw new RuntimeError("Internal state of the gradient tape is invalid: " +
@@ -223,7 +223,7 @@ namespace Tensorflow.Gradients
bool found = false; bool found = false;
for (int j = 0; j < op_it.output_tensor_info.Length; ++j) for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
{ {
if (op_it.output_tensor_info[j].GetID() == id)
if (op_it.output_tensor_info[j].GetTensor() == id)
{ {
found = true; found = true;
var ones = op_it.output_tensor_info[j].OnesLike(); var ones = op_it.output_tensor_info[j].OnesLike();
@@ -253,10 +253,10 @@ namespace Tensorflow.Gradients
return result; return result;
} }


Queue<long> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape,
UnorderedMap<long, long> op_missing_tensor)
Queue<Tensor> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape,
UnorderedMap<Tensor, long> op_missing_tensor)
{ {
var result = new Queue<long>();
var result = new Queue<Tensor>();
foreach (var op_entry in op_tape) foreach (var op_entry in op_tape)
{ {
if (!op_missing_tensor.find(op_entry.Key)) if (!op_missing_tensor.find(op_entry.Key))


+ 5
- 5
src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs View File

@@ -6,14 +6,14 @@ namespace Tensorflow.Gradients
{ {
public partial class Tape public partial class Tape
{ {
public BackpropInitialState PrepareBackprop(long[] target,
public BackpropInitialState PrepareBackprop(Tensor[] target,
TensorTape tensor_tape, TensorTape tensor_tape,
OpTape<BackwardFunction, TapeTensor> op_tape, OpTape<BackwardFunction, TapeTensor> op_tape,
UnorderedSet<long> sources_set,
UnorderedSet<Tensor> sources_set,
bool persistent_tape) bool persistent_tape)
{ {
BackpropInitialState result = new BackpropInitialState(); BackpropInitialState result = new BackpropInitialState();
var tensor_stack = new Queue<long>(target);
var tensor_stack = new Queue<Tensor>(target);
while (tensor_stack.Count > 0) while (tensor_stack.Count > 0)
{ {
var tensor_id = tensor_stack.Dequeue(); var tensor_id = tensor_stack.Dequeue();
@@ -21,7 +21,7 @@ namespace Tensorflow.Gradients
if (!tensor_tape.find(tensor_id, out var op_id)) if (!tensor_tape.find(tensor_id, out var op_id))
continue; continue;


if (op_id == -1 ||
if (op_id == null ||
!op_tape.find(op_id, out var op_it) || !op_tape.find(op_id, out var op_it) ||
result.op_tape.find(op_id, out var result_op_it)) result.op_tape.find(op_id, out var result_op_it))
continue; continue;
@@ -46,7 +46,7 @@ namespace Tensorflow.Gradients


foreach (var pair in result.tensor_usage_counts) foreach (var pair in result.tensor_usage_counts)
{ {
if (tensor_tape.find(pair.Key, out var it) && it != -1)
if (tensor_tape.find(pair.Key, out var it) && it != null)
result.op_missing_tensor[it] += 1; result.op_missing_tensor[it] += 1;
} }




+ 9
- 19
src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs View File

@@ -4,49 +4,39 @@ using Tensorflow.Util;
using static Tensorflow.tensorflow; using static Tensorflow.tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Linq; using System.Linq;
using Tensorflow.Eager;


namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
public partial class Tape public partial class Tape
{ {
long next_op_id_ = 0; long next_op_id_ = 0;
UnorderedMap<long, long> tensor_usage_;
UnorderedMap<Tensor, long> tensor_usage_;


public void RecordOperation(string op_type, public void RecordOperation(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
Func<BackwardFunction> backward_function_getter) Func<BackwardFunction> backward_function_getter)
{ {
var input_ids = input_tensors.Select(x => x.Id).ToArray();
var input_dtypes = input_tensors.Select(x => x.dtype).ToArray();

if (!ShouldRecord(input_ids, input_dtypes))
{
if (!ShouldRecord(input_tensors))
return; return;
}


long op_id = next_op_id_++;
var ids = new List<long>(input_ids.Length);
foreach (var i in input_ids)
{
var op_id = new EagerTensor(next_op_id_++);
foreach (var i in input_tensors)
tensor_usage_[i]++; tensor_usage_[i]++;
ids.Add(i);
}


var tensors = new List<TapeTensor>(output_tensors.Length);
foreach (var o in output_tensors) foreach (var o in output_tensors)
{ {
tensor_tape_[o.GetID()] = op_id;
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
tensor_usage_[o.GetID()] = 1;
tensors.Add(o);
tensor_tape_[o.GetTensor()] = op_id;
tensor_usage_[o.GetTensor()] = 1;
} }


op_tape_[op_id] = new OpTapeEntry<BackwardFunction, TapeTensor> op_tape_[op_id] = new OpTapeEntry<BackwardFunction, TapeTensor>
{ {
op_type = op_type, op_type = op_type,
output_tensor_info = tensors.ToArray(),
input_tensor_id = ids.ToArray(),
output_tensor_info = output_tensors,
input_tensor_id = input_tensors,
backward_function = backward_function_getter() backward_function = backward_function_getter()
}; };
} }


+ 44
- 42
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -1,57 +1,56 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.tensorflow;


namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
public partial class Tape : ITape public partial class Tape : ITape
{ {
int nesting_id;
static int tape_nesting_id_counter = 0;
bool persistent_;
bool watch_accessed_variables;
int _id;
// static int tape_nesting_id_counter = 0;
bool _persistent;
public bool Persistent => _persistent;
bool _recording;
bool _created_eagerly;
TensorTape tensor_tape_; TensorTape tensor_tape_;
OpTape<BackwardFunction, TapeTensor> op_tape_; OpTape<BackwardFunction, TapeTensor> op_tape_;
/// <summary> /// <summary>
/// A deque-backed stack, whose element references are not invalidated by /// A deque-backed stack, whose element references are not invalidated by
/// pushes and pops at the back. /// pushes and pops at the back.
/// </summary> /// </summary>
Stack<AccumulatorCallState> call_state_;
// Stack<AccumulatorCallState> call_state_;


public Tape(bool persistent, bool watch_accessed_variables) public Tape(bool persistent, bool watch_accessed_variables)
{ {
this.persistent_ = persistent;
this.watch_accessed_variables = watch_accessed_variables;

_persistent = persistent;
_created_eagerly = tf.Context.executing_eagerly();
tensor_tape_ = new TensorTape(); tensor_tape_ = new TensorTape();
op_tape_ = new OpTape<BackwardFunction, TapeTensor>(); op_tape_ = new OpTape<BackwardFunction, TapeTensor>();
tensor_usage_ = new UnorderedMap<long, long>();
nesting_id = ++tape_nesting_id_counter;
tf.GetTapeSet().Add(this);
tensor_usage_ = new UnorderedMap<Tensor, long>();
if(_created_eagerly)
tf.Context.start_step();
// nesting_id = ++tape_nesting_id_counter;
} }


/// <summary> /// <summary>
/// Marks this tensor to be watched by the given tape. /// Marks this tensor to be watched by the given tape.
/// </summary> /// </summary>
/// <param name="x"></param> /// <param name="x"></param>
public void Watch(long tensor_id)
public void Watch(Tensor x)
{ {
if (!CouldBackprop())
return;

tf.Logger.Debug($"Watch tensor_id={tensor_id}");
tensor_tape_.emplace(tensor_id, -1);
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}");
tensor_tape_.emplace(x, null);
} }


public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes)
public bool ShouldRecord(Tensor[] tensors)
{ {
for (int i = 0; i < tensor_ids.Length; ++i)
var dtypes = tensors.Select(x => x.dtype).ToArray();
for (int i = 0; i < tensors.Length; ++i)
{ {
if (tensor_tape_.find(tensor_ids[i]))
if (tensor_tape_.find(tensors[i]))
{ {
if (IsDtypeTrainable(dtypes[i])) if (IsDtypeTrainable(dtypes[i]))
return true; return true;
@@ -60,18 +59,9 @@ namespace Tensorflow.Gradients
return false; return false;
} }


/// <summary>
/// Pops the given tape in the stack.
/// </summary>
/// <param name="tape"></param>
public void PopTape(ITape tape)
{
tf.GetTapeSet().Remove(tape);
}

public void VariableAccessed(ResourceVariable variable) public void VariableAccessed(ResourceVariable variable)
{ {
Watch(variable.Handle.Id);
Watch(variable.Handle);
} }


public ResourceVariable[] WatchedVariables() public ResourceVariable[] WatchedVariables()
@@ -97,17 +87,29 @@ namespace Tensorflow.Gradients
} }
} }


bool CouldForwardprop()
=> HasAccumulator();
public void StartRecord()
{
if (_recording)
throw new ValueError("Tape is still recording, This can happen if you try to " +
"re-enter an already-active tape.");
_recording = true;
}


bool CouldBackprop()
=> HasGradientTape();
public void StopRecord()
{
if (!_recording)
throw new ValueError("Tape is not recording.");
if (_created_eagerly)
tf.Context.end_step();
_recording = false;
}


bool HasAccumulator()
//return !GetAccumulatorSet()->empty();
=> false;
public void SetTapeId(int id)
{
_id = id;
}


bool HasGradientTape()
=> tf.GetTapeSet().Count > 0;
public override string ToString()
=> $"Tape {_id} {(_recording ? "Recording" : "Stopped")}";
} }
} }

+ 8
- 8
src/TensorFlowNET.Core/Gradients/TapeTensor.cs View File

@@ -4,18 +4,18 @@ namespace Tensorflow.Gradients
{ {
public class TapeTensor public class TapeTensor
{ {
long id;
TF_DataType dtype;
Shape shape;
Tensor tensor;
long id => tensor.Id;
TF_DataType dtype => tensor.dtype;
Shape shape => tensor.shape;


public TapeTensor(long id, TF_DataType dtype, Shape shape)
public TapeTensor(Tensor tensor)
{ {
this.id = id;
this.dtype = dtype;
this.shape = shape;
this.tensor = tensor;
} }


public long GetID() => id;
public long GetID() => tensor.Id;
public Tensor GetTensor() => tensor;


public Tensor ZerosLike() public Tensor ZerosLike()
=> tf.zeros(shape: shape, dtype: dtype); => tf.zeros(shape: shape, dtype: dtype);


+ 2
- 2
src/TensorFlowNET.Core/Gradients/TensorTape.cs View File

@@ -3,11 +3,11 @@
namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
/// <summary> /// <summary>
/// Map from tensor_id to internally-defined operation-id of the operation which
/// Map from tensor to internally-defined operation-id of the operation which
/// produced this tensor. A value of -1 means that the tensor was directly /// produced this tensor. A value of -1 means that the tensor was directly
/// watched and not the result of any operation in the tape. /// watched and not the result of any operation in the tape.
/// </summary> /// </summary>
public class TensorTape : UnorderedMap<long, long>
public class TensorTape : UnorderedMap<Tensor, Tensor>
{ {


} }


+ 1
- 11
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -543,7 +543,7 @@ namespace Tensorflow
{ {
if (_IsBackpropagatable(output)) if (_IsBackpropagatable(output))
{ {
var c = _Consumers(output, func_graphs).ToList();
var c = output.consumers().ToList();
c.ForEach(x => queue.Enqueue(x)); c.ForEach(x => queue.Enqueue(x));
} }
} }
@@ -551,16 +551,6 @@ namespace Tensorflow
} }
} }


/// <summary>
/// Returns the consumers of t, crossing closure boundaries where necessary.
/// </summary>
/// <param name="t"></param>
/// <param name="func_graphs"></param>
private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs)
{
return t.consumers();
}

private static bool _IsBackpropagatable(Tensor tensor) private static bool _IsBackpropagatable(Tensor tensor)
{ {
if (_IsTrainable(tensor)) if (_IsTrainable(tensor))


+ 10
- 0
src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs View File

@@ -12,6 +12,7 @@ namespace Tensorflow.NumPy
{ {
TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data), TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data),
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data), TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data),
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data), TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data),
_ => throw new NotImplementedException("") _ => throw new NotImplementedException("")
}; };
@@ -34,6 +35,15 @@ namespace Tensorflow.NumPy
_ => throw new NotImplementedException("") _ => throw new NotImplementedException("")
}; };


static T Scalar<T>(int input)
=> Type.GetTypeCode(typeof(T)) switch
{
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
};

static T Scalar<T>(long input) static T Scalar<T>(long input)
=> Type.GetTypeCode(typeof(T)) switch => Type.GetTypeCode(typeof(T)) switch
{ {


+ 1
- 0
src/TensorFlowNET.Core/Operations/Operation.Input.cs View File

@@ -98,6 +98,7 @@ namespace Tensorflow
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i; var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i;
control_inputs[i] = new Operation(*(IntPtr*)handle); control_inputs[i] = new Operation(*(IntPtr*)handle);
} }
Marshal.FreeHGlobal(control_input_handle);
} }


return control_inputs; return control_inputs;


+ 2
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -66,7 +66,7 @@ namespace Tensorflow
var inputptr = (TF_Input*)handle; var inputptr = (TF_Input*)handle;
for (int i = 0; i < num; i++) for (int i = 0; i < num; i++)
consumers[i] = *(inputptr + i); consumers[i] = *(inputptr + i);
Marshal.FreeHGlobal(handle);
return consumers; return consumers;
} }


@@ -83,6 +83,7 @@ namespace Tensorflow
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i;
control_outputs[i] = new Operation(*(IntPtr*)handle); control_outputs[i] = new Operation(*(IntPtr*)handle);
} }
Marshal.FreeHGlobal(control_output_handle);
} }


return control_outputs; return control_outputs;


+ 1
- 1
src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper));
} }
} }
Marshal.FreeHGlobal(handle);
return consumers; return consumers;
} }
} }

+ 2
- 37
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow
{ {
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients);


public partial class tensorflow : ITensorFlowObject
public partial class tensorflow
{ {
public TF_DataType byte8 = TF_DataType.TF_UINT8; public TF_DataType byte8 = TF_DataType.TF_UINT8;
public TF_DataType int8 = TF_DataType.TF_INT8; public TF_DataType int8 = TF_DataType.TF_INT8;
@@ -64,6 +64,7 @@ namespace Tensorflow


private void InitGradientEnvironment() private void InitGradientEnvironment()
{ {
_tapeSet = new GradientTape();
ops.RegisterFromAssembly(); ops.RegisterFromAssembly();
} }


@@ -106,41 +107,5 @@ namespace Tensorflow
{ {
return new Session(null, config).as_default(); return new Session(null, config).as_default();
} }

List<ITape> tape_set;
public List<ITape> GetTapeSet()
{
if (tape_set == null)
{
tape_set = new List<ITape>();
}

return tape_set;
}

public void __init__()
{

}

public void __enter__()
{

}

public void __exit__()
{

}

public void __del__()
{

}

public void Dispose()
{

}
} }
} }

+ 5
- 10
test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs View File

@@ -16,9 +16,9 @@ namespace TensorFlowNET.UnitTest.Gradient
{ {
// Calcute the gradient of w * w // Calcute the gradient of w * w
// by Automatic Differentiation in Eager mode // by Automatic Differentiation in Eager mode
// in tensorflow.net 2.x that is in development intensively
var w = tf.constant(1.5f); var w = tf.constant(1.5f);
using var tape = tf.GradientTape(); using var tape = tf.GradientTape();
// w is defined before tape is recording
tape.watch(w); tape.watch(w);
var loss = w * w; var loss = w * w;
var grad = tape.gradient(loss, w); var grad = tape.gradient(loss, w);
@@ -56,8 +56,6 @@ namespace TensorFlowNET.UnitTest.Gradient
} }
} }



[Ignore]
[TestMethod] [TestMethod]
public void SquaredDifference_1D() public void SquaredDifference_1D()
{ {
@@ -66,14 +64,15 @@ namespace TensorFlowNET.UnitTest.Gradient
// Expected is 2*(abs(x1-x2)) // Expected is 2*(abs(x1-x2))
Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 }); Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 }); Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
float[] expected = new float[] {
float[] expected = new float[]
{
(29-1) * 2, (29-1) * 2,
(27-3) * 2, (27-3) * 2,
(23-5) * 2, (23-5) * 2,
(7-21) * 2, (7-21) * 2,
(11-19) * 2, (11-19) * 2,
(13-17) * 2 (13-17) * 2
};
};


// Sanity check // Sanity check
using (var tape = tf.GradientTape()) using (var tape = tf.GradientTape())
@@ -100,7 +99,7 @@ namespace TensorFlowNET.UnitTest.Gradient




/// <summary> /// <summary>
/// Calcute the gradient of w * w * w
/// Calcute the higher derivative gradient of w * w * w
/// 高阶梯度 /// 高阶梯度
/// </summary> /// </summary>
[TestMethod] [TestMethod]
@@ -110,10 +109,8 @@ namespace TensorFlowNET.UnitTest.Gradient
using var tape1 = tf.GradientTape(); using var tape1 = tf.GradientTape();
using var tape2 = tf.GradientTape(); using var tape2 = tf.GradientTape();
var y = x * x * x; var y = x * x * x;
tape2.Dispose();
var dy_dx = tape2.gradient(y, x); var dy_dx = tape2.gradient(y, x);
Assert.AreEqual((float)dy_dx, 3.0f); Assert.AreEqual((float)dy_dx, 3.0f);
tape1.Dispose();
var d2y_d2x = tape1.gradient(dy_dx, x); var d2y_d2x = tape1.gradient(dy_dx, x);
Assert.AreEqual((float)d2y_d2x, 6.0f); Assert.AreEqual((float)d2y_d2x, 6.0f);
} }
@@ -140,8 +137,6 @@ namespace TensorFlowNET.UnitTest.Gradient
tape.watch(x); tape.watch(x);
var y = tf.reduce_sum(x); var y = tf.reduce_sum(x);
var z = tf.multiply(y, y); var z = tf.multiply(y, y);
tape.Dispose();

var dz_dx = tape.gradient(z, x); var dz_dx = tape.gradient(z, x);


var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };


Loading…
Cancel
Save