Browse Source

fix Keras Functional Inputs #624

tags/v0.30
Oceania2018 5 years ago
parent
commit
b6f155c71d
15 changed files with 161 additions and 55 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Eager/IEagerRunner.cs
  4. +0
    -5
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  5. +30
    -4
      src/TensorFlowNET.Core/Graphs/AutoGraph.cs
  6. +7
    -7
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  7. +12
    -1
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  8. +25
    -21
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  9. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  10. +32
    -2
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  11. +31
    -5
      src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs
  12. +12
    -9
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  13. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  14. +1
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  15. +1
    -0
      src/TensorFlowNET.Core/ops.name_scope.cs

+ 2
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;

@@ -88,6 +89,7 @@ namespace Tensorflow.Contexts
context_switches.Pop();
}

[DebuggerStepThrough]
public Tensor RunInAutoMode(Func<Tensor> graphAction, Func<Tensor> eagerAction, params Tensor[] tensors)
{
var shouldRunInEager = executing_eagerly()


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Eager
{
public partial class EagerRunner
{
bool RecordGradient(string op_name,
public bool RecordGradient(string op_name,
Tensor[] inputs,
object[] attrs,
Tensor[] results)


+ 5
- 0
src/TensorFlowNET.Core/Eager/IEagerRunner.cs View File

@@ -35,5 +35,10 @@ namespace Tensorflow.Eager
Tensor[] target,
Tensor[] sources,
Tensor[] output_gradients);

bool RecordGradient(string op_name,
Tensor[] inputs,
object[] attrs,
Tensor[] results);
}
}

+ 0
- 5
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -19,12 +19,9 @@ namespace Tensorflow.Functions
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

tf.compat.v1.disable_eager_execution();

// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
graph.as_default();
var input = tf.placeholder(dtype);
var output = func(input);

@@ -34,8 +31,6 @@ namespace Tensorflow.Functions
new Operation[] { output },
null);
}

tf.enable_eager_execution();
}

public Tensor Execute(Tensor arg)


+ 30
- 4
src/TensorFlowNET.Core/Graphs/AutoGraph.cs View File

@@ -9,14 +9,42 @@ namespace Tensorflow.Graphs
{
public class AutoGraph
{
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";

// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
var input = tf.placeholder(tf.int32);
var output = func(input);

var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
var func_handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { output },
null);
}

return (Tensor input) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
};
}

public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
tf.compat.v1.disable_eager_execution();
// IntPtr func_handle;
using(var graph = new FuncGraph(func_name))
{
graph.as_default();
var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32);
var output = func(input1, input2);
@@ -28,8 +56,6 @@ namespace Tensorflow.Graphs
null);
}

tf.enable_eager_execution();

return (Tensor a, Tensor b) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,


+ 7
- 7
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -1,9 +1,10 @@
/*using MethodBoundaryAspect.Fody.Attributes;
using MethodBoundaryAspect.Fody.Attributes;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs
@@ -18,7 +19,10 @@ namespace Tensorflow.Graphs

public override void OnEntry(MethodExecutionArgs args)
{
func_name = $"autograph_{args.Instance}.{args.Method.Name}";
if (args.Instance is TensorFlowOpLayer op)
func_name = $"autograph_{op.OpType}.{args.Method.Name}";
else
func_name = $"autograph_{args.Instance}.{args.Method.Name}";

if (functions.ContainsKey(func_name))
{
@@ -27,11 +31,8 @@ namespace Tensorflow.Graphs
return;
}
tf.compat.v1.disable_eager_execution();

// make function as an Operation by autograph
graph = new FuncGraph(func_name);
graph.as_default();

originalInputs = new Tensor[args.Arguments.Length];
// convert args to placeholder
@@ -57,7 +58,6 @@ namespace Tensorflow.Graphs
null);

graph.Dispose();
tf.enable_eager_execution();

Func<Tensor[], Tensor> function = (x) =>
{
@@ -77,4 +77,4 @@ namespace Tensorflow.Graphs
args.ReturnValue = function(originalInputs);
}
}
}*/
}

+ 12
- 1
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Graphs
Graph outer_graph;
string func_name;
IntPtr func_handle;
public string FuncName => c_api.StringPiece(c_api.TF_FunctionName(func_handle));
public string FuncName => func_name;

/// <summary>
/// Construct a new FuncGraph.
@@ -27,6 +27,9 @@ namespace Tensorflow.Graphs
{
outer_graph = ops.get_default_graph();
func_name = name;

tf.Context.graph_mode();
as_default();
}

public IntPtr ToGraph(Operation[] opers,
@@ -55,7 +58,15 @@ namespace Tensorflow.Graphs
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle);
status.Check(true);

func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle));

return func_handle;
}

protected override void DisposeManagedResources()
{
base.DisposeManagedResources();
tf.Context.restore_mode();
}
}
}

+ 25
- 21
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -174,19 +174,19 @@ namespace Tensorflow.Keras.Engine

// Build a dict {depth: list of nodes with this depth}
var nodes_by_depth = new Dictionary<int, List<Node>>();
foreach (var node in nodes_depths)
foreach (var (node, depth) in enumerate(nodes_depths))
{
if (!nodes_by_depth.ContainsKey(node.Value))
nodes_by_depth[node.Value] = new List<Node>();
nodes_by_depth[node.Value].append(node.Key);
if (!nodes_by_depth.ContainsKey(depth))
nodes_by_depth[depth] = new List<Node>();
nodes_by_depth[depth].append(node);
}

var layers_by_depth = new Dictionary<int, List<Layer>>();
foreach (var layer in layers_depths)
foreach (var (layer, depth) in enumerate(layers_depths))
{
if (!layers_by_depth.ContainsKey(layer.Value))
layers_by_depth[layer.Value] = new List<Layer>();
layers_by_depth[layer.Value].append(layer.Key);
if (!layers_by_depth.ContainsKey(depth))
layers_by_depth[depth] = new List<Layer>();
layers_by_depth[depth].append(layer);
}

// Get sorted list of layer depths.
@@ -256,16 +256,21 @@ namespace Tensorflow.Keras.Engine

// Propagate to all previous tensors connected to this node.
nodes_in_progress.Add(node);
foreach (var k_tensor in node.KerasInputs)
BuildMapHelper(k_tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices);
if (!node.IsInput)
{
foreach (var k_tensor in node.KerasInputs)
{
BuildMapHelper(k_tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
layer_indices);
}
}

finished_nodes.Add(node);
nodes_in_progress.Remove(node);
nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node);
nodes_in_decreasing_depth.append(node);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
@@ -282,12 +287,12 @@ namespace Tensorflow.Keras.Engine
input_t.KerasMask = masks[i];
}

var tensor_dict = new Dictionary<int, Tensor[]>();
var tensor_dict = new Dictionary<int, Queue<Tensor>>();
foreach (var (x, y) in zip(this.inputs, inputs))
{
var y1 = conform_to_reference_input(y, x);
var x_id = x.GetHashCode();
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1).ToArray();
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1));
}

var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
@@ -301,14 +306,13 @@ namespace Tensorflow.Keras.Engine
if (node.IsInput)
continue;

var layer_inputs = new Tensors(tensor_dict[node.FlatInputIds[0]]);
tensor_dict[node.FlatInputIds[0]] = new Tensor[0];
var layer_inputs = node.MapArguments(tensor_dict);

var outputs = node.Layer.Apply(layer_inputs, is_training: training);
// Update tensor_dict.
// Update tensor_dict for next input
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray();
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));
}
}



+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -89,6 +89,7 @@ namespace Tensorflow.Keras.Engine

ThreadLocal<CallContext> callContext;
public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;

public Layer(LayerArgs args)
{


+ 32
- 2
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
@@ -35,7 +36,7 @@ namespace Tensorflow.Keras.Engine

public int[] node_indices;
public int[] tensor_indices;
public Tensors input_tensors;
public Tensors input_tensors => args.InputTensors;
public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes;
public TensorShape[] output_shapes;
@@ -44,7 +45,8 @@ namespace Tensorflow.Keras.Engine
public bool IsInput => args.InputTensors == null;
public int[] FlatInputIds { get; set; }
public int[] FlatOutputIds { get; set; }

bool _single_positional_tensor_passed => KerasInputs.Count() == 1;
Dictionary<int, int> _keras_inputs_ids_and_indices = new Dictionary<int, int>();
public Node[] ParentNodes
{
get
@@ -68,6 +70,9 @@ namespace Tensorflow.Keras.Engine
if (args.InputTensors != null)
KerasInputs.AddRange(args.InputTensors);

foreach(var(i, ele) in enumerate(KerasInputs))
_keras_inputs_ids_and_indices[i] = ele.GetHashCode();

// Wire up Node to Layers.
layer.InboundNodes.Add(this);
foreach (var kt in KerasInputs)
@@ -88,5 +93,30 @@ namespace Tensorflow.Keras.Engine
FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray();
FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray();
}

/// <summary>
/// Maps Keras Tensors to computed Tensors using `tensor_dict`.
/// </summary>
/// <param name="tensor_dict"></param>
/// <returns></returns>
public Tensors MapArguments(Dictionary<int, Queue<Tensor>> tensor_dict)
{
if (_single_positional_tensor_passed)
{
var kt_id = _keras_inputs_ids_and_indices[0];
return tensor_dict[kt_id].Dequeue();
}
else
{
var flat_arguments = KerasInputs.Select(x => x).ToArray();
foreach (var (kt_index, kt_id) in enumerate(_keras_inputs_ids_and_indices))
flat_arguments[kt_index] = tensor_dict[kt_id].Dequeue();

return flat_arguments;
}
}

public override string ToString()
=> $"{Layer.Name}, {KerasInputs.Count} inputs: {string.Join(",", KerasInputs.Select(x => x.name))}";
}
}

+ 31
- 5
src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs View File

@@ -1,6 +1,8 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Graphs;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;

@@ -9,7 +11,10 @@ namespace Tensorflow.Keras.Engine
public class TensorFlowOpLayer : Layer
{
TensorFlowOpLayerArgs args;
Dictionary<int, NDArray> constants => args.Constants;
NodeDef node_def => args.NodeDef;
static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_";
public string OpType => node_def.Op;

public TensorFlowOpLayer(TensorFlowOpLayerArgs args)
: base(new LayerArgs
@@ -26,13 +31,34 @@ namespace Tensorflow.Keras.Engine

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (tf.Context.executing_eagerly())
return _defun_call(inputs);
return MakOp(inputs);
}

// [AutoGraph]
Tensors MakOp(Tensors inputs)
[AutoGraph]
Tensor _defun_call(Tensor inputs)
=> MakOp(inputs);

Tensor MakOp(Tensor inputs)
{
return inputs;
}
foreach (var (index, constant) in enumerate(constants))
{

}

var graph = inputs.graph;
var (c_op, c_op_desc) = ops._create_c_op(graph, node_def, new[] { inputs }, new Operation[0]);
var op = graph._create_op_from_tf_operation(c_op);
op._control_flow_post_processing();

// Record the gradient because custom-made ops don't go through the
// code-gen'd eager call path
var op_type = op.node_def.Name;

tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs);

return op.output;
}
}
}

+ 12
- 9
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -409,15 +409,18 @@ namespace Tensorflow.Operations
}

public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("LeakyRelu", name: name, args: new
{
features,
alpha
});

return _op.output;
}
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("LeakyRelu", name: name,
args: new {
features,
alpha
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"LeakyRelu", name,
null,
features,
"alpha", alpha).FirstOrDefault(),
features);

public static Tensor max_pool(Tensor input,
int[] ksize,


+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -79,6 +79,7 @@ https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.11.4" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="NumSharp.Lite" Version="0.1.9" />
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
</ItemGroup>


+ 1
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow
public TF_DataType dtype => items.First().dtype;
public TensorShape shape => items.First().TensorShape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public bool IsEagerTensor => items.First().IsEagerTensor;

public Tensor this[int index] => items[index];


+ 1
- 0
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -46,6 +46,7 @@ namespace Tensorflow
_values = values;
}

[DebuggerStepThrough]
public void __enter__()
{
if (tf.Context.executing_eagerly())


Loading…
Cancel
Save