@@ -238,9 +238,67 @@ namespace Tensorflow.Operations | |||
return real_val; | |||
} | |||
public override void AddInnerOp(Operation resultOp) | |||
{ | |||
throw new NotImplementedException(); | |||
protected override void _AddOpInternal(Operation op) | |||
{ | |||
if (op.inputs.Length == 0) | |||
{ | |||
//If we're in a while loop, remove any control inputs from outside the | |||
// loop. | |||
_RemoveExternalControlEdges(op); | |||
if (!op.control_inputs.Any(input_op => OpInContext(input_op))) | |||
op._add_control_input(_pivot.op); | |||
} | |||
else | |||
{ | |||
// Make each input to 'op' available in this CondContext. If an input is | |||
// already part of this context there's nothing to do, but if it's | |||
// external, AddValue() will handle adding the appropriate Switch node and | |||
// other bookkeeping. | |||
for (int index = 0; index < op.inputs.Length; index++) | |||
{ | |||
var x = op.inputs[index]; | |||
Tensor real_x = null; | |||
if (op.type == "Merge" && x.op.type == "NextIteration") | |||
{ | |||
//# Edge case: if we're importing a while loop inside this CondContext, | |||
//# AddValue() will not correctly handle the NextIteration inputs to | |||
//# Merge node. The problem is that the NextIteration should also be | |||
//# part of this context, but if we're importing it won't have been | |||
//# processed and added to the context yet, so AddValue() will try to | |||
//# add a Switch which results in an invalid graph. Instead, we use the | |||
//# NextIteration input as-is here, and it will eventually be added to | |||
//# the context via AddOp(). | |||
real_x = x; | |||
} | |||
else | |||
{ | |||
real_x = AddValue(x); | |||
} | |||
if (real_x != x) | |||
op._update_input(index, real_x); | |||
} | |||
// Remove any external control dependency on this op. | |||
_RemoveExternalControlEdges(op); | |||
// TODO: implement below code dependencies | |||
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||
// op._add_control_input(_pivot.op); | |||
} | |||
// Mark op's outputs as seen by this context and any outer contexts. | |||
var output_names = op.outputs.Select(x => x.name).ToArray(); | |||
IControlFlowContext ctxt = this; | |||
while (ctxt != null) | |||
{ | |||
foreach (var name in output_names) | |||
ctxt.values.Add(name); | |||
ctxt = ctxt.outer_context; | |||
} | |||
if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) | |||
op.graph.prevent_fetching(op); | |||
if (_outer_context != null) | |||
_outer_context.AddInnerOp(op); | |||
} | |||
public CondContextDef to_proto(string export_scope) | |||
@@ -119,9 +119,14 @@ namespace Tensorflow.Operations | |||
return null; | |||
} | |||
public virtual void AddInnerOp(Operation resultOp) | |||
/// <summary> | |||
/// Notifies a scope about an operator added to an inner scope. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
public virtual void AddInnerOp(Operation op) | |||
{ | |||
// to be overridden | |||
if (_outer_context != null) | |||
_outer_context.AddInnerOp(op); | |||
} | |||
protected HashSet<string> _values = new HashSet<string>(); | |||
@@ -131,68 +136,10 @@ namespace Tensorflow.Operations | |||
/// </summary> | |||
protected virtual void _AddOpInternal(Operation op) | |||
{ | |||
if (op.inputs.Length == 0) | |||
{ | |||
//If we're in a while loop, remove any control inputs from outside the | |||
// loop. | |||
_RemoveExternalControlEdges(op); | |||
if (!op.control_inputs.Any(input_op => OpInContext(input_op))) | |||
op._add_control_input(_pivot.op); | |||
} | |||
else | |||
{ | |||
// Make each input to 'op' available in this CondContext. If an input is | |||
// already part of this context there's nothing to do, but if it's | |||
// external, AddValue() will handle adding the appropriate Switch node and | |||
// other bookkeeping. | |||
for (int index = 0; index < op.inputs.Length; index++) | |||
{ | |||
var x = op.inputs[index]; | |||
Tensor real_x = null; | |||
if (op.type == "Merge" && x.op.type == "NextIteration") | |||
{ | |||
//# Edge case: if we're importing a while loop inside this CondContext, | |||
//# AddValue() will not correctly handle the NextIteration inputs to | |||
//# Merge node. The problem is that the NextIteration should also be | |||
//# part of this context, but if we're importing it won't have been | |||
//# processed and added to the context yet, so AddValue() will try to | |||
//# add a Switch which results in an invalid graph. Instead, we use the | |||
//# NextIteration input as-is here, and it will eventually be added to | |||
//# the context via AddOp(). | |||
real_x = x; | |||
} | |||
else | |||
{ | |||
real_x = AddValue(x); | |||
} | |||
if (real_x != x) | |||
op._update_input(index, real_x); | |||
} | |||
// Remove any external control dependency on this op. | |||
_RemoveExternalControlEdges(op); | |||
// TODO: implement below code dependencies | |||
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient") | |||
// op._add_control_input(_pivot.op); | |||
} | |||
// Mark op's outputs as seen by this context and any outer contexts. | |||
var output_names = op.outputs.Select(x => x.name).ToArray(); | |||
IControlFlowContext ctxt = this; | |||
while (ctxt != null) | |||
{ | |||
foreach(var name in output_names) | |||
ctxt.values.Add(name); | |||
ctxt = ctxt.outer_context; | |||
} | |||
if (_outer_context != null || !control_flow_ops.IsLoopExit(op)) | |||
op.graph.prevent_fetching(op); | |||
if (_outer_context != null) | |||
_outer_context.AddInnerOp(op); | |||
} | |||
private bool OpInContext(Operation op) | |||
protected bool OpInContext(Operation op) | |||
{ | |||
return IsContainingContext(op._get_control_flow_context(), this); | |||
} | |||
@@ -23,7 +23,8 @@ namespace Tensorflow | |||
return with(ops.name_scope(name, "l2_normalize", new { x }), scope => | |||
{ | |||
x = ops.convert_to_tensor(x, name: "x"); | |||
var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true); | |||
var sq = math_ops.square(x); | |||
var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true); | |||
var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)); | |||
return math_ops.multiply(x, x_inv_norm, name: name); | |||
}); | |||
@@ -360,6 +360,8 @@ namespace Tensorflow | |||
/// <returns>The default `Session` being used in the current thread.</returns> | |||
public static Session get_default_session() | |||
{ | |||
if (tf.defaultSession == null) | |||
tf.defaultSession = tf.Session(); | |||
return tf.defaultSession; | |||
} | |||
@@ -143,10 +143,7 @@ namespace TensorFlowNET.UnitTest | |||
// return self._eval_helper(tensors) | |||
// else: | |||
{ | |||
var sess = ops.get_default_session(); | |||
if (sess == null) | |||
sess = self.session(); | |||
with<Session>(sess, s => | |||
with(ops.get_default_session(), s => | |||
{ | |||
var ndarray=tensor.eval(); | |||
if (typeof(T) == typeof(double)) | |||