Browse Source

fix unit test evaluate.

tags/v0.9
Oceania2018 6 years ago
parent
commit
bd1e853187
5 changed files with 74 additions and 69 deletions
  1. +61
    -3
      src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
  2. +8
    -61
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  4. +2
    -0
      src/TensorFlowNET.Core/ops.py.cs
  5. +1
    -4
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 61
- 3
src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs View File

@@ -238,9 +238,67 @@ namespace Tensorflow.Operations
return real_val;
}
public override void AddInnerOp(Operation resultOp)
{
throw new NotImplementedException();
protected override void _AddOpInternal(Operation op)
{
if (op.inputs.Length == 0)
{
//If we're in a while loop, remove any control inputs from outside the
// loop.
_RemoveExternalControlEdges(op);
if (!op.control_inputs.Any(input_op => OpInContext(input_op)))
op._add_control_input(_pivot.op);
}
else
{
// Make each input to 'op' available in this CondContext. If an input is
// already part of this context there's nothing to do, but if it's
// external, AddValue() will handle adding the appropriate Switch node and
// other bookkeeping.
for (int index = 0; index < op.inputs.Length; index++)
{
var x = op.inputs[index];
Tensor real_x = null;
if (op.type == "Merge" && x.op.type == "NextIteration")
{
//# Edge case: if we're importing a while loop inside this CondContext,
//# AddValue() will not correctly handle the NextIteration inputs to
//# Merge node. The problem is that the NextIteration should also be
//# part of this context, but if we're importing it won't have been
//# processed and added to the context yet, so AddValue() will try to
//# add a Switch which results in an invalid graph. Instead, we use the
//# NextIteration input as-is here, and it will eventually be added to
//# the context via AddOp().
real_x = x;
}
else
{
real_x = AddValue(x);
}
if (real_x != x)
op._update_input(index, real_x);
}
// Remove any external control dependency on this op.
_RemoveExternalControlEdges(op);
// TODO: implement below code dependencies
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
// op._add_control_input(_pivot.op);
}
// Mark op's outputs as seen by this context and any outer contexts.
var output_names = op.outputs.Select(x => x.name).ToArray();
IControlFlowContext ctxt = this;
while (ctxt != null)
{
foreach (var name in output_names)
ctxt.values.Add(name);
ctxt = ctxt.outer_context;
}

if (_outer_context != null || !control_flow_ops.IsLoopExit(op))
op.graph.prevent_fetching(op);

if (_outer_context != null)
_outer_context.AddInnerOp(op);
}
public CondContextDef to_proto(string export_scope)


+ 8
- 61
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -119,9 +119,14 @@ namespace Tensorflow.Operations
return null;
}

public virtual void AddInnerOp(Operation resultOp)
/// <summary>
/// Notifies a scope about an operator added to an inner scope.
/// </summary>
/// <param name="op"></param>
public virtual void AddInnerOp(Operation op)
{
// to be overridden
if (_outer_context != null)
_outer_context.AddInnerOp(op);
}

protected HashSet<string> _values = new HashSet<string>();
@@ -131,68 +136,10 @@ namespace Tensorflow.Operations
/// </summary>
protected virtual void _AddOpInternal(Operation op)
{
if (op.inputs.Length == 0)
{
//If we're in a while loop, remove any control inputs from outside the
// loop.
_RemoveExternalControlEdges(op);
if (!op.control_inputs.Any(input_op => OpInContext(input_op)))
op._add_control_input(_pivot.op);
}
else
{
// Make each input to 'op' available in this CondContext. If an input is
// already part of this context there's nothing to do, but if it's
// external, AddValue() will handle adding the appropriate Switch node and
// other bookkeeping.
for (int index = 0; index < op.inputs.Length; index++)
{
var x = op.inputs[index];
Tensor real_x = null;
if (op.type == "Merge" && x.op.type == "NextIteration")
{
//# Edge case: if we're importing a while loop inside this CondContext,
//# AddValue() will not correctly handle the NextIteration inputs to
//# Merge node. The problem is that the NextIteration should also be
//# part of this context, but if we're importing it won't have been
//# processed and added to the context yet, so AddValue() will try to
//# add a Switch which results in an invalid graph. Instead, we use the
//# NextIteration input as-is here, and it will eventually be added to
//# the context via AddOp().
real_x = x;
}
else
{
real_x = AddValue(x);
}
if (real_x != x)
op._update_input(index, real_x);
}
// Remove any external control dependency on this op.
_RemoveExternalControlEdges(op);
// TODO: implement below code dependencies
//if (op.graph._is_function(op.type) || op.type == "SymbolicGradient")
// op._add_control_input(_pivot.op);
}
// Mark op's outputs as seen by this context and any outer contexts.
var output_names = op.outputs.Select(x => x.name).ToArray();
IControlFlowContext ctxt = this;
while (ctxt != null)
{
foreach(var name in output_names)
ctxt.values.Add(name);
ctxt = ctxt.outer_context;
}

if (_outer_context != null || !control_flow_ops.IsLoopExit(op))
op.graph.prevent_fetching(op);

if (_outer_context != null)
_outer_context.AddInnerOp(op);
}
private bool OpInContext(Operation op)
protected bool OpInContext(Operation op)
{
return IsContainingContext(op._get_control_flow_context(), this);
}


+ 2
- 1
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -23,7 +23,8 @@ namespace Tensorflow
return with(ops.name_scope(name, "l2_normalize", new { x }), scope =>
{
x = ops.convert_to_tensor(x, name: "x");
var square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims: true);
var sq = math_ops.square(x);
var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true);
var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon));
return math_ops.multiply(x, x_inv_norm, name: name);
});


+ 2
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -360,6 +360,8 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session()
{
if (tf.defaultSession == null)
tf.defaultSession = tf.Session();
return tf.defaultSession;
}



+ 1
- 4
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -143,10 +143,7 @@ namespace TensorFlowNET.UnitTest
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == null)
sess = self.session();
with<Session>(sess, s =>
with(ops.get_default_session(), s =>
{
var ndarray=tensor.eval();
if (typeof(T) == typeof(double))


Loading…
Cancel
Save