Browse Source

Massive updates for TensorFlowOpLayer #652

tags/v0.30
Oceania2018 4 years ago
parent
commit
fafed7dd7d
19 changed files with 423 additions and 94 deletions
  1. +57
    -7
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  2. +44
    -0
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  3. +25
    -0
      src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs
  4. +38
    -0
      src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs
  5. +120
    -0
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Functions/c_api.function.cs
  7. +0
    -2
      src/TensorFlowNET.Core/Gradients/ITape.cs
  8. +7
    -5
      src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
  9. +28
    -10
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  10. +6
    -5
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  11. +20
    -45
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  12. +25
    -6
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  14. +2
    -0
      src/TensorFlowNET.Core/Operations/Operation.cs
  15. +28
    -6
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  16. +2
    -2
      src/TensorFlowNET.Core/tensorflow.cs
  17. +1
    -1
      src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs
  18. +14
    -3
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
  19. +2
    -1
      test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs

+ 57
- 7
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Framework.Models; using Tensorflow.Framework.Models;
using Tensorflow.Graphs; using Tensorflow.Graphs;
@@ -11,11 +12,34 @@ namespace Tensorflow.Functions
/// </summary> /// </summary>
public class ConcreteFunction : IDisposable public class ConcreteFunction : IDisposable
{ {
public string Name => _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
IntPtr _handle; IntPtr _handle;
FuncGraph func_graph;

public string Name
{
get
{
if (func_graph != null)
return func_graph.FuncName;

return _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
}
}
public Tensor[] Outputs; public Tensor[] Outputs;
public Type ReturnType;
public TensorSpec[] OutputStructure; public TensorSpec[] OutputStructure;


public ConcreteFunction(string name)
{
func_graph = new FuncGraph(name);
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
{
func_graph = graph;
}

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
{ {
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
@@ -28,8 +52,8 @@ namespace Tensorflow.Functions


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers, _handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { output },
new[] { input },
new[] { output },
null); null);
} }
} }
@@ -48,8 +72,8 @@ namespace Tensorflow.Functions


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers, _handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { output.variant_tensor.op },
new[] { input },
new[] { output.variant_tensor },
null); null);
} }
} }
@@ -72,12 +96,38 @@ namespace Tensorflow.Functions


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = graph.ToGraph(opers, _handle = graph.ToGraph(opers,
new Operation[] { input1, input2, input3 },
new Operation[] { outputs.Item1.op, outputs.Item2.op },
new[] { input1, input2, input3 },
new[] { outputs.Item1, outputs.Item2 },
null); null);
} }
} }


public void ToGraph(Tensors inputs, Tensors outputs)
{
var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
_handle = func_graph.ToGraph(opers,
inputs,
outputs,
null);
}

public Tensors Invoke(Tensors inputs)
{
var forward_backward = SelectForwardAndBackwardFunctions(inputs, 1, tf.Context.executing_eagerly());
var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null;
if (tf.Context.executing_eagerly())
flat_outputs = forward_function.Call(args_with_tangents);
forward_backward.Record(flat_outputs);
return flat_outputs;
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);
return new ForwardBackwardCall(functions, args, tape_watching: true);
}

public void Dispose() public void Dispose()
{ {
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);


+ 44
- 0
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -0,0 +1,44 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
{
public class EagerDefinedFunction
{
public int _num_outputs;
public string Name => _func_graph.FuncName;

FuncGraph _func_graph;
public EagerDefinedFunction(string name, FuncGraph graph,
Tensors inputs, Tensors outputs,
Dictionary<string, string> attrs)
{
_num_outputs = outputs.Length;
var input_ops = inputs.Select(x => x.op).ToArray();
var operations = graph.get_operations().Where(x => !input_ops.Contains(x.op))
.Select(x => x as Operation).ToArray();
var output_names = new string[0];

_func_graph = new FuncGraph(graph, name, attrs);
_func_graph.ToGraph(operations, inputs, outputs, output_names);
}

public Tensors Call(Tensors args)
{
var results = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
_func_graph.FuncName,
args,
null,
_num_outputs);

return results;
}
}
}

+ 25
- 0
src/TensorFlowNET.Core/Functions/FirstOrderTapeGradientFunctions.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Graphs;

namespace Tensorflow.Functions
{
public class FirstOrderTapeGradientFunctions : TapeGradientFunctions
{
public FirstOrderTapeGradientFunctions(FuncGraph func_graph,
bool need_gradients_for_jvps) : base(func_graph,
need_gradients_for_jvps)
{

}

public override EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
{
var outputs = _func_graph.Outputs;
(_forward, _forward_graph, _backward, _forwardprop_output_indices, _num_forwardprop_outputs)
= BuildFunctionsForOutputs(outputs, inference_args);
return _forward;
}
}
}

+ 38
- 0
src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs View File

@@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Functions
{
/// <summary>
/// Holds the state of a function call between execution and recording.
/// </summary>
public class ForwardBackwardCall
{
TapeGradientFunctions _functions;
Tensors _inference_args;
Tensors _input_tangents;
bool _tape_watching;

public ForwardBackwardCall(TapeGradientFunctions functions,
Tensors inference_args,
bool tape_watching)
{
_functions = functions;
_inference_args = inference_args;
_tape_watching = tape_watching;
}

public (EagerDefinedFunction, Tensors) Forward()
{
var forward_function = _functions.Forward(_inference_args);
return (forward_function, _inference_args);
}

public void Record(Tensors flat_outputs)
{
if (_tape_watching && flat_outputs != null)
_functions.Record(flat_outputs, _inference_args);
}
}
}

+ 120
- 0
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -0,0 +1,120 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Graphs;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;

namespace Tensorflow.Functions
{
/// <summary>
/// Caches forward and backward functions compatible with eager gradients.
/// </summary>
public abstract class TapeGradientFunctions
{
string FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name";
string BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name";
string _FORWARD_PREFIX = "__forward_";
string _BACKWARD_PREFIX = "__backward_";
string _INFERENCE_PREFIX = "__inference_";

protected FuncGraph _func_graph;
protected EagerDefinedFunction _forward;
protected FuncGraph _forward_graph;
protected List<int> _forwardprop_output_indices;
protected int _num_forwardprop_outputs;
protected ConcreteFunction _backward;

public TapeGradientFunctions(FuncGraph func_graph,
bool need_gradients_for_jvps)
{
_func_graph = func_graph;
}

public EagerDefinedFunction Forward(Tensors inference_args)
{
return ForwardAndBackwardFunctions(inference_args);
}

/// <summary>
/// Record the function call operation.
/// </summary>
/// <param name="flat_outputs"></param>
/// <param name="inference_args"></param>
public void Record(Tensors flat_outputs, Tensors inference_args)
{
var (backward_function, to_record) = _wrap_backward_function(_forward_graph, _backward, flat_outputs);
tf.Runner.RecordGradient(_forward.Name, flat_outputs, new object[0], inference_args,
getBackwardFunction: () => backward_function);
}

(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs)
{
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
return new Tensor[0];

/*var gradients = ops.gradientFunctions[op_name](new EagerOperation
{
Name = op_name,
NumInputs = op_inputs.Length,
Inputs = op_inputs,
NumOutputs = op_outputs.Length,
Outputs = op_outputs,
SkipInputIndices = unneeded_gradients,
Attrs = attrs
}, output_grads);

return gradients;*/
};

return (_backward_function_wrapper, flat_outputs);
}

protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
BuildFunctionsForOutputs(Tensors outputs, Tensors inference_args)
{
var trainable_outputs = new List<Tensor>();
var trainable_indices = new List<int>();
foreach(var (index, output) in enumerate(outputs))
{
if (gradients_util.IsTrainable(output))
{
trainable_outputs.Add(output);
trainable_indices.Add(index);
}
}

var gradients_wrt_outputs = new List<Tensor>();
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}");
foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
_func_graph.Inputs,
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);

tf.Context.restore_mode();

var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
backwards_graph.Inputs = gradients_wrt_outputs;
backwards_graph.Outputs = gradients_wrt_inputs;

var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);

var forward_function_attr = new Dictionary<string, string>();
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,
_func_graph.Inputs, _func_graph.Outputs, forward_function_attr);
return (forward_function, _func_graph, backward_function, null, 0);
}

public virtual EagerDefinedFunction ForwardAndBackwardFunctions(Tensors inference_args)
{
throw new NotImplementedException("");
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Functions/c_api.function.cs View File

@@ -47,6 +47,9 @@ namespace Tensorflow
string description, string description,
SafeStatusHandle status); SafeStatusHandle status);


[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_FunctionSetAttrValueProto(IntPtr func, string attr_name, byte[] proto, int proto_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_FunctionName(IntPtr func); public static extern IntPtr TF_FunctionName(IntPtr func);




+ 0
- 2
src/TensorFlowNET.Core/Gradients/ITape.cs View File

@@ -13,8 +13,6 @@ namespace Tensorflow.Gradients
void RecordOperation(string op_type, void RecordOperation(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter); Func<BackwardFunction> backward_function_getter);


void VariableAccessed(ResourceVariable variable); void VariableAccessed(ResourceVariable variable);


+ 7
- 5
src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using Tensorflow.Util; using Tensorflow.Util;
using static Tensorflow.tensorflow; using static Tensorflow.tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Linq;


namespace Tensorflow.Gradients namespace Tensorflow.Gradients
{ {
@@ -14,18 +15,19 @@ namespace Tensorflow.Gradients
public void RecordOperation(string op_type, public void RecordOperation(string op_type,
Tensor[] input_tensors, Tensor[] input_tensors,
TapeTensor[] output_tensors, TapeTensor[] output_tensors,
long[] input_tensor_id,
TF_DataType[] input_dtypes,
Func<BackwardFunction> backward_function_getter) Func<BackwardFunction> backward_function_getter)
{ {
if (!ShouldRecord(input_tensor_id, input_dtypes))
var input_ids = input_tensors.Select(x => x.Id).ToArray();
var input_dtypes = input_tensors.Select(x => x.dtype).ToArray();

if (!ShouldRecord(input_ids, input_dtypes))
{ {
return; return;
} }


long op_id = next_op_id_++; long op_id = next_op_id_++;
var ids = new List<long>(input_tensor_id.Length);
foreach (var i in input_tensor_id)
var ids = new List<long>(input_ids.Length);
foreach (var i in input_ids)
{ {
tensor_usage_[i]++; tensor_usage_[i]++;
ids.Add(i); ids.Add(i);


+ 28
- 10
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -17,6 +17,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Graphs;
using Tensorflow.Operations.ControlFlows; using Tensorflow.Operations.ControlFlows;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -39,7 +40,14 @@ namespace Tensorflow


// If src_graph is a _FuncGraph (i.e. a function body), gather it and all // If src_graph is a _FuncGraph (i.e. a function body), gather it and all
// ancestor graphs. This is necessary for correctly handling captured values. // ancestor graphs. This is necessary for correctly handling captured values.
var func_graphs = new List<FuncGraph>();
var curr_graph = src_graph; var curr_graph = src_graph;
if (src_graph is FuncGraph func_graph)
{
func_graphs.append(func_graph);
curr_graph = func_graph.OuterGraph;
}


if (stop_gradients == null) if (stop_gradients == null)
stop_gradients = new Tensor[0]; stop_gradients = new Tensor[0];
@@ -84,7 +92,7 @@ namespace Tensorflow
var to_ops = ys.Select(x => x.op).ToList(); var to_ops = ys.Select(x => x.op).ToList();
var from_ops = xs.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList();
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
(reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs , xs);


// Add the initial gradients for the ys. // Add the initial gradients for the ys.
foreach (var (y, grad_y) in zip(ys, grad_ys)) foreach (var (y, grad_y) in zip(ys, grad_ys))
@@ -258,11 +266,8 @@ namespace Tensorflow
{ {
var new_grad_ys = new List<Tensor>(); var new_grad_ys = new List<Tensor>();


for (int i = 0; i < grad_ys.Length; i++)
foreach(var (i, (y, grad_y)) in enumerate(zip(ys, grad_ys)))
{ {
var grad_y = grad_ys[i];
var y = ys[i];

_maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops); _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops);


if (grad_y == null) if (grad_y == null)
@@ -272,8 +277,17 @@ namespace Tensorflow
var shape = array_ops.shape(y); var shape = array_ops.shape(y);
var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}");
var fill = gen_array_ops.fill(shape, constant); var fill = gen_array_ops.fill(shape, constant);
new_grad_ys.Add(fill);
new_grad_ys.append(fill);
continue;
} }

if (y.dtype.is_floating() || y.dtype.is_integer())
{

}

// Create a grad_y tensor in the name scope of the gradient.
new_grad_ys.append(array_ops.identity(grad_y, name: $"grad_ys_{i}"));
} }


return new_grad_ys.ToArray(); return new_grad_ys.ToArray();
@@ -294,7 +308,11 @@ namespace Tensorflow
/// <param name="colocate_gradients_with_ops"></param> /// <param name="colocate_gradients_with_ops"></param>
/// <param name="func_graphs"></param> /// <param name="func_graphs"></param>
/// <param name="xs"></param> /// <param name="xs"></param>
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, Tensor[] xs)
private static (Operation[], Dictionary<string, int>, ControlFlowState) _PendingCount(List<Operation> to_ops,
List<Operation> from_ops,
bool colocate_gradients_with_ops,
List<FuncGraph> func_graphs,
Tensor[] xs)
{ {
// Mark reachable ops from from_ops. // Mark reachable ops from from_ops.
var reached_ops = new List<Operation>(); var reached_ops = new List<Operation>();
@@ -511,7 +529,7 @@ namespace Tensorflow
/// <param name="from_ops"></param> /// <param name="from_ops"></param>
/// <param name="reached_ops"></param> /// <param name="reached_ops"></param>
/// <param name="func_graphs"></param> /// <param name="func_graphs"></param>
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs)
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<FuncGraph> func_graphs)
{ {
Queue<Operation> queue = new Queue<Operation>(from_ops); Queue<Operation> queue = new Queue<Operation>(from_ops);
while (queue.Count > 0) while (queue.Count > 0)
@@ -538,7 +556,7 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="t"></param> /// <param name="t"></param>
/// <param name="func_graphs"></param> /// <param name="func_graphs"></param>
private static Operation[] _Consumers(Tensor t, List<object> func_graphs)
private static Operation[] _Consumers(Tensor t, List<FuncGraph> func_graphs)
{ {
return t.consumers(); return t.consumers();
} }
@@ -647,7 +665,7 @@ namespace Tensorflow
} }
} }


private static bool IsTrainable(Tensor tensor)
public static bool IsTrainable(Tensor tensor)
{ {
var dtype = tensor.dtype.as_base_dtype(); var dtype = tensor.dtype.as_base_dtype();
return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64, return new TF_DataType[] { dtypes.float16, dtypes.float32, dtypes.float64,


+ 6
- 5
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -18,10 +18,11 @@ namespace Tensorflow.Graphs


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
var func_handle = graph.ToGraph(opers, var func_handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { output },
new[] { input },
new[] { output },
null); null);
} }


return (Tensor input) => return (Tensor input) =>
{ {
@@ -48,11 +49,11 @@ namespace Tensorflow.Graphs


var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
var func_handle = graph.ToGraph(opers, var func_handle = graph.ToGraph(opers,
new Operation[] { input1, input2 },
new Operation[] { output },
new[] { input1, input2 },
new[] { output },
null); null);
} }
return (Tensor a, Tensor b) => return (Tensor a, Tensor b) =>
{ {
var result = tf.Runner.TFE_Execute(tf.Context, var result = tf.Runner.TFE_Execute(tf.Context,


+ 20
- 45
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Functions;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Graphs namespace Tensorflow.Graphs
@@ -10,10 +11,10 @@ namespace Tensorflow.Graphs
[AllowChangingInputArguments] [AllowChangingInputArguments]
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
{ {
FuncGraph graph;
ConcreteFunction function;
Tensors originalInputs; Tensors originalInputs;
string func_name; string func_name;
static Dictionary<string, Func<Tensors, Tensors>> functions = new Dictionary<string, Func<Tensors, Tensors>>();
static Dictionary<string, ConcreteFunction> functions = new Dictionary<string, ConcreteFunction>();


public override void OnEntry(MethodExecutionArgs args) public override void OnEntry(MethodExecutionArgs args)
{ {
@@ -21,22 +22,24 @@ namespace Tensorflow.Graphs


if (functions.ContainsKey(func_name)) if (functions.ContainsKey(func_name))
{ {
function = functions[func_name];
if (args.Arguments[0] is Tensors tensor_inputs) if (args.Arguments[0] is Tensors tensor_inputs)
args.ReturnValue = functions[func_name](tensor_inputs.ToArray());
args.ReturnValue = ConvertReturnValue(function.Invoke(tensor_inputs));
else else
args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray());
args.ReturnValue = ConvertReturnValue(function.Invoke(args.Arguments.Select(x => x as Tensor).ToArray()));
args.FlowBehavior = FlowBehavior.Return; args.FlowBehavior = FlowBehavior.Return;
return; return;
} }


// make function as an Operation by autograph // make function as an Operation by autograph
graph = new FuncGraph(func_name);
// need to restore mode when exits
function = new ConcreteFunction(func_name);


// convert to Tensors // convert to Tensors
if (args.Arguments[0] is Tensors inputs) if (args.Arguments[0] is Tensors inputs)
{ {
originalInputs = inputs; originalInputs = inputs;
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray();
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape, name: "inputs")).ToArray();
args.Arguments[0] = new Tensors(new_inputs); args.Arguments[0] = new Tensors(new_inputs);
} }
else else
@@ -48,7 +51,7 @@ namespace Tensorflow.Graphs
if (args.Arguments[i] is EagerTensor tensor) if (args.Arguments[i] is EagerTensor tensor)
{ {
originalInputs[i] = tensor; originalInputs[i] = tensor;
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs");
} }
} }
} }
@@ -56,58 +59,30 @@ namespace Tensorflow.Graphs


public override void OnExit(MethodExecutionArgs args) public override void OnExit(MethodExecutionArgs args)
{ {
var returnValue = mark_as_return(args.ReturnValue as Tensors);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();

if (args.ReturnValue is Tensors outputs) if (args.ReturnValue is Tensors outputs)
{ {
if (args.Arguments[0] is Tensors inputs) if (args.Arguments[0] is Tensors inputs)
{
graph.ToGraph(opers,
inputs.Select(x => x.op).ToArray(),
outputs.Select(x => x.op).ToArray(),
null);
}
function.ToGraph(inputs, outputs);
else else
{
graph.ToGraph(opers,
args.Arguments.Select(x => (x as Tensor).op).ToArray(),
outputs.Select(x => x.op).ToArray(),
null);
}
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), outputs);
} }
else else
{
graph.ToGraph(opers,
args.Arguments.Select(x => (x as Tensor).op).ToArray(),
new Operation[] { (args.ReturnValue as Tensor).op },
null);
}

graph.Dispose();
function.ToGraph(args.Arguments.Select(x => x as Tensor).ToArray(), args.ReturnValue as Tensor);


Func<Tensors, Tensors> function = (x) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
x,
null,
1);

return result[0];
};
// cache function. // cache function.
function.ReturnType = args.ReturnValue.GetType();
functions[func_name] = function; functions[func_name] = function;


// run function // run function
args.ReturnValue = function(originalInputs);
args.ReturnValue = ConvertReturnValue(function.Invoke(originalInputs));
} }


Tensor mark_as_return(Tensor tensor)
object ConvertReturnValue(Tensors tensors)
{ {
return array_ops.identity(tensor);
if (function.ReturnType == typeof(Tensor))
return (Tensor)tensors;
else
return tensors;
} }
} }
} }

+ 25
- 6
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -10,13 +10,18 @@ namespace Tensorflow.Graphs
/// </summary> /// </summary>
public class FuncGraph : Graph public class FuncGraph : Graph
{ {
List<Operation> inputs;
List<Operation> outputs;
Graph outer_graph; Graph outer_graph;
public Graph OuterGraph => outer_graph;

string func_name; string func_name;

// _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle));
IntPtr func_handle; IntPtr func_handle;
public string FuncName => func_name; public string FuncName => func_name;


public Tensors Inputs { get; set; }
public Tensors Outputs { get; set; }

/// <summary> /// <summary>
/// Construct a new FuncGraph. /// Construct a new FuncGraph.
/// </summary> /// </summary>
@@ -29,8 +34,17 @@ namespace Tensorflow.Graphs
as_default(); as_default();
} }


public FuncGraph(IntPtr handle, string name)
{
outer_graph = ops.get_default_graph();
func_name = name;

tf.Context.graph_mode();
as_default();
}

public IntPtr ToGraph(Operation[] opers, public IntPtr ToGraph(Operation[] opers,
Operation[] inputs, Operation[] outputs,
Tensor[] inputs, Tensor[] outputs,
string[] output_names) string[] output_names)
{ {
using var status = new Status(); using var status = new Status();
@@ -40,9 +54,9 @@ namespace Tensorflow.Graphs
opers.Length, opers.Length,
opers.Select(x => (IntPtr)x).ToArray(), opers.Select(x => (IntPtr)x).ToArray(),
inputs.Length, inputs.Length,
inputs.Select(x => new TF_Output(x, 0)).ToArray(),
inputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
outputs.Length, outputs.Length,
outputs.Select(x => new TF_Output(x, 0)).ToArray(),
outputs.Select(x => new TF_Output(x.op, 0)).ToArray(),
output_names == null || output_names.Length == 0 ? null : output_names, output_names == null || output_names.Length == 0 ? null : output_names,
IntPtr.Zero, IntPtr.Zero,
null, null,
@@ -57,13 +71,18 @@ namespace Tensorflow.Graphs


func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle));


Inputs = inputs;
// mark_as_return
Outputs = outputs.Select(x => array_ops.identity(x)).ToArray();

tf.Context.restore_mode();

return func_handle; return func_handle;
} }


protected override void DisposeManagedResources() protected override void DisposeManagedResources()
{ {
base.DisposeManagedResources(); base.DisposeManagedResources();
tf.Context.restore_mode();
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -69,7 +69,7 @@ namespace Tensorflow
throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\"."); throw new ValueError($"Could not find operation \"{operName}\" inside graph \"{_graph_key}\".");


var defaultKey = tf.get_default_graph().graph_key; var defaultKey = tf.get_default_graph().graph_key;
if (graph_key != defaultKey)
if (tf.get_default_graph().GetType().Name == "Graph" && graph_key != defaultKey)
{ {
//Console.WriteLine($"Current graph is not default graph."); //Console.WriteLine($"Current graph is not default graph.");
throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}"); throw new RuntimeError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}");


+ 2
- 0
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -218,6 +218,8 @@ namespace Tensorflow
{ {
case nameof(Int32): case nameof(Int32):
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray(); return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray();
case nameof(Int64):
return x.List.I.Select(x => (T)Convert.ChangeType(x, typeof(T))).ToArray();
default: default:
return null; return null;
} }


+ 28
- 6
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -235,6 +235,14 @@ namespace Tensorflow
} }


var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input });
if (tf.Runner.MustRecordGradient())
{
tf.Runner.RecordGradient("Identity", _op.inputs, new object[]
{
"T", _op.get_attr<TF_DataType>("T")
}, _op.outputs);
}


return _op.output; return _op.output;
} }
@@ -632,8 +640,8 @@ namespace Tensorflow
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0,
int shrink_axis_mask = 0, string name = null) int shrink_axis_mask = 0, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
{ {
shape, shape,
begin, begin,
@@ -645,8 +653,8 @@ namespace Tensorflow
ellipsis_mask, ellipsis_mask,
new_axis_mask, new_axis_mask,
shrink_axis_mask shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StridedSliceGrad", name, "StridedSliceGrad", name,
null, null,
shape, begin, end, strides, dy, shape, begin, end, strides, dy,
@@ -654,8 +662,22 @@ namespace Tensorflow
"end_mask", end_mask, "end_mask", end_mask,
"ellipsis_mask", ellipsis_mask, "ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask, "new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
shape, begin, end, strides, dy);
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T"),
"Index", op.get_attr<TF_DataType>("Index"),
"begin_mask", op.get_attr<long>("begin_mask"),
"end_mask", op.get_attr<long>("end_mask"),
"ellipsis_mask", op.get_attr<long>("ellipsis_mask"),
"new_axis_mask", op.get_attr<long>("new_axis_mask"),
"shrink_axis_mask", op.get_attr<long>("shrink_axis_mask")
};
tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs);
},
new Tensors(shape, begin, end, strides, dy));


/// <summary> /// <summary>
/// Removes dimensions of size 1 from the shape of a tensor. /// Removes dimensions of size 1 from the shape of a tensor.


+ 2
- 2
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -23,6 +23,8 @@ using Tensorflow.Gradients;


namespace Tensorflow namespace Tensorflow
{ {
public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients);

public partial class tensorflow : ITensorFlowObject public partial class tensorflow : ITensorFlowObject
{ {
public TF_DataType byte8 = TF_DataType.TF_UINT8; public TF_DataType byte8 = TF_DataType.TF_UINT8;
@@ -37,8 +39,6 @@ namespace Tensorflow
public TF_DataType chars = TF_DataType.TF_STRING; public TF_DataType chars = TF_DataType.TF_STRING;
public TF_DataType @string = TF_DataType.TF_STRING; public TF_DataType @string = TF_DataType.TF_STRING;


public delegate Tensor[] BackwardFunction(Tensor[] grads, long[] unneeded_gradients);

public Status Status; public Status Status;
public OpDefLibrary OpDefLib; public OpDefLibrary OpDefLib;
public Context Context; public Context Context;


+ 1
- 1
src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs View File

@@ -56,7 +56,7 @@ namespace Tensorflow.Keras.Engine


// Record the gradient because custom-made ops don't go through the // Record the gradient because custom-made ops don't go through the
// code-gen'd eager call path // code-gen'd eager call path
var op_type = op.node_def.Name;
var op_type = op.node_def.Op;


tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs);




+ 14
- 3
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -1,9 +1,9 @@
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using static Tensorflow.KerasApi;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using System.Collections.Generic; using System.Collections.Generic;
using System; using System;
using System.Linq;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
{ {
@@ -26,9 +26,20 @@ namespace Tensorflow.Keras.Layers


var result = array_ops.reshape(inputs, shape.ToArray()); var result = array_ops.reshape(inputs, shape.ToArray());
if (!tf.Context.executing_eagerly()) if (!tf.Context.executing_eagerly())
// result = result.set_shape(compute_output_shape(inputs.shape));
throw new NotImplementedException("");
result.set_shape(compute_output_shape(inputs.shape));
return result; return result;
} }

TensorShape compute_output_shape(TensorShape input_shape)
{
if (input_shape.dims[0] == -1)
{
input_shape = input_shape.dims[0];
var output_shape = input_shape.concatenate(args.TargetShape.dims);
return output_shape;
}
else
throw new NotImplementedException("");
}
} }
} }

+ 2
- 1
test/TensorFlowNET.UnitTest/ManagedAPI/FunctionApiTest.cs View File

@@ -1,6 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using System; using System;
using Tensorflow; using Tensorflow;
using Tensorflow.Graphs;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.ManagedAPI namespace TensorFlowNET.UnitTest.ManagedAPI
@@ -36,7 +37,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
/// <param name="a"></param> /// <param name="a"></param>
/// <param name="b"></param> /// <param name="b"></param>
/// <returns></returns> /// <returns></returns>
// [AutoGraph]
[AutoGraph]
Tensor Mul(Tensor a, Tensor b) Tensor Mul(Tensor a, Tensor b)
{ {
return a * b; return a * b;


Loading…
Cancel
Save