@@ -1,4 +1,5 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Graphs; | |||
@@ -11,11 +12,34 @@ namespace Tensorflow.Functions | |||
/// </summary> | |||
public class ConcreteFunction : IDisposable | |||
{ | |||
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
IntPtr _handle; | |||
FuncGraph func_graph; | |||
public string Name | |||
{ | |||
get | |||
{ | |||
if (func_graph != null) | |||
return func_graph.FuncName; | |||
return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
} | |||
} | |||
public Tensor[] Outputs; | |||
public Type ReturnType; | |||
public TensorSpec[] OutputStructure; | |||
public ConcreteFunction(string name) | |||
{ | |||
func_graph = new FuncGraph(name); | |||
} | |||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||
{ | |||
func_graph = graph; | |||
} | |||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
{ | |||
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; | |||
@@ -28,8 +52,8 @@ namespace Tensorflow.Functions | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new Operation[] { input }, | |||
new Operation[] { output }, | |||
new[] { input }, | |||
new[] { output }, | |||
null); | |||
} | |||
} | |||
@@ -48,8 +72,8 @@ namespace Tensorflow.Functions | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new Operation[] { input }, | |||
new Operation[] { output.variant_tensor.op }, | |||
new[] { input }, | |||
new[] { output.variant_tensor }, | |||
null); | |||
} | |||
} | |||
@@ -72,12 +96,38 @@ namespace Tensorflow.Functions | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = graph.ToGraph(opers, | |||
new Operation[] { input1, input2, input3 }, | |||
new Operation[] { outputs.Item1.op, outputs.Item2.op }, | |||
new[] { input1, input2, input3 }, | |||
new[] { outputs.Item1, outputs.Item2 }, | |||
null); | |||
} | |||
} | |||
public void ToGraph(Tensors inputs, Tensors outputs) | |||
{ | |||
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
_handle = func_graph.ToGraph(opers, | |||
inputs, | |||
outputs, | |||
null); | |||
} | |||
public Tensors Invoke(Tensors inputs) | |||
{ | |||
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly()); | |||
var (forward_function, args_with_tangents) = forward_backward.Forward(); | |||
Tensors flat_outputs = null; | |||
if (tf.Context.executing_eagerly()) | |||
flat_outputs = forward_function.Call(args_with_tangents); | |||
forward_backward.Record(flat_outputs); | |||
return flat_outputs; | |||
} | |||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
{ | |||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
} | |||
public void Dispose() | |||
{ | |||
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | |||
@@ -0,0 +1,44 @@ | |||
using Google.Protobuf; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Functions | |||
{ | |||
public class EagerDefinedFunction | |||
{ | |||
public int _num_outputs; | |||
public string Name => _func_graph.FuncName; | |||
FuncGraph _func_graph; | |||
public EagerDefinedFunction(string name, FuncGraph graph, | |||
Tensors inputs, Tensors outputs, | |||
Dictionary<string, string> attrs) | |||
{ | |||
_num_outputs = outputs.Length; | |||
var input_ops = inputs.Select(x => x.op).ToArray(); | |||
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op)) | |||
.Select(x => x as Operation).ToArray(); | |||
var output_names = new string[0]; | |||
_func_graph = new FuncGraph(graph, name, attrs); | |||
_func_graph.ToGraph(operations, inputs, outputs, output_names); | |||
} | |||
public Tensors Call(Tensors args) | |||
{ | |||
var results = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
_func_graph.FuncName, | |||
args, | |||
null, | |||
_num_outputs); | |||
return results; | |||
} | |||
} | |||
} |
@@ -0,0 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow.Functions | |||
{ | |||
public class FirstOrderTapeGradientFunctions : TapeGradientFunctions | |||
{ | |||
public FirstOrderTapeGradientFunctions(FuncGraph func_graph, | |||
bool need_gradients_for_jvps) : base(func_graph, | |||
need_gradients_for_jvps) | |||
{ | |||
} | |||
public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
{ | |||
var outputs = _func_graph.Outputs; | |||
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs) | |||
= BuildFunctionsForOutputs(outputs, inference_args); | |||
return _forward; | |||
} | |||
} | |||
} |
@@ -0,0 +1,38 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Functions | |||
{ | |||
/// <summary> | |||
/// Holds the state of a function call between execution and recording. | |||
/// </summary> | |||
public class ForwardBackwardCall | |||
{ | |||
TapeGradientFunctions _functions; | |||
Tensors _inference_args; | |||
Tensors _input_tangents; | |||
bool _tape_watching; | |||
public ForwardBackwardCall(TapeGradientFunctions functions, | |||
Tensors inference_args, | |||
bool tape_watching) | |||
{ | |||
_functions = functions; | |||
_inference_args = inference_args; | |||
_tape_watching = tape_watching; | |||
} | |||
public (EagerDefinedFunction, Tensors) Forward() | |||
{ | |||
var forward_function = _functions.Forward(_inference_args); | |||
return (forward_function, _inference_args); | |||
} | |||
public void Record(Tensors flat_outputs) | |||
{ | |||
if (_tape_watching && flat_outputs != null) | |||
_functions.Record(flat_outputs, _inference_args); | |||
} | |||
} | |||
} |
@@ -0,0 +1,120 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.tensorflow; | |||
namespace Tensorflow.Functions | |||
{ | |||
/// <summary> | |||
/// Caches forward and backward functions compatible with eager gradients. | |||
/// </summary> | |||
public abstract class TapeGradientFunctions | |||
{ | |||
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"; | |||
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"; | |||
string _FORWARD_PREFIX = "__forward_"; | |||
string _BACKWARD_PREFIX = "__backward_"; | |||
string _INFERENCE_PREFIX = "__inference_"; | |||
protected FuncGraph _func_graph; | |||
protected EagerDefinedFunction _forward; | |||
protected FuncGraph _forward_graph; | |||
protected List<int> _forwardprop_output_indices; | |||
protected int _num_forwardprop_outputs; | |||
protected ConcreteFunction _backward; | |||
public TapeGradientFunctions(FuncGraph func_graph, | |||
bool need_gradients_for_jvps) | |||
{ | |||
_func_graph = func_graph; | |||
} | |||
public EagerDefinedFunction Forward(Tensors inference_args) | |||
{ | |||
return ForwardAndBackwardFunctions(inference_args); | |||
} | |||
/// <summary> | |||
/// Record the function call operation. | |||
/// </summary> | |||
/// <param name="flat_outputs"></param> | |||
/// <param name="inference_args"></param> | |||
public void Record(Tensors flat_outputs, Tensors inference_args) | |||
{ | |||
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs); | |||
tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args, | |||
getBackwardFunction: () => backward_function); | |||
} | |||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||
{ | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
{ | |||
return new Tensor[0]; | |||
/*var gradients = ops.gradientFunctions[op_name](new EagerOperation | |||
{ | |||
Name = op_name, | |||
NumInputs = op_inputs.Length, | |||
Inputs = op_inputs, | |||
NumOutputs = op_outputs.Length, | |||
Outputs = op_outputs, | |||
SkipInputIndices = unneeded_gradients, | |||
Attrs = attrs | |||
}, output_grads); | |||
return gradients;*/ | |||
}; | |||
return (_backward_function_wrapper, flat_outputs); | |||
} | |||
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args) | |||
{ | |||
var trainable_outputs = new List<Tensor>(); | |||
var trainable_indices = new List<int>(); | |||
foreach(var (index, output) in enumerate(outputs)) | |||
{ | |||
if (gradients_util.IsTrainable(output)) | |||
{ | |||
trainable_outputs.Add(output); | |||
trainable_indices.Add(index); | |||
} | |||
} | |||
var gradients_wrt_outputs = new List<Tensor>(); | |||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||
foreach (var output in trainable_outputs) | |||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
_func_graph.Inputs, | |||
grad_ys: gradients_wrt_outputs.ToArray(), | |||
src_graph: _func_graph); | |||
tf.Context.restore_mode(); | |||
var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||
var backward_function_attr = new Dictionary<string, string>(); | |||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
backwards_graph.Inputs = gradients_wrt_outputs; | |||
backwards_graph.Outputs = gradients_wrt_inputs; | |||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
var forward_function_attr = new Dictionary<string, string>(); | |||
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr); | |||
return (forward_function, _func_graph, backward_function, null, 0); | |||
} | |||
public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -47,6 +47,9 @@ namespace Tensorflow | |||
string description, | |||
SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern IntPtr TF_FunctionName(IntPtr func); | |||
@@ -13,8 +13,6 @@ namespace Tensorflow.Gradients | |||
void RecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
long[] input_tensor_id, | |||
TF_DataType[] input_dtypes, | |||
Func<BackwardFunction> backward_function_getter); | |||
void VariableAccessed(ResourceVariable variable); | |||
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using Tensorflow.Util; | |||
using static Tensorflow.tensorflow; | |||
using static Tensorflow.Binding; | |||
using System.Linq; | |||
namespace Tensorflow.Gradients | |||
{ | |||
@@ -14,18 +15,19 @@ namespace Tensorflow.Gradients | |||
public void RecordOperation(string op_type, | |||
Tensor[] input_tensors, | |||
TapeTensor[] output_tensors, | |||
long[] input_tensor_id, | |||
TF_DataType[] input_dtypes, | |||
Func<BackwardFunction> backward_function_getter) | |||
{ | |||
if (!ShouldRecord(input_tensor_id, input_dtypes)) | |||
var input_ids = input_tensors.Select(x => x.Id).ToArray(); | |||
var input_dtypes = input_tensors.Select(x => x.dtype).ToArray(); | |||
if (!ShouldRecord(input_ids, input_dtypes)) | |||
{ | |||
return; | |||
} | |||
long op_id = next_op_id_++; | |||
var ids = new List<long>(input_tensor_id.Length); | |||
foreach (var i in input_tensor_id) | |||
var ids = new List<long>(input_ids.Length); | |||
foreach (var i in input_ids) | |||
{ | |||
tensor_usage_[i]++; | |||
ids.Add(i); | |||
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Operations.ControlFlows; | |||
using static Tensorflow.Binding; | |||
@@ -39,7 +40,14 @@ namespace Tensorflow | |||
// If src_graph is a _FuncGraph (i.e. a function body), gather it and all | |||
// ancestor graphs. This is necessary for correctly handling captured values. | |||
var func_graphs = new List<FuncGraph>(); | |||
var curr_graph = src_graph; | |||
if (src_graph is FuncGraph func_graph) | |||
{ | |||
func_graphs.append(func_graph); | |||
curr_graph = func_graph.OuterGraph; | |||
} | |||
if (stop_gradients == null) | |||
stop_gradients = new Tensor[0]; | |||
@@ -84,7 +92,7 @@ namespace Tensorflow | |||
var to_ops = ys.Select(x => x.op).ToList(); | |||
var from_ops = xs.Select(x => x.op).ToList(); | |||
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); | |||
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs); | |||
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs); | |||
// Add the initial gradients for the ys. | |||
foreach (var (y, grad_y) in zip(ys, grad_ys)) | |||
@@ -258,11 +266,8 @@ namespace Tensorflow | |||
{ | |||
var new_grad_ys = new List<Tensor>(); | |||
for (int i = 0; i < grad_ys.Length; i++) | |||
foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys))) | |||
{ | |||
var grad_y = grad_ys[i]; | |||
var y = ys[i]; | |||
_maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); | |||
if (grad_y == null) | |||
@@ -272,8 +277,17 @@ namespace Tensorflow | |||
var shape = array_ops.shape(y); | |||
var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); | |||
var fill = gen_array_ops.fill(shape, constant); | |||
new_grad_ys.Add(fill); | |||
new_grad_ys.append(fill); | |||
continue; | |||
} | |||
if (y.dtype.is_floating() || y.dtype.is_integer()) | |||
{ | |||
} | |||
// Create a grad_y tensor in the name scope of the gradient. | |||
new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}")); | |||
} | |||
return new_grad_ys.ToArray(); | |||
@@ -294,7 +308,11 @@ namespace Tensorflow | |||
/// <param name="colocate_gradients_with_ops"></param> | |||
/// <param name="func_graphs"></param> | |||
/// <param name="xs"></param> | |||
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs) | |||
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, | |||
List<Operation> from_ops, | |||
bool colocate_gradients_with_ops, | |||
List<FuncGraph> func_graphs, | |||
Tensor[] xs) | |||
{ | |||
// Mark reachable ops from from_ops. | |||
var reached_ops = new List<Operation>(); | |||
@@ -511,7 +529,7 @@ namespace Tensorflow | |||
/// <param name="from_ops"></param> | |||
/// <param name="reached_ops"></param> | |||
/// <param name="func_graphs"></param> | |||
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs) | |||
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<FuncGraph> func_graphs) | |||
{ | |||
Queue<Operation> queue = new Queue<Operation>(from_ops); | |||
while (queue.Count > 0) | |||
@@ -538,7 +556,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <param name="t"></param> | |||
/// <param name="func_graphs"></param> | |||
private static Operation[] _Consumers(Tensor t, List<object> func_graphs) | |||
private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs) | |||
{ | |||
return t.consumers(); | |||
} | |||
@@ -647,7 +665,7 @@ namespace Tensorflow | |||
} | |||
} | |||
private static bool IsTrainable(Tensor tensor) | |||
public static bool IsTrainable(Tensor tensor) | |||
{ | |||
var dtype = tensor.dtype.as_base_dtype(); | |||
return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, | |||
@@ -18,10 +18,11 @@ namespace Tensorflow.Graphs | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
var func_handle = graph.ToGraph(opers, | |||
new Operation[] { input }, | |||
new Operation[] { output }, | |||
new[] { input }, | |||
new[] { output }, | |||
null); | |||
} | |||
return (Tensor input) => | |||
{ | |||
@@ -48,11 +49,11 @@ namespace Tensorflow.Graphs | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
var func_handle = graph.ToGraph(opers, | |||
new Operation[] { input1, input2 }, | |||
new Operation[] { output }, | |||
new[] { input1, input2 }, | |||
new[] { output }, | |||
null); | |||
} | |||
return (Tensor a, Tensor b) => | |||
{ | |||
var result = tf.Runner.TFE_Execute(tf.Context, | |||
@@ -3,6 +3,7 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Functions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Graphs | |||
@@ -10,10 +11,10 @@ namespace Tensorflow.Graphs | |||
[AllowChangingInputArguments] | |||
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect | |||
{ | |||
FuncGraph graph; | |||
ConcreteFunction function; | |||
Tensors originalInputs; | |||
string func_name; | |||
static Dictionary<string, Func<Tensors, Tensors>> functions = new Dictionary<string, Func<Tensors, Tensors>>(); | |||
static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>(); | |||
public override void OnEntry(MethodExecutionArgs args) | |||
{ | |||
@@ -21,22 +22,24 @@ namespace Tensorflow.Graphs | |||
if (functions.ContainsKey(func_name)) | |||
{ | |||
function = functions[func_name]; | |||
if (args.Arguments[0] is Tensors tensor_inputs) | |||
args.ReturnValue = functions[func_name](tensor_inputs.ToArray()); | |||
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs)); | |||
else | |||
args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray()); | |||
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray())); | |||
args.FlowBehavior = FlowBehavior.Return; | |||
return; | |||
} | |||
// make function as an Operation by autograph | |||
graph = new FuncGraph(func_name); | |||
// need to restore mode when exits | |||
function = new ConcreteFunction(func_name); | |||
// convert to Tensors | |||
if (args.Arguments[0] is Tensors inputs) | |||
{ | |||
originalInputs = inputs; | |||
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray(); | |||
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape, name: "inputs")).ToArray(); | |||
args.Arguments[0] = new Tensors(new_inputs); | |||
} | |||
else | |||
@@ -48,7 +51,7 @@ namespace Tensorflow.Graphs | |||
if (args.Arguments[i] is EagerTensor tensor) | |||
{ | |||
originalInputs[i] = tensor; | |||
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape); | |||
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | |||
} | |||
} | |||
} | |||
@@ -56,58 +59,30 @@ namespace Tensorflow.Graphs | |||
public override void OnExit(MethodExecutionArgs args) | |||
{ | |||
var returnValue = mark_as_return(args.ReturnValue as Tensors); | |||
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); | |||
if (args.ReturnValue is Tensors outputs) | |||
{ | |||
if (args.Arguments[0] is Tensors inputs) | |||
{ | |||
graph.ToGraph(opers, | |||
inputs.Select(x => x.op).ToArray(), | |||
outputs.Select(x => x.op).ToArray(), | |||
null); | |||
} | |||
function.ToGraph(inputs, outputs); | |||
else | |||
{ | |||
graph.ToGraph(opers, | |||
args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||
outputs.Select(x => x.op).ToArray(), | |||
null); | |||
} | |||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs); | |||
} | |||
else | |||
{ | |||
graph.ToGraph(opers, | |||
args.Arguments.Select(x => (x as Tensor).op).ToArray(), | |||
new Operation[] { (args.ReturnValue as Tensor).op }, | |||
null); | |||
} | |||
graph.Dispose(); | |||
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor); | |||
Func<Tensors, Tensors> function = (x) => | |||
{ | |||
var result = tf.Runner.TFE_Execute(tf.Context, | |||
tf.Context.DeviceName, | |||
func_name, | |||
x, | |||
null, | |||
1); | |||
return result[0]; | |||
}; | |||
// cache function. | |||
function.ReturnType = args.ReturnValue.GetType(); | |||
functions[func_name] = function; | |||
// run function | |||
args.ReturnValue = function(originalInputs); | |||
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs)); | |||
} | |||
Tensor mark_as_return(Tensor tensor) | |||
object ConvertReturnValue(Tensors tensors) | |||
{ | |||
return array_ops.identity(tensor); | |||
if (function.ReturnType == typeof(Tensor)) | |||
return (Tensor)tensors; | |||
else | |||
return tensors; | |||
} | |||
} | |||
} |
@@ -10,13 +10,18 @@ namespace Tensorflow.Graphs | |||
/// </summary> | |||
public class FuncGraph : Graph | |||
{ | |||
List<Operation> inputs; | |||
List<Operation> outputs; | |||
Graph outer_graph; | |||
public Graph OuterGraph => outer_graph; | |||
string func_name; | |||
// _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
IntPtr func_handle; | |||
public string FuncName => func_name; | |||
public Tensors Inputs { get; set; } | |||
public Tensors Outputs { get; set; } | |||
/// <summary> | |||
/// Construct a new FuncGraph. | |||
/// </summary> | |||
@@ -29,8 +34,17 @@ namespace Tensorflow.Graphs | |||
as_default(); | |||
} | |||
public FuncGraph(IntPtr handle, string name) | |||
{ | |||
outer_graph = ops.get_default_graph(); | |||
func_name = name; | |||
tf.Context.graph_mode(); | |||
as_default(); | |||
} | |||
public IntPtr ToGraph(Operation[] opers, | |||
Operation[] inputs, Operation[] outputs, | |||
Tensor[] inputs, Tensor[] outputs, | |||
string[] output_names) | |||
{ | |||
using var status = new Status(); | |||
@@ -40,9 +54,9 @@ namespace Tensorflow.Graphs | |||
opers.Length, | |||
opers.Select(x => (IntPtr)x).ToArray(), | |||
inputs.Length, | |||
inputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
outputs.Length, | |||
outputs.Select(x => new TF_Output(x, 0)).ToArray(), | |||
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(), | |||
output_names == null || output_names.Length == 0 ? null : output_names, | |||
IntPtr.Zero, | |||
null, | |||
@@ -57,13 +71,18 @@ namespace Tensorflow.Graphs | |||
func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||
Inputs = inputs; | |||
// mark_as_return | |||
Outputs = outputs.Select(x => array_ops.identity(x)).ToArray(); | |||
tf.Context.restore_mode(); | |||
return func_handle; | |||
} | |||
protected override void DisposeManagedResources() | |||
{ | |||
base.DisposeManagedResources(); | |||
tf.Context.restore_mode(); | |||
} | |||
} | |||
} |
@@ -69,7 +69,7 @@ namespace Tensorflow | |||
throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); | |||
var defaultKey = tf.get_default_graph().graph_key; | |||
if (graph_key != defaultKey) | |||
if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey) | |||
{ | |||
//Console.WriteLine($"Current graph is not default graph."); | |||
throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); | |||
@@ -218,6 +218,8 @@ namespace Tensorflow | |||
{ | |||
case nameof(Int32): | |||
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||
case nameof(Int64): | |||
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); | |||
default: | |||
return null; | |||
} | |||
@@ -235,6 +235,14 @@ namespace Tensorflow | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); | |||
if (tf.Runner.MustRecordGradient()) | |||
{ | |||
tf.Runner.RecordGradient("Identity", _op.inputs, new object[] | |||
{ | |||
"T", _op.get_attr<TF_DataType>("T") | |||
}, _op.outputs); | |||
} | |||
return _op.output; | |||
} | |||
@@ -632,8 +640,8 @@ namespace Tensorflow | |||
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | |||
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, | |||
int shrink_axis_mask = 0, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||
=> tf.Context.RunInAutoMode2( | |||
() => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new | |||
{ | |||
shape, | |||
begin, | |||
@@ -645,8 +653,8 @@ namespace Tensorflow | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
}).output, | |||
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"StridedSliceGrad", name, | |||
null, | |||
shape, begin, end, strides, dy, | |||
@@ -654,8 +662,22 @@ namespace Tensorflow | |||
"end_mask", end_mask, | |||
"ellipsis_mask", ellipsis_mask, | |||
"new_axis_mask", new_axis_mask, | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
shape, begin, end, strides, dy); | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
(op) => | |||
{ | |||
var attrs = new object[] | |||
{ | |||
"T", op.get_attr<TF_DataType>("T"), | |||
"Index", op.get_attr<TF_DataType>("Index"), | |||
"begin_mask", op.get_attr<long>("begin_mask"), | |||
"end_mask", op.get_attr<long>("end_mask"), | |||
"ellipsis_mask", op.get_attr<long>("ellipsis_mask"), | |||
"new_axis_mask", op.get_attr<long>("new_axis_mask"), | |||
"shrink_axis_mask", op.get_attr<long>("shrink_axis_mask") | |||
}; | |||
tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs); | |||
}, | |||
new Tensors(shape, begin, end, strides, dy)); | |||
/// <summary> | |||
/// Removes dimensions of size 1 from the shape of a tensor. | |||
@@ -23,6 +23,8 @@ using Tensorflow.Gradients; | |||
namespace Tensorflow | |||
{ | |||
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||
public partial class tensorflow : ITensorFlowObject | |||
{ | |||
public TF_DataType byte8 = TF_DataType.TF_UINT8; | |||
@@ -37,8 +39,6 @@ namespace Tensorflow | |||
public TF_DataType chars = TF_DataType.TF_STRING; | |||
public TF_DataType @string = TF_DataType.TF_STRING; | |||
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients); | |||
public Status Status; | |||
public OpDefLibrary OpDefLib; | |||
public Context Context; | |||
@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Engine | |||
// Record the gradient because custom-made ops don't go through the | |||
// code-gen'd eager call path | |||
var op_type = op.node_def.Name; | |||
var op_type = op.node_def.Op; | |||
tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | |||
@@ -1,9 +1,9 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.KerasApi; | |||
using static Tensorflow.Binding; | |||
using System.Collections.Generic; | |||
using System; | |||
using System.Linq; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
@@ -26,9 +26,20 @@ namespace Tensorflow.Keras.Layers | |||
var result = array_ops.reshape(inputs, shape.ToArray()); | |||
if (!tf.Context.executing_eagerly()) | |||
// result = result.set_shape(compute_output_shape(inputs.shape)); | |||
throw new NotImplementedException(""); | |||
result.set_shape(compute_output_shape(inputs.shape)); | |||
return result; | |||
} | |||
TensorShape compute_output_shape(TensorShape input_shape) | |||
{ | |||
if (input_shape.dims[0] == -1) | |||
{ | |||
input_shape = input_shape.dims[0]; | |||
var output_shape = input_shape.concatenate(args.TargetShape.dims); | |||
return output_shape; | |||
} | |||
else | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -1,6 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.ManagedAPI | |||
@@ -36,7 +37,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
/// <param name="a"></param> | |||
/// <param name="b"></param> | |||
/// <returns></returns> | |||
// [AutoGraph] | |||
[AutoGraph] | |||
Tensor Mul(Tensor a, Tensor b) | |||
{ | |||
return a * b; | |||