Browse Source

Add FunctionCallOptions.

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
e19e59b3dd
10 changed files with 62 additions and 20 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +20
    -0
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  5. +18
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  6. +8
    -8
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  7. +2
    -2
      src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
  8. +5
    -3
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  9. +4
    -4
      src/TensorFlowNET.Core/Graphs/Graph.cs
  10. +1
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs

+ 2
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -35,6 +35,7 @@ namespace Tensorflow.Contexts
public string ScopeName { get; set; } = "";
bool initialized = false;
ContextSwitchStack context_switches;
public FunctionCallOptions FunctionCallOptions { get; }

public SafeContextHandle Handle { get; }

@@ -44,6 +45,7 @@ namespace Tensorflow.Contexts
status.Check(true);
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
initialized = true;
FunctionCallOptions = new FunctionCallOptions();
}

/// <summary>


+ 20
- 0
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Google.Protobuf;
using Google.Protobuf.Collections;

namespace Tensorflow.Contexts
{
public class FunctionCallOptions
{
public string config_proto_serialized()
{
var config = new ConfigProto
{
AllowSoftPlacement = true,
};
return config.ToByteString().ToStringUtf8();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -371,7 +371,7 @@ namespace Tensorflow.Eager
switch (type)
{
case TF_AttrType.TF_ATTR_STRING:
c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length);
c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length);
break;
case TF_AttrType.TF_ATTR_TYPE:
c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value);


+ 1
- 1
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -241,7 +241,7 @@ namespace Tensorflow
/// <param name="value">const void*</param>
/// <param name="length">size_t</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length);
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values);


+ 18
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow.Functions
{
IntPtr _handle;
FuncGraph func_graph;
public Tensor[] CapturedInputs => func_graph.external_captures();

public string Name
{
@@ -38,6 +39,8 @@ namespace Tensorflow.Functions
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
{
func_graph = graph;

ToGraph(graph.Inputs, graph.Outputs);
}

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
@@ -124,6 +127,21 @@ namespace Tensorflow.Functions
return flat_outputs;
}

public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs)
{
var new_args = new List<Tensor>();
new_args.AddRange(args);
new_args.AddRange(captured_inputs);
args = new_args.ToArray();

var attrs = new object[]
{
"executor_type", "",
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized()
};
return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs);
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);


+ 8
- 8
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -48,7 +48,7 @@ namespace Tensorflow.Functions
getBackwardFunction: () => backward_function);
}

(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs)
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
@@ -61,10 +61,11 @@ namespace Tensorflow.Functions
processed_args.add(arg);
input_index += 1;
}
return output_grads;// backward.Invoke(processed_args.ToArray());

return backward.CallFlat(processed_args.ToArray(), outputs);
};

return (_backward_function_wrapper, flat_outputs);
return (_backward_function_wrapper, outputs);
}

protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int)
@@ -82,7 +83,7 @@ namespace Tensorflow.Functions
}

var gradients_wrt_outputs = new List<Tensor>();
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}");
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}");
foreach (var output in trainable_outputs)
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape));
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(),
@@ -90,16 +91,15 @@ namespace Tensorflow.Functions
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);

tf.Context.restore_mode();

var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}";
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
gradients_wrt_outputs.append(backwards_graph.internal_captures());
backwards_graph.Inputs = gradients_wrt_outputs;
backwards_graph.Outputs = gradients_wrt_inputs;

var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr);
var forward_function_attr = new Dictionary<string, string>();
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name;
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph,


+ 2
- 2
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs View File

@@ -49,14 +49,14 @@ namespace Tensorflow
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
(oper, out_grads) =>
{
tf.Logger.Debug($"Caculate Gradient: {m.Name}");
tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}");
var results = g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[];
foreach (var result in results.Where(x => x != null))
tf.Logger.Debug($"{result.TensorShape}");
tf.Logger.Debug($"Gradient: {result.name} {result.TensorShape}");
return results;
}
);


+ 5
- 3
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -26,7 +26,9 @@ namespace Tensorflow.Graphs
public Tensors Outputs { get; set; }
public Dictionary<string, string> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray();

/// <summary>
/// Construct a new FuncGraph.
/// </summary>
@@ -129,7 +131,7 @@ namespace Tensorflow.Graphs
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
{
Tensor placeholder = null;
if (!_captures.ContainsKey(tensor.Id))
if (!_captures.Contains(tensor.Id))
{
placeholder = _create_substitute_placeholder(tensor,
name: name,
@@ -139,7 +141,7 @@ namespace Tensorflow.Graphs
}
else
{
placeholder = _captures[tensor.Id].Item1;
placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2;
}

BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>


+ 4
- 4
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -557,16 +557,16 @@ namespace Tensorflow

public Tensor[] external_captures()
{
Tensor[] captures = new Tensor[this._captures.Count];
ICollection inner = this._captures.Keys; // c[0]
Tensor[] captures = new Tensor[_captures.Count];
ICollection inner = _captures.Keys; // c[0]
inner.CopyTo(captures, 0);
return captures;
}

public Tensor[] internal_captures()
{
Tensor[] captures = new Tensor[this._captures.Count];
ICollection inner = this._captures.Values; // c[1]
Tensor[] captures = new Tensor[_captures.Count];
ICollection inner = _captures.Values; // c[1]
inner.CopyTo(captures, 0);
return captures;
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}");
var outputs = node.Layer.Apply(layer_inputs, is_training: training);
foreach (var output in outputs.Where(x => x != null))
tf.Logger.Debug($"{output.TensorShape}");
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
// Update tensor_dict for next input
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));


Loading…
Cancel
Save