@@ -14,22 +14,32 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | |||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class tensorflow | public partial class tensorflow | ||||
{ | { | ||||
GradientTape _tapeSet; | |||||
/// <summary> | /// <summary> | ||||
/// Record operations for automatic differentiation. | /// Record operations for automatic differentiation. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="persistent"></param> | /// <param name="persistent"></param> | ||||
/// <param name="watch_accessed_variables"></param> | /// <param name="watch_accessed_variables"></param> | ||||
/// <returns></returns> | |||||
/// <returns>Tape set</returns> | |||||
public GradientTape GradientTape(bool persistent = false, | public GradientTape GradientTape(bool persistent = false, | ||||
bool watch_accessed_variables = true) | bool watch_accessed_variables = true) | ||||
=> new GradientTape(persistent: persistent, | |||||
{ | |||||
var tape = _tapeSet.PushTape(persistent: persistent, | |||||
watch_accessed_variables: watch_accessed_variables); | watch_accessed_variables: watch_accessed_variables); | ||||
tape.StartRecord(); | |||||
return _tapeSet; | |||||
} | |||||
public Stack<ITape> GetTapeSet() | |||||
=> _tapeSet.GetTapeSet(); | |||||
public Tensor[] gradients(Tensor[] ys, | public Tensor[] gradients(Tensor[] ys, | ||||
Tensor[] xs, | Tensor[] xs, | ||||
@@ -4,7 +4,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public static partial class Binding | public static partial class Binding | ||||
{ | { | ||||
public static tensorflow tf { get; } = New<tensorflow>(); | |||||
public static tensorflow tf { get; } = new tensorflow(); | |||||
/// <summary> | /// <summary> | ||||
/// Alias to null, similar to python's None. | /// Alias to null, similar to python's None. | ||||
@@ -11,5 +11,19 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
return HasGradientTape(); | return HasGradientTape(); | ||||
} | } | ||||
private bool ShouldRecord(Tensor[] inputs) | |||||
{ | |||||
bool should_record = false; | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
{ | |||||
if (tape.ShouldRecord(inputs)) | |||||
{ | |||||
should_record = true; | |||||
break; | |||||
} | |||||
} | |||||
return should_record; | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,7 +2,6 @@ | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.tensorflow; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
@@ -14,18 +13,7 @@ namespace Tensorflow.Eager | |||||
Tensor[] results, | Tensor[] results, | ||||
Func<BackwardFunction> getBackwardFunction = null) | Func<BackwardFunction> getBackwardFunction = null) | ||||
{ | { | ||||
var input_ids = MakeTensorIDList(inputs); | |||||
var input_dtypes = MakeTensorDtypeList(inputs); | |||||
bool should_record = false; | |||||
foreach (var tape in tf.GetTapeSet()) | |||||
{ | |||||
if (tape.ShouldRecord(input_ids, input_dtypes)) | |||||
{ | |||||
should_record = true; | |||||
break; | |||||
} | |||||
} | |||||
bool should_record = ShouldRecord(inputs); | |||||
if (!should_record) | if (!should_record) | ||||
{ | { | ||||
@@ -43,9 +31,6 @@ namespace Tensorflow.Eager | |||||
tf.Logger.Debug($"RecordGradient: op_name={op_name}"); | tf.Logger.Debug($"RecordGradient: op_name={op_name}"); | ||||
Tensor[] op_outputs; | Tensor[] op_outputs; | ||||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||||
bool op_outputs_tuple_created = false; | |||||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||||
var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name); | var unused_output_indices = gradient_exclustions.OpGradientUnusedOutputIndices(op_name); | ||||
if (unused_output_indices != null) | if (unused_output_indices != null) | ||||
{ | { | ||||
@@ -53,7 +38,6 @@ namespace Tensorflow.Eager | |||||
op_outputs = new Tensor[0]; | op_outputs = new Tensor[0]; | ||||
else | else | ||||
{ | { | ||||
op_outputs_tuple_created = true; | |||||
// op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices); | // op_outputs = CopySequenceSettingIndicesToNull(results, *unused_output_indices); | ||||
} | } | ||||
} | } | ||||
@@ -61,9 +45,6 @@ namespace Tensorflow.Eager | |||||
op_outputs = results; | op_outputs = results; | ||||
Tensor[] op_inputs; | Tensor[] op_inputs; | ||||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||||
bool op_inputs_tuple_created = false; | |||||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||||
var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name); | var unused_input_indices = gradient_exclustions.OpGradientUnusedInputIndices(op_name); | ||||
if (unused_input_indices != null) | if (unused_input_indices != null) | ||||
{ | { | ||||
@@ -71,7 +52,6 @@ namespace Tensorflow.Eager | |||||
op_inputs = new Tensor[0]; | op_inputs = new Tensor[0]; | ||||
else | else | ||||
{ | { | ||||
op_inputs_tuple_created = true; | |||||
// op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); | // op_inputs = CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); | ||||
} | } | ||||
} | } | ||||
@@ -125,11 +105,6 @@ namespace Tensorflow.Eager | |||||
return HasGradientTape(); | return HasGradientTape(); | ||||
} | } | ||||
long[] MakeTensorIDList(Tensor[] tensors) | |||||
{ | |||||
return tensors.Select(x => x.Id).ToArray(); | |||||
} | |||||
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | TF_DataType[] MakeTensorDtypeList(Tensor[] tensors) | ||||
{ | { | ||||
return tensors.Select(x => x.dtype).ToArray(); | return tensors.Select(x => x.dtype).ToArray(); | ||||
@@ -310,7 +310,7 @@ namespace Tensorflow.Eager | |||||
for (int i = 0; i < num_values; ++i) | for (int i = 0; i < num_values; ++i) | ||||
{ | { | ||||
dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); | dims[i] = Marshal.AllocHGlobal(sizeof(long) * values1[i].ndim); | ||||
tf.memcpy(dims[i], values1[i].dims.Select(x => (long)x).ToArray(), values1[i].ndim * sizeof(long)); | |||||
tf.memcpy(dims[i], values1[i].dims, values1[i].ndim * sizeof(long)); | |||||
} | } | ||||
c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); | c_api.TFE_OpSetAttrShapeList(op, key, dims, num_dims, num_values, status.Handle); | ||||
@@ -14,18 +14,16 @@ namespace Tensorflow.Eager | |||||
Tensor[] sources, | Tensor[] sources, | ||||
Tensor[] output_gradients) | Tensor[] output_gradients) | ||||
{ | { | ||||
var target_vec = MakeTensorIDList(target); | |||||
var sources_vec = MakeTensorIDList(sources); | |||||
var target_vec = target; | |||||
var sources_vec = sources; | |||||
var sources_set = sources_vec; | var sources_set = sources_vec; | ||||
var seq_array = target; | var seq_array = target; | ||||
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>(); | |||||
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>(); | |||||
for (int i = 0; i < target.Length; ++i) | for (int i = 0; i < target.Length; ++i) | ||||
{ | { | ||||
var target_id = target_vec[i]; | |||||
var tensor = seq_array[i]; | |||||
source_tensors_that_are_targets.Add(target_id, TapeTensorFromTensor(tensor)); | |||||
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i])); | |||||
} | } | ||||
if (output_gradients != null) | if (output_gradients != null) | ||||
@@ -1,7 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using static Tensorflow.tensorflow; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
@@ -12,16 +12,13 @@ namespace Tensorflow.Eager | |||||
Tensor[] output_tensors, | Tensor[] output_tensors, | ||||
Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
{ | { | ||||
var output_info = new List<TapeTensor>(); | |||||
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray(); | |||||
if (!TapeTensorsFromTensorSequence(output_tensors, output_info)) | |||||
return false; | |||||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(), | |||||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info, | |||||
backward_function_getter)) | backward_function_getter)) | ||||
return false; | return false; | ||||
TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(), | |||||
TapeSetRecordBackprop(op_type, input_tensors, output_info, | |||||
backward_function_getter); | backward_function_getter); | ||||
return true; | return true; | ||||
@@ -1,12 +0,0 @@ | |||||
using Tensorflow.Gradients; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public partial class EagerRunner | |||||
{ | |||||
TapeTensor TapeTensorFromTensor(Tensor tensor) | |||||
{ | |||||
return new TapeTensor(tensor.Id, tensor.dtype, tensor.shape); | |||||
} | |||||
} | |||||
} |
@@ -1,18 +0,0 @@ | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Gradients; | |||||
namespace Tensorflow.Eager | |||||
{ | |||||
public partial class EagerRunner | |||||
{ | |||||
bool TapeTensorsFromTensorSequence(Tensor[] output_seq, | |||||
List<TapeTensor> output_info) | |||||
{ | |||||
for (var i = 0; i < output_seq.Length; ++i) | |||||
{ | |||||
output_info.Add(TapeTensorFromTensor(output_seq[i])); | |||||
} | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -7,21 +7,21 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public OpTape<BackwardFunction, TapeTensor> op_tape { get; set; } | public OpTape<BackwardFunction, TapeTensor> op_tape { get; set; } | ||||
/// <summary> | /// <summary> | ||||
/// Map from tensor ID to how many references still exist for this tensor in | |||||
/// Map from tensor to how many references still exist for this tensor in | |||||
/// the tape. | /// the tape. | ||||
/// </summary> | /// </summary> | ||||
public UnorderedMap<long, long> tensor_usage_counts { get; set; } | |||||
public UnorderedMap<Tensor, long> tensor_usage_counts { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Maps from op ID to how many output tensors of this op still need to have | /// Maps from op ID to how many output tensors of this op still need to have | ||||
/// their gradients computed. | /// their gradients computed. | ||||
/// </summary> | /// </summary> | ||||
public UnorderedMap<long, long> op_missing_tensor { get; set; } | |||||
public UnorderedMap<Tensor, long> op_missing_tensor { get; set; } | |||||
public BackpropInitialState() | public BackpropInitialState() | ||||
{ | { | ||||
op_tape = new OpTape<BackwardFunction, TapeTensor>(); | op_tape = new OpTape<BackwardFunction, TapeTensor>(); | ||||
tensor_usage_counts = new UnorderedMap<long, long>(); | |||||
op_missing_tensor = new UnorderedMap<long, long>(); | |||||
tensor_usage_counts = new UnorderedMap<Tensor, long>(); | |||||
op_missing_tensor = new UnorderedMap<Tensor, long>(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -6,6 +6,7 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Gradient Tape Set | |||||
/// Record operations for automatic differentiation. | /// Record operations for automatic differentiation. | ||||
/// | /// | ||||
/// Operations are recorded if they are executed within this context manager and | /// Operations are recorded if they are executed within this context manager and | ||||
@@ -18,54 +19,35 @@ namespace Tensorflow.Gradients | |||||
/// </summary> | /// </summary> | ||||
public class GradientTape : IDisposable | public class GradientTape : IDisposable | ||||
{ | { | ||||
bool _recording; | |||||
public bool Recording => _recording; | |||||
bool _persistent; | |||||
bool _watch_accessed_variables; | |||||
ResourceVariable[] _watched_variables; | |||||
bool _created_eagerly; | |||||
ITape _tape; | |||||
public GradientTape(bool persistent = false, | |||||
bool watch_accessed_variables = true) | |||||
int _nextTapeId; | |||||
ITape _tape => _tapeSet.Peek(); | |||||
Stack<ITape> _tapeSet; | |||||
public GradientTape() | |||||
{ | { | ||||
_persistent = persistent; | |||||
_watch_accessed_variables = watch_accessed_variables; | |||||
_created_eagerly = tf.Context.executing_eagerly(); | |||||
_recording = false; | |||||
_created_eagerly = tf.Context.executing_eagerly(); | |||||
// Enters a context inside which operations are recorded on this tape. | |||||
if (_created_eagerly) | |||||
{ | |||||
tf.Context.ensure_initialized(); | |||||
tf.Context.start_step(); | |||||
} | |||||
_push_tape(); | |||||
_tapeSet = new Stack<ITape>(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Pushes a new tape onto the tape stack. | |||||
/// New tape onto the tape stack. | |||||
/// </summary> | /// </summary> | ||||
private void _push_tape() | |||||
public ITape PushTape(bool persistent = false, | |||||
bool watch_accessed_variables = true) | |||||
{ | { | ||||
if (_recording) | |||||
throw new ValueError("Tape is still recording, This can happen if you try to " + | |||||
"re-enter an already-active tape."); | |||||
if (_tape == null) | |||||
_tape = new Tape(_persistent, _watch_accessed_variables); | |||||
else | |||||
tf.GetTapeSet().Add(_tape); | |||||
// Enters a context inside which operations are recorded on this tape. | |||||
if (tf.Context.executing_eagerly()) | |||||
tf.Context.ensure_initialized(); | |||||
_recording = true; | |||||
var tape = new Tape(persistent, watch_accessed_variables); | |||||
tape.SetTapeId(_nextTapeId++); | |||||
_tapeSet.Push(tape); | |||||
return tape; | |||||
} | } | ||||
private void _pop_tape() | |||||
ITape PopTape() | |||||
{ | { | ||||
if (!_recording) | |||||
throw new ValueError("Tape is not recording."); | |||||
_tape.PopTape(_tape); | |||||
_recording = false; | |||||
_tape.StopRecord(); | |||||
return _tapeSet.Pop(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -74,7 +56,9 @@ namespace Tensorflow.Gradients | |||||
/// <param name="x"></param> | /// <param name="x"></param> | ||||
public void watch(Tensor x) | public void watch(Tensor x) | ||||
{ | { | ||||
_tape.Watch(x.Id); | |||||
if (!_tapeSet.Any()) | |||||
return; | |||||
_tape.Watch(x); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -85,13 +69,9 @@ namespace Tensorflow.Gradients | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor gradient(Tensor target, Tensor source) | public Tensor gradient(Tensor target, Tensor source) | ||||
{ | { | ||||
if (_recording) | |||||
{ | |||||
if (!_persistent) | |||||
_pop_tape(); | |||||
} | |||||
ITape tape = stop_recording(); | |||||
var results = tf.Runner.TFE_TapeGradient(_tape, | |||||
var results = tf.Runner.TFE_TapeGradient(tape, | |||||
new[] { target }, | new[] { target }, | ||||
new[] { source }, | new[] { source }, | ||||
null); | null); | ||||
@@ -115,22 +95,17 @@ namespace Tensorflow.Gradients | |||||
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources) | ||||
{ | { | ||||
if (_recording) | |||||
{ | |||||
if (!_persistent) | |||||
_pop_tape(); | |||||
} | |||||
var tape = stop_recording(); | |||||
var results = tf.Runner.TFE_TapeGradient(_tape, | |||||
var results = tf.Runner.TFE_TapeGradient(tape, | |||||
new[] { target }, | new[] { target }, | ||||
sources.Select(x => x.Handle).ToArray(), | sources.Select(x => x.Handle).ToArray(), | ||||
null); | null); | ||||
if (!_persistent) | |||||
if (!tape.Persistent) | |||||
{ | { | ||||
// Keep track of watched variables before setting tape to None | // Keep track of watched variables before setting tape to None | ||||
_watched_variables = _tape.WatchedVariables(); | |||||
_tape = null; | |||||
// _watched_variables = _tape.WatchedVariables(); | |||||
} | } | ||||
return results; | return results; | ||||
@@ -139,18 +114,20 @@ namespace Tensorflow.Gradients | |||||
/// <summary> | /// <summary> | ||||
/// Temporarily stops recording operations on this tape. | /// Temporarily stops recording operations on this tape. | ||||
/// </summary> | /// </summary> | ||||
public void stop_recording() | |||||
public ITape stop_recording() | |||||
{ | { | ||||
_pop_tape(); | |||||
var tape = _tape; | |||||
if (!tape.Persistent) | |||||
tape = PopTape(); | |||||
return tape; | |||||
} | } | ||||
public Stack<ITape> GetTapeSet() | |||||
=> _tapeSet; | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
if (_recording) | |||||
_pop_tape(); | |||||
if (_created_eagerly) | |||||
tf.Context.end_step(); | |||||
_tapeSet.Clear(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,15 +1,15 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.tensorflow; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public interface ITape | public interface ITape | ||||
{ | { | ||||
void PopTape(ITape tape); | |||||
bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes); | |||||
void SetTapeId(int id); | |||||
bool ShouldRecord(Tensor[] tensors); | |||||
void StartRecord(); | |||||
void StopRecord(); | |||||
bool Persistent { get; } | |||||
void RecordOperation(string op_type, | void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
@@ -17,13 +17,13 @@ namespace Tensorflow.Gradients | |||||
void VariableAccessed(ResourceVariable variable); | void VariableAccessed(ResourceVariable variable); | ||||
void Watch(long tensor_id); | |||||
void Watch(Tensor x); | |||||
ResourceVariable[] WatchedVariables(); | ResourceVariable[] WatchedVariables(); | ||||
Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
long[] source_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||||
Tensor[] source_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients); | Tensor[] output_gradients); | ||||
} | } | ||||
} | } |
@@ -8,7 +8,7 @@ namespace Tensorflow.Gradients | |||||
/// <typeparam name="BackwardFunction"></typeparam> | /// <typeparam name="BackwardFunction"></typeparam> | ||||
/// <typeparam name="TapeTensor"></typeparam> | /// <typeparam name="TapeTensor"></typeparam> | ||||
public class OpTape<BackwardFunction, TapeTensor> : | public class OpTape<BackwardFunction, TapeTensor> : | ||||
UnorderedMap<long, OpTapeEntry<BackwardFunction, TapeTensor>> | |||||
UnorderedMap<Tensor, OpTapeEntry<BackwardFunction, TapeTensor>> | |||||
{ | { | ||||
} | } | ||||
@@ -1,4 +1,6 @@ | |||||
namespace Tensorflow.Gradients | |||||
using System.Linq; | |||||
namespace Tensorflow.Gradients | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Represents an entry in the tape. | /// Represents an entry in the tape. | ||||
@@ -9,9 +11,9 @@ | |||||
{ | { | ||||
public string op_type { get; set; } | public string op_type { get; set; } | ||||
public TapeTensor[] output_tensor_info { get; set; } | public TapeTensor[] output_tensor_info { get; set; } | ||||
public long[] input_tensor_id { get; set; } | |||||
public Tensor[] input_tensor_id { get; set; } | |||||
public BackwardFunction backward_function { get; set; } | public BackwardFunction backward_function { get; set; } | ||||
public override string ToString() | public override string ToString() | ||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id)}"; | |||||
=> $"{op_type}, inputs: {string.Join(",", input_tensor_id.Select(x => x.Id))}"; | |||||
} | } | ||||
} | } |
@@ -11,17 +11,17 @@ namespace Tensorflow.Gradients | |||||
int kMinAggregateCount = 4; | int kMinAggregateCount = 4; | ||||
int kMinAggregateBytes = 128 * 1024 * 1024; | int kMinAggregateBytes = 128 * 1024 * 1024; | ||||
public Tensor[] ComputeGradient(long[] target_tensor_ids, | |||||
long[] source_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
public Tensor[] ComputeGradient(Tensor[] target_tensor_ids, | |||||
Tensor[] source_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients) | Tensor[] output_gradients) | ||||
{ | { | ||||
var result = new List<Tensor>(source_tensor_ids.Length); | var result = new List<Tensor>(source_tensor_ids.Length); | ||||
var sources_set = new UnorderedSet<long>(source_tensor_ids); | |||||
var gradients_size = new UnorderedMap<long, long>(); | |||||
var sources_set = new UnorderedSet<Tensor>(source_tensor_ids); | |||||
var gradients_size = new UnorderedMap<Tensor, long>(); | |||||
var state = PrepareBackprop( | var state = PrepareBackprop( | ||||
target_tensor_ids, tensor_tape_, op_tape_, sources_set, persistent_); | |||||
target_tensor_ids, tensor_tape_, op_tape_, sources_set, _persistent); | |||||
var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | var op_stack = InitialStack(state.op_tape, state.op_missing_tensor); | ||||
var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | var gradients = InitialGradients(target_tensor_ids, sources_that_are_targets, | ||||
output_gradients, | output_gradients, | ||||
@@ -51,7 +51,7 @@ namespace Tensorflow.Gradients | |||||
var zero_indices = new List<int>(); | var zero_indices = new List<int>(); | ||||
for (int i = 0; i < trace.output_tensor_info.Length; ++i) | for (int i = 0; i < trace.output_tensor_info.Length; ++i) | ||||
{ | { | ||||
var id = trace.output_tensor_info[i].GetID(); | |||||
var id = trace.output_tensor_info[i].GetTensor(); | |||||
if (!gradients.find(id, out var grad_it)) | if (!gradients.find(id, out var grad_it)) | ||||
{ | { | ||||
if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) && | if (FunctionsAcceptingNoneForIndicesMap().find(trace.op_type, out var func_name_it) && | ||||
@@ -96,7 +96,7 @@ namespace Tensorflow.Gradients | |||||
if (in_gradients.Count() != trace.input_tensor_id.Count()) | if (in_gradients.Count() != trace.input_tensor_id.Count()) | ||||
throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | throw new RuntimeError($"Recorded operation '{trace.op_type}' returned too few gradients. Expected {trace.input_tensor_id.Length} but received {in_gradients.Count()}"); | ||||
if (!persistent_) | |||||
if (!_persistent) | |||||
{ | { | ||||
// trace.backward_function_deleter(trace.backward_function); | // trace.backward_function_deleter(trace.backward_function); | ||||
} | } | ||||
@@ -147,7 +147,7 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
var op_id = tape_it; | var op_id = tape_it; | ||||
if (op_id == -1) | |||||
if (op_id == null) | |||||
continue; | continue; | ||||
if (state.op_missing_tensor.find(op_id, out var missing_it)) | if (state.op_missing_tensor.find(op_id, out var missing_it)) | ||||
@@ -162,7 +162,7 @@ namespace Tensorflow.Gradients | |||||
if (state.op_tape.Count > 0) | if (state.op_tape.Count > 0) | ||||
throw new RuntimeError("Invalid tape state."); | throw new RuntimeError("Invalid tape state."); | ||||
var used_gradient_ids = new List<long>(source_tensor_ids.Length); | |||||
var used_gradient_ids = new List<Tensor>(source_tensor_ids.Length); | |||||
foreach (var id in source_tensor_ids) | foreach (var id in source_tensor_ids) | ||||
{ | { | ||||
if (!gradients.find(id, out var grad_it)) | if (!gradients.find(id, out var grad_it)) | ||||
@@ -203,19 +203,19 @@ namespace Tensorflow.Gradients | |||||
return m; | return m; | ||||
} | } | ||||
UnorderedMapEnumerable<long, List<Tensor>> InitialGradients(long[] target_tensor_ids, | |||||
UnorderedMap<long, TapeTensor> sources_that_are_targets, | |||||
UnorderedMapEnumerable<Tensor, List<Tensor>> InitialGradients(Tensor[] target_tensor_ids, | |||||
UnorderedMap<Tensor, TapeTensor> sources_that_are_targets, | |||||
Tensor[] output_gradients, | Tensor[] output_gradients, | ||||
TensorTape tensor_tape, | TensorTape tensor_tape, | ||||
OpTape<BackwardFunction, TapeTensor> op_tape) | OpTape<BackwardFunction, TapeTensor> op_tape) | ||||
{ | { | ||||
var result = new UnorderedMapEnumerable<long, List<Tensor>>(); | |||||
var result = new UnorderedMapEnumerable<Tensor, List<Tensor>>(); | |||||
for (int i = 0; i < target_tensor_ids.Length; ++i) | for (int i = 0; i < target_tensor_ids.Length; ++i) | ||||
{ | { | ||||
var id = target_tensor_ids[i]; | var id = target_tensor_ids[i]; | ||||
if (output_gradients.Length == 0 || output_gradients[i] == null) | if (output_gradients.Length == 0 || output_gradients[i] == null) | ||||
{ | { | ||||
if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1) | |||||
if (tensor_tape.find(id, out var tensor_id) && tensor_id != null) | |||||
{ | { | ||||
if (!op_tape.find(tensor_tape[id], out var op_it)) | if (!op_tape.find(tensor_tape[id], out var op_it)) | ||||
throw new RuntimeError("Internal state of the gradient tape is invalid: " + | throw new RuntimeError("Internal state of the gradient tape is invalid: " + | ||||
@@ -223,7 +223,7 @@ namespace Tensorflow.Gradients | |||||
bool found = false; | bool found = false; | ||||
for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | for (int j = 0; j < op_it.output_tensor_info.Length; ++j) | ||||
{ | { | ||||
if (op_it.output_tensor_info[j].GetID() == id) | |||||
if (op_it.output_tensor_info[j].GetTensor() == id) | |||||
{ | { | ||||
found = true; | found = true; | ||||
var ones = op_it.output_tensor_info[j].OnesLike(); | var ones = op_it.output_tensor_info[j].OnesLike(); | ||||
@@ -253,10 +253,10 @@ namespace Tensorflow.Gradients | |||||
return result; | return result; | ||||
} | } | ||||
Queue<long> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape, | |||||
UnorderedMap<long, long> op_missing_tensor) | |||||
Queue<Tensor> InitialStack(OpTape<BackwardFunction, TapeTensor> op_tape, | |||||
UnorderedMap<Tensor, long> op_missing_tensor) | |||||
{ | { | ||||
var result = new Queue<long>(); | |||||
var result = new Queue<Tensor>(); | |||||
foreach (var op_entry in op_tape) | foreach (var op_entry in op_tape) | ||||
{ | { | ||||
if (!op_missing_tensor.find(op_entry.Key)) | if (!op_missing_tensor.find(op_entry.Key)) | ||||
@@ -6,14 +6,14 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
public BackpropInitialState PrepareBackprop(long[] target, | |||||
public BackpropInitialState PrepareBackprop(Tensor[] target, | |||||
TensorTape tensor_tape, | TensorTape tensor_tape, | ||||
OpTape<BackwardFunction, TapeTensor> op_tape, | OpTape<BackwardFunction, TapeTensor> op_tape, | ||||
UnorderedSet<long> sources_set, | |||||
UnorderedSet<Tensor> sources_set, | |||||
bool persistent_tape) | bool persistent_tape) | ||||
{ | { | ||||
BackpropInitialState result = new BackpropInitialState(); | BackpropInitialState result = new BackpropInitialState(); | ||||
var tensor_stack = new Queue<long>(target); | |||||
var tensor_stack = new Queue<Tensor>(target); | |||||
while (tensor_stack.Count > 0) | while (tensor_stack.Count > 0) | ||||
{ | { | ||||
var tensor_id = tensor_stack.Dequeue(); | var tensor_id = tensor_stack.Dequeue(); | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Gradients | |||||
if (!tensor_tape.find(tensor_id, out var op_id)) | if (!tensor_tape.find(tensor_id, out var op_id)) | ||||
continue; | continue; | ||||
if (op_id == -1 || | |||||
if (op_id == null || | |||||
!op_tape.find(op_id, out var op_it) || | !op_tape.find(op_id, out var op_it) || | ||||
result.op_tape.find(op_id, out var result_op_it)) | result.op_tape.find(op_id, out var result_op_it)) | ||||
continue; | continue; | ||||
@@ -46,7 +46,7 @@ namespace Tensorflow.Gradients | |||||
foreach (var pair in result.tensor_usage_counts) | foreach (var pair in result.tensor_usage_counts) | ||||
{ | { | ||||
if (tensor_tape.find(pair.Key, out var it) && it != -1) | |||||
if (tensor_tape.find(pair.Key, out var it) && it != null) | |||||
result.op_missing_tensor[it] += 1; | result.op_missing_tensor[it] += 1; | ||||
} | } | ||||
@@ -4,49 +4,39 @@ using Tensorflow.Util; | |||||
using static Tensorflow.tensorflow; | using static Tensorflow.tensorflow; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public partial class Tape | public partial class Tape | ||||
{ | { | ||||
long next_op_id_ = 0; | long next_op_id_ = 0; | ||||
UnorderedMap<long, long> tensor_usage_; | |||||
UnorderedMap<Tensor, long> tensor_usage_; | |||||
public void RecordOperation(string op_type, | public void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
{ | { | ||||
var input_ids = input_tensors.Select(x => x.Id).ToArray(); | |||||
var input_dtypes = input_tensors.Select(x => x.dtype).ToArray(); | |||||
if (!ShouldRecord(input_ids, input_dtypes)) | |||||
{ | |||||
if (!ShouldRecord(input_tensors)) | |||||
return; | return; | ||||
} | |||||
long op_id = next_op_id_++; | |||||
var ids = new List<long>(input_ids.Length); | |||||
foreach (var i in input_ids) | |||||
{ | |||||
var op_id = new EagerTensor(next_op_id_++); | |||||
foreach (var i in input_tensors) | |||||
tensor_usage_[i]++; | tensor_usage_[i]++; | ||||
ids.Add(i); | |||||
} | |||||
var tensors = new List<TapeTensor>(output_tensors.Length); | |||||
foreach (var o in output_tensors) | foreach (var o in output_tensors) | ||||
{ | { | ||||
tensor_tape_[o.GetID()] = op_id; | |||||
tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}"); | ||||
tensor_usage_[o.GetID()] = 1; | |||||
tensors.Add(o); | |||||
tensor_tape_[o.GetTensor()] = op_id; | |||||
tensor_usage_[o.GetTensor()] = 1; | |||||
} | } | ||||
op_tape_[op_id] = new OpTapeEntry<BackwardFunction, TapeTensor> | op_tape_[op_id] = new OpTapeEntry<BackwardFunction, TapeTensor> | ||||
{ | { | ||||
op_type = op_type, | op_type = op_type, | ||||
output_tensor_info = tensors.ToArray(), | |||||
input_tensor_id = ids.ToArray(), | |||||
output_tensor_info = output_tensors, | |||||
input_tensor_id = input_tensors, | |||||
backward_function = backward_function_getter() | backward_function = backward_function_getter() | ||||
}; | }; | ||||
} | } | ||||
@@ -1,57 +1,56 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.tensorflow; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
public partial class Tape : ITape | public partial class Tape : ITape | ||||
{ | { | ||||
int nesting_id; | |||||
static int tape_nesting_id_counter = 0; | |||||
bool persistent_; | |||||
bool watch_accessed_variables; | |||||
int _id; | |||||
// static int tape_nesting_id_counter = 0; | |||||
bool _persistent; | |||||
public bool Persistent => _persistent; | |||||
bool _recording; | |||||
bool _created_eagerly; | |||||
TensorTape tensor_tape_; | TensorTape tensor_tape_; | ||||
OpTape<BackwardFunction, TapeTensor> op_tape_; | OpTape<BackwardFunction, TapeTensor> op_tape_; | ||||
/// <summary> | /// <summary> | ||||
/// A deque-backed stack, whose element references are not invalidated by | /// A deque-backed stack, whose element references are not invalidated by | ||||
/// pushes and pops at the back. | /// pushes and pops at the back. | ||||
/// </summary> | /// </summary> | ||||
Stack<AccumulatorCallState> call_state_; | |||||
// Stack<AccumulatorCallState> call_state_; | |||||
public Tape(bool persistent, bool watch_accessed_variables) | public Tape(bool persistent, bool watch_accessed_variables) | ||||
{ | { | ||||
this.persistent_ = persistent; | |||||
this.watch_accessed_variables = watch_accessed_variables; | |||||
_persistent = persistent; | |||||
_created_eagerly = tf.Context.executing_eagerly(); | |||||
tensor_tape_ = new TensorTape(); | tensor_tape_ = new TensorTape(); | ||||
op_tape_ = new OpTape<BackwardFunction, TapeTensor>(); | op_tape_ = new OpTape<BackwardFunction, TapeTensor>(); | ||||
tensor_usage_ = new UnorderedMap<long, long>(); | |||||
nesting_id = ++tape_nesting_id_counter; | |||||
tf.GetTapeSet().Add(this); | |||||
tensor_usage_ = new UnorderedMap<Tensor, long>(); | |||||
if(_created_eagerly) | |||||
tf.Context.start_step(); | |||||
// nesting_id = ++tape_nesting_id_counter; | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
/// Marks this tensor to be watched by the given tape. | /// Marks this tensor to be watched by the given tape. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="x"></param> | /// <param name="x"></param> | ||||
public void Watch(long tensor_id) | |||||
public void Watch(Tensor x) | |||||
{ | { | ||||
if (!CouldBackprop()) | |||||
return; | |||||
tf.Logger.Debug($"Watch tensor_id={tensor_id}"); | |||||
tensor_tape_.emplace(tensor_id, -1); | |||||
tf.Logger.Debug($"Watch tensor id={x.Id}, name={x.name}"); | |||||
tensor_tape_.emplace(x, null); | |||||
} | } | ||||
public bool ShouldRecord(long[] tensor_ids, TF_DataType[] dtypes) | |||||
public bool ShouldRecord(Tensor[] tensors) | |||||
{ | { | ||||
for (int i = 0; i < tensor_ids.Length; ++i) | |||||
var dtypes = tensors.Select(x => x.dtype).ToArray(); | |||||
for (int i = 0; i < tensors.Length; ++i) | |||||
{ | { | ||||
if (tensor_tape_.find(tensor_ids[i])) | |||||
if (tensor_tape_.find(tensors[i])) | |||||
{ | { | ||||
if (IsDtypeTrainable(dtypes[i])) | if (IsDtypeTrainable(dtypes[i])) | ||||
return true; | return true; | ||||
@@ -60,18 +59,9 @@ namespace Tensorflow.Gradients | |||||
return false; | return false; | ||||
} | } | ||||
/// <summary> | |||||
/// Pops the given tape in the stack. | |||||
/// </summary> | |||||
/// <param name="tape"></param> | |||||
public void PopTape(ITape tape) | |||||
{ | |||||
tf.GetTapeSet().Remove(tape); | |||||
} | |||||
public void VariableAccessed(ResourceVariable variable) | public void VariableAccessed(ResourceVariable variable) | ||||
{ | { | ||||
Watch(variable.Handle.Id); | |||||
Watch(variable.Handle); | |||||
} | } | ||||
public ResourceVariable[] WatchedVariables() | public ResourceVariable[] WatchedVariables() | ||||
@@ -97,17 +87,29 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
} | } | ||||
bool CouldForwardprop() | |||||
=> HasAccumulator(); | |||||
public void StartRecord() | |||||
{ | |||||
if (_recording) | |||||
throw new ValueError("Tape is still recording, This can happen if you try to " + | |||||
"re-enter an already-active tape."); | |||||
_recording = true; | |||||
} | |||||
bool CouldBackprop() | |||||
=> HasGradientTape(); | |||||
public void StopRecord() | |||||
{ | |||||
if (!_recording) | |||||
throw new ValueError("Tape is not recording."); | |||||
if (_created_eagerly) | |||||
tf.Context.end_step(); | |||||
_recording = false; | |||||
} | |||||
bool HasAccumulator() | |||||
//return !GetAccumulatorSet()->empty(); | |||||
=> false; | |||||
public void SetTapeId(int id) | |||||
{ | |||||
_id = id; | |||||
} | |||||
bool HasGradientTape() | |||||
=> tf.GetTapeSet().Count > 0; | |||||
public override string ToString() | |||||
=> $"Tape {_id} {(_recording ? "Recording" : "Stopped")}"; | |||||
} | } | ||||
} | } |
@@ -4,18 +4,18 @@ namespace Tensorflow.Gradients | |||||
{ | { | ||||
public class TapeTensor | public class TapeTensor | ||||
{ | { | ||||
long id; | |||||
TF_DataType dtype; | |||||
Shape shape; | |||||
Tensor tensor; | |||||
long id => tensor.Id; | |||||
TF_DataType dtype => tensor.dtype; | |||||
Shape shape => tensor.shape; | |||||
public TapeTensor(long id, TF_DataType dtype, Shape shape) | |||||
public TapeTensor(Tensor tensor) | |||||
{ | { | ||||
this.id = id; | |||||
this.dtype = dtype; | |||||
this.shape = shape; | |||||
this.tensor = tensor; | |||||
} | } | ||||
public long GetID() => id; | |||||
public long GetID() => tensor.Id; | |||||
public Tensor GetTensor() => tensor; | |||||
public Tensor ZerosLike() | public Tensor ZerosLike() | ||||
=> tf.zeros(shape: shape, dtype: dtype); | => tf.zeros(shape: shape, dtype: dtype); | ||||
@@ -3,11 +3,11 @@ | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Map from tensor_id to internally-defined operation-id of the operation which | |||||
/// Map from tensor to internally-defined operation-id of the operation which | |||||
/// produced this tensor. A value of -1 means that the tensor was directly | /// produced this tensor. A value of -1 means that the tensor was directly | ||||
/// watched and not the result of any operation in the tape. | /// watched and not the result of any operation in the tape. | ||||
/// </summary> | /// </summary> | ||||
public class TensorTape : UnorderedMap<long, long> | |||||
public class TensorTape : UnorderedMap<Tensor, Tensor> | |||||
{ | { | ||||
} | } | ||||
@@ -543,7 +543,7 @@ namespace Tensorflow | |||||
{ | { | ||||
if (_IsBackpropagatable(output)) | if (_IsBackpropagatable(output)) | ||||
{ | { | ||||
var c = _Consumers(output, func_graphs).ToList(); | |||||
var c = output.consumers().ToList(); | |||||
c.ForEach(x => queue.Enqueue(x)); | c.ForEach(x => queue.Enqueue(x)); | ||||
} | } | ||||
} | } | ||||
@@ -551,16 +551,6 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// Returns the consumers of t, crossing closure boundaries where necessary. | |||||
/// </summary> | |||||
/// <param name="t"></param> | |||||
/// <param name="func_graphs"></param> | |||||
private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs) | |||||
{ | |||||
return t.consumers(); | |||||
} | |||||
private static bool _IsBackpropagatable(Tensor tensor) | private static bool _IsBackpropagatable(Tensor tensor) | ||||
{ | { | ||||
if (_IsTrainable(tensor)) | if (_IsTrainable(tensor)) | ||||
@@ -12,6 +12,7 @@ namespace Tensorflow.NumPy | |||||
{ | { | ||||
TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data), | TF_DataType.TF_UINT8 => Scalar<T>(*(byte*)nd.data), | ||||
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data), | TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data), | ||||
TF_DataType.TF_INT32 => Scalar<T>(*(int*)nd.data), | |||||
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data), | TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data), | ||||
_ => throw new NotImplementedException("") | _ => throw new NotImplementedException("") | ||||
}; | }; | ||||
@@ -34,6 +35,15 @@ namespace Tensorflow.NumPy | |||||
_ => throw new NotImplementedException("") | _ => throw new NotImplementedException("") | ||||
}; | }; | ||||
static T Scalar<T>(int input) | |||||
=> Type.GetTypeCode(typeof(T)) switch | |||||
{ | |||||
TypeCode.Byte => (T)Convert.ChangeType(input, TypeCode.Byte), | |||||
TypeCode.Int64 => (T)Convert.ChangeType(input, TypeCode.Int64), | |||||
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
static T Scalar<T>(long input) | static T Scalar<T>(long input) | ||||
=> Type.GetTypeCode(typeof(T)) switch | => Type.GetTypeCode(typeof(T)) switch | ||||
{ | { | ||||
@@ -98,6 +98,7 @@ namespace Tensorflow | |||||
var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i; | var handle = control_input_handle + Marshal.SizeOf<IntPtr>() * i; | ||||
control_inputs[i] = new Operation(*(IntPtr*)handle); | control_inputs[i] = new Operation(*(IntPtr*)handle); | ||||
} | } | ||||
Marshal.FreeHGlobal(control_input_handle); | |||||
} | } | ||||
return control_inputs; | return control_inputs; | ||||
@@ -66,7 +66,7 @@ namespace Tensorflow | |||||
var inputptr = (TF_Input*)handle; | var inputptr = (TF_Input*)handle; | ||||
for (int i = 0; i < num; i++) | for (int i = 0; i < num; i++) | ||||
consumers[i] = *(inputptr + i); | consumers[i] = *(inputptr + i); | ||||
Marshal.FreeHGlobal(handle); | |||||
return consumers; | return consumers; | ||||
} | } | ||||
@@ -83,6 +83,7 @@ namespace Tensorflow | |||||
var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | var handle = control_output_handle + Marshal.SizeOf<IntPtr>() * i; | ||||
control_outputs[i] = new Operation(*(IntPtr*)handle); | control_outputs[i] = new Operation(*(IntPtr*)handle); | ||||
} | } | ||||
Marshal.FreeHGlobal(control_output_handle); | |||||
} | } | ||||
return control_outputs; | return control_outputs; | ||||
@@ -36,7 +36,7 @@ namespace Tensorflow | |||||
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); | consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(oper)); | ||||
} | } | ||||
} | } | ||||
Marshal.FreeHGlobal(handle); | |||||
return consumers; | return consumers; | ||||
} | } | ||||
} | } |
@@ -25,7 +25,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | ||||
public partial class tensorflow : ITensorFlowObject | |||||
public partial class tensorflow | |||||
{ | { | ||||
public TF_DataType byte8 = TF_DataType.TF_UINT8; | public TF_DataType byte8 = TF_DataType.TF_UINT8; | ||||
public TF_DataType int8 = TF_DataType.TF_INT8; | public TF_DataType int8 = TF_DataType.TF_INT8; | ||||
@@ -64,6 +64,7 @@ namespace Tensorflow | |||||
private void InitGradientEnvironment() | private void InitGradientEnvironment() | ||||
{ | { | ||||
_tapeSet = new GradientTape(); | |||||
ops.RegisterFromAssembly(); | ops.RegisterFromAssembly(); | ||||
} | } | ||||
@@ -106,41 +107,5 @@ namespace Tensorflow | |||||
{ | { | ||||
return new Session(null, config).as_default(); | return new Session(null, config).as_default(); | ||||
} | } | ||||
List<ITape> tape_set; | |||||
public List<ITape> GetTapeSet() | |||||
{ | |||||
if (tape_set == null) | |||||
{ | |||||
tape_set = new List<ITape>(); | |||||
} | |||||
return tape_set; | |||||
} | |||||
public void __init__() | |||||
{ | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
} | |||||
public void __del__() | |||||
{ | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -16,9 +16,9 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
{ | { | ||||
// Calcute the gradient of w * w | // Calcute the gradient of w * w | ||||
// by Automatic Differentiation in Eager mode | // by Automatic Differentiation in Eager mode | ||||
// in tensorflow.net 2.x that is in development intensively | |||||
var w = tf.constant(1.5f); | var w = tf.constant(1.5f); | ||||
using var tape = tf.GradientTape(); | using var tape = tf.GradientTape(); | ||||
// w is defined before tape is recording | |||||
tape.watch(w); | tape.watch(w); | ||||
var loss = w * w; | var loss = w * w; | ||||
var grad = tape.gradient(loss, w); | var grad = tape.gradient(loss, w); | ||||
@@ -56,8 +56,6 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
} | } | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void SquaredDifference_1D() | public void SquaredDifference_1D() | ||||
{ | { | ||||
@@ -66,14 +64,15 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
// Expected is 2*(abs(x1-x2)) | // Expected is 2*(abs(x1-x2)) | ||||
Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 }); | Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 }); | ||||
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 }); | Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 }); | ||||
float[] expected = new float[] { | |||||
float[] expected = new float[] | |||||
{ | |||||
(29-1) * 2, | (29-1) * 2, | ||||
(27-3) * 2, | (27-3) * 2, | ||||
(23-5) * 2, | (23-5) * 2, | ||||
(7-21) * 2, | (7-21) * 2, | ||||
(11-19) * 2, | (11-19) * 2, | ||||
(13-17) * 2 | (13-17) * 2 | ||||
}; | |||||
}; | |||||
// Sanity check | // Sanity check | ||||
using (var tape = tf.GradientTape()) | using (var tape = tf.GradientTape()) | ||||
@@ -100,7 +99,7 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
/// <summary> | /// <summary> | ||||
/// Calcute the gradient of w * w * w | |||||
/// Calcute the higher derivative gradient of w * w * w | |||||
/// 高阶梯度 | /// 高阶梯度 | ||||
/// </summary> | /// </summary> | ||||
[TestMethod] | [TestMethod] | ||||
@@ -110,10 +109,8 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
using var tape1 = tf.GradientTape(); | using var tape1 = tf.GradientTape(); | ||||
using var tape2 = tf.GradientTape(); | using var tape2 = tf.GradientTape(); | ||||
var y = x * x * x; | var y = x * x * x; | ||||
tape2.Dispose(); | |||||
var dy_dx = tape2.gradient(y, x); | var dy_dx = tape2.gradient(y, x); | ||||
Assert.AreEqual((float)dy_dx, 3.0f); | Assert.AreEqual((float)dy_dx, 3.0f); | ||||
tape1.Dispose(); | |||||
var d2y_d2x = tape1.gradient(dy_dx, x); | var d2y_d2x = tape1.gradient(dy_dx, x); | ||||
Assert.AreEqual((float)d2y_d2x, 6.0f); | Assert.AreEqual((float)d2y_d2x, 6.0f); | ||||
} | } | ||||
@@ -140,8 +137,6 @@ namespace TensorFlowNET.UnitTest.Gradient | |||||
tape.watch(x); | tape.watch(x); | ||||
var y = tf.reduce_sum(x); | var y = tf.reduce_sum(x); | ||||
var z = tf.multiply(y, y); | var z = tf.multiply(y, y); | ||||
tape.Dispose(); | |||||
var dz_dx = tape.gradient(z, x); | var dz_dx = tape.gradient(z, x); | ||||
var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | var expected = new float[] { 8.0f, 8.0f, 8.0f, 8.0f }; | ||||