using System.Collections;
using System.Collections.Generic;
using Colorful;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using NumSharp;
using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Python;
namespace TensorFlowNET.UnitTest.nest_test
{
///
/// excerpt of tensorflow/python/framework/util/nest_test.py
///
[TestClass]
public class NestTest : PythonTest
{
[TestInitialize]
public void TestInitialize()
{
tf.Graph().as_default();
}
//public class PointXY
//{
// public double x;
// public double y;
//}
// if attr:
// class BadAttr(object):
// """Class that has a non-iterable __attrs_attrs__."""
// __attrs_attrs__ = None
// @attr.s
// class SampleAttr(object):
// field1 = attr.ib()
// field2 = attr.ib()
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testAttrsFlattenAndPack(self) :
// if attr is None:
// self.skipTest("attr module is unavailable.")
// field_values = [1, 2]
// sample_attr = NestTest.SampleAttr(* field_values)
// self.assertFalse(nest._is_attrs(field_values))
// self.assertTrue(nest._is_attrs(sample_attr))
// flat = nest.flatten(sample_attr)
// self.assertEqual(field_values, flat)
// restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
// self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
// self.assertEqual(restructured_from_flat, sample_attr)
//# Check that flatten fails if attributes are not iterable
// with self.assertRaisesRegexp(TypeError, "object is not iterable"):
// flat = nest.flatten(NestTest.BadAttr())
[TestMethod]
public void testFlattenAndPack()
{
object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } };
var flat = new List { "a", "b", "c", "d", "e", "f", "g", "h" };
self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 });
self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString());
structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } };
flat = new List { 4, 2, 1, 0 };
self.assertEqual(nest.flatten(structure), flat);
var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[];
//Console.WriteLine(JArray.FromObject(restructured_from_flat));
self.assertEqual(restructured_from_flat, structure);
self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4);
self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2);
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1);
self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);
self.assertEqual(new List { 5 }, nest.flatten(5));
var flat1 = nest.flatten(np.array(new[] { 5 }));
self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1);
self.assertEqual("a", nest.pack_sequence_as(5, new List { "a" }));
self.assertEqual(np.array(new[] { 5 }),
nest.pack_sequence_as("scalar", new List { np.array(new[] { 5 }) }));
Assert.ThrowsException(() => nest.pack_sequence_as("scalar", new List() { 4, 5 }));
Assert.ThrowsException(() =>
nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List { "a", "b", "c" }));
}
// @parameterized.parameters({"mapping_type": collections.OrderedDict
// },
// {"mapping_type": _CustomMapping
//})
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testFlattenDictOrder(self, mapping_type) :
// """`flatten` orders dicts by key, including OrderedDicts."""
// ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
// plain = {"d": 3, "b": 1, "a": 0, "c": 2}
// ordered_flat = nest.flatten(ordered)
// plain_flat = nest.flatten(plain)
// self.assertEqual([0, 1, 2, 3], ordered_flat)
// self.assertEqual([0, 1, 2, 3], plain_flat)
// @parameterized.parameters({"mapping_type": collections.OrderedDict},
// {"mapping_type": _CustomMapping})
// def testPackDictOrder(self, mapping_type):
// """Packing orders dicts by key, including OrderedDicts."""
// custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
// plain = {"d": 0, "b": 0, "a": 0, "c": 0}
// seq = [0, 1, 2, 3]
//custom_reconstruction = nest.pack_sequence_as(custom, seq)
//plain_reconstruction = nest.pack_sequence_as(plain, seq)
// self.assertIsInstance(custom_reconstruction, mapping_type)
// self.assertIsInstance(plain_reconstruction, dict)
// self.assertEqual(
// mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
// custom_reconstruction)
// self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
// Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testFlattenAndPack_withDicts(self) :
// # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
// mess = [
// "z",
// NestTest.Abc(3, 4), {
// "d": _CustomMapping({
// 41: 4
// }),
// "c": [
// 1,
// collections.OrderedDict([
// ("b", 3),
// ("a", 2),
// ]),
// ],
// "b": 5
// }, 17
// ]
// flattened = nest.flatten(mess)
// self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
// structure_of_mess = [
// 14,
// NestTest.Abc("a", True),
// {
// "d": _CustomMapping({
// 41: 42
// }),
// "c": [
// 0,
// collections.OrderedDict([
// ("b", 9),
// ("a", 8),
// ]),
// ],
// "b": 3
// },
// "hi everybody",
// ]
// unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
// self.assertEqual(unflattened, mess)
// # Check also that the OrderedDict was created, with the correct key order.
//unflattened_ordered_dict = unflattened[2]["c"][1]
// self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
// self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
// unflattened_custom_mapping = unflattened[2]["d"]
// self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
// self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
[TestMethod]
public void testFlatten_numpyIsNotFlattened()
{
var structure = np.array(1, 2, 3);
var flattened = nest.flatten(structure);
self.assertEqual(len(flattened), 1);
}
[TestMethod]
public void testFlatten_stringIsNotFlattened()
{
var structure = "lots of letters";
var flattened = nest.flatten(structure);
self.assertEqual(len(flattened), 1);
var unflattened = nest.pack_sequence_as("goodbye", flattened);
self.assertEqual(structure, unflattened);
}
// def testPackSequenceAs_notIterableError(self) :
// with self.assertRaisesRegexp(TypeError,
// "flat_sequence must be a sequence"):
// nest.pack_sequence_as("hi", "bye")
[TestMethod]
public void testPackSequenceAs_wrongLengthsError()
{
Assert.ThrowsException(() =>
{
// with self.assertRaisesRegexp(
// ValueError,
// "Structure had 2 elements, but flat_sequence had 3 elements."):
nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" });
});
}
[TestMethod]
public void testIsSequence()
{
self.assertFalse(nest.is_sequence("1234"));
self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } }));
// TODO: ValueTuple
//self.assertTrue(nest.is_sequence(((7, 8), (5, 6))));
self.assertTrue(nest.is_sequence(new object[] { }));
self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 }));
self.assertFalse(nest.is_sequence(new HashSet { 1, 2 }));
var ones = array_ops.ones(new int[] { 2, 3 });
self.assertFalse(nest.is_sequence(ones));
self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones)));
self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 })));
}
// @parameterized.parameters({"mapping_type": _CustomMapping},
// {"mapping_type": dict})
// def testFlattenDictItems(self, mapping_type):
// dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
// flat = {4: "a", 5: "b", 6: "c", 8: "d"}
// self.assertEqual(nest.flatten_dict_items(dictionary), flat)
// with self.assertRaises(TypeError):
// nest.flatten_dict_items(4)
// bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
// with self.assertRaisesRegexp(ValueError, "not unique"):
// nest.flatten_dict_items(bad_dictionary)
// another_bad_dictionary = mapping_type({
// (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
// })
// with self.assertRaisesRegexp(
// ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
// nest.flatten_dict_items(another_bad_dictionary)
//# pylint does not correctly recognize these as class names and
//# suggests to use variable style under_score naming.
//# pylint: disable=invalid-name
// Named0ab = collections.namedtuple("named_0", ("a", "b"))
// Named1ab = collections.namedtuple("named_1", ("a", "b"))
// SameNameab = collections.namedtuple("same_name", ("a", "b"))
// SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
// SameNamexy = collections.namedtuple("same_name", ("x", "y"))
// SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
// SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
// NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
// # pylint: enable=invalid-name
// class SameNamedType1(SameNameab):
// pass
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testAssertSameStructure(self):
// structure1 = (((1, 2), 3), 4, (5, 6))
// structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
// structure_different_num_elements = ("spam", "eggs")
// structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
// nest.assert_same_structure(structure1, structure2)
// nest.assert_same_structure("abc", 1.0)
// nest.assert_same_structure("abc", np.array([0, 1]))
// nest.assert_same_structure("abc", constant_op.constant([0, 1]))
// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// "More specifically: Substructure "
// r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
// 'substructure "type=str str=spam" is not\n'
// "Entire first structure:\n"
// r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
// "Entire second structure:\n"
// r"\(\., \.\)")):
// nest.assert_same_structure(structure1, structure_different_num_elements)
// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// r'More specifically: Substructure "type=list str=\[0, 1\]" '
// r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
// "is not")):
// nest.assert_same_structure([0, 1], np.array([0, 1]))
// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// r'More specifically: Substructure "type=list str=\[0, 1\]" '
// 'is a sequence, while substructure "type=int str=0" '
// "is not")):
// nest.assert_same_structure(0, [0, 1])
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure(structure1, structure_different_nesting)
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
// NestTest.Named0ab("a", "b"))
// nest.assert_same_structure(NestTest.Named0ab(3, 4),
// NestTest.Named0ab("a", "b"))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure(NestTest.Named0ab(3, 4),
// NestTest.Named0ab([3], 4))
// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure([[3], 4], [3, [4]])
// structure1_list = [[[1, 2], 3], 4, [5, 6]]
// with self.assertRaisesRegexp(TypeError,
// "don't have the same sequence type"):
// nest.assert_same_structure(structure1, structure1_list)
// nest.assert_same_structure(structure1, structure2, check_types= False)
// nest.assert_same_structure(structure1, structure1_list, check_types=False)
// with self.assertRaisesRegexp(ValueError,
// "don't have the same set of keys"):
// nest.assert_same_structure({"a": 1}, {"b": 1})
// nest.assert_same_structure(NestTest.SameNameab(0, 1),
// NestTest.SameNameab2(2, 3))
// # This assertion is expected to pass: two namedtuples with the same
// # name and field names are considered to be identical.
// nest.assert_same_structure(
// NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
// NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
// expected_message = "The two structures don't have the same.*"
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_same_structure(
// NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
// NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
// EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
// def testHeterogeneousComparison(self):
// nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
// nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
[TestMethod]
public void testMapStructure()
{
var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } };
var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } };
var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1);
var structure1_strings = nest.map_structure(x => $"{x}", structure1);
var s = JArray.FromObject(structure1_plus1).ToString();
Console.WriteLine(s);
// nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2);
self.assertEqual(
new object[] { new object[] { new object[] { 1 + 7, 2 + 8}, 3 + 9}, 4 + 10, new object[] { 5 + 11, 6 + 12}},
structure1_plus_structure2);
// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
// self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
// # Empty structures
// self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
// self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
// self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
// self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
// NestTest.EmptyNT()))
// # This is checking actual equality of types, empty list != empty tuple
// self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
// with self.assertRaisesRegexp(TypeError, "callable"):
// nest.map_structure("bad", structure1_plus1)
// with self.assertRaisesRegexp(ValueError, "at least one structure"):
// nest.map_structure(lambda x: x)
// with self.assertRaisesRegexp(ValueError, "same number of elements"):
// nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, 3, (3,))
// with self.assertRaisesRegexp(TypeError, "same sequence type"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
// structure1_list = [[[1, 2], 3], 4, [5, 6]]
// with self.assertRaisesRegexp(TypeError, "same sequence type"):
// nest.map_structure(lambda x, y: None, structure1, structure1_list)
// nest.map_structure(lambda x, y: None, structure1, structure1_list,
// check_types=False)
// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
// check_types=False)
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
// nest.map_structure(lambda x: None, structure1, foo="a")
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
// nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
// ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
}
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testMapStructureWithStrings(self) :
// inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
// inp_b = NestTest.ABTuple(a=2, b=(1, 3))
// out = nest.map_structure(lambda string, repeats: string* repeats,
// inp_a,
// inp_b)
// self.assertEqual("foofoo", out.a)
// self.assertEqual("bar", out.b[0])
// self.assertEqual("bazbazbaz", out.b[1])
// nt = NestTest.ABTuple(a=("something", "something_else"),
// b="yet another thing")
// rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
// # Check the output is the correct structure, and all strings are reversed.
// nest.assert_same_structure(nt, rev_nt)
// self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
// self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
// self.assertEqual(nt.b[::- 1], rev_nt.b)
// @test_util.run_deprecated_v1
// def testMapStructureOverPlaceholders(self) :
// inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
// array_ops.placeholder(dtypes.float32, shape=[3, 7]))
// inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
// array_ops.placeholder(dtypes.float32, shape=[3, 7]))
// output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
// nest.assert_same_structure(output, inp_a)
// self.assertShapeEqual(np.zeros((3, 4)), output[0])
// self.assertShapeEqual(np.zeros((3, 7)), output[1])
// feed_dict = {
// inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
// inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
// }
// with self.cached_session() as sess:
// output_np = sess.run(output, feed_dict=feed_dict)
// self.assertAllClose(output_np[0],
// feed_dict[inp_a][0] + feed_dict[inp_b][0])
// self.assertAllClose(output_np[1],
// feed_dict[inp_a][1] + feed_dict[inp_b][1])
// def testAssertShallowStructure(self):
// inp_ab = ["a", "b"]
//inp_abc = ["a", "b", "c"]
//expected_message = (
// "The two structures don't have the same sequence length. Input "
// "structure has length 2, while shallow structure has length 3.")
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_shallow_structure(inp_abc, inp_ab)
// inp_ab1 = [(1, 1), (2, 2)]
// inp_ab2 = [[1, 1], [2, 2]]
// expected_message = (
// "The two structures don't have the same sequence type. Input structure "
// "has type <(type|class) 'tuple'>, while shallow structure has type "
// "<(type|class) 'list'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// nest.assert_shallow_structure(inp_ab2, inp_ab1)
// nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)
// inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
// inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
// expected_message = (
// r"The two structures don't have the same keys. Input "
// r"structure has keys \['c'\], while shallow structure has "
// r"keys \['d'\].")
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_shallow_structure(inp_ab2, inp_ab1)
// inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
// inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
// nest.assert_shallow_structure(inp_ab, inp_ba)
// # This assertion is expected to pass: two namedtuples with the same
//# name and field names are considered to be identical.
//inp_shallow = NestTest.SameNameab(1, 2)
// inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
// def testFlattenUpTo(self):
// # Shallow tree ends at scalar.
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
// shallow_tree = [[True, True], [False, True]]
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
// self.assertEqual(flattened_shallow_tree, [True, True, False, True])
//# Shallow tree ends at string.
// input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
// shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// input_tree_flattened = nest.flatten(input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
// self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
// # Make sure dicts are correctly flattened, yielding values, not keys.
//input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
// shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [1, { "c": 2}, 3, (4, 5)])
// # Namedtuples.
// ab_tuple = NestTest.ABTuple
// input_tree = ab_tuple(a =[0, 1], b = 2)
// shallow_tree = ab_tuple(a= 0, b= 1)
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [[0, 1], 2])
// # Nested dicts, OrderedDicts and namedtuples.
// input_tree = collections.OrderedDict(
// [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
// ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
// shallow_tree = input_tree
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
// shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [ab_tuple(a =[0, { "b": 1}], b=2),
// 3,
// collections.OrderedDict([("f", 4)])])
// shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [ab_tuple(a =[0, {"b": 1}], b=2),
// {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
// ## Shallow non-list edge-case.
// # Using iterable elements.
// input_tree = ["input_tree"]
//shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// input_tree = ["input_tree_0", "input_tree_1"]
//shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// # Using non-iterable elements.
//input_tree = [0]
//shallow_tree = 9
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// input_tree = [0, 1]
//shallow_tree = 9
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// ## Both non-list edge-case.
//# Using iterable elements.
//input_tree = "input_tree"
// shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// # Using non-iterable elements.
//input_tree = 0
// shallow_tree = 0
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// ## Input non-list edge-case.
//# Using iterable elements.
//input_tree = "input_tree"
// shallow_tree = ["shallow_tree"]
//expected_message = ("If shallow structure is a sequence, input must also "
// "be a sequence. Input has type: <(type|class) 'str'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
// input_tree = "input_tree"
// shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
//with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
//# Using non-iterable elements.
// input_tree = 0
// shallow_tree = [9]
//expected_message = ("If shallow structure is a sequence, input must also "
// "be a sequence. Input has type: <(type|class) 'int'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
// input_tree = 0
// shallow_tree = [9, 8]
//with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
// def testMapStructureUpTo(self) :
// # Named tuples.
// ab_tuple = collections.namedtuple("ab_tuple", "a, b")
// op_tuple = collections.namedtuple("op_tuple", "add, mul")
// inp_val = ab_tuple(a= 2, b= 3)
// inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
// self.assertEqual(out.a, 6)
// self.assertEqual(out.b, 15)
// # Lists.
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
// name_list = ["evens", ["odds", "primes"]]
// out = nest.map_structure_up_to(
// name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
// name_list, data_list)
// self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
// # Dicts.
// inp_val = dict(a= 2, b= 3)
// inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// self.assertEqual(out["a"], 6)
// self.assertEqual(out["b"], 15)
// # Non-equal dicts.
// inp_val = dict(a= 2, b= 3)
// inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
// with self.assertRaisesRegexp(ValueError, "same keys"):
// nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// # Dict+custom mapping.
// inp_val = dict(a= 2, b= 3)
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// self.assertEqual(out["a"], 6)
// self.assertEqual(out["b"], 15)
// # Non-equal dict/mapping.
// inp_val = dict(a= 2, b= 3)
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
// with self.assertRaisesRegexp(ValueError, "same keys"):
// nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// def testGetTraverseShallowStructure(self):
// scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
// scalar_traverse_r = nest.get_traverse_shallow_structure(
// lambda s: not isinstance(s, tuple),
// scalar_traverse_input)
// self.assertEqual(scalar_traverse_r,
// [True, True, False, [True, True], {"a": False}, []])
// nest.assert_shallow_structure(scalar_traverse_r,
// scalar_traverse_input)
// structure_traverse_input = [(1, [2]), ([1], 2)]
// structure_traverse_r = nest.get_traverse_shallow_structure(
// lambda s: (True, False) if isinstance(s, tuple) else True,
// structure_traverse_input)
// self.assertEqual(structure_traverse_r,
// [(True, False), ([True], False)])
// nest.assert_shallow_structure(structure_traverse_r,
// structure_traverse_input)
// with self.assertRaisesRegexp(TypeError, "returned structure"):
// nest.get_traverse_shallow_structure(lambda _: [True], 0)
// with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
// nest.get_traverse_shallow_structure(lambda _: 1, [1])
// with self.assertRaisesRegexp(
// TypeError, "didn't return a depth=1 structure of bools"):
// nest.get_traverse_shallow_structure(lambda _: [1], [1])
// def testYieldFlatStringPaths(self):
// for inputs_expected in ({"inputs": [], "expected": []},
// {"inputs": 3, "expected": [()]},
// {"inputs": [3], "expected": [(0,)]},
// {"inputs": {"a": 3}, "expected": [("a",)]},
// {"inputs": {"a": {"b": 4}},
// "expected": [("a", "b")]},
// {"inputs": [{"a": 2}], "expected": [(0, "a")]},
// {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
// {"inputs": [{"a": [(23, 42)]}],
// "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
// {"inputs": [{"a": ([23], 42)}],
// "expected": [(0, "a", 0, 0), (0, "a", 1)]},
// {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
// "expected": [("a", "a"), ("c", 0, 0, 0)]},
// {"inputs": {"0": [{"1": 23}]},
// "expected": [("0", 0, "1")]}):
// inputs = inputs_expected["inputs"]
// expected = inputs_expected["expected"]
// self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
// def testFlattenWithStringPaths(self):
// for inputs_expected in (
// {"inputs": [], "expected": []},
// {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
// {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
// inputs = inputs_expected["inputs"]
// expected = inputs_expected["expected"]
// self.assertEqual(
// nest.flatten_with_joined_string_paths(inputs, separator="/"),
// expected)
// # Need a separate test for namedtuple as we can't declare tuple definitions
// # in the @parameterized arguments.
// def testFlattenNamedTuple(self):
// # pylint: disable=invalid-name
// Foo = collections.namedtuple("Foo", ["a", "b"])
// Bar = collections.namedtuple("Bar", ["c", "d"])
// # pylint: enable=invalid-name
// test_cases = [
// (Foo(a = 3, b = Bar(c = 23, d = 42)),
// [("a", 3), ("b/c", 23), ("b/d", 42)]),
// (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
// [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
// (Bar(c = 42, d = 43),
// [("c", 42), ("d", 43)]),
// (Bar(c =[42], d = 43),
// [("c/0", 42), ("d", 43)]),
// ]
// for inputs, expected in test_cases:
// self.assertEqual(
// list(nest.flatten_with_joined_string_paths(inputs)), expected)
// @parameterized.named_parameters(
// ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
// ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
// {"a": ("a", 4), "b": ("b", 6)}),
// ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
// ("nested",
// {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
// {"a": [("a/0", 10), ("a/1", 12)],
// "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
// def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
// def format_sum(path, * values):
// return (path, sum(values))
// result = nest.map_structure_with_paths(format_sum, s1, s2,
// check_types=check_types)
// self.assertEqual(expected, result)
// @parameterized.named_parameters(
// ("tuples", (1, 2), (3, 4, 5), ValueError),
// ("dicts", {"a": 1}, {"b": 2}, ValueError),
// ("mixed", (1, 2), [3, 4], TypeError),
// ("nested",
// {"a": [2, 3], "b": [1, 3]},
// {"b": [5, 6, 7], "a": [8, 9]},
// ValueError
// ))
// def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
// with self.assertRaises(error_type):
// nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)
//class NestBenchmark(test.Benchmark):
// def run_and_report(self, s1, s2, name):
// burn_iter, test_iter = 100, 30000
// for _ in xrange(burn_iter) :
// nest.assert_same_structure(s1, s2)
// t0 = time.time()
// for _ in xrange(test_iter) :
// nest.assert_same_structure(s1, s2)
// t1 = time.time()
// self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
// name=name)
// def benchmark_assert_structure(self):
// s1 = (((1, 2), 3), 4, (5, 6))
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
// self.run_and_report(s1, s2, "assert_same_structure_6_elem")
// s1 = (((1, 2), 3), 4, (5, 6)) * 10
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
// self.run_and_report(s1, s2, "assert_same_structure_60_elem")
//if __name__ == "__main__":
// test.main()
}
}