From 48a11d47100b8f63f3347f6d78cff27214e6a7ee Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 12 Mar 2019 23:07:49 -0500 Subject: [PATCH] CondContext, BatchNormalization. --- src/TensorFlowNET.Core/APIs/tf.nn.cs | 14 ++++ .../Framework/smart_module.cs | 5 +- .../Graphs/Graph.Control.cs | 6 +- .../Graphs/_ControlDependenciesController.cs | 2 +- .../Keras/Layers/BatchNormalization.cs | 21 ++++- .../Keras/Utils/tf_utils.cs | 5 +- .../Operations/ControlFlows/CondContext.cs | 76 +++++++++++++++++++ .../ControlFlows/ControlFlowContext.cs | 46 +++++++++++ .../ControlFlows/IControlFlowContext.cs | 10 +++ .../Operations/ControlFlows/WhileContext.cs | 10 +++ .../Operations/NnOps/gen_nn_ops.cs | 25 ++++++ .../Operations/Operation.Control.cs | 8 ++ .../Operations/control_flow_ops.py.cs | 23 +++++- .../Operations/nn_impl.py.cs | 32 ++++++++ src/TensorFlowNET.Core/Tensors/tensor_util.cs | 3 + src/TensorFlowNET.Core/ops.GraphKeys.cs | 3 + 16 files changed, 276 insertions(+), 13 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs create mode 100644 src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 65ad45b9..44203906 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -26,6 +26,20 @@ namespace Tensorflow name: name); public static IActivation relu => new relu(); + + public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, + RefVariable scale, + RefVariable offset, + Tensor mean = null, + Tensor variance = null, + float epsilon = 0.001f, + string data_format = "NHWC", + bool is_training = true, + string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, + epsilon: epsilon, + data_format: data_format, + is_training: is_training, + name: name); } } } diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs index 2ba80cbc..ea5bf790 100644 --- a/src/TensorFlowNET.Core/Framework/smart_module.cs +++ b/src/TensorFlowNET.Core/Framework/smart_module.cs @@ -6,7 +6,10 @@ namespace Tensorflow.Framework { public class smart_module { - public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) + public static object smart_cond(Tensor pred, + Func<(Tensor, Tensor, Tensor)> true_fn = null, + Func<(Tensor, Tensor, Tensor)> false_fn = null, + string name = null) { return control_flow_ops.cond(pred, true_fn: true_fn, diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index c9e3be84..a1977968 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -8,7 +8,7 @@ namespace Tensorflow { public partial class Graph { - public Context _control_flow_context; + public IControlFlowContext _control_flow_context; private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); public Queue<_ControlDependenciesController> _control_dependencies_stack @@ -72,7 +72,7 @@ namespace Tensorflow /// Returns the current control flow context. /// /// A context object. - public Context _get_control_flow_context() + public IControlFlowContext _get_control_flow_context() { return _control_flow_context; } @@ -81,7 +81,7 @@ namespace Tensorflow /// Sets the current control flow context. /// /// a context object. - public void _set_control_flow_context(Context ctx) + public void _set_control_flow_context(IControlFlowContext ctx) { _control_flow_context = ctx; } diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs index f1ddcb44..3887d2a1 100644 --- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs +++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs @@ -15,7 +15,7 @@ namespace Tensorflow private List _seen_nodes; private Queue<_ControlDependenciesController> _old_stack; private bool _new_stack; - private Context _old_control_flow_context; + private IControlFlowContext _old_control_flow_context; public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs index 8f82983e..1223e350 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs @@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers var beta = this.beta; var gamma = this.gamma; - Action _fused_batch_norm_training = () => + Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () => { - + return tf.nn.fused_batch_norm( + inputs, + gamma, + beta, + epsilon: epsilon, + data_format: _data_format); }; - Action _fused_batch_norm_inference = () => + Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () => { - + return tf.nn.fused_batch_norm( + inputs, + gamma, + beta, + mean: moving_mean, + variance: moving_variance, + epsilon: epsilon, + is_training: false, + data_format: _data_format); }; tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); diff --git a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs index 9a7d5ea1..4e155493 100644 --- a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs +++ b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs @@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils return true; } - public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) + public static object smart_cond(Tensor pred, + Func<(Tensor, Tensor, Tensor)> true_fn = null, + Func<(Tensor, Tensor, Tensor)> false_fn = null, + string name = null) { return smart_module.smart_cond(pred, true_fn: true_fn, diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs new file mode 100644 index 00000000..3c233e8d --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs @@ -0,0 +1,76 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + /// + /// The context for the conditional construct. + /// + public class CondContext : ControlFlowContext + { + private string _name; + /// + /// The boolean tensor for the cond predicate + /// + private Tensor _pred; + /// + /// The predicate tensor in this branch + /// + private Tensor _pivot; + /// + /// 0 or 1 representing this branch + /// + private int _branch; + /// + /// + /// + private List _values = new List(); + private Dictionary _external_values = new Dictionary(); + + /// + /// + /// + /// The `boolean` tensor for the conditional predicate. + /// The predicate tensor in this branch. + /// 0 or 1 representing this branch. + /// Name of the `CondContext` python object. + /// + /// + public CondContext(Tensor pred, + Tensor pivot, + int branch, + string name = "cond_text", + object context_def = null, + string import_scope = null) + { + _name = ops.get_default_graph().unique_name(name); + if (context_def != null) + throw new NotImplementedException("CondContext context_def is not null"); + else + { + // Initializes the default fields. + base.__init__(); + _pred = pred; + _pivot = pivot; + + // Values considered to have been already seen in this context. pred is not + // included in this context. + _values.Add(pred.name); + _external_values[pred.name] = pred; + _values.Add(pivot.name); + pivot.op._set_control_flow_context(this); + } + } + + public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn) + { + // Add the subgraph defined by fn() to the graph. + var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + var original_result = fn(); + var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); + + return original_result; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs new file mode 100644 index 00000000..7079606f --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public abstract class ControlFlowContext : IPython, IControlFlowContext + { + protected Stack _context_stack; + public ControlFlowContext() + { + _context_stack = new Stack(); + } + + public void __init__() + { + + } + + public void __enter__() + { + } + + public virtual void Enter() + { + var graph = ops.get_default_graph(); + _context_stack.Push(graph._get_control_flow_context()); + graph._set_control_flow_context(this); + } + + public void Exit() + { + var graph = ops.get_default_graph(); + var last_context = _context_stack.Pop(); + graph._set_control_flow_context(last_context); + } + + public void __exit__() + { + } + + public void Dispose() + { + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs new file mode 100644 index 00000000..52719538 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface IControlFlowContext + { + } +} diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs new file mode 100644 index 00000000..a31819dc --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + public class WhileContext : ControlFlowContext + { + } +} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 78346a8f..a93c1653 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -52,5 +52,30 @@ namespace Tensorflow.Operations return _op.outputs[0]; } + + public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x, + Tensor scale, + Tensor offset, + Tensor mean, + Tensor variance, + float epsilon = 0.0001f, + string data_format = "NHWC", + bool is_training = true, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new + { + x, + scale, + offset, + mean, + variance, + epsilon, + data_format, + is_training + }); + + return (_op.outputs[0], _op.outputs[1], _op.outputs[2]); + } } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index a51d1ca9..74078e27 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -1,11 +1,14 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; namespace Tensorflow { public partial class Operation { + private CondContext _control_flow_context; + /// /// Add this op to its control flow context. /// @@ -24,5 +27,10 @@ namespace Tensorflow c_api.TF_AddControlInput(graph, op); } } + + public void _set_control_flow_context(CondContext ctx) + { + _control_flow_context = ctx; + } } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index e9b75ab8..bca74989 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Operations; namespace Tensorflow { @@ -136,9 +137,9 @@ namespace Tensorflow return gen_array_ops.identity(data, name: name); } - public static (Tensor, Tensor) cond(Tensor pred, - Action true_fn = null, - Action false_fn = null, + public static (Tensor, Tensor) cond(Tensor pred, + Func<(Tensor, Tensor, Tensor)> true_fn = null, + Func<(Tensor, Tensor, Tensor)> false_fn = null, bool strict = false, string name = null) { @@ -154,6 +155,22 @@ namespace Tensorflow foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) tensor.op.graph.prevent_fetching(tensor.op); + // Build the graph for the true branch in a new context. + var context_t = new CondContext(pred, pivot_1, branch: 1); + context_t.Enter(); + var res_t = context_t.BuildCondBranch(true_fn); + context_t.Exit(); + + // Build the graph for the false branch in a new context. + var context_f = new CondContext(pred, pivot_2, branch: 0); + context_f.Enter(); + var res_f = context_f.BuildCondBranch(false_fn); + context_f.Exit(); + + var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 }; + var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 }; + + return (p_2, p_1); }); } diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index fe0f9dcd..81515e18 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations; namespace Tensorflow { @@ -44,5 +45,36 @@ namespace Tensorflow return (mean, variance); }); } + + public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, + RefVariable scale, + RefVariable offset, + Tensor mean, + Tensor variance, + float epsilon = 0.001f, + string data_format = "NHWC", + bool is_training = true, + string name = null) + { + x = ops.convert_to_tensor(x, name: "input"); + var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); + var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); + if (mean == null) + mean = constant_op.constant(new float[0]); + if(variance == null) + variance = constant_op.constant(new float[0]); + var min_epsilon = 1.001e-5f; + epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; + + return gen_nn_ops._fused_batch_norm(x, + scale_tensor, + offset_tensor, + mean, + variance, + epsilon, + data_format, + is_training, + name); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index ede6d495..bee8e68e 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -107,6 +107,9 @@ namespace Tensorflow case float floatVal: nparray = floatVal; break; + case float[] floatVals: + nparray = floatVals; + break; case double doubleVal: nparray = doubleVal; break; diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 540f8d55..61d527f6 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -44,6 +44,9 @@ namespace Tensorflow /// Key to collect update_ops /// public static string UPDATE_OPS = "update_ops"; + + // Used to store v2 summary names. + public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; } } }