@@ -1,4 +1,5 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
@@ -11,11 +12,34 @@ namespace Tensorflow.Functions | |||||
/// </summary> | /// </summary> | ||||
public class ConcreteFunction : IDisposable | public class ConcreteFunction : IDisposable | ||||
{ | { | ||||
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
IntPtr _handle; | IntPtr _handle; | ||||
FuncGraph func_graph; | |||||
public string Name | |||||
{ | |||||
get | |||||
{ | |||||
if (func_graph != null) | |||||
return func_graph.FuncName; | |||||
return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
} | |||||
} | |||||
public Tensor[] Outputs; | public Tensor[] Outputs; | ||||
public Type ReturnType; | |||||
public TensorSpec[] OutputStructure; | public TensorSpec[] OutputStructure; | ||||
public ConcreteFunction(string name) | |||||
{ | |||||
func_graph = new FuncGraph(name); | |||||
} | |||||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||||
{ | |||||
func_graph = graph; | |||||
} | |||||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
{ | { | ||||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | ||||
@@ -28,8 +52,8 @@ namespace Tensorflow.Functions | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
_handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
new Operation[] { input }, | |||||
new Operation[] { output }, | |||||
new[] { input }, | |||||
new[] { output }, | |||||
null); | null); | ||||
} | } | ||||
} | } | ||||
@@ -48,8 +72,8 @@ namespace Tensorflow.Functions | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
_handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
new Operation[] { input }, | |||||
new Operation[] { output.variant_tensor.op }, | |||||
new[] { input }, | |||||
new[] { output.variant_tensor }, | |||||
null); | null); | ||||
} | } | ||||
} | } | ||||
@@ -72,12 +96,38 @@ namespace Tensorflow.Functions | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
_handle = graph.ToGraph(opers, | _handle = graph.ToGraph(opers, | ||||
new Operation[] { input1, input2, input3 }, | |||||
new Operation[] { outputs.Item1.op, outputs.Item2.op }, | |||||
new[] { input1, input2, input3 }, | |||||
new[] { outputs.Item1, outputs.Item2 }, | |||||
null); | null); | ||||
} | } | ||||
} | } | ||||
public void ToGraph(Tensors inputs, Tensors outputs) | |||||
{ | |||||
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
_handle = func_graph.ToGraph(opers, | |||||
inputs, | |||||
outputs, | |||||
null); | |||||
} | |||||
public Tensors Invoke(Tensors inputs) | |||||
{ | |||||
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||||
Tensors flat_outputs = null; | |||||
if (tf.Context.executing_eagerly()) | |||||
flat_outputs = forward_function.Call(args_with_tangents); | |||||
forward_backward.Record(flat_outputs); | |||||
return flat_outputs; | |||||
} | |||||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||||
{ | |||||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||||
} | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | ||||
@@ -0,0 +1,44 @@ | |||||
using Google.Protobuf; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Graphs; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
public class EagerDefinedFunction | |||||
{ | |||||
public int _num_outputs; | |||||
public string Name => _func_graph.FuncName; | |||||
FuncGraph _func_graph; | |||||
public EagerDefinedFunction(string name, FuncGraph graph, | |||||
Tensors inputs, Tensors outputs, | |||||
Dictionary<string, string> attrs) | |||||
{ | |||||
_num_outputs = outputs.Length; | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | |||||
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||||
.Select(x => x as Operation).ToArray(); | |||||
var output_names = new string[0]; | |||||
_func_graph = new FuncGraph(graph, name, attrs); | |||||
_func_graph.ToGraph(operations, inputs, outputs, output_names); | |||||
} | |||||
public Tensors Call(Tensors args) | |||||
{ | |||||
var results = tf.Runner.TFE_Execute(tf.Context, | |||||
tf.Context.DeviceName, | |||||
_func_graph.FuncName, | |||||
args, | |||||
null, | |||||
_num_outputs); | |||||
return results; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,25 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Graphs; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
public class FirstOrderTapeGradientFunctions : TapeGradientFunctions | |||||
{ | |||||
public FirstOrderTapeGradientFunctions(FuncGraph func_graph, | |||||
bool need_gradients_for_jvps) : base(func_graph, | |||||
need_gradients_for_jvps) | |||||
{ | |||||
} | |||||
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
{ | |||||
var outputs = _func_graph.Outputs; | |||||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||||
= BuildFunctionsForOutputs(outputs, inference_args); | |||||
return _forward; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,38 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
/// <summary> | |||||
/// Holds the state of a function call between execution and recording. | |||||
/// </summary> | |||||
public class ForwardBackwardCall | |||||
{ | |||||
TapeGradientFunctions _functions; | |||||
Tensors _inference_args; | |||||
Tensors _input_tangents; | |||||
bool _tape_watching; | |||||
public ForwardBackwardCall(TapeGradientFunctions functions, | |||||
Tensors inference_args, | |||||
bool tape_watching) | |||||
{ | |||||
_functions = functions; | |||||
_inference_args = inference_args; | |||||
_tape_watching = tape_watching; | |||||
} | |||||
public (EagerDefinedFunction, Tensors) Forward() | |||||
{ | |||||
var forward_function = _functions.Forward(_inference_args); | |||||
return (forward_function, _inference_args); | |||||
} | |||||
public void Record(Tensors flat_outputs) | |||||
{ | |||||
if (_tape_watching && flat_outputs != null) | |||||
_functions.Record(flat_outputs, _inference_args); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,120 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Graphs; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.tensorflow; | |||||
namespace Tensorflow.Functions | |||||
{ | |||||
/// <summary> | |||||
/// Caches forward and backward functions compatible with eager gradients. | |||||
/// </summary> | |||||
public abstract class TapeGradientFunctions | |||||
{ | |||||
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||||
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||||
string _FORWARD_PREFIX = "__forward_"; | |||||
string _BACKWARD_PREFIX = "__backward_"; | |||||
string _INFERENCE_PREFIX = "__inference_"; | |||||
protected FuncGraph _func_graph; | |||||
protected EagerDefinedFunction _forward; | |||||
protected FuncGraph _forward_graph; | |||||
protected List<int> _forwardprop_output_indices; | |||||
protected int _num_forwardprop_outputs; | |||||
protected ConcreteFunction _backward; | |||||
public TapeGradientFunctions(FuncGraph func_graph, | |||||
bool need_gradients_for_jvps) | |||||
{ | |||||
_func_graph = func_graph; | |||||
} | |||||
public EagerDefinedFunction Forward(Tensors inference_args) | |||||
{ | |||||
return ForwardAndBackwardFunctions(inference_args); | |||||
} | |||||
/// <summary> | |||||
/// Record the function call operation. | |||||
/// </summary> | |||||
/// <param name="flat_outputs"></param> | |||||
/// <param name="inference_args"></param> | |||||
public void Record(Tensors flat_outputs, Tensors inference_args) | |||||
{ | |||||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||||
tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args, | |||||
getBackwardFunction: () => backward_function); | |||||
} | |||||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||||
{ | |||||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||||
{ | |||||
return new Tensor[0]; | |||||
/*var gradients = ops.gradientFunctions[op_name](new EagerOperation | |||||
{ | |||||
Name = op_name, | |||||
NumInputs = op_inputs.Length, | |||||
Inputs = op_inputs, | |||||
NumOutputs = op_outputs.Length, | |||||
Outputs = op_outputs, | |||||
SkipInputIndices = unneeded_gradients, | |||||
Attrs = attrs | |||||
}, output_grads); | |||||
return gradients;*/ | |||||
}; | |||||
return (_backward_function_wrapper, flat_outputs); | |||||
} | |||||
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||||
BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) | |||||
{ | |||||
var trainable_outputs = new List<Tensor>(); | |||||
var trainable_indices = new List<int>(); | |||||
foreach(var (index, output) in enumerate(outputs)) | |||||
{ | |||||
if (gradients_util.IsTrainable(output)) | |||||
{ | |||||
trainable_outputs.Add(output); | |||||
trainable_indices.Add(index); | |||||
} | |||||
} | |||||
var gradients_wrt_outputs = new List<Tensor>(); | |||||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||||
foreach (var output in trainable_outputs) | |||||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||||
_func_graph.Inputs, | |||||
grad_ys: gradients_wrt_outputs.ToArray(), | |||||
src_graph: _func_graph); | |||||
tf.Context.restore_mode(); | |||||
var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||||
var backward_function_attr = new Dictionary<string, string>(); | |||||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||||
backwards_graph.Inputs = gradients_wrt_outputs; | |||||
backwards_graph.Outputs = gradients_wrt_inputs; | |||||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||||
var forward_function_attr = new Dictionary<string, string>(); | |||||
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||||
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||||
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||||
return (forward_function, _func_graph, backward_function, null, 0); | |||||
} | |||||
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -47,6 +47,9 @@ namespace Tensorflow | |||||
string description, | string description, | ||||
SafeStatusHandle status); | SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern IntPtr TF_FunctionName(IntPtr func); | public static extern IntPtr TF_FunctionName(IntPtr func); | ||||
@@ -13,8 +13,6 @@ namespace Tensorflow.Gradients | |||||
void RecordOperation(string op_type, | void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_tensor_id, | |||||
TF_DataType[] input_dtypes, | |||||
Func<BackwardFunction> backward_function_getter); | Func<BackwardFunction> backward_function_getter); | ||||
void VariableAccessed(ResourceVariable variable); | void VariableAccessed(ResourceVariable variable); | ||||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.tensorflow; | using static Tensorflow.tensorflow; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Linq; | |||||
namespace Tensorflow.Gradients | namespace Tensorflow.Gradients | ||||
{ | { | ||||
@@ -14,18 +15,19 @@ namespace Tensorflow.Gradients | |||||
public void RecordOperation(string op_type, | public void RecordOperation(string op_type, | ||||
Tensor[] input_tensors, | Tensor[] input_tensors, | ||||
TapeTensor[] output_tensors, | TapeTensor[] output_tensors, | ||||
long[] input_tensor_id, | |||||
TF_DataType[] input_dtypes, | |||||
Func<BackwardFunction> backward_function_getter) | Func<BackwardFunction> backward_function_getter) | ||||
{ | { | ||||
if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||||
var input_ids = input_tensors.Select(x => x.Id).ToArray(); | |||||
var input_dtypes = input_tensors.Select(x => x.dtype).ToArray(); | |||||
if (!ShouldRecord(input_ids, input_dtypes)) | |||||
{ | { | ||||
return; | return; | ||||
} | } | ||||
long op_id = next_op_id_++; | long op_id = next_op_id_++; | ||||
var ids = new List<long>(input_tensor_id.Length); | |||||
foreach (var i in input_tensor_id) | |||||
var ids = new List<long>(input_ids.Length); | |||||
foreach (var i in input_ids) | |||||
{ | { | ||||
tensor_usage_[i]++; | tensor_usage_[i]++; | ||||
ids.Add(i); | ids.Add(i); | ||||
@@ -17,6 +17,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Graphs; | |||||
using Tensorflow.Operations.ControlFlows; | using Tensorflow.Operations.ControlFlows; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -39,7 +40,14 @@ namespace Tensorflow | |||||
// If src_graph is a _FuncGraph (i.e. a function body), gather it and all | // If src_graph is a _FuncGraph (i.e. a function body), gather it and all | ||||
// ancestor graphs. This is necessary for correctly handling captured values. | // ancestor graphs. This is necessary for correctly handling captured values. | ||||
var func_graphs = new List<FuncGraph>(); | |||||
var curr_graph = src_graph; | var curr_graph = src_graph; | ||||
if (src_graph is FuncGraph func_graph) | |||||
{ | |||||
func_graphs.append(func_graph); | |||||
curr_graph = func_graph.OuterGraph; | |||||
} | |||||
if (stop_gradients == null) | if (stop_gradients == null) | ||||
stop_gradients = new Tensor[0]; | stop_gradients = new Tensor[0]; | ||||
@@ -84,7 +92,7 @@ namespace Tensorflow | |||||
var to_ops = ys.Select(x => x.op).ToList(); | var to_ops = ys.Select(x => x.op).ToList(); | ||||
var from_ops = xs.Select(x => x.op).ToList(); | var from_ops = xs.Select(x => x.op).ToList(); | ||||
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | ||||
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||||
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs); | |||||
// Add the initial gradients for the ys. | // Add the initial gradients for the ys. | ||||
foreach (var (y, grad_y) in zip(ys, grad_ys)) | foreach (var (y, grad_y) in zip(ys, grad_ys)) | ||||
@@ -258,11 +266,8 @@ namespace Tensorflow | |||||
{ | { | ||||
var new_grad_ys = new List<Tensor>(); | var new_grad_ys = new List<Tensor>(); | ||||
for (int i = 0; i < grad_ys.Length; i++) | |||||
foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys))) | |||||
{ | { | ||||
var grad_y = grad_ys[i]; | |||||
var y = ys[i]; | |||||
_maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); | _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); | ||||
if (grad_y == null) | if (grad_y == null) | ||||
@@ -272,8 +277,17 @@ namespace Tensorflow | |||||
var shape = array_ops.shape(y); | var shape = array_ops.shape(y); | ||||
var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); | var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); | ||||
var fill = gen_array_ops.fill(shape, constant); | var fill = gen_array_ops.fill(shape, constant); | ||||
new_grad_ys.Add(fill); | |||||
new_grad_ys.append(fill); | |||||
continue; | |||||
} | } | ||||
if (y.dtype.is_floating() || y.dtype.is_integer()) | |||||
{ | |||||
} | |||||
// Create a grad_y tensor in the name scope of the gradient. | |||||
new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}")); | |||||
} | } | ||||
return new_grad_ys.ToArray(); | return new_grad_ys.ToArray(); | ||||
@@ -294,7 +308,11 @@ namespace Tensorflow | |||||
/// <param name="colocate_gradients_with_ops"></param> | /// <param name="colocate_gradients_with_ops"></param> | ||||
/// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
/// <param name="xs"></param> | /// <param name="xs"></param> | ||||
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||||
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, | |||||
List<Operation> from_ops, | |||||
bool colocate_gradients_with_ops, | |||||
List<FuncGraph> func_graphs, | |||||
Tensor[] xs) | |||||
{ | { | ||||
// Mark reachable ops from from_ops. | // Mark reachable ops from from_ops. | ||||
var reached_ops = new List<Operation>(); | var reached_ops = new List<Operation>(); | ||||
@@ -511,7 +529,7 @@ namespace Tensorflow | |||||
/// <param name="from_ops"></param> | /// <param name="from_ops"></param> | ||||
/// <param name="reached_ops"></param> | /// <param name="reached_ops"></param> | ||||
/// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs) | |||||
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<FuncGraph> func_graphs) | |||||
{ | { | ||||
Queue<Operation> queue = new Queue<Operation>(from_ops); | Queue<Operation> queue = new Queue<Operation>(from_ops); | ||||
while (queue.Count > 0) | while (queue.Count > 0) | ||||
@@ -538,7 +556,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
/// <param name="t"></param> | /// <param name="t"></param> | ||||
/// <param name="func_graphs"></param> | /// <param name="func_graphs"></param> | ||||
private static Operation[] _Consumers(Tensor t, List<object> func_graphs) | |||||
private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs) | |||||
{ | { | ||||
return t.consumers(); | return t.consumers(); | ||||
} | } | ||||
@@ -647,7 +665,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
private static bool IsTrainable(Tensor tensor) | |||||
public static bool IsTrainable(Tensor tensor) | |||||
{ | { | ||||
var dtype = tensor.dtype.as_base_dtype(); | var dtype = tensor.dtype.as_base_dtype(); | ||||
return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | ||||
@@ -18,10 +18,11 @@ namespace Tensorflow.Graphs | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
var func_handle = graph.ToGraph(opers, | var func_handle = graph.ToGraph(opers, | ||||
new Operation[] { input }, | |||||
new Operation[] { output }, | |||||
new[] { input }, | |||||
new[] { output }, | |||||
null); | null); | ||||
} | } | ||||
return (Tensor input) => | return (Tensor input) => | ||||
{ | { | ||||
@@ -48,11 +49,11 @@ namespace Tensorflow.Graphs | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | ||||
var func_handle = graph.ToGraph(opers, | var func_handle = graph.ToGraph(opers, | ||||
new Operation[] { input1, input2 }, | |||||
new Operation[] { output }, | |||||
new[] { input1, input2 }, | |||||
new[] { output }, | |||||
null); | null); | ||||
} | } | ||||
return (Tensor a, Tensor b) => | return (Tensor a, Tensor b) => | ||||
{ | { | ||||
var result = tf.Runner.TFE_Execute(tf.Context, | var result = tf.Runner.TFE_Execute(tf.Context, | ||||
@@ -3,6 +3,7 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Functions; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Graphs | namespace Tensorflow.Graphs | ||||
@@ -10,10 +11,10 @@ namespace Tensorflow.Graphs | |||||
[AllowChangingInputArguments] | [AllowChangingInputArguments] | ||||
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | ||||
{ | { | ||||
FuncGraph graph; | |||||
ConcreteFunction function; | |||||
Tensors originalInputs; | Tensors originalInputs; | ||||
string func_name; | string func_name; | ||||
static Dictionary<string, Func<Tensors, Tensors>> functions = new Dictionary<string, Func<Tensors, Tensors>>(); | |||||
static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>(); | |||||
public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
{ | { | ||||
@@ -21,22 +22,24 @@ namespace Tensorflow.Graphs | |||||
if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
{ | { | ||||
function = functions[func_name]; | |||||
if (args.Arguments[0] is Tensors tensor_inputs) | if (args.Arguments[0] is Tensors tensor_inputs) | ||||
args.ReturnValue = functions[func_name](tensor_inputs.ToArray()); | |||||
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||||
else | else | ||||
args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray()); | |||||
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||||
args.FlowBehavior = FlowBehavior.Return; | args.FlowBehavior = FlowBehavior.Return; | ||||
return; | return; | ||||
} | } | ||||
// make function as an Operation by autograph | // make function as an Operation by autograph | ||||
graph = new FuncGraph(func_name); | |||||
// need to restore mode when exits | |||||
function = new ConcreteFunction(func_name); | |||||
// convert to Tensors | // convert to Tensors | ||||
if (args.Arguments[0] is Tensors inputs) | if (args.Arguments[0] is Tensors inputs) | ||||
{ | { | ||||
originalInputs = inputs; | originalInputs = inputs; | ||||
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray(); | |||||
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape, name: "inputs")).ToArray(); | |||||
args.Arguments[0] = new Tensors(new_inputs); | args.Arguments[0] = new Tensors(new_inputs); | ||||
} | } | ||||
else | else | ||||
@@ -48,7 +51,7 @@ namespace Tensorflow.Graphs | |||||
if (args.Arguments[i] is EagerTensor tensor) | if (args.Arguments[i] is EagerTensor tensor) | ||||
{ | { | ||||
originalInputs[i] = tensor; | originalInputs[i] = tensor; | ||||
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); | |||||
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -56,58 +59,30 @@ namespace Tensorflow.Graphs | |||||
public override void OnExit(MethodExecutionArgs args) | public override void OnExit(MethodExecutionArgs args) | ||||
{ | { | ||||
var returnValue = mark_as_return(args.ReturnValue as Tensors); | |||||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||||
if (args.ReturnValue is Tensors outputs) | if (args.ReturnValue is Tensors outputs) | ||||
{ | { | ||||
if (args.Arguments[0] is Tensors inputs) | if (args.Arguments[0] is Tensors inputs) | ||||
{ | |||||
graph.ToGraph(opers, | |||||
inputs.Select(x => x.op).ToArray(), | |||||
outputs.Select(x => x.op).ToArray(), | |||||
null); | |||||
} | |||||
function.ToGraph(inputs, outputs); | |||||
else | else | ||||
{ | |||||
graph.ToGraph(opers, | |||||
args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||||
outputs.Select(x => x.op).ToArray(), | |||||
null); | |||||
} | |||||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||||
} | } | ||||
else | else | ||||
{ | |||||
graph.ToGraph(opers, | |||||
args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||||
new Operation[] { (args.ReturnValue as Tensor).op }, | |||||
null); | |||||
} | |||||
graph.Dispose(); | |||||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||||
Func<Tensors, Tensors> function = (x) => | |||||
{ | |||||
var result = tf.Runner.TFE_Execute(tf.Context, | |||||
tf.Context.DeviceName, | |||||
func_name, | |||||
x, | |||||
null, | |||||
1); | |||||
return result[0]; | |||||
}; | |||||
// cache function. | // cache function. | ||||
function.ReturnType = args.ReturnValue.GetType(); | |||||
functions[func_name] = function; | functions[func_name] = function; | ||||
// run function | // run function | ||||
args.ReturnValue = function(originalInputs); | |||||
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||||
} | } | ||||
Tensor mark_as_return(Tensor tensor) | |||||
object ConvertReturnValue(Tensors tensors) | |||||
{ | { | ||||
return array_ops.identity(tensor); | |||||
if (function.ReturnType == typeof(Tensor)) | |||||
return (Tensor)tensors; | |||||
else | |||||
return tensors; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -10,13 +10,18 @@ namespace Tensorflow.Graphs | |||||
/// </summary> | /// </summary> | ||||
public class FuncGraph : Graph | public class FuncGraph : Graph | ||||
{ | { | ||||
List<Operation> inputs; | |||||
List<Operation> outputs; | |||||
Graph outer_graph; | Graph outer_graph; | ||||
public Graph OuterGraph => outer_graph; | |||||
string func_name; | string func_name; | ||||
// _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||||
IntPtr func_handle; | IntPtr func_handle; | ||||
public string FuncName => func_name; | public string FuncName => func_name; | ||||
public Tensors Inputs { get; set; } | |||||
public Tensors Outputs { get; set; } | |||||
/// <summary> | /// <summary> | ||||
/// Construct a new FuncGraph. | /// Construct a new FuncGraph. | ||||
/// </summary> | /// </summary> | ||||
@@ -29,8 +34,17 @@ namespace Tensorflow.Graphs | |||||
as_default(); | as_default(); | ||||
} | } | ||||
public FuncGraph(IntPtr handle, string name) | |||||
{ | |||||
outer_graph = ops.get_default_graph(); | |||||
func_name = name; | |||||
tf.Context.graph_mode(); | |||||
as_default(); | |||||
} | |||||
public IntPtr ToGraph(Operation[] opers, | public IntPtr ToGraph(Operation[] opers, | ||||
Operation[] inputs, Operation[] outputs, | |||||
Tensor[] inputs, Tensor[] outputs, | |||||
string[] output_names) | string[] output_names) | ||||
{ | { | ||||
using var status = new Status(); | using var status = new Status(); | ||||
@@ -40,9 +54,9 @@ namespace Tensorflow.Graphs | |||||
opers.Length, | opers.Length, | ||||
opers.Select(x => (IntPtr)x).ToArray(), | opers.Select(x => (IntPtr)x).ToArray(), | ||||
inputs.Length, | inputs.Length, | ||||
inputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
outputs.Length, | outputs.Length, | ||||
outputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||||
output_names == null || output_names.Length == 0 ? null : output_names, | output_names == null || output_names.Length == 0 ? null : output_names, | ||||
IntPtr.Zero, | IntPtr.Zero, | ||||
null, | null, | ||||
@@ -57,13 +71,18 @@ namespace Tensorflow.Graphs | |||||
func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | ||||
Inputs = inputs; | |||||
// mark_as_return | |||||
Outputs = outputs.Select(x => array_ops.identity(x)).ToArray(); | |||||
tf.Context.restore_mode(); | |||||
return func_handle; | return func_handle; | ||||
} | } | ||||
protected override void DisposeManagedResources() | protected override void DisposeManagedResources() | ||||
{ | { | ||||
base.DisposeManagedResources(); | base.DisposeManagedResources(); | ||||
tf.Context.restore_mode(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -69,7 +69,7 @@ namespace Tensorflow | |||||
throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); | throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); | ||||
var defaultKey = tf.get_default_graph().graph_key; | var defaultKey = tf.get_default_graph().graph_key; | ||||
if (graph_key != defaultKey) | |||||
if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey) | |||||
{ | { | ||||
//Console.WriteLine($"Current graph is not default graph."); | //Console.WriteLine($"Current graph is not default graph."); | ||||
throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); | throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); | ||||
@@ -218,6 +218,8 @@ namespace Tensorflow | |||||
{ | { | ||||
case nameof(Int32): | case nameof(Int32): | ||||
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | ||||
case nameof(Int64): | |||||
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||||
default: | default: | ||||
return null; | return null; | ||||
} | } | ||||
@@ -235,6 +235,14 @@ namespace Tensorflow | |||||
} | } | ||||
var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); | var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); | ||||
if (tf.Runner.MustRecordGradient()) | |||||
{ | |||||
tf.Runner.RecordGradient("Identity", _op.inputs, new object[] | |||||
{ | |||||
"T", _op.get_attr<TF_DataType>("T") | |||||
}, _op.outputs); | |||||
} | |||||
return _op.output; | return _op.output; | ||||
} | } | ||||
@@ -632,8 +640,8 @@ namespace Tensorflow | |||||
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | ||||
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, | int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, | ||||
int shrink_axis_mask = 0, string name = null) | int shrink_axis_mask = 0, string name = null) | ||||
=> tf.Context.RunInAutoMode(() | |||||
=> tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||||
=> tf.Context.RunInAutoMode2( | |||||
() => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||||
{ | { | ||||
shape, | shape, | ||||
begin, | begin, | ||||
@@ -645,8 +653,8 @@ namespace Tensorflow | |||||
ellipsis_mask, | ellipsis_mask, | ||||
new_axis_mask, | new_axis_mask, | ||||
shrink_axis_mask | shrink_axis_mask | ||||
}).output, () | |||||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
}).output, | |||||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"StridedSliceGrad", name, | "StridedSliceGrad", name, | ||||
null, | null, | ||||
shape, begin, end, strides, dy, | shape, begin, end, strides, dy, | ||||
@@ -654,8 +662,22 @@ namespace Tensorflow | |||||
"end_mask", end_mask, | "end_mask", end_mask, | ||||
"ellipsis_mask", ellipsis_mask, | "ellipsis_mask", ellipsis_mask, | ||||
"new_axis_mask", new_axis_mask, | "new_axis_mask", new_axis_mask, | ||||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||||
shape, begin, end, strides, dy); | |||||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||||
(op) => | |||||
{ | |||||
var attrs = new object[] | |||||
{ | |||||
"T", op.get_attr<TF_DataType>("T"), | |||||
"Index", op.get_attr<TF_DataType>("Index"), | |||||
"begin_mask", op.get_attr<long>("begin_mask"), | |||||
"end_mask", op.get_attr<long>("end_mask"), | |||||
"ellipsis_mask", op.get_attr<long>("ellipsis_mask"), | |||||
"new_axis_mask", op.get_attr<long>("new_axis_mask"), | |||||
"shrink_axis_mask", op.get_attr<long>("shrink_axis_mask") | |||||
}; | |||||
tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs); | |||||
}, | |||||
new Tensors(shape, begin, end, strides, dy)); | |||||
/// <summary> | /// <summary> | ||||
/// Removes dimensions of size 1 from the shape of a tensor. | /// Removes dimensions of size 1 from the shape of a tensor. | ||||
@@ -23,6 +23,8 @@ using Tensorflow.Gradients; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||||
public partial class tensorflow : ITensorFlowObject | public partial class tensorflow : ITensorFlowObject | ||||
{ | { | ||||
public TF_DataType byte8 = TF_DataType.TF_UINT8; | public TF_DataType byte8 = TF_DataType.TF_UINT8; | ||||
@@ -37,8 +39,6 @@ namespace Tensorflow | |||||
public TF_DataType chars = TF_DataType.TF_STRING; | public TF_DataType chars = TF_DataType.TF_STRING; | ||||
public TF_DataType @string = TF_DataType.TF_STRING; | public TF_DataType @string = TF_DataType.TF_STRING; | ||||
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||||
public Status Status; | public Status Status; | ||||
public OpDefLibrary OpDefLib; | public OpDefLibrary OpDefLib; | ||||
public Context Context; | public Context Context; | ||||
@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Engine | |||||
// Record the gradient because custom-made ops don't go through the | // Record the gradient because custom-made ops don't go through the | ||||
// code-gen'd eager call path | // code-gen'd eager call path | ||||
var op_type = op.node_def.Name; | |||||
var op_type = op.node_def.Op; | |||||
tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | ||||
@@ -1,9 +1,9 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.KerasApi; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System; | using System; | ||||
using System.Linq; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
@@ -26,9 +26,20 @@ namespace Tensorflow.Keras.Layers | |||||
var result = array_ops.reshape(inputs, shape.ToArray()); | var result = array_ops.reshape(inputs, shape.ToArray()); | ||||
if (!tf.Context.executing_eagerly()) | if (!tf.Context.executing_eagerly()) | ||||
// result = result.set_shape(compute_output_shape(inputs.shape)); | |||||
throw new NotImplementedException(""); | |||||
result.set_shape(compute_output_shape(inputs.shape)); | |||||
return result; | return result; | ||||
} | } | ||||
TensorShape compute_output_shape(TensorShape input_shape) | |||||
{ | |||||
if (input_shape.dims[0] == -1) | |||||
{ | |||||
input_shape = input_shape.dims[0]; | |||||
var output_shape = input_shape.concatenate(args.TargetShape.dims); | |||||
return output_shape; | |||||
} | |||||
else | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,6 +1,7 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System; | using System; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Graphs; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
@@ -36,7 +37,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
/// <param name="a"></param> | /// <param name="a"></param> | ||||
/// <param name="b"></param> | /// <param name="b"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
// [AutoGraph] | |||||
[AutoGraph] | |||||
Tensor Mul(Tensor a, Tensor b) | Tensor Mul(Tensor a, Tensor b) | ||||
{ | { | ||||
return a * b; | return a * b; | ||||