diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs
index b5724aaa..d722cb14 100644
--- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs
@@ -14,22 +14,32 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
using Tensorflow.Gradients;
namespace Tensorflow
{
public partial class tensorflow
{
+ GradientTape _tapeSet;
+
///
/// Record operations for automatic differentiation.
///
///
///
- ///
+ /// Tape set
public GradientTape GradientTape(bool persistent = false,
bool watch_accessed_variables = true)
- => new GradientTape(persistent: persistent,
+ {
+ var tape = _tapeSet.PushTape(persistent: persistent,
watch_accessed_variables: watch_accessed_variables);
+ tape.StartRecord();
+ return _tapeSet;
+ }
+
+ public Stack GetTapeSet()
+ => _tapeSet.GetTapeSet();
public Tensor[] gradients(Tensor[] ys,
Tensor[] xs,
diff --git a/src/TensorFlowNET.Core/Binding.cs b/src/TensorFlowNET.Core/Binding.cs
index a257dd6c..004f35a3 100644
--- a/src/TensorFlowNET.Core/Binding.cs
+++ b/src/TensorFlowNET.Core/Binding.cs
@@ -4,7 +4,7 @@ namespace Tensorflow
{
public static partial class Binding
{
- public static tensorflow tf { get; } = New();
+ public static tensorflow tf { get; } = new tensorflow();
///
/// Alias to null, similar to python's None.
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
index a90f673c..c4bce84f 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs
@@ -11,5 +11,19 @@ namespace Tensorflow.Eager
{
return HasGradientTape();
}
+
+ private bool ShouldRecord(Tensor[] inputs)
+ {
+ bool should_record = false;
+ foreach (var tape in tf.GetTapeSet())
+ {
+ if (tape.ShouldRecord(inputs))
+ {
+ should_record = true;
+ break;
+ }
+ }
+ return should_record;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
index d072306a..5682f328 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
@@ -2,7 +2,6 @@
using System.Linq;
using Tensorflow.Gradients;
using static Tensorflow.Binding;
-using static Tensorflow.tensorflow;
namespace Tensorflow.Eager
{
@@ -14,18 +13,7 @@ namespace Tensorflow.Eager
Tensor[] results,
Func getBackwardFunction = null)
{
- var input_ids = MakeTensorIDList(inputs);
- var input_dtypes = MakeTensorDtypeList(inputs);
-
- bool should_record = false;
- foreach (var tape in tf.GetTapeSet())
- {
- if (tape.ShouldRecord(input_ids, input_dtypes))
- {
- should_record = true;
- break;
- }
- }
+ bool should_record = ShouldRecord(inputs);
if (!should_record)
{
@@ -43,9 +31,6 @@ namespace Tensorflow.Eager
tf.Logger.Debug($"RecordGradient: op_name={op_name}");
Tensor[] op_outputs;
-#pragma warning disable CS0219 // Variable is assigned but its value is never used
- bool op_outputs_tuple_created = false;
-#pragma warning restore CS0219 // Variable is assigned but its value is never used
var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name);
if (unused_output_indices != null)
{
@@ -53,7 +38,6 @@ namespace Tensorflow.Eager
op_outputs = new Tensor[0];
else
{
- op_outputs_tuple_created = true;
// op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices);
}
}
@@ -61,9 +45,6 @@ namespace Tensorflow.Eager
op_outputs = results;
Tensor[] op_inputs;
-#pragma warning disable CS0219 // Variable is assigned but its value is never used
- bool op_inputs_tuple_created = false;
-#pragma warning restore CS0219 // Variable is assigned but its value is never used
var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name);
if (unused_input_indices != null)
{
@@ -71,7 +52,6 @@ namespace Tensorflow.Eager
op_inputs = new Tensor[0];
else
{
- op_inputs_tuple_created = true;
// op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
}
}
@@ -125,11 +105,6 @@ namespace Tensorflow.Eager
return HasGradientTape();
}
- long[] MakeTensorIDList(Tensor[] tensors)
- {
- return tensors.Select(x => x.Id).ToArray();
- }
-
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
{
return tensors.Select(x => x.dtype).ToArray();
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
index 5a491dd7..1626de22 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
@@ -310,7 +310,7 @@ namespace Tensorflow.Eager
for (int i = 0; i < num_values; ++i)
{
dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim);
- tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim * sizeof(long));
+ tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long));
}
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle);
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
index 3f15ac55..c96d09e5 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
@@ -14,18 +14,16 @@ namespace Tensorflow.Eager
Tensor[] sources,
Tensor[] output_gradients)
{
- var target_vec = MakeTensorIDList(target);
- var sources_vec = MakeTensorIDList(sources);
+ var target_vec = target;
+ var sources_vec = sources;
var sources_set = sources_vec;
var seq_array = target;
- var source_tensors_that_are_targets = new UnorderedMap();
+ var source_tensors_that_are_targets = new UnorderedMap();
for (int i = 0; i < target.Length; ++i)
{
- var target_id = target_vec[i];
- var tensor = seq_array[i];
- source_tensors_that_are_targets.Add(target_id, TapeTensorFromTensor(tensor));
+ source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
}
if (output_gradients != null)
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
index e70a513f..861f26fc 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs
@@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using Tensorflow.Gradients;
-using static Tensorflow.tensorflow;
namespace Tensorflow.Eager
{
@@ -12,16 +12,13 @@ namespace Tensorflow.Eager
Tensor[] output_tensors,
Func backward_function_getter)
{
- var output_info = new List();
+ var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();
- if (!TapeTensorsFromTensorSequence(output_tensors, output_info))
- return false;
-
- if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(),
+ if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
backward_function_getter))
return false;
- TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(),
+ TapeSetRecordBackprop(op_type, input_tensors, output_info,
backward_function_getter);
return true;
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorFromTensor.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorFromTensor.cs
deleted file mode 100644
index 4dabc9a1..00000000
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorFromTensor.cs
+++ /dev/null
@@ -1,12 +0,0 @@
-using Tensorflow.Gradients;
-
-namespace Tensorflow.Eager
-{
- public partial class EagerRunner
- {
- TapeTensor TapeTensorFromTensor(Tensor tensor)
- {
- return new TapeTensor(tensor.Id, tensor.dtype, tensor.shape);
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorsFromTensorSequence.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorsFromTensorSequence.cs
deleted file mode 100644
index 34998c68..00000000
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeTensorsFromTensorSequence.cs
+++ /dev/null
@@ -1,18 +0,0 @@
-using System.Collections.Generic;
-using Tensorflow.Gradients;
-
-namespace Tensorflow.Eager
-{
- public partial class EagerRunner
- {
- bool TapeTensorsFromTensorSequence(Tensor[] output_seq,
- List output_info)
- {
- for (var i = 0; i < output_seq.Length; ++i)
- {
- output_info.Add(TapeTensorFromTensor(output_seq[i]));
- }
- return true;
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
index 0e775d46..06ae7ce7 100644
--- a/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
+++ b/src/TensorFlowNET.Core/Gradients/BackpropInitialState.cs
@@ -7,21 +7,21 @@ namespace Tensorflow.Gradients
{
public OpTape op_tape { get; set; }
///
- /// Map from tensor ID to how many references still exist for this tensor in
+ /// Map from tensor to how many references still exist for this tensor in
/// the tape.
///
- public UnorderedMap tensor_usage_counts { get; set; }
+ public UnorderedMap tensor_usage_counts { get; set; }
///
/// Maps from op ID to how many output tensors of this op still need to have
/// their gradients computed.
///
- public UnorderedMap op_missing_tensor { get; set; }
+ public UnorderedMap op_missing_tensor { get; set; }
public BackpropInitialState()
{
op_tape = new OpTape();
- tensor_usage_counts = new UnorderedMap();
- op_missing_tensor = new UnorderedMap();
+ tensor_usage_counts = new UnorderedMap();
+ op_missing_tensor = new UnorderedMap();
}
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs
index 0987a102..31517e58 100644
--- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs
+++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs
@@ -6,6 +6,7 @@ using static Tensorflow.Binding;
namespace Tensorflow.Gradients
{
///
+ /// Gradient Tape Set
/// Record operations for automatic differentiation.
///
/// Operations are recorded if they are executed within this context manager and
@@ -18,54 +19,35 @@ namespace Tensorflow.Gradients
///
public class GradientTape : IDisposable
{
- bool _recording;
- public bool Recording => _recording;
- bool _persistent;
- bool _watch_accessed_variables;
- ResourceVariable[] _watched_variables;
- bool _created_eagerly;
- ITape _tape;
-
- public GradientTape(bool persistent = false,
- bool watch_accessed_variables = true)
+ int _nextTapeId;
+ ITape _tape => _tapeSet.Peek();
+ Stack _tapeSet;
+
+ public GradientTape()
{
- _persistent = persistent;
- _watch_accessed_variables = watch_accessed_variables;
- _created_eagerly = tf.Context.executing_eagerly();
- _recording = false;
- _created_eagerly = tf.Context.executing_eagerly();
- // Enters a context inside which operations are recorded on this tape.
- if (_created_eagerly)
- {
- tf.Context.ensure_initialized();
- tf.Context.start_step();
- }
- _push_tape();
+ _tapeSet = new Stack();
}
///
- /// Pushes a new tape onto the tape stack.
+ /// New tape onto the tape stack.
///
- private void _push_tape()
+ public ITape PushTape(bool persistent = false,
+ bool watch_accessed_variables = true)
{
- if (_recording)
- throw new ValueError("Tape is still recording, This can happen if you try to " +
- "re-enter an already-active tape.");
-
- if (_tape == null)
- _tape = new Tape(_persistent, _watch_accessed_variables);
- else
- tf.GetTapeSet().Add(_tape);
+ // Enters a context inside which operations are recorded on this tape.
+ if (tf.Context.executing_eagerly())
+ tf.Context.ensure_initialized();
- _recording = true;
+ var tape = new Tape(persistent, watch_accessed_variables);
+ tape.SetTapeId(_nextTapeId++);
+ _tapeSet.Push(tape);
+ return tape;
}
- private void _pop_tape()
+ ITape PopTape()
{
- if (!_recording)
- throw new ValueError("Tape is not recording.");
- _tape.PopTape(_tape);
- _recording = false;
+ _tape.StopRecord();
+ return _tapeSet.Pop();
}
///
@@ -74,7 +56,9 @@ namespace Tensorflow.Gradients
///
public void watch(Tensor x)
{
- _tape.Watch(x.Id);
+ if (!_tapeSet.Any())
+ return;
+ _tape.Watch(x);
}
///
@@ -85,13 +69,9 @@ namespace Tensorflow.Gradients
///
public Tensor gradient(Tensor target, Tensor source)
{
- if (_recording)
- {
- if (!_persistent)
- _pop_tape();
- }
+ ITape tape = stop_recording();
- var results = tf.Runner.TFE_TapeGradient(_tape,
+ var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
new[] { source },
null);
@@ -115,22 +95,17 @@ namespace Tensorflow.Gradients
public Tensor[] gradient(Tensor target, IEnumerable sources)
{
- if (_recording)
- {
- if (!_persistent)
- _pop_tape();
- }
+ var tape = stop_recording();
- var results = tf.Runner.TFE_TapeGradient(_tape,
+ var results = tf.Runner.TFE_TapeGradient(tape,
new[] { target },
sources.Select(x => x.Handle).ToArray(),
null);
- if (!_persistent)
+ if (!tape.Persistent)
{
// Keep track of watched variables before setting tape to None
- _watched_variables = _tape.WatchedVariables();
- _tape = null;
+ // _watched_variables = _tape.WatchedVariables();
}
return results;
@@ -139,18 +114,20 @@ namespace Tensorflow.Gradients
///
/// Temporarily stops recording operations on this tape.
///
- public void stop_recording()
+ public ITape stop_recording()
{
- _pop_tape();
+ var tape = _tape;
+ if (!tape.Persistent)
+ tape = PopTape();
+ return tape;
}
+ public Stack GetTapeSet()
+ => _tapeSet;
+
public void Dispose()
{
- if (_recording)
- _pop_tape();
-
- if (_created_eagerly)
- tf.Context.end_step();
+ _tapeSet.Clear();
}
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/ITape.cs b/src/TensorFlowNET.Core/Gradients/ITape.cs
index 279ad876..c4e88617 100644
--- a/src/TensorFlowNET.Core/Gradients/ITape.cs
+++ b/src/TensorFlowNET.Core/Gradients/ITape.cs
@@ -1,15 +1,15 @@
using System;
using Tensorflow.Util;
-using static Tensorflow.tensorflow;
namespace Tensorflow.Gradients
{
public interface ITape
{
- void PopTape(ITape tape);
-
- bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes);
-
+ void SetTapeId(int id);
+ bool ShouldRecord(Tensor[] tensors);
+ void StartRecord();
+ void StopRecord();
+ bool Persistent { get; }
void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
@@ -17,13 +17,13 @@ namespace Tensorflow.Gradients
void VariableAccessed(ResourceVariable variable);
- void Watch(long tensor_id);
+ void Watch(Tensor x);
ResourceVariable[] WatchedVariables();
- Tensor[] ComputeGradient(long[] target_tensor_ids,
- long[] source_tensor_ids,
- UnorderedMap sources_that_are_targets,
+ Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
+ Tensor[] source_tensor_ids,
+ UnorderedMap sources_that_are_targets,
Tensor[] output_gradients);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/OpTape.cs b/src/TensorFlowNET.Core/Gradients/OpTape.cs
index 329cdee8..7c79eb5d 100644
--- a/src/TensorFlowNET.Core/Gradients/OpTape.cs
+++ b/src/TensorFlowNET.Core/Gradients/OpTape.cs
@@ -8,7 +8,7 @@ namespace Tensorflow.Gradients
///
///
public class OpTape :
- UnorderedMap>
+ UnorderedMap>
{
}
diff --git a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
index d44ea361..165ef14f 100644
--- a/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
+++ b/src/TensorFlowNET.Core/Gradients/OpTapeEntry.cs
@@ -1,4 +1,6 @@
-namespace Tensorflow.Gradients
+using System.Linq;
+
+namespace Tensorflow.Gradients
{
///
/// Represents an entry in the tape.
@@ -9,9 +11,9 @@
{
public string op_type { get; set; }
public TapeTensor[] output_tensor_info { get; set; }
- public long[] input_tensor_id { get; set; }
+ public Tensor[] input_tensor_id { get; set; }
public BackwardFunction backward_function { get; set; }
public override string ToString()
- => $"{op_type}, inputs: {string.Join(",", input_tensor_id)}";
+ => $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}";
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
index a9d8b10a..70e1a743 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.ComputeGradient.cs
@@ -11,17 +11,17 @@ namespace Tensorflow.Gradients
int kMinAggregateCount = 4;
int kMinAggregateBytes = 128 * 1024 * 1024;
- public Tensor[] ComputeGradient(long[] target_tensor_ids,
- long[] source_tensor_ids,
- UnorderedMap sources_that_are_targets,
+ public Tensor[] ComputeGradient(Tensor[] target_tensor_ids,
+ Tensor[] source_tensor_ids,
+ UnorderedMap sources_that_are_targets,
Tensor[] output_gradients)
{
var result = new List(source_tensor_ids.Length);
- var sources_set = new UnorderedSet(source_tensor_ids);
- var gradients_size = new UnorderedMap();
+ var sources_set = new UnorderedSet(source_tensor_ids);
+ var gradients_size = new UnorderedMap();
var state = PrepareBackprop(
- target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_);
+ target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent);
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor);
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets,
output_gradients,
@@ -51,7 +51,7 @@ namespace Tensorflow.Gradients
var zero_indices = new List();
for (int i = 0; i < trace.output_tensor_info.Length; ++i)
{
- var id = trace.output_tensor_info[i].GetID();
+ var id = trace.output_tensor_info[i].GetTensor();
if (!gradients.find(id, out var grad_it))
{
if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) &&
@@ -96,7 +96,7 @@ namespace Tensorflow.Gradients
if (in_gradients.Count() != trace.input_tensor_id.Count())
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}");
- if (!persistent_)
+ if (!_persistent)
{
// trace.backward_function_deleter(trace.backward_function);
}
@@ -147,7 +147,7 @@ namespace Tensorflow.Gradients
}
var op_id = tape_it;
- if (op_id == -1)
+ if (op_id == null)
continue;
if (state.op_missing_tensor.find(op_id, out var missing_it))
@@ -162,7 +162,7 @@ namespace Tensorflow.Gradients
if (state.op_tape.Count > 0)
throw new RuntimeError("Invalid tape state.");
- var used_gradient_ids = new List(source_tensor_ids.Length);
+ var used_gradient_ids = new List(source_tensor_ids.Length);
foreach (var id in source_tensor_ids)
{
if (!gradients.find(id, out var grad_it))
@@ -203,19 +203,19 @@ namespace Tensorflow.Gradients
return m;
}
- UnorderedMapEnumerable> InitialGradients(long[] target_tensor_ids,
- UnorderedMap sources_that_are_targets,
+ UnorderedMapEnumerable> InitialGradients(Tensor[] target_tensor_ids,
+ UnorderedMap sources_that_are_targets,
Tensor[] output_gradients,
TensorTape tensor_tape,
OpTape op_tape)
{
- var result = new UnorderedMapEnumerable>();
+ var result = new UnorderedMapEnumerable>();
for (int i = 0; i < target_tensor_ids.Length; ++i)
{
var id = target_tensor_ids[i];
if (output_gradients.Length == 0 || output_gradients[i] == null)
{
- if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1)
+ if (tensor_tape.find(id, out var tensor_id) && tensor_id != null)
{
if (!op_tape.find(tensor_tape[id], out var op_it))
throw new RuntimeError("Internal state of the gradient tape is invalid: " +
@@ -223,7 +223,7 @@ namespace Tensorflow.Gradients
bool found = false;
for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
{
- if (op_it.output_tensor_info[j].GetID() == id)
+ if (op_it.output_tensor_info[j].GetTensor() == id)
{
found = true;
var ones = op_it.output_tensor_info[j].OnesLike();
@@ -253,10 +253,10 @@ namespace Tensorflow.Gradients
return result;
}
- Queue InitialStack(OpTape op_tape,
- UnorderedMap op_missing_tensor)
+ Queue InitialStack(OpTape op_tape,
+ UnorderedMap op_missing_tensor)
{
- var result = new Queue();
+ var result = new Queue();
foreach (var op_entry in op_tape)
{
if (!op_missing_tensor.find(op_entry.Key))
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
index ba95fc99..ae81b8d5 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.PrepareBackprop.cs
@@ -6,14 +6,14 @@ namespace Tensorflow.Gradients
{
public partial class Tape
{
- public BackpropInitialState PrepareBackprop(long[] target,
+ public BackpropInitialState PrepareBackprop(Tensor[] target,
TensorTape tensor_tape,
OpTape op_tape,
- UnorderedSet sources_set,
+ UnorderedSet sources_set,
bool persistent_tape)
{
BackpropInitialState result = new BackpropInitialState();
- var tensor_stack = new Queue(target);
+ var tensor_stack = new Queue(target);
while (tensor_stack.Count > 0)
{
var tensor_id = tensor_stack.Dequeue();
@@ -21,7 +21,7 @@ namespace Tensorflow.Gradients
if (!tensor_tape.find(tensor_id, out var op_id))
continue;
- if (op_id == -1 ||
+ if (op_id == null ||
!op_tape.find(op_id, out var op_it) ||
result.op_tape.find(op_id, out var result_op_it))
continue;
@@ -46,7 +46,7 @@ namespace Tensorflow.Gradients
foreach (var pair in result.tensor_usage_counts)
{
- if (tensor_tape.find(pair.Key, out var it) && it != -1)
+ if (tensor_tape.find(pair.Key, out var it) && it != null)
result.op_missing_tensor[it] += 1;
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
index 7b0e51f2..4435c312 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
@@ -4,49 +4,39 @@ using Tensorflow.Util;
using static Tensorflow.tensorflow;
using static Tensorflow.Binding;
using System.Linq;
+using Tensorflow.Eager;
namespace Tensorflow.Gradients
{
public partial class Tape
{
long next_op_id_ = 0;
- UnorderedMap tensor_usage_;
+ UnorderedMap tensor_usage_;
public void RecordOperation(string op_type,
Tensor[] input_tensors,
TapeTensor[] output_tensors,
Func backward_function_getter)
{
- var input_ids = input_tensors.Select(x => x.Id).ToArray();
- var input_dtypes = input_tensors.Select(x => x.dtype).ToArray();
-
- if (!ShouldRecord(input_ids, input_dtypes))
- {
+ if (!ShouldRecord(input_tensors))
return;
- }
- long op_id = next_op_id_++;
- var ids = new List(input_ids.Length);
- foreach (var i in input_ids)
- {
+ var op_id = new EagerTensor(next_op_id_++);
+ foreach (var i in input_tensors)
tensor_usage_[i]++;
- ids.Add(i);
- }
- var tensors = new List(output_tensors.Length);
foreach (var o in output_tensors)
{
- tensor_tape_[o.GetID()] = op_id;
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
- tensor_usage_[o.GetID()] = 1;
- tensors.Add(o);
+ tensor_tape_[o.GetTensor()] = op_id;
+ tensor_usage_[o.GetTensor()] = 1;
}
op_tape_[op_id] = new OpTapeEntry
{
op_type = op_type,
- output_tensor_info = tensors.ToArray(),
- input_tensor_id = ids.ToArray(),
+ output_tensor_info = output_tensors,
+ input_tensor_id = input_tensors,
backward_function = backward_function_getter()
};
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs
index 08cbc1da..35710c14 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.cs
@@ -1,57 +1,56 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using Tensorflow.Util;
using static Tensorflow.Binding;
-using static Tensorflow.tensorflow;
namespace Tensorflow.Gradients
{
public partial class Tape : ITape
{
- int nesting_id;
- static int tape_nesting_id_counter = 0;
- bool persistent_;
- bool watch_accessed_variables;
+ int _id;
+ // static int tape_nesting_id_counter = 0;
+ bool _persistent;
+ public bool Persistent => _persistent;
+ bool _recording;
+ bool _created_eagerly;
TensorTape tensor_tape_;
OpTape op_tape_;
-
+
///
/// A deque-backed stack, whose element references are not invalidated by
/// pushes and pops at the back.
///
- Stack call_state_;
+ // Stack call_state_;
public Tape(bool persistent, bool watch_accessed_variables)
{
- this.persistent_ = persistent;
- this.watch_accessed_variables = watch_accessed_variables;
-
+ _persistent = persistent;
+ _created_eagerly = tf.Context.executing_eagerly();
tensor_tape_ = new TensorTape();
op_tape_ = new OpTape();
- tensor_usage_ = new UnorderedMap();
-
- nesting_id = ++tape_nesting_id_counter;
- tf.GetTapeSet().Add(this);
+ tensor_usage_ = new UnorderedMap();
+ if(_created_eagerly)
+ tf.Context.start_step();
+ // nesting_id = ++tape_nesting_id_counter;
}
///
/// Marks this tensor to be watched by the given tape.
///
///
- public void Watch(long tensor_id)
+ public void Watch(Tensor x)
{
- if (!CouldBackprop())
- return;
-
- tf.Logger.Debug($"Watch tensor_id={tensor_id}");
- tensor_tape_.emplace(tensor_id, -1);
+ tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}");
+ tensor_tape_.emplace(x, null);
}
- public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes)
+ public bool ShouldRecord(Tensor[] tensors)
{
- for (int i = 0; i < tensor_ids.Length; ++i)
+ var dtypes = tensors.Select(x => x.dtype).ToArray();
+ for (int i = 0; i < tensors.Length; ++i)
{
- if (tensor_tape_.find(tensor_ids[i]))
+ if (tensor_tape_.find(tensors[i]))
{
if (IsDtypeTrainable(dtypes[i]))
return true;
@@ -60,18 +59,9 @@ namespace Tensorflow.Gradients
return false;
}
- ///
- /// Pops the given tape in the stack.
- ///
- ///
- public void PopTape(ITape tape)
- {
- tf.GetTapeSet().Remove(tape);
- }
-
public void VariableAccessed(ResourceVariable variable)
{
- Watch(variable.Handle.Id);
+ Watch(variable.Handle);
}
public ResourceVariable[] WatchedVariables()
@@ -97,17 +87,29 @@ namespace Tensorflow.Gradients
}
}
- bool CouldForwardprop()
- => HasAccumulator();
+ public void StartRecord()
+ {
+ if (_recording)
+ throw new ValueError("Tape is still recording, This can happen if you try to " +
+ "re-enter an already-active tape.");
+ _recording = true;
+ }
- bool CouldBackprop()
- => HasGradientTape();
+ public void StopRecord()
+ {
+ if (!_recording)
+ throw new ValueError("Tape is not recording.");
+ if (_created_eagerly)
+ tf.Context.end_step();
+ _recording = false;
+ }
- bool HasAccumulator()
- //return !GetAccumulatorSet()->empty();
- => false;
+ public void SetTapeId(int id)
+ {
+ _id = id;
+ }
- bool HasGradientTape()
- => tf.GetTapeSet().Count > 0;
+ public override string ToString()
+ => $"Tape {_id} {(_recording ? "Recording" : "Stopped")}";
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs
index fe24e1d1..210794d8 100644
--- a/src/TensorFlowNET.Core/Gradients/TapeTensor.cs
+++ b/src/TensorFlowNET.Core/Gradients/TapeTensor.cs
@@ -4,18 +4,18 @@ namespace Tensorflow.Gradients
{
public class TapeTensor
{
- long id;
- TF_DataType dtype;
- Shape shape;
+ Tensor tensor;
+ long id => tensor.Id;
+ TF_DataType dtype => tensor.dtype;
+ Shape shape => tensor.shape;
- public TapeTensor(long id, TF_DataType dtype, Shape shape)
+ public TapeTensor(Tensor tensor)
{
- this.id = id;
- this.dtype = dtype;
- this.shape = shape;
+ this.tensor = tensor;
}
- public long GetID() => id;
+ public long GetID() => tensor.Id;
+ public Tensor GetTensor() => tensor;
public Tensor ZerosLike()
=> tf.zeros(shape: shape, dtype: dtype);
diff --git a/src/TensorFlowNET.Core/Gradients/TensorTape.cs b/src/TensorFlowNET.Core/Gradients/TensorTape.cs
index c2760407..de478bee 100644
--- a/src/TensorFlowNET.Core/Gradients/TensorTape.cs
+++ b/src/TensorFlowNET.Core/Gradients/TensorTape.cs
@@ -3,11 +3,11 @@
namespace Tensorflow.Gradients
{
///
- /// Map from tensor_id to internally-defined operation-id of the operation which
+ /// Map from tensor to internally-defined operation-id of the operation which
/// produced this tensor. A value of -1 means that the tensor was directly
/// watched and not the result of any operation in the tape.
///
- public class TensorTape : UnorderedMap
+ public class TensorTape : UnorderedMap
{
}
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
index 771887be..40a83493 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs
@@ -543,7 +543,7 @@ namespace Tensorflow
{
if (_IsBackpropagatable(output))
{
- var c = _Consumers(output, func_graphs).ToList();
+ var c = output.consumers().ToList();
c.ForEach(x => queue.Enqueue(x));
}
}
@@ -551,16 +551,6 @@ namespace Tensorflow
}
}
- ///
- /// Returns the consumers of t, crossing closure boundaries where necessary.
- ///
- ///
- ///
- private static Operation[] _Consumers(Tensor t, List func_graphs)
- {
- return t.consumers();
- }
-
private static bool _IsBackpropagatable(Tensor tensor)
{
if (_IsTrainable(tensor))
diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
index fbab95a9..2d042a5d 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
@@ -12,6 +12,7 @@ namespace Tensorflow.NumPy
{
TF_DataType.TF_UINT8 => Scalar(*(byte*)nd.data),
TF_DataType.TF_FLOAT => Scalar(*(float*)nd.data),
+ TF_DataType.TF_INT32 => Scalar(*(int*)nd.data),
TF_DataType.TF_INT64 => Scalar(*(long*)nd.data),
_ => throw new NotImplementedException("")
};
@@ -34,6 +35,15 @@ namespace Tensorflow.NumPy
_ => throw new NotImplementedException("")
};
+ static T Scalar(int input)
+ => Type.GetTypeCode(typeof(T)) switch
+ {
+ TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte),
+ TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64),
+ TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
+ _ => throw new NotImplementedException("")
+ };
+
static T Scalar(long input)
=> Type.GetTypeCode(typeof(T)) switch
{
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
index 0b7bad8b..44ac52e1 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs
@@ -98,6 +98,7 @@ namespace Tensorflow
var handle = control_input_handle + Marshal.SizeOf() * i;
control_inputs[i] = new Operation(*(IntPtr*)handle);
}
+ Marshal.FreeHGlobal(control_input_handle);
}
return control_inputs;
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
index 2fd80fb3..b5d6191d 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
@@ -66,7 +66,7 @@ namespace Tensorflow
var inputptr = (TF_Input*)handle;
for (int i = 0; i < num; i++)
consumers[i] = *(inputptr + i);
-
+ Marshal.FreeHGlobal(handle);
return consumers;
}
@@ -83,6 +83,7 @@ namespace Tensorflow
var handle = control_output_handle + Marshal.SizeOf() * i;
control_outputs[i] = new Operation(*(IntPtr*)handle);
}
+ Marshal.FreeHGlobal(control_output_handle);
}
return control_outputs;
diff --git a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs
index f004bc54..4077efa9 100644
--- a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs
+++ b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs
@@ -36,7 +36,7 @@ namespace Tensorflow
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper));
}
}
-
+ Marshal.FreeHGlobal(handle);
return consumers;
}
}
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index acaa6a1a..fd07cc3b 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -25,7 +25,7 @@ namespace Tensorflow
{
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients);
- public partial class tensorflow : ITensorFlowObject
+ public partial class tensorflow
{
public TF_DataType byte8 = TF_DataType.TF_UINT8;
public TF_DataType int8 = TF_DataType.TF_INT8;
@@ -64,6 +64,7 @@ namespace Tensorflow
private void InitGradientEnvironment()
{
+ _tapeSet = new GradientTape();
ops.RegisterFromAssembly();
}
@@ -106,41 +107,5 @@ namespace Tensorflow
{
return new Session(null, config).as_default();
}
-
- List tape_set;
- public List GetTapeSet()
- {
- if (tape_set == null)
- {
- tape_set = new List();
- }
-
- return tape_set;
- }
-
- public void __init__()
- {
-
- }
-
- public void __enter__()
- {
-
- }
-
- public void __exit__()
- {
-
- }
-
- public void __del__()
- {
-
- }
-
- public void Dispose()
- {
-
- }
}
}
diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
index 29913ce4..e41e1d61 100644
--- a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
+++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
@@ -16,9 +16,9 @@ namespace TensorFlowNET.UnitTest.Gradient
{
// Calcute the gradient of w * w
// by Automatic Differentiation in Eager mode
- // in tensorflow.net 2.x that is in development intensively
var w = tf.constant(1.5f);
using var tape = tf.GradientTape();
+ // w is defined before tape is recording
tape.watch(w);
var loss = w * w;
var grad = tape.gradient(loss, w);
@@ -56,8 +56,6 @@ namespace TensorFlowNET.UnitTest.Gradient
}
}
-
- [Ignore]
[TestMethod]
public void SquaredDifference_1D()
{
@@ -66,14 +64,15 @@ namespace TensorFlowNET.UnitTest.Gradient
// Expected is 2*(abs(x1-x2))
Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
- float[] expected = new float[] {
+ float[] expected = new float[]
+ {
(29-1) * 2,
(27-3) * 2,
(23-5) * 2,
(7-21) * 2,
(11-19) * 2,
(13-17) * 2
- };
+ };
// Sanity check
using (var tape = tf.GradientTape())
@@ -100,7 +99,7 @@ namespace TensorFlowNET.UnitTest.Gradient
///
- /// Calcute the gradient of w * w * w
+ /// Calcute the higher derivative gradient of w * w * w
/// 高阶梯度
///
[TestMethod]
@@ -110,10 +109,8 @@ namespace TensorFlowNET.UnitTest.Gradient
using var tape1 = tf.GradientTape();
using var tape2 = tf.GradientTape();
var y = x * x * x;
- tape2.Dispose();
var dy_dx = tape2.gradient(y, x);
Assert.AreEqual((float)dy_dx, 3.0f);
- tape1.Dispose();
var d2y_d2x = tape1.gradient(dy_dx, x);
Assert.AreEqual((float)d2y_d2x, 6.0f);
}
@@ -140,8 +137,6 @@ namespace TensorFlowNET.UnitTest.Gradient
tape.watch(x);
var y = tf.reduce_sum(x);
var z = tf.multiply(y, y);
- tape.Dispose();
-
var dz_dx = tape.gradient(z, x);
var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f };