@@ -26,6 +26,20 @@ namespace Tensorflow | |||
name: name); | |||
public static IActivation relu => new relu(); | |||
public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||
RefVariable scale, | |||
RefVariable offset, | |||
Tensor mean = null, | |||
Tensor variance = null, | |||
float epsilon = 0.001f, | |||
string data_format = "NHWC", | |||
bool is_training = true, | |||
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, | |||
epsilon: epsilon, | |||
data_format: data_format, | |||
is_training: is_training, | |||
name: name); | |||
} | |||
} | |||
} |
@@ -6,7 +6,10 @@ namespace Tensorflow.Framework | |||
{ | |||
public class smart_module | |||
{ | |||
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||
public static object smart_cond(Tensor pred, | |||
Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||
Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||
string name = null) | |||
{ | |||
return control_flow_ops.cond(pred, | |||
true_fn: true_fn, | |||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||
{ | |||
public partial class Graph | |||
{ | |||
public Context _control_flow_context; | |||
public IControlFlowContext _control_flow_context; | |||
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | |||
public Queue<_ControlDependenciesController> _control_dependencies_stack | |||
@@ -72,7 +72,7 @@ namespace Tensorflow | |||
/// Returns the current control flow context. | |||
/// </summary> | |||
/// <returns>A context object.</returns> | |||
public Context _get_control_flow_context() | |||
public IControlFlowContext _get_control_flow_context() | |||
{ | |||
return _control_flow_context; | |||
} | |||
@@ -81,7 +81,7 @@ namespace Tensorflow | |||
/// Sets the current control flow context. | |||
/// </summary> | |||
/// <param name="ctx">a context object.</param> | |||
public void _set_control_flow_context(Context ctx) | |||
public void _set_control_flow_context(IControlFlowContext ctx) | |||
{ | |||
_control_flow_context = ctx; | |||
} | |||
@@ -15,7 +15,7 @@ namespace Tensorflow | |||
private List<ITensorOrOperation> _seen_nodes; | |||
private Queue<_ControlDependenciesController> _old_stack; | |||
private bool _new_stack; | |||
private Context _old_control_flow_context; | |||
private IControlFlowContext _old_control_flow_context; | |||
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||
@@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers | |||
var beta = this.beta; | |||
var gamma = this.gamma; | |||
Action _fused_batch_norm_training = () => | |||
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () => | |||
{ | |||
return tf.nn.fused_batch_norm( | |||
inputs, | |||
gamma, | |||
beta, | |||
epsilon: epsilon, | |||
data_format: _data_format); | |||
}; | |||
Action _fused_batch_norm_inference = () => | |||
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () => | |||
{ | |||
return tf.nn.fused_batch_norm( | |||
inputs, | |||
gamma, | |||
beta, | |||
mean: moving_mean, | |||
variance: moving_variance, | |||
epsilon: epsilon, | |||
is_training: false, | |||
data_format: _data_format); | |||
}; | |||
tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | |||
@@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils | |||
return true; | |||
} | |||
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||
public static object smart_cond(Tensor pred, | |||
Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||
Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||
string name = null) | |||
{ | |||
return smart_module.smart_cond(pred, | |||
true_fn: true_fn, | |||
@@ -0,0 +1,76 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
/// <summary> | |||
/// The context for the conditional construct. | |||
/// </summary> | |||
public class CondContext : ControlFlowContext | |||
{ | |||
private string _name; | |||
/// <summary> | |||
/// The boolean tensor for the cond predicate | |||
/// </summary> | |||
private Tensor _pred; | |||
/// <summary> | |||
/// The predicate tensor in this branch | |||
/// </summary> | |||
private Tensor _pivot; | |||
/// <summary> | |||
/// 0 or 1 representing this branch | |||
/// </summary> | |||
private int _branch; | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
private List<string> _values = new List<string>(); | |||
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="pred">The `boolean` tensor for the conditional predicate.</param> | |||
/// <param name="pivot">The predicate tensor in this branch.</param> | |||
/// <param name="branch">0 or 1 representing this branch.</param> | |||
/// <param name="name">Name of the `CondContext` python object.</param> | |||
/// <param name="context_def"></param> | |||
/// <param name="import_scope"></param> | |||
public CondContext(Tensor pred, | |||
Tensor pivot, | |||
int branch, | |||
string name = "cond_text", | |||
object context_def = null, | |||
string import_scope = null) | |||
{ | |||
_name = ops.get_default_graph().unique_name(name); | |||
if (context_def != null) | |||
throw new NotImplementedException("CondContext context_def is not null"); | |||
else | |||
{ | |||
// Initializes the default fields. | |||
base.__init__(); | |||
_pred = pred; | |||
_pivot = pivot; | |||
// Values considered to have been already seen in this context. pred is not | |||
// included in this context. | |||
_values.Add(pred.name); | |||
_external_values[pred.name] = pred; | |||
_values.Add(pivot.name); | |||
pivot.op._set_control_flow_context(this); | |||
} | |||
} | |||
public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn) | |||
{ | |||
// Add the subgraph defined by fn() to the graph. | |||
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
var original_result = fn(); | |||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||
return original_result; | |||
} | |||
} | |||
} |
@@ -0,0 +1,46 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
public abstract class ControlFlowContext : IPython, IControlFlowContext | |||
{ | |||
protected Stack<IControlFlowContext> _context_stack; | |||
public ControlFlowContext() | |||
{ | |||
_context_stack = new Stack<IControlFlowContext>(); | |||
} | |||
public void __init__() | |||
{ | |||
} | |||
public void __enter__() | |||
{ | |||
} | |||
public virtual void Enter() | |||
{ | |||
var graph = ops.get_default_graph(); | |||
_context_stack.Push(graph._get_control_flow_context()); | |||
graph._set_control_flow_context(this); | |||
} | |||
public void Exit() | |||
{ | |||
var graph = ops.get_default_graph(); | |||
var last_context = _context_stack.Pop(); | |||
graph._set_control_flow_context(last_context); | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
public void Dispose() | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public interface IControlFlowContext | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class WhileContext : ControlFlowContext | |||
{ | |||
} | |||
} |
@@ -52,5 +52,30 @@ namespace Tensorflow.Operations | |||
return _op.outputs[0]; | |||
} | |||
public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x, | |||
Tensor scale, | |||
Tensor offset, | |||
Tensor mean, | |||
Tensor variance, | |||
float epsilon = 0.0001f, | |||
string data_format = "NHWC", | |||
bool is_training = true, | |||
string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new | |||
{ | |||
x, | |||
scale, | |||
offset, | |||
mean, | |||
variance, | |||
epsilon, | |||
data_format, | |||
is_training | |||
}); | |||
return (_op.outputs[0], _op.outputs[1], _op.outputs[2]); | |||
} | |||
} | |||
} |
@@ -1,11 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
public partial class Operation | |||
{ | |||
private CondContext _control_flow_context; | |||
/// <summary> | |||
/// Add this op to its control flow context. | |||
/// </summary> | |||
@@ -24,5 +27,10 @@ namespace Tensorflow | |||
c_api.TF_AddControlInput(graph, op); | |||
} | |||
} | |||
public void _set_control_flow_context(CondContext ctx) | |||
{ | |||
_control_flow_context = ctx; | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
@@ -136,9 +137,9 @@ namespace Tensorflow | |||
return gen_array_ops.identity(data, name: name); | |||
} | |||
public static (Tensor, Tensor) cond(Tensor pred, | |||
Action true_fn = null, | |||
Action false_fn = null, | |||
public static (Tensor, Tensor) cond(Tensor pred, | |||
Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||
Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||
bool strict = false, | |||
string name = null) | |||
{ | |||
@@ -154,6 +155,22 @@ namespace Tensorflow | |||
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | |||
tensor.op.graph.prevent_fetching(tensor.op); | |||
// Build the graph for the true branch in a new context. | |||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||
context_t.Enter(); | |||
var res_t = context_t.BuildCondBranch(true_fn); | |||
context_t.Exit(); | |||
// Build the graph for the false branch in a new context. | |||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||
context_f.Enter(); | |||
var res_f = context_f.BuildCondBranch(false_fn); | |||
context_f.Exit(); | |||
var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 }; | |||
var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 }; | |||
return (p_2, p_1); | |||
}); | |||
} | |||
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
@@ -44,5 +45,36 @@ namespace Tensorflow | |||
return (mean, variance); | |||
}); | |||
} | |||
public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||
RefVariable scale, | |||
RefVariable offset, | |||
Tensor mean, | |||
Tensor variance, | |||
float epsilon = 0.001f, | |||
string data_format = "NHWC", | |||
bool is_training = true, | |||
string name = null) | |||
{ | |||
x = ops.convert_to_tensor(x, name: "input"); | |||
var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); | |||
var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); | |||
if (mean == null) | |||
mean = constant_op.constant(new float[0]); | |||
if(variance == null) | |||
variance = constant_op.constant(new float[0]); | |||
var min_epsilon = 1.001e-5f; | |||
epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; | |||
return gen_nn_ops._fused_batch_norm(x, | |||
scale_tensor, | |||
offset_tensor, | |||
mean, | |||
variance, | |||
epsilon, | |||
data_format, | |||
is_training, | |||
name); | |||
} | |||
} | |||
} |
@@ -107,6 +107,9 @@ namespace Tensorflow | |||
case float floatVal: | |||
nparray = floatVal; | |||
break; | |||
case float[] floatVals: | |||
nparray = floatVals; | |||
break; | |||
case double doubleVal: | |||
nparray = doubleVal; | |||
break; | |||
@@ -44,6 +44,9 @@ namespace Tensorflow | |||
/// Key to collect update_ops | |||
/// </summary> | |||
public static string UPDATE_OPS = "update_ops"; | |||
// Used to store v2 summary names. | |||
public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||
} | |||
} | |||
} |