Browse Source

Merge pull request #216 from henon/master

more tests and Graph.unique_name fix
tags/v0.9
Haiping GitHub 6 years ago
parent
commit
8d93cb003f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 368 additions and 196 deletions
  1. +47
    -40
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +24
    -3
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  3. +114
    -152
      test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs
  4. +175
    -0
      test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs
  5. +8
    -1
      test/TensorFlowNET.UnitTest/PythonTest.cs

+ 47
- 40
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -73,8 +73,8 @@ namespace Tensorflow
return var._as_graph_element();

return null;
}
}
private ITensorOrOperation _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
{
string types_str = "";
@@ -99,7 +99,7 @@ namespace Tensorflow
// If obj appears to be a name...
if (obj is string name)
{
if(name.Contains(":") && allow_tensor)
if (name.Contains(":") && allow_tensor)
{
string op_name = name.Split(':')[0];
int out_n = int.Parse(name.Split(':')[1]);
@@ -107,7 +107,7 @@ namespace Tensorflow
if (_nodes_by_name.ContainsKey(op_name))
return _nodes_by_name[op_name].outputs[out_n];
}
else if(!name.Contains(":") & allow_operation)
else if (!name.Contains(":") & allow_operation)
{
if (!_nodes_by_name.ContainsKey(name))
throw new KeyError($"The name {name} refers to an Operation not in the graph.");
@@ -166,8 +166,8 @@ namespace Tensorflow
throw new RuntimeError("Graph is finalized and cannot be modified.");
}

public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null,
public unsafe Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null,
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
{
if (inputs == null)
@@ -188,7 +188,7 @@ namespace Tensorflow
var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops);

var op = new Operation(node_def,
var op = new Operation(node_def,
this,
inputs: inputs,
output_types: dtypes,
@@ -259,54 +259,61 @@ namespace Tensorflow
_name_stack = new_stack;

return String.IsNullOrEmpty(new_stack) ? "" : new_stack + "/";
}

}
/// <summary>
/// Return a unique operation name for `name`.
///
/// Note: You rarely need to call `unique_name()` directly.Most of
/// the time you just need to create `with g.name_scope()` blocks to
/// generate structured names.
///
/// `unique_name` is used to generate structured names, separated by
/// `"/"`, to help identify operations when debugging a graph.
/// Operation names are displayed in error messages reported by the
/// TensorFlow runtime, and in various visualization tools such as
/// TensorBoard.
///
/// If `mark_as_used` is set to `True`, which is the default, a new
/// unique name is created and marked as in use.If it's set to `False`,
/// the unique name is returned without actually being marked as used.
/// This is useful when the caller simply wants to know what the name
/// to be created will be.
/// </summary>
/// <param name="name">The name for an operation.</param>
/// <param name="mark_as_used"> Whether to mark this name as being used.</param>
/// <returns>A string to be passed to `create_op()` that will be used
/// to name the operation being created.</returns>
public string unique_name(string name, bool mark_as_used = true)
{
if (!String.IsNullOrEmpty(_name_stack))
{
name = _name_stack + "/" + name;
}

// For the sake of checking for names in use, we treat names as case
// insensitive (e.g. foo = Foo).
var name_key = name.ToLower();
int i = 0;
if (_names_in_use.ContainsKey(name_key))
{
foreach (var item in _names_in_use)
{
if (item.Key == name_key)
{
i = _names_in_use[name_key];
break;
}
i++;
}
}

i = _names_in_use[name_key];
// Increment the number for "name_key".
if (mark_as_used)
if (_names_in_use.ContainsKey(name_key))
_names_in_use[name_key]++;
else
_names_in_use[name_key] = i + 1;
_names_in_use[name_key] = i + 1;
if (i > 0)
{
var base_name_key = name_key;

// Make sure the composed name key is not already used.
if (_names_in_use.ContainsKey(name_key))
var base_name_key = name_key;
while (_names_in_use.ContainsKey(name_key))
{
name_key = $"{base_name_key}_{i}";
i += 1;
}

// Mark the composed name_key as used in case someone wants
// to call unique_name("name_1").
if (mark_as_used)
_names_in_use[name_key] = 1;

name = $"{name}_{i - 1}";
// Return the new name with the original capitalization of the given name.
name = $"{name}_{i-1}";
}

return name;
}

@@ -375,8 +382,8 @@ namespace Tensorflow
public void prevent_fetching(Operation op)
{
_unfetchable_ops.Add(op);
}
}
public void Dispose()
{
c_api.TF_DeleteGraph(_handle);
@@ -387,8 +394,8 @@ namespace Tensorflow
}

public void __exit__()
{
{
}

public static implicit operator IntPtr(Graph graph)


+ 24
- 3
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -116,10 +116,19 @@ namespace Tensorflow
case int intVal:
nparray = intVal;
break;
case int[] intVals:
nparray = np.array(intVals);
break;
case int[,] intVals:
nparray = np.array(intVals);
break;
case long intVal:
nparray = intVal;
break;
case int[] intVals:
case long[] intVals:
nparray = np.array(intVals);
break;
case long[,] intVals:
nparray = np.array(intVals);
break;
case float floatVal:
@@ -128,9 +137,18 @@ namespace Tensorflow
case float[] floatVals:
nparray = floatVals;
break;
case float[,] floatVals:
nparray = np.array(floatVals);
break;
case double doubleVal:
nparray = doubleVal;
break;
case double[] doubleVals:
nparray = np.array(doubleVals);
break;
case double[,] doubleVals:
nparray = np.array(doubleVals);
break;
case string strVal:
nparray = strVal;
break;
@@ -140,8 +158,11 @@ namespace Tensorflow
case byte[] byteValues:
nparray = byteValues;
break;
case byte[,] byteValues:
nparray = np.array(byteValues);
break;
default:
throw new NotImplementedException("make_tensor_proto Not Implemented");
throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented");
}
}
else
@@ -174,7 +195,7 @@ namespace Tensorflow
nparray = Convert.ToString(values);
break;
default:
throw new NotImplementedException("make_tensor_proto Not Implemented");
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
}
}
}


+ 114
- 152
test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs View File

@@ -70,7 +70,7 @@ namespace TensorFlowNET.UnitTest
{
a = constant_op.constant(1.0);
var b1 = future();
with(g.control_dependencies(new [] { a, b}), ctrl =>
with(g.control_dependencies(new[] { a, b }), ctrl =>
{
c = constant_op.constant(3.0);
});
@@ -157,8 +157,8 @@ namespace TensorFlowNET.UnitTest
});
});
});
AssertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
AssertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs);
assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs);
}
[TestMethod]
@@ -170,158 +170,114 @@ namespace TensorFlowNET.UnitTest
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null;
with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
with(g.control_dependencies(null), ctrl3 =>
{
with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
{
with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
{
// deps [a_3, a_4]
b_3_4 = constant_op.constant(7.0);
});
// deps = [a_3]
b_3 = constant_op.constant(8.0);
});
// deps back to None
b_none = constant_op.constant(9.0);
});
// deps back to [a_1, a_2]
b_1_2 = constant_op.constant(10.0);
});
// deps back to [a_1]
b_1 = constant_op.constant(11.0);
with(g.control_dependencies(null), ctrl6 =>
{
// deps are None again
b_none2 = constant_op.constant(12.0);
});
});
AssertItemsEqual(new[] {a_3.op, a_4.op}, b_3_4.op.control_inputs);
AssertItemsEqual(new[] {a_3.op}, b_3.op.control_inputs);
AssertItemsEqual(new object[0], b_none.op.control_inputs);
AssertItemsEqual(new[] {a_1.op, a_2.op}, b_1_2.op.control_inputs);
AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
AssertItemsEqual(new object[0], b_none2.op.control_inputs);
/*
def testClear(self):
g = ops.Graph()
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
with g.control_dependencies([a_2]):
with g.control_dependencies(None):
with g.control_dependencies([a_3]):
with g.control_dependencies([a_4]):
# deps [a_3, a_4]
b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps = [a_3]
b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to None
b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to [a_1, a_2]
b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
# deps back to [a_1]
b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies(None):
# deps are None again
b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
self.assertItemsEqual([], b_none.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
self.assertItemsEqual([], b_none2.op.control_inputs)
*/
with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
with(g.control_dependencies(null), ctrl3 =>
{
with(g.control_dependencies(new[] { a_3 }), ctrl4 =>
{
with(g.control_dependencies(new[] { a_4 }), ctrl5 =>
{
// deps [a_3, a_4]
b_3_4 = constant_op.constant(7.0);
});
// deps = [a_3]
b_3 = constant_op.constant(8.0);
});
// deps back to None
b_none = constant_op.constant(9.0);
});
// deps back to [a_1, a_2]
b_1_2 = constant_op.constant(10.0);
});
// deps back to [a_1]
b_1 = constant_op.constant(11.0);
with(g.control_dependencies(null), ctrl6 =>
{
// deps are None again
b_none2 = constant_op.constant(12.0);
});
});
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs);
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs);
assertItemsEqual(new object[0], b_none.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs);
assertItemsEqual(new object[0], b_none2.op.control_inputs);
}
[Ignore("will fail due to unsupported op 'FloatOutput'")]
[TestMethod]
public void TestComplex()
{
/*
def testComplex(self):
g = ops.Graph()
# Usage pattern:
# * Nodes a_i are constants defined at the outermost scope, and are used
# as control inputs for the ith nested scope.
# * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
# * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
# * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
# * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_1]):
b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
[dtypes.float32])
e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a_2]):
b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
[dtypes.float32])
e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
[dtypes.float32])
with g.control_dependencies([a_3]):
b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
[dtypes.float32])
e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
[dtypes.float32])
with g.control_dependencies([a_4]):
b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
[dtypes.float32])
c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
[dtypes.float32])
d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
[dtypes.float32])
e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
[dtypes.float32])
var g = tf.Graph().as_default();
// Usage pattern:
// * Nodes a_i are constants defined at the outermost scope, and are used
// as control inputs for the ith nested scope.
// * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
// * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
// * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
// * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
var a_1 = constant_op.constant(1.0);
var a_2 = constant_op.constant(2.0);
var a_3 = constant_op.constant(3.0);
var a_4 = constant_op.constant(4.0);
Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null;
Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null;
Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null;
Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null;
with(g.control_dependencies(new[] { a_1 }), ctrl1 =>
{
b_1 = tf.multiply(a_3, a_4);
c_1 = tf.multiply(a_1, b_1.output);
d_1 = tf.multiply(b_1.output, c_1.output);
e_1 = constant_op.constant(5.0);
with(g.control_dependencies(new[] { a_2 }), ctrl2 =>
{
b_2 = tf.multiply(a_3, a_4);
c_2 = tf.multiply(a_1, b_1.output);
d_2 = tf.multiply(b_2.output, c_2.output);
e_2 = tf.multiply(e_1.output, e_1.output);
with(g.control_dependencies(new[] { a_3 }), ctrl3 =>
{
b_3 = tf.multiply(a_3, a_4);
c_3 = tf.multiply(a_1, b_1.output);
d_3 = tf.multiply(b_3.output, c_3.output);
e_3 = tf.multiply(e_2.output, e_2.output);
with(g.control_dependencies(new[] { a_4 }), ctrl4 =>
{
b_4 = tf.multiply(a_3, a_4);
c_4 = tf.multiply(a_1, b_1.output);
d_4 = tf.multiply(b_4.output, c_4.output);
e_4 = tf.multiply(e_3.output, e_3.output);
});
});
});
});
self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs);
assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs);
assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs);
assertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs);
self.assertItemsEqual([], c_1.op.control_inputs)
self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
assertItemsEqual(new object[0], c_1.op.control_inputs);
assertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs);
assertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs);
assertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs);
self.assertItemsEqual([], d_1.op.control_inputs)
self.assertItemsEqual([], d_2.op.control_inputs)
self.assertItemsEqual([], d_3.op.control_inputs)
self.assertItemsEqual([], d_4.op.control_inputs)
assertItemsEqual(new object[0], d_1.op.control_inputs);
assertItemsEqual(new object[0], d_2.op.control_inputs);
assertItemsEqual(new object[0], d_3.op.control_inputs);
assertItemsEqual(new object[0], d_4.op.control_inputs);
self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
*/
assertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs);
assertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs);
assertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs);
assertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs);
}
[Ignore("will fail due to unsupported op 'FloatOutput'")]
[Ignore("Don't know how to create an operation with two outputs")]
[TestMethod]
public void TestRepeatedDependency()
{
@@ -337,16 +293,22 @@ namespace TensorFlowNET.UnitTest
self.assertEqual(b.op.control_inputs, [a])
self.assertEqual(c.op.control_inputs, [a])
def testNoControlDependencyWithDataDependency(self):
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
with g.control_dependencies([a]):
b = _apply_op(g, "Identity", [a], [dtypes.float32])
self.assertEqual(b.op.control_inputs, [])
*/
}
[TestMethod]
public void TestNoControlDependencyWithDataDependency()
{
var g = tf.Graph().as_default();
Operation b = null;
var a = constant_op.constant(100.0);
with(g.control_dependencies(new[] { a }), ctrl1 =>
{
b = array_ops.identity(a);
});
Assert.AreEqual(0, b.op.control_inputs.Length);
}
}
}

+ 175
- 0
test/TensorFlowNET.UnitTest/CreateOpFromTfOperationTest.cs View File

@@ -0,0 +1,175 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py
/// # These cases test the private Graph._create_op_from_tf_operation
/// # method. Arguably we should only test the public APIs that depend on this
/// # method. However, this logic is complex and tricky, and it can be difficult to
/// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
/// # the control flow context isn't set properly, but a more complicated use case
/// # that might not be obvious to test will fail). Thus we instead explicitly test
/// # the low-level behavior.
/// </summary>
[TestClass]
public class CreateOpFromTfOperationTest : PythonTest
{
[TestMethod]
public void TestShape()
{
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
{
var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}});
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op);
Assert.AreEqual("myop", op.name);
Assert.AreEqual("Identity", op.type);
Assert.AreEqual(1, len(op.outputs));
assertItemsEqual(new []{2, 3}, op.outputs[0].shape);
});
}
[TestMethod]
public void TestUniqueName()
{
var graph = tf.Graph().as_default();
with<Graph>(graph, g =>
{
//var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
//var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
//var op = g._create_op_from_tf_operation(c_op);
//var op2 = g._create_op_from_tf_operation(c_op2);
var op = constant_op.constant(0, name:"myop").op;
var op2 = constant_op.constant(0, name: "myop_1").op;
// Create ops with same names as op1 and op2. We expect the new names to be
// uniquified.
var op3 = constant_op.constant(0, name: "myop").op;
var op4 = constant_op.constant(0, name: "myop_1").op;
self.assertEqual(op.name, "myop");
self.assertEqual(op2.name, "myop_1");
self.assertEqual(op3.name, "myop_2");
self.assertEqual(op4.name, "myop_1_1");
});
}
/*
@test_util.run_v1_only("b/120545219")
def testCond(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def true_fn():
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "cond/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return x
control_flow_ops.cond(x < 10, true_fn, lambda: x)
op = g.get_operation_by_name("cond/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "cond/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Switch")
self.assertEqual(op_input.inputs[0], x)
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"cond/cond_text")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoop(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
self.assertEqual(op.name, "myloop/myop")
self.assertEqual(op.type, "IntInput")
self.assertEqual(op.outputs, [])
op_input = op.inputs[0].op
self.assertEqual(op_input.type, "Enter")
self.assertEqual(list(op_input.inputs), [x])
self.assertEqual(op.graph, g)
# pylint: disable=protected-access
self.assertIsNotNone(op._get_control_flow_context())
self.assertEqual(op._get_control_flow_context().name,
"myloop/while_context")
# pylint: enable=protected-access
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithInternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
def body(i):
c = constant_op.constant(1.0, name="c")
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
c = g.get_operation_by_name("myloop/c")
self.assertIsNotNone(c)
# Internal control dep is preserved
self.assertEqual(op.control_inputs, [c])
@test_util.run_v1_only("b/120545219")
def testWhileLoopWithExternalControlDep(self):
g = ops.Graph()
with g.as_default():
x = test_ops.int_output()
c = constant_op.constant(1.0)
def body(i):
ops._create_c_op(ops.get_default_graph(),
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
with ops.control_dependencies([c]):
new_ops = g._add_new_tf_operations()
self.assertEqual(len(new_ops), 1)
return i
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
op = g.get_operation_by_name("myloop/myop")
self.assertIsNotNone(op)
# External control dep is removed and replaced with internal control dep
self.assertNotEqual(op.control_inputs[0], c.op)
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
*/
}
}

+ 8
- 1
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest
/// </summary>
public class PythonTest : Python
{
public void AssertItemsEqual(ICollection expected, ICollection given)
public void assertItemsEqual(ICollection expected, ICollection given)
{
Assert.IsNotNull(expected);
Assert.IsNotNull(given);
@@ -23,5 +23,12 @@ namespace TensorFlowNET.UnitTest
for(int i=0; i<e.Length; i++)
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
}
public void assertEqual(object given, object expected)
{
Assert.AreEqual(expected, given);
}
protected PythonTest self { get => this; }
}
}

Loading…
Cancel
Save