@@ -26,6 +26,20 @@ namespace Tensorflow | |||||
name: name); | name: name); | ||||
public static IActivation relu => new relu(); | public static IActivation relu => new relu(); | ||||
public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||||
RefVariable scale, | |||||
RefVariable offset, | |||||
Tensor mean = null, | |||||
Tensor variance = null, | |||||
float epsilon = 0.001f, | |||||
string data_format = "NHWC", | |||||
bool is_training = true, | |||||
string name = null) => nn_impl.fused_batch_norm(x, scale, offset, mean, variance, | |||||
epsilon: epsilon, | |||||
data_format: data_format, | |||||
is_training: is_training, | |||||
name: name); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -6,7 +6,10 @@ namespace Tensorflow.Framework | |||||
{ | { | ||||
public class smart_module | public class smart_module | ||||
{ | { | ||||
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||||
public static object smart_cond(Tensor pred, | |||||
Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||||
Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||||
string name = null) | |||||
{ | { | ||||
return control_flow_ops.cond(pred, | return control_flow_ops.cond(pred, | ||||
true_fn: true_fn, | true_fn: true_fn, | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public partial class Graph | public partial class Graph | ||||
{ | { | ||||
public Context _control_flow_context; | |||||
public IControlFlowContext _control_flow_context; | |||||
private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | private Queue<_ControlDependenciesController> _graph_control_dependencies_stack = new Queue<_ControlDependenciesController>(); | ||||
public Queue<_ControlDependenciesController> _control_dependencies_stack | public Queue<_ControlDependenciesController> _control_dependencies_stack | ||||
@@ -72,7 +72,7 @@ namespace Tensorflow | |||||
/// Returns the current control flow context. | /// Returns the current control flow context. | ||||
/// </summary> | /// </summary> | ||||
/// <returns>A context object.</returns> | /// <returns>A context object.</returns> | ||||
public Context _get_control_flow_context() | |||||
public IControlFlowContext _get_control_flow_context() | |||||
{ | { | ||||
return _control_flow_context; | return _control_flow_context; | ||||
} | } | ||||
@@ -81,7 +81,7 @@ namespace Tensorflow | |||||
/// Sets the current control flow context. | /// Sets the current control flow context. | ||||
/// </summary> | /// </summary> | ||||
/// <param name="ctx">a context object.</param> | /// <param name="ctx">a context object.</param> | ||||
public void _set_control_flow_context(Context ctx) | |||||
public void _set_control_flow_context(IControlFlowContext ctx) | |||||
{ | { | ||||
_control_flow_context = ctx; | _control_flow_context = ctx; | ||||
} | } | ||||
@@ -15,7 +15,7 @@ namespace Tensorflow | |||||
private List<ITensorOrOperation> _seen_nodes; | private List<ITensorOrOperation> _seen_nodes; | ||||
private Queue<_ControlDependenciesController> _old_stack; | private Queue<_ControlDependenciesController> _old_stack; | ||||
private bool _new_stack; | private bool _new_stack; | ||||
private Context _old_control_flow_context; | |||||
private IControlFlowContext _old_control_flow_context; | |||||
public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | ||||
@@ -142,14 +142,27 @@ namespace Tensorflow.Keras.Layers | |||||
var beta = this.beta; | var beta = this.beta; | ||||
var gamma = this.gamma; | var gamma = this.gamma; | ||||
Action _fused_batch_norm_training = () => | |||||
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_training = () => | |||||
{ | { | ||||
return tf.nn.fused_batch_norm( | |||||
inputs, | |||||
gamma, | |||||
beta, | |||||
epsilon: epsilon, | |||||
data_format: _data_format); | |||||
}; | }; | ||||
Action _fused_batch_norm_inference = () => | |||||
Func<(Tensor, Tensor, Tensor)> _fused_batch_norm_inference = () => | |||||
{ | { | ||||
return tf.nn.fused_batch_norm( | |||||
inputs, | |||||
gamma, | |||||
beta, | |||||
mean: moving_mean, | |||||
variance: moving_variance, | |||||
epsilon: epsilon, | |||||
is_training: false, | |||||
data_format: _data_format); | |||||
}; | }; | ||||
tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | tf_utils.smart_cond(training, _fused_batch_norm_training, _fused_batch_norm_inference); | ||||
@@ -18,7 +18,10 @@ namespace Tensorflow.Keras.Utils | |||||
return true; | return true; | ||||
} | } | ||||
public static object smart_cond(Tensor pred, Action true_fn = null, Action false_fn = null, string name = null) | |||||
public static object smart_cond(Tensor pred, | |||||
Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||||
Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||||
string name = null) | |||||
{ | { | ||||
return smart_module.smart_cond(pred, | return smart_module.smart_cond(pred, | ||||
true_fn: true_fn, | true_fn: true_fn, | ||||
@@ -0,0 +1,76 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
/// <summary> | |||||
/// The context for the conditional construct. | |||||
/// </summary> | |||||
public class CondContext : ControlFlowContext | |||||
{ | |||||
private string _name; | |||||
/// <summary> | |||||
/// The boolean tensor for the cond predicate | |||||
/// </summary> | |||||
private Tensor _pred; | |||||
/// <summary> | |||||
/// The predicate tensor in this branch | |||||
/// </summary> | |||||
private Tensor _pivot; | |||||
/// <summary> | |||||
/// 0 or 1 representing this branch | |||||
/// </summary> | |||||
private int _branch; | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
private List<string> _values = new List<string>(); | |||||
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>(); | |||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="pred">The `boolean` tensor for the conditional predicate.</param> | |||||
/// <param name="pivot">The predicate tensor in this branch.</param> | |||||
/// <param name="branch">0 or 1 representing this branch.</param> | |||||
/// <param name="name">Name of the `CondContext` python object.</param> | |||||
/// <param name="context_def"></param> | |||||
/// <param name="import_scope"></param> | |||||
public CondContext(Tensor pred, | |||||
Tensor pivot, | |||||
int branch, | |||||
string name = "cond_text", | |||||
object context_def = null, | |||||
string import_scope = null) | |||||
{ | |||||
_name = ops.get_default_graph().unique_name(name); | |||||
if (context_def != null) | |||||
throw new NotImplementedException("CondContext context_def is not null"); | |||||
else | |||||
{ | |||||
// Initializes the default fields. | |||||
base.__init__(); | |||||
_pred = pred; | |||||
_pivot = pivot; | |||||
// Values considered to have been already seen in this context. pred is not | |||||
// included in this context. | |||||
_values.Add(pred.name); | |||||
_external_values[pred.name] = pred; | |||||
_values.Add(pivot.name); | |||||
pivot.op._set_control_flow_context(this); | |||||
} | |||||
} | |||||
public (Tensor, Tensor, Tensor) BuildCondBranch(Func<(Tensor, Tensor, Tensor)> fn) | |||||
{ | |||||
// Add the subgraph defined by fn() to the graph. | |||||
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
var original_result = fn(); | |||||
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION); | |||||
return original_result; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,46 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public abstract class ControlFlowContext : IPython, IControlFlowContext | |||||
{ | |||||
protected Stack<IControlFlowContext> _context_stack; | |||||
public ControlFlowContext() | |||||
{ | |||||
_context_stack = new Stack<IControlFlowContext>(); | |||||
} | |||||
public void __init__() | |||||
{ | |||||
} | |||||
public void __enter__() | |||||
{ | |||||
} | |||||
public virtual void Enter() | |||||
{ | |||||
var graph = ops.get_default_graph(); | |||||
_context_stack.Push(graph._get_control_flow_context()); | |||||
graph._set_control_flow_context(this); | |||||
} | |||||
public void Exit() | |||||
{ | |||||
var graph = ops.get_default_graph(); | |||||
var last_context = _context_stack.Pop(); | |||||
graph._set_control_flow_context(last_context); | |||||
} | |||||
public void __exit__() | |||||
{ | |||||
} | |||||
public void Dispose() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public interface IControlFlowContext | |||||
{ | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class WhileContext : ControlFlowContext | |||||
{ | |||||
} | |||||
} |
@@ -52,5 +52,30 @@ namespace Tensorflow.Operations | |||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
} | } | ||||
public static (Tensor, Tensor, Tensor) _fused_batch_norm(Tensor x, | |||||
Tensor scale, | |||||
Tensor offset, | |||||
Tensor mean, | |||||
Tensor variance, | |||||
float epsilon = 0.0001f, | |||||
string data_format = "NHWC", | |||||
bool is_training = true, | |||||
string name = null) | |||||
{ | |||||
var _op = _op_def_lib._apply_op_helper("FusedBatchNorm", name: name, args: new | |||||
{ | |||||
x, | |||||
scale, | |||||
offset, | |||||
mean, | |||||
variance, | |||||
epsilon, | |||||
data_format, | |||||
is_training | |||||
}); | |||||
return (_op.outputs[0], _op.outputs[1], _op.outputs[2]); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,11 +1,14 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public partial class Operation | public partial class Operation | ||||
{ | { | ||||
private CondContext _control_flow_context; | |||||
/// <summary> | /// <summary> | ||||
/// Add this op to its control flow context. | /// Add this op to its control flow context. | ||||
/// </summary> | /// </summary> | ||||
@@ -24,5 +27,10 @@ namespace Tensorflow | |||||
c_api.TF_AddControlInput(graph, op); | c_api.TF_AddControlInput(graph, op); | ||||
} | } | ||||
} | } | ||||
public void _set_control_flow_context(CondContext ctx) | |||||
{ | |||||
_control_flow_context = ctx; | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -136,9 +137,9 @@ namespace Tensorflow | |||||
return gen_array_ops.identity(data, name: name); | return gen_array_ops.identity(data, name: name); | ||||
} | } | ||||
public static (Tensor, Tensor) cond(Tensor pred, | |||||
Action true_fn = null, | |||||
Action false_fn = null, | |||||
public static (Tensor, Tensor) cond(Tensor pred, | |||||
Func<(Tensor, Tensor, Tensor)> true_fn = null, | |||||
Func<(Tensor, Tensor, Tensor)> false_fn = null, | |||||
bool strict = false, | bool strict = false, | ||||
string name = null) | string name = null) | ||||
{ | { | ||||
@@ -154,6 +155,22 @@ namespace Tensorflow | |||||
foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred }) | ||||
tensor.op.graph.prevent_fetching(tensor.op); | tensor.op.graph.prevent_fetching(tensor.op); | ||||
// Build the graph for the true branch in a new context. | |||||
var context_t = new CondContext(pred, pivot_1, branch: 1); | |||||
context_t.Enter(); | |||||
var res_t = context_t.BuildCondBranch(true_fn); | |||||
context_t.Exit(); | |||||
// Build the graph for the false branch in a new context. | |||||
var context_f = new CondContext(pred, pivot_2, branch: 0); | |||||
context_f.Enter(); | |||||
var res_f = context_f.BuildCondBranch(false_fn); | |||||
context_f.Exit(); | |||||
var res_t_flat = new Tensor[] { res_t.Item1, res_t.Item2, res_t.Item3 }; | |||||
var res_f_flat = new Tensor[] { res_f.Item1, res_f.Item2, res_f.Item3 }; | |||||
return (p_2, p_1); | return (p_2, p_1); | ||||
}); | }); | ||||
} | } | ||||
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -44,5 +45,36 @@ namespace Tensorflow | |||||
return (mean, variance); | return (mean, variance); | ||||
}); | }); | ||||
} | } | ||||
public static (Tensor, Tensor, Tensor) fused_batch_norm(Tensor x, | |||||
RefVariable scale, | |||||
RefVariable offset, | |||||
Tensor mean, | |||||
Tensor variance, | |||||
float epsilon = 0.001f, | |||||
string data_format = "NHWC", | |||||
bool is_training = true, | |||||
string name = null) | |||||
{ | |||||
x = ops.convert_to_tensor(x, name: "input"); | |||||
var scale_tensor = ops.convert_to_tensor(scale, name: "scale"); | |||||
var offset_tensor = ops.convert_to_tensor(offset, name: "offset"); | |||||
if (mean == null) | |||||
mean = constant_op.constant(new float[0]); | |||||
if(variance == null) | |||||
variance = constant_op.constant(new float[0]); | |||||
var min_epsilon = 1.001e-5f; | |||||
epsilon = epsilon > min_epsilon ? epsilon : min_epsilon; | |||||
return gen_nn_ops._fused_batch_norm(x, | |||||
scale_tensor, | |||||
offset_tensor, | |||||
mean, | |||||
variance, | |||||
epsilon, | |||||
data_format, | |||||
is_training, | |||||
name); | |||||
} | |||||
} | } | ||||
} | } |
@@ -107,6 +107,9 @@ namespace Tensorflow | |||||
case float floatVal: | case float floatVal: | ||||
nparray = floatVal; | nparray = floatVal; | ||||
break; | break; | ||||
case float[] floatVals: | |||||
nparray = floatVals; | |||||
break; | |||||
case double doubleVal: | case double doubleVal: | ||||
nparray = doubleVal; | nparray = doubleVal; | ||||
break; | break; | ||||
@@ -44,6 +44,9 @@ namespace Tensorflow | |||||
/// Key to collect update_ops | /// Key to collect update_ops | ||||
/// </summary> | /// </summary> | ||||
public static string UPDATE_OPS = "update_ops"; | public static string UPDATE_OPS = "update_ops"; | ||||
// Used to store v2 summary names. | |||||
public static string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||||
} | } | ||||
} | } | ||||
} | } |