@@ -18,6 +18,7 @@ using System; | |||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
{ | { | ||||
@@ -114,6 +115,36 @@ namespace Tensorflow.Contexts | |||||
} | } | ||||
} | } | ||||
[DebuggerStepThrough] | |||||
public Tensors RunInAutoMode2(Func<Tensors> graphAction, | |||||
Func<Tensors> eagerAction, | |||||
Action<Operation> recordGradient, | |||||
Tensors tensors) | |||||
{ | |||||
var shouldRunInEager = executing_eagerly() | |||||
&& tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||||
if (shouldRunInEager) | |||||
return eagerAction(); | |||||
else | |||||
{ | |||||
if (executing_eagerly()) | |||||
{ | |||||
graph_mode(); | |||||
var result = graphAction(); | |||||
restore_mode(); | |||||
return result; | |||||
} | |||||
else | |||||
{ | |||||
var result = graphAction(); | |||||
if (tf.Runner.MustRecordGradient()) | |||||
recordGradient(result[0].op); | |||||
return result; | |||||
} | |||||
} | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
=> Handle.Dispose(); | => Handle.Dispose(); | ||||
} | } | ||||
@@ -11,7 +11,8 @@ namespace Tensorflow.Eager | |||||
public bool RecordGradient(string op_name, | public bool RecordGradient(string op_name, | ||||
Tensor[] inputs, | Tensor[] inputs, | ||||
object[] attrs, | object[] attrs, | ||||
Tensor[] results) | |||||
Tensor[] results, | |||||
Func<BackwardFunction> getBackwardFunction = null) | |||||
{ | { | ||||
var input_ids = MakeTensorIDList(inputs); | var input_ids = MakeTensorIDList(inputs); | ||||
var input_dtypes = MakeTensorDtypeList(inputs); | var input_dtypes = MakeTensorDtypeList(inputs); | ||||
@@ -77,13 +78,20 @@ namespace Tensorflow.Eager | |||||
else | else | ||||
op_inputs = inputs; | op_inputs = inputs; | ||||
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, | |||||
() => GetGradientFunction(op_name, inputs, attrs, results)); | |||||
TapeSetRecordOperation(op_name, inputs, results, | |||||
getBackwardFunction ?? GetBackwradFunction(op_name, inputs, attrs, results)); | |||||
return true; | return true; | ||||
} | } | ||||
Func<BackwardFunction> GetBackwradFunction(string op_name, | |||||
Tensor[] op_inputs, | |||||
object[] attrs, | |||||
Tensor[] op_outputs) | |||||
{ | |||||
return () => GetGradientFunction(op_name, op_inputs, attrs, op_outputs); | |||||
} | |||||
BackwardFunction GetGradientFunction(string op_name, | BackwardFunction GetGradientFunction(string op_name, | ||||
Tensor[] op_inputs, | Tensor[] op_inputs, | ||||
object[] attrs, | object[] attrs, | ||||
@@ -10,8 +10,6 @@ namespace Tensorflow.Eager | |||||
void TapeSetRecordBackprop(string op_type, | void TapeSetRecordBackprop(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_ids, | |||||
TF_DataType[] input_dtypes, | |||||
Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
{ | { | ||||
if (!CouldBackprop()) | if (!CouldBackprop()) | ||||
@@ -22,7 +20,6 @@ namespace Tensorflow.Eager | |||||
foreach (var tape in tf.GetTapeSet()) | foreach (var tape in tf.GetTapeSet()) | ||||
{ | { | ||||
tape.RecordOperation(op_type, input_tensors, output_tensors, | tape.RecordOperation(op_type, input_tensors, output_tensors, | ||||
input_ids, input_dtypes, | |||||
backward_function_getter); | backward_function_getter); | ||||
} | } | ||||
} | } | ||||
@@ -9,8 +9,6 @@ namespace Tensorflow.Eager | |||||
bool TapeSetRecordForwardprop(string op_type, | bool TapeSetRecordForwardprop(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_ids, | |||||
TF_DataType[] input_dtypes, | |||||
Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
{ | { | ||||
if (!CouldForwardprop()) | if (!CouldForwardprop()) | ||||
@@ -7,11 +7,9 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
bool TapeSetRecordOperation(string op_type, | |||||
public bool TapeSetRecordOperation(string op_type, | |||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
Tensor[] output_tensors, | Tensor[] output_tensors, | ||||
long[] input_ids, | |||||
TF_DataType[] input_dtypes, | |||||
Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
{ | { | ||||
var output_info = new List<TapeTensor>(); | var output_info = new List<TapeTensor>(); | ||||
@@ -20,11 +18,11 @@ namespace Tensorflow.Eager | |||||
return false; | return false; | ||||
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(), | if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info.ToArray(), | ||||
input_ids, input_dtypes, backward_function_getter)) | |||||
backward_function_getter)) | |||||
return false; | return false; | ||||
TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(), | TapeSetRecordBackprop(op_type, input_tensors, output_info.ToArray(), | ||||
input_ids, input_dtypes, backward_function_getter); | |||||
backward_function_getter); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using static Tensorflow.tensorflow; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
@@ -37,7 +38,8 @@ namespace Tensorflow.Eager | |||||
bool RecordGradient(string op_name, | bool RecordGradient(string op_name, | ||||
Tensor[] inputs, | Tensor[] inputs, | ||||
object[] attrs, | object[] attrs, | ||||
Tensor[] results); | |||||
Tensor[] results, | |||||
Func<BackwardFunction> getBackwardFunction = null); | |||||
bool MustRecordGradient(); | bool MustRecordGradient(); | ||||
@@ -54,10 +54,7 @@ namespace Tensorflow.Gradients | |||||
if (tensor_tape_.find(tensor_ids[i])) | if (tensor_tape_.find(tensor_ids[i])) | ||||
{ | { | ||||
if (IsDtypeTrainable(dtypes[i])) | if (IsDtypeTrainable(dtypes[i])) | ||||
{ | |||||
tf.Logger.Debug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}"); | |||||
return true; | return true; | ||||
} | |||||
} | } | ||||
} | } | ||||
return false; | return false; | ||||
@@ -19,6 +19,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Reflection; | using System.Reflection; | ||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -47,11 +48,17 @@ namespace Tensorflow | |||||
{ | { | ||||
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | ||||
(oper, out_grads) => | (oper, out_grads) => | ||||
g.InvokeMember(m.Name, | |||||
BindingFlags.InvokeMethod, | |||||
null, | |||||
null, | |||||
args: new object[] { oper, out_grads }) as Tensor[] | |||||
{ | |||||
tf.Logger.Debug($"Caculate Gradient: {m.Name}"); | |||||
var results = g.InvokeMember(m.Name, | |||||
BindingFlags.InvokeMethod, | |||||
null, | |||||
null, | |||||
args: new object[] { oper, out_grads }) as Tensor[]; | |||||
foreach (var result in results.Where(x => x != null)) | |||||
tf.Logger.Debug($"{result.TensorShape}"); | |||||
return results; | |||||
} | |||||
); | ); | ||||
} | } | ||||
@@ -358,7 +358,7 @@ namespace Tensorflow | |||||
} | } | ||||
var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy }); | var op = tf.OpDefLib._apply_op_helper("SigmoidGrad", name: name, args: new { y, dy }); | ||||
return op.output; | return op.output; | ||||
} | } | ||||
@@ -335,9 +335,10 @@ namespace Tensorflow.Keras.Engine | |||||
var layer_inputs = node.MapArguments(tensor_dict); | var layer_inputs = node.MapArguments(tensor_dict); | ||||
tf.Logger.Debug($"{node.Layer}: {node.Layer.Name}"); | |||||
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | |||||
var outputs = node.Layer.Apply(layer_inputs, is_training: training); | var outputs = node.Layer.Apply(layer_inputs, is_training: training); | ||||
foreach (var output in outputs.Where(x => x != null)) | |||||
tf.Logger.Debug($"{output.TensorShape}"); | |||||
// Update tensor_dict for next input | // Update tensor_dict for next input | ||||
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | ||||
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | ||||
@@ -4,9 +4,6 @@ namespace Tensorflow | |||||
{ | { | ||||
public static class KerasApi | public static class KerasApi | ||||
{ | { | ||||
public static KerasInterface Keras(this tensorflow tf) | |||||
=> new KerasInterface(); | |||||
public static KerasInterface keras { get; } = new KerasInterface(); | public static KerasInterface keras { get; } = new KerasInterface(); | ||||
} | } | ||||
} | } |