diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 43480a83..709b478d 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -18,6 +18,7 @@ using System; using System.Diagnostics; using System.Linq; using Tensorflow.Eager; +using static Tensorflow.Binding; namespace Tensorflow.Contexts { @@ -114,6 +115,36 @@ namespace Tensorflow.Contexts } } + [DebuggerStepThrough] + public Tensors RunInAutoMode2(Func graphAction, + Func eagerAction, + Action recordGradient, + Tensors tensors) + { + var shouldRunInEager = executing_eagerly() + && tensors.Count(x => x.IsEagerTensor) == tensors.Length; + + if (shouldRunInEager) + return eagerAction(); + else + { + if (executing_eagerly()) + { + graph_mode(); + var result = graphAction(); + restore_mode(); + return result; + } + else + { + var result = graphAction(); + if (tf.Runner.MustRecordGradient()) + recordGradient(result[0].op); + return result; + } + } + } + public void Dispose() => Handle.Dispose(); } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs index bf5eba19..ad3bd244 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs @@ -11,7 +11,8 @@ namespace Tensorflow.Eager public bool RecordGradient(string op_name, Tensor[] inputs, object[] attrs, - Tensor[] results) + Tensor[] results, + Func getBackwardFunction = null) { var input_ids = MakeTensorIDList(inputs); var input_dtypes = MakeTensorDtypeList(inputs); @@ -77,13 +78,20 @@ namespace Tensorflow.Eager else op_inputs = inputs; - TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, - () => GetGradientFunction(op_name, inputs, attrs, results)); - + TapeSetRecordOperation(op_name, inputs, results, + getBackwardFunction ?? GetBackwradFunction(op_name, inputs, attrs, results)); return true; } + Func GetBackwradFunction(string op_name, + Tensor[] op_inputs, + object[] attrs, + Tensor[] op_outputs) + { + return () => GetGradientFunction(op_name, op_inputs, attrs, op_outputs); + } + BackwardFunction GetGradientFunction(string op_name, Tensor[] op_inputs, object[] attrs, diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs index 1e764e64..22515f4e 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs @@ -10,8 +10,6 @@ namespace Tensorflow.Eager void TapeSetRecordBackprop(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, - long[] input_ids, - TF_DataType[] input_dtypes, Func backward_function_getter) { if (!CouldBackprop()) @@ -22,7 +20,6 @@ namespace Tensorflow.Eager foreach (var tape in tf.GetTapeSet()) { tape.RecordOperation(op_type, input_tensors, output_tensors, - input_ids, input_dtypes, backward_function_getter); } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs index 844addd2..1c5cac7b 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordForwardprop.cs @@ -9,8 +9,6 @@ namespace Tensorflow.Eager bool TapeSetRecordForwardprop(string op_type, Tensor[] input_tensors, TapeTensor[] output_tensors, - long[] input_ids, - TF_DataType[] input_dtypes, Func backward_function_getter) { if (!CouldForwardprop()) diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs index bb623d72..e70a513f 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs @@ -7,11 +7,9 @@ namespace Tensorflow.Eager { public partial class EagerRunner { - bool TapeSetRecordOperation(string op_type, + public bool TapeSetRecordOperation(string op_type, Tensor[] input_tensors, Tensor[] output_tensors, - long[] input_ids, - TF_DataType[] input_dtypes, Func backward_function_getter) { var output_info = new List(); @@ -20,11 +18,11 @@ namespace Tensorflow.Eager return false; if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(), - input_ids, input_dtypes, backward_function_getter)) + backward_function_getter)) return false; TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(), - input_ids, input_dtypes, backward_function_getter); + backward_function_getter); return true; } diff --git a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs index bbc1b882..6f401c1d 100644 --- a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs +++ b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs @@ -1,6 +1,7 @@ using System; using Tensorflow.Contexts; using Tensorflow.Gradients; +using static Tensorflow.tensorflow; namespace Tensorflow.Eager { @@ -37,7 +38,8 @@ namespace Tensorflow.Eager bool RecordGradient(string op_name, Tensor[] inputs, object[] attrs, - Tensor[] results); + Tensor[] results, + Func getBackwardFunction = null); bool MustRecordGradient(); diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs index ddfa590e..08cbc1da 100644 --- a/src/TensorFlowNET.Core/Gradients/Tape.cs +++ b/src/TensorFlowNET.Core/Gradients/Tape.cs @@ -54,10 +54,7 @@ namespace Tensorflow.Gradients if (tensor_tape_.find(tensor_ids[i])) { if (IsDtypeTrainable(dtypes[i])) - { - tf.Logger.Debug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}"); return true; - } } } return false; diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index a43799aa..5a3f5835 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -19,6 +19,7 @@ using System.Collections.Generic; using System.Linq; using System.Reflection; using Tensorflow.Gradients; +using static Tensorflow.Binding; namespace Tensorflow { @@ -47,11 +48,17 @@ namespace Tensorflow { RegisterGradientFunction(m.GetCustomAttribute().Name, (oper, out_grads) => - g.InvokeMember(m.Name, - BindingFlags.InvokeMethod, - null, - null, - args: new object[] { oper, out_grads }) as Tensor[] + { + tf.Logger.Debug($"Caculate Gradient: {m.Name}"); + var results = g.InvokeMember(m.Name, + BindingFlags.InvokeMethod, + null, + null, + args: new object[] { oper, out_grads }) as Tensor[]; + foreach (var result in results.Where(x => x != null)) + tf.Logger.Debug($"{result.TensorShape}"); + return results; + } ); } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index f059b9bd..fa74bffe 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -358,7 +358,7 @@ namespace Tensorflow } var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy }); - + return op.output; } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index ac7f084c..56d0863a 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -335,9 +335,10 @@ namespace Tensorflow.Keras.Engine var layer_inputs = node.MapArguments(tensor_dict); - tf.Logger.Debug($"{node.Layer}: {node.Layer.Name}"); + tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); var outputs = node.Layer.Apply(layer_inputs, is_training: training); - + foreach (var output in outputs.Where(x => x != null)) + tf.Logger.Debug($"{output.TensorShape}"); // Update tensor_dict for next input foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); diff --git a/src/TensorFlowNET.Keras/KerasApi.cs b/src/TensorFlowNET.Keras/KerasApi.cs index a22c0399..d10ced0c 100644 --- a/src/TensorFlowNET.Keras/KerasApi.cs +++ b/src/TensorFlowNET.Keras/KerasApi.cs @@ -4,9 +4,6 @@ namespace Tensorflow { public static class KerasApi { - public static KerasInterface Keras(this tensorflow tf) - => new KerasInterface(); - public static KerasInterface keras { get; } = new KerasInterface(); } }