diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 65ad45b9..44203906 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -26,6 +26,20 @@ namespace Tensorflow
name: name);
public static IActivation relu => new relu();
+
+ public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
+ RefVariable scale,
+ RefVariable offset,
+ Tensor mean = null,
+ Tensor variance = null,
+ float epsilon = 0.001f,
+ string data_format = "NHWC",
+ bool is_training = true,
+ string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance,
+ epsilon: epsilon,
+ data_format: data_format,
+ is_training: is_training,
+ name: name);
}
}
}
diff --git a/src/TensorFlowNET.Core/Framework/smart_module.cs b/src/TensorFlowNET.Core/Framework/smart_module.cs
index 2ba80cbc..ea5bf790 100644
--- a/src/TensorFlowNET.Core/Framework/smart_module.cs
+++ b/src/TensorFlowNET.Core/Framework/smart_module.cs
@@ -6,7 +6,10 @@ namespace Tensorflow.Framework
{
public class smart_module
{
- public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
+ public static object smart_cond(Tensor pred,
+ Func<(Tensor, Tensor, Tensor)> true_fn = null,
+ Func<(Tensor, Tensor, Tensor)> false_fn = null,
+ string name = null)
{
return control_flow_ops.cond(pred,
true_fn: true_fn,
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
index c9e3be84..a1977968 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -8,7 +8,7 @@ namespace Tensorflow
{
public partial class Graph
{
- public Context _control_flow_context;
+ public IControlFlowContext _control_flow_context;
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>();
public Queue<_ControlDependenciesController> _control_dependencies_stack
@@ -72,7 +72,7 @@ namespace Tensorflow
/// Returns the current control flow context.
///
/// A context object.
- public Context _get_control_flow_context()
+ public IControlFlowContext _get_control_flow_context()
{
return _control_flow_context;
}
@@ -81,7 +81,7 @@ namespace Tensorflow
/// Sets the current control flow context.
///
/// a context object.
- public void _set_control_flow_context(Context ctx)
+ public void _set_control_flow_context(IControlFlowContext ctx)
{
_control_flow_context = ctx;
}
diff --git a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
index f1ddcb44..3887d2a1 100644
--- a/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
+++ b/src/TensorFlowNET.Core/Graphs/_ControlDependenciesController.cs
@@ -15,7 +15,7 @@ namespace Tensorflow
private List _seen_nodes;
private Queue<_ControlDependenciesController> _old_stack;
private bool _new_stack;
- private Context _old_control_flow_context;
+ private IControlFlowContext _old_control_flow_context;
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray();
diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
index 8f82983e..1223e350 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
@@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers
var beta = this.beta;
var gamma = this.gamma;
- Action _fused_batch_norm_training = () =>
+ Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () =>
{
-
+ return tf.nn.fused_batch_norm(
+ inputs,
+ gamma,
+ beta,
+ epsilon: epsilon,
+ data_format: _data_format);
};
- Action _fused_batch_norm_inference = () =>
+ Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () =>
{
-
+ return tf.nn.fused_batch_norm(
+ inputs,
+ gamma,
+ beta,
+ mean: moving_mean,
+ variance: moving_variance,
+ epsilon: epsilon,
+ is_training: false,
+ data_format: _data_format);
};
tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference);
diff --git a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
index 9a7d5ea1..4e155493 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/tf_utils.cs
@@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils
return true;
}
- public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null)
+ public static object smart_cond(Tensor pred,
+ Func<(Tensor, Tensor, Tensor)> true_fn = null,
+ Func<(Tensor, Tensor, Tensor)> false_fn = null,
+ string name = null)
{
return smart_module.smart_cond(pred,
true_fn: true_fn,
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
new file mode 100644
index 00000000..3c233e8d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -0,0 +1,76 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ ///
+ /// The context for the conditional construct.
+ ///
+ public class CondContext : ControlFlowContext
+ {
+ private string _name;
+ ///
+ /// The boolean tensor for the cond predicate
+ ///
+ private Tensor _pred;
+ ///
+ /// The predicate tensor in this branch
+ ///
+ private Tensor _pivot;
+ ///
+ /// 0 or 1 representing this branch
+ ///
+ private int _branch;
+ ///
+ ///
+ ///
+ private List _values = new List();
+ private Dictionary _external_values = new Dictionary();
+
+ ///
+ ///
+ ///
+ /// The `boolean` tensor for the conditional predicate.
+ /// The predicate tensor in this branch.
+ /// 0 or 1 representing this branch.
+ /// Name of the `CondContext` python object.
+ ///
+ ///
+ public CondContext(Tensor pred,
+ Tensor pivot,
+ int branch,
+ string name = "cond_text",
+ object context_def = null,
+ string import_scope = null)
+ {
+ _name = ops.get_default_graph().unique_name(name);
+ if (context_def != null)
+ throw new NotImplementedException("CondContext context_def is not null");
+ else
+ {
+ // Initializes the default fields.
+ base.__init__();
+ _pred = pred;
+ _pivot = pivot;
+
+ // Values considered to have been already seen in this context. pred is not
+ // included in this context.
+ _values.Add(pred.name);
+ _external_values[pred.name] = pred;
+ _values.Add(pivot.name);
+ pivot.op._set_control_flow_context(this);
+ }
+ }
+
+ public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn)
+ {
+ // Add the subgraph defined by fn() to the graph.
+ var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+ var original_result = fn();
+ var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
+
+ return original_result;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
new file mode 100644
index 00000000..7079606f
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -0,0 +1,46 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ public abstract class ControlFlowContext : IPython, IControlFlowContext
+ {
+ protected Stack _context_stack;
+ public ControlFlowContext()
+ {
+ _context_stack = new Stack();
+ }
+
+ public void __init__()
+ {
+
+ }
+
+ public void __enter__()
+ {
+ }
+
+ public virtual void Enter()
+ {
+ var graph = ops.get_default_graph();
+ _context_stack.Push(graph._get_control_flow_context());
+ graph._set_control_flow_context(this);
+ }
+
+ public void Exit()
+ {
+ var graph = ops.get_default_graph();
+ var last_context = _context_stack.Pop();
+ graph._set_control_flow_context(last_context);
+ }
+
+ public void __exit__()
+ {
+ }
+
+ public void Dispose()
+ {
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
new file mode 100644
index 00000000..52719538
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public interface IControlFlowContext
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
new file mode 100644
index 00000000..a31819dc
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ public class WhileContext : ControlFlowContext
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index 78346a8f..a93c1653 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -52,5 +52,30 @@ namespace Tensorflow.Operations
return _op.outputs[0];
}
+
+ public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x,
+ Tensor scale,
+ Tensor offset,
+ Tensor mean,
+ Tensor variance,
+ float epsilon = 0.0001f,
+ string data_format = "NHWC",
+ bool is_training = true,
+ string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new
+ {
+ x,
+ scale,
+ offset,
+ mean,
+ variance,
+ epsilon,
+ data_format,
+ is_training
+ });
+
+ return (_op.outputs[0], _op.outputs[1], _op.outputs[2]);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
index a51d1ca9..74078e27 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs
@@ -1,11 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Operations;
namespace Tensorflow
{
public partial class Operation
{
+ private CondContext _control_flow_context;
+
///
/// Add this op to its control flow context.
///
@@ -24,5 +27,10 @@ namespace Tensorflow
c_api.TF_AddControlInput(graph, op);
}
}
+
+ public void _set_control_flow_context(CondContext ctx)
+ {
+ _control_flow_context = ctx;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index e9b75ab8..bca74989 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -136,9 +137,9 @@ namespace Tensorflow
return gen_array_ops.identity(data, name: name);
}
- public static (Tensor, Tensor) cond(Tensor pred,
- Action true_fn = null,
- Action false_fn = null,
+ public static (Tensor, Tensor) cond(Tensor pred,
+ Func<(Tensor, Tensor, Tensor)> true_fn = null,
+ Func<(Tensor, Tensor, Tensor)> false_fn = null,
bool strict = false,
string name = null)
{
@@ -154,6 +155,22 @@ namespace Tensorflow
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
tensor.op.graph.prevent_fetching(tensor.op);
+ // Build the graph for the true branch in a new context.
+ var context_t = new CondContext(pred, pivot_1, branch: 1);
+ context_t.Enter();
+ var res_t = context_t.BuildCondBranch(true_fn);
+ context_t.Exit();
+
+ // Build the graph for the false branch in a new context.
+ var context_f = new CondContext(pred, pivot_2, branch: 0);
+ context_f.Enter();
+ var res_f = context_f.BuildCondBranch(false_fn);
+ context_f.Exit();
+
+ var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 };
+ var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 };
+
+
return (p_2, p_1);
});
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index fe0f9dcd..81515e18 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -44,5 +45,36 @@ namespace Tensorflow
return (mean, variance);
});
}
+
+ public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x,
+ RefVariable scale,
+ RefVariable offset,
+ Tensor mean,
+ Tensor variance,
+ float epsilon = 0.001f,
+ string data_format = "NHWC",
+ bool is_training = true,
+ string name = null)
+ {
+ x = ops.convert_to_tensor(x, name: "input");
+ var scale_tensor = ops.convert_to_tensor(scale, name: "scale");
+ var offset_tensor = ops.convert_to_tensor(offset, name: "offset");
+ if (mean == null)
+ mean = constant_op.constant(new float[0]);
+ if(variance == null)
+ variance = constant_op.constant(new float[0]);
+ var min_epsilon = 1.001e-5f;
+ epsilon = epsilon > min_epsilon ? epsilon : min_epsilon;
+
+ return gen_nn_ops._fused_batch_norm(x,
+ scale_tensor,
+ offset_tensor,
+ mean,
+ variance,
+ epsilon,
+ data_format,
+ is_training,
+ name);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index ede6d495..bee8e68e 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -107,6 +107,9 @@ namespace Tensorflow
case float floatVal:
nparray = floatVal;
break;
+ case float[] floatVals:
+ nparray = floatVals;
+ break;
case double doubleVal:
nparray = doubleVal;
break;
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index 540f8d55..61d527f6 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -44,6 +44,9 @@ namespace Tensorflow
/// Key to collect update_ops
///
public static string UPDATE_OPS = "update_ops";
+
+ // Used to store v2 summary names.
+ public static string _SUMMARY_COLLECTION = "_SUMMARY_V2";
}
}
}