using System; using System.Collections; using System.Collections.Generic; using Microsoft.VisualStudio.TestTools.UnitTesting; using Newtonsoft.Json.Linq; using NumSharp; using Tensorflow; using Tensorflow.Util; using static Tensorflow.Binding; namespace TensorFlowNET.UnitTest.nest_test { /// /// excerpt of tensorflow/python/framework/util/nest_test.py /// [TestClass] public class NestTest : PythonTest { [TestInitialize] public void TestInitialize() { tf.Graph().as_default(); } //public class PointXY //{ // public double x; // public double y; //} // if attr: // class BadAttr(object): // """Class that has a non-iterable __attrs_attrs__.""" // __attrs_attrs__ = None // @attr.s // class SampleAttr(object): // field1 = attr.ib() // field2 = attr.ib() // @test_util.assert_no_new_pyobjects_executing_eagerly // def testAttrsFlattenAndPack(self) : // if attr is None: // self.skipTest("attr module is unavailable.") // field_values = [1, 2] // sample_attr = NestTest.SampleAttr(* field_values) // self.assertFalse(nest._is_attrs(field_values)) // self.assertTrue(nest._is_attrs(sample_attr)) // flat = nest.flatten(sample_attr) // self.assertEqual(field_values, flat) // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) // self.assertEqual(restructured_from_flat, sample_attr) //# Check that flatten fails if attributes are not iterable // with self.assertRaisesRegexp(TypeError, "object is not iterable"): // flat = nest.flatten(NestTest.BadAttr()) [Ignore] [TestMethod] public void testFlattenAndPack() { object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } }; var flat = new List { "a", "b", "c", "d", "e", "f", "g", "h" }; self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 }); self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString()); structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } }; flat = new List { 4, 2, 1, 0 }; self.assertEqual(nest.flatten(structure), flat); var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[]; //Console.WriteLine(JArray.FromObject(restructured_from_flat)); self.assertEqual(restructured_from_flat, structure); self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4); self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2); self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1); self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0); self.assertEqual(new List { 5 }, nest.flatten(5)); var flat1 = nest.flatten(np.array(new[] { 5 })); self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1); self.assertEqual("a", nest.pack_sequence_as(5, new List { "a" })); self.assertEqual(np.array(new[] { 5 }), nest.pack_sequence_as("scalar", new List { np.array(new[] { 5 }) })); Assert.ThrowsException(() => nest.pack_sequence_as("scalar", new List() { 4, 5 })); Assert.ThrowsException(() => nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List { "a", "b", "c" })); } // @parameterized.parameters({"mapping_type": collections.OrderedDict // }, // {"mapping_type": _CustomMapping //}) // @test_util.assert_no_new_pyobjects_executing_eagerly // def testFlattenDictOrder(self, mapping_type) : // """`flatten` orders dicts by key, including OrderedDicts.""" // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) // plain = {"d": 3, "b": 1, "a": 0, "c": 2} // ordered_flat = nest.flatten(ordered) // plain_flat = nest.flatten(plain) // self.assertEqual([0, 1, 2, 3], ordered_flat) // self.assertEqual([0, 1, 2, 3], plain_flat) // @parameterized.parameters({"mapping_type": collections.OrderedDict}, // {"mapping_type": _CustomMapping}) // def testPackDictOrder(self, mapping_type): // """Packing orders dicts by key, including OrderedDicts.""" // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) // plain = {"d": 0, "b": 0, "a": 0, "c": 0} // seq = [0, 1, 2, 3] //custom_reconstruction = nest.pack_sequence_as(custom, seq) //plain_reconstruction = nest.pack_sequence_as(plain, seq) // self.assertIsInstance(custom_reconstruction, mapping_type) // self.assertIsInstance(plain_reconstruction, dict) // self.assertEqual( // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), // custom_reconstruction) // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name // @test_util.assert_no_new_pyobjects_executing_eagerly // def testFlattenAndPack_withDicts(self) : // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. // mess = [ // "z", // NestTest.Abc(3, 4), { // "d": _CustomMapping({ // 41: 4 // }), // "c": [ // 1, // collections.OrderedDict([ // ("b", 3), // ("a", 2), // ]), // ], // "b": 5 // }, 17 // ] // flattened = nest.flatten(mess) // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) // structure_of_mess = [ // 14, // NestTest.Abc("a", True), // { // "d": _CustomMapping({ // 41: 42 // }), // "c": [ // 0, // collections.OrderedDict([ // ("b", 9), // ("a", 8), // ]), // ], // "b": 3 // }, // "hi everybody", // ] // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) // self.assertEqual(unflattened, mess) // # Check also that the OrderedDict was created, with the correct key order. //unflattened_ordered_dict = unflattened[2]["c"][1] // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) // unflattened_custom_mapping = unflattened[2]["d"] // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) [TestMethod] public void testFlatten_numpyIsNotFlattened() { var structure = np.array(1, 2, 3); var flattened = nest.flatten(structure); self.assertEqual(len(flattened), 1); } [TestMethod] public void testFlatten_stringIsNotFlattened() { var structure = "lots of letters"; var flattened = nest.flatten(structure); self.assertEqual(len(flattened), 1); var unflattened = nest.pack_sequence_as("goodbye", flattened); self.assertEqual(structure, unflattened); } // def testPackSequenceAs_notIterableError(self) : // with self.assertRaisesRegexp(TypeError, // "flat_sequence must be a sequence"): // nest.pack_sequence_as("hi", "bye") [TestMethod] public void testPackSequenceAs_wrongLengthsError() { Assert.ThrowsException(() => { // with self.assertRaisesRegexp( // ValueError, // "Structure had 2 elements, but flat_sequence had 3 elements."): nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" }); }); } [TestMethod] public void testIsSequence() { self.assertFalse(nest.is_sequence("1234")); self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } })); // TODO: ValueTuple //self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))); self.assertTrue(nest.is_sequence(new object[] { })); self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 })); self.assertFalse(nest.is_sequence(new HashSet { 1, 2 })); var ones = array_ops.ones(new int[] { 2, 3 }); self.assertFalse(nest.is_sequence(ones)); self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones))); self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 }))); } // @parameterized.parameters({"mapping_type": _CustomMapping}, // {"mapping_type": dict}) // def testFlattenDictItems(self, mapping_type): // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) // flat = {4: "a", 5: "b", 6: "c", 8: "d"} // self.assertEqual(nest.flatten_dict_items(dictionary), flat) // with self.assertRaises(TypeError): // nest.flatten_dict_items(4) // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) // with self.assertRaisesRegexp(ValueError, "not unique"): // nest.flatten_dict_items(bad_dictionary) // another_bad_dictionary = mapping_type({ // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) // }) // with self.assertRaisesRegexp( // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): // nest.flatten_dict_items(another_bad_dictionary) //# pylint does not correctly recognize these as class names and //# suggests to use variable style under_score naming. //# pylint: disable=invalid-name // Named0ab = collections.namedtuple("named_0", ("a", "b")) // Named1ab = collections.namedtuple("named_1", ("a", "b")) // SameNameab = collections.namedtuple("same_name", ("a", "b")) // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) // SameNamexy = collections.namedtuple("same_name", ("x", "y")) // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) // # pylint: enable=invalid-name // class SameNamedType1(SameNameab): // pass // @test_util.assert_no_new_pyobjects_executing_eagerly // def testAssertSameStructure(self): // structure1 = (((1, 2), 3), 4, (5, 6)) // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) // structure_different_num_elements = ("spam", "eggs") // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) // nest.assert_same_structure(structure1, structure2) // nest.assert_same_structure("abc", 1.0) // nest.assert_same_structure("abc", np.array([0, 1])) // nest.assert_same_structure("abc", constant_op.constant([0, 1])) // with self.assertRaisesRegexp( // ValueError, // ("The two structures don't have the same nested structure\\.\n\n" // "First structure:.*?\n\n" // "Second structure:.*\n\n" // "More specifically: Substructure " // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' // 'substructure "type=str str=spam" is not\n' // "Entire first structure:\n" // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" // "Entire second structure:\n" // r"\(\., \.\)")): // nest.assert_same_structure(structure1, structure_different_num_elements) // with self.assertRaisesRegexp( // ValueError, // ("The two structures don't have the same nested structure\\.\n\n" // "First structure:.*?\n\n" // "Second structure:.*\n\n" // r'More specifically: Substructure "type=list str=\[0, 1\]" ' // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' // "is not")): // nest.assert_same_structure([0, 1], np.array([0, 1])) // with self.assertRaisesRegexp( // ValueError, // ("The two structures don't have the same nested structure\\.\n\n" // "First structure:.*?\n\n" // "Second structure:.*\n\n" // r'More specifically: Substructure "type=list str=\[0, 1\]" ' // 'is a sequence, while substructure "type=int str=0" ' // "is not")): // nest.assert_same_structure(0, [0, 1]) // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) // with self.assertRaisesRegexp( // ValueError, // ("don't have the same nested structure\\.\n\n" // "First structure: .*?\n\nSecond structure: ")): // nest.assert_same_structure(structure1, structure_different_nesting) // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), // NestTest.Named0ab("a", "b")) // nest.assert_same_structure(NestTest.Named0ab(3, 4), // NestTest.Named0ab("a", "b")) // self.assertRaises(TypeError, nest.assert_same_structure, // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) // with self.assertRaisesRegexp( // ValueError, // ("don't have the same nested structure\\.\n\n" // "First structure: .*?\n\nSecond structure: ")): // nest.assert_same_structure(NestTest.Named0ab(3, 4), // NestTest.Named0ab([3], 4)) // with self.assertRaisesRegexp( // ValueError, // ("don't have the same nested structure\\.\n\n" // "First structure: .*?\n\nSecond structure: ")): // nest.assert_same_structure([[3], 4], [3, [4]]) // structure1_list = [[[1, 2], 3], 4, [5, 6]] // with self.assertRaisesRegexp(TypeError, // "don't have the same sequence type"): // nest.assert_same_structure(structure1, structure1_list) // nest.assert_same_structure(structure1, structure2, check_types= False) // nest.assert_same_structure(structure1, structure1_list, check_types=False) // with self.assertRaisesRegexp(ValueError, // "don't have the same set of keys"): // nest.assert_same_structure({"a": 1}, {"b": 1}) // nest.assert_same_structure(NestTest.SameNameab(0, 1), // NestTest.SameNameab2(2, 3)) // # This assertion is expected to pass: two namedtuples with the same // # name and field names are considered to be identical. // nest.assert_same_structure( // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) // expected_message = "The two structures don't have the same.*" // with self.assertRaisesRegexp(ValueError, expected_message): // nest.assert_same_structure( // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) // self.assertRaises(TypeError, nest.assert_same_structure, // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) // self.assertRaises(TypeError, nest.assert_same_structure, // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) // self.assertRaises(TypeError, nest.assert_same_structure, // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name // def testHeterogeneousComparison(self): // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) [Ignore] [TestMethod] public void testMapStructure() { var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } }; var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } }; var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1); var structure1_strings = nest.map_structure(x => $"{x}", structure1); var s = JArray.FromObject(structure1_plus1).ToString(); Console.WriteLine(s); // nest.assert_same_structure(structure1, structure1_plus1) self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 }); self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" }); var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2); self.assertEqual( new object[] { new object[] { new object[] { 1 + 7, 2 + 8}, 3 + 9}, 4 + 10, new object[] { 5 + 11, 6 + 12}}, structure1_plus_structure2); // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) // # Empty structures // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, // NestTest.EmptyNT())) // # This is checking actual equality of types, empty list != empty tuple // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) // with self.assertRaisesRegexp(TypeError, "callable"): // nest.map_structure("bad", structure1_plus1) // with self.assertRaisesRegexp(ValueError, "at least one structure"): // nest.map_structure(lambda x: x) // with self.assertRaisesRegexp(ValueError, "same number of elements"): // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) // with self.assertRaisesRegexp(ValueError, "same nested structure"): // nest.map_structure(lambda x, y: None, 3, (3,)) // with self.assertRaisesRegexp(TypeError, "same sequence type"): // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) // with self.assertRaisesRegexp(ValueError, "same nested structure"): // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) // structure1_list = [[[1, 2], 3], 4, [5, 6]] // with self.assertRaisesRegexp(TypeError, "same sequence type"): // nest.map_structure(lambda x, y: None, structure1, structure1_list) // nest.map_structure(lambda x, y: None, structure1, structure1_list, // check_types=False) // with self.assertRaisesRegexp(ValueError, "same nested structure"): // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), // check_types=False) // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): // nest.map_structure(lambda x: None, structure1, foo="a") // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name } // @test_util.assert_no_new_pyobjects_executing_eagerly // def testMapStructureWithStrings(self) : // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) // out = nest.map_structure(lambda string, repeats: string* repeats, // inp_a, // inp_b) // self.assertEqual("foofoo", out.a) // self.assertEqual("bar", out.b[0]) // self.assertEqual("bazbazbaz", out.b[1]) // nt = NestTest.ABTuple(a=("something", "something_else"), // b="yet another thing") // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) // # Check the output is the correct structure, and all strings are reversed. // nest.assert_same_structure(nt, rev_nt) // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) // self.assertEqual(nt.b[::- 1], rev_nt.b) // @test_util.run_deprecated_v1 // def testMapStructureOverPlaceholders(self) : // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), // array_ops.placeholder(dtypes.float32, shape=[3, 7])) // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), // array_ops.placeholder(dtypes.float32, shape=[3, 7])) // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) // nest.assert_same_structure(output, inp_a) // self.assertShapeEqual(np.zeros((3, 4)), output[0]) // self.assertShapeEqual(np.zeros((3, 7)), output[1]) // feed_dict = { // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) // } // with self.cached_session() as sess: // output_np = sess.run(output, feed_dict=feed_dict) // self.assertAllClose(output_np[0], // feed_dict[inp_a][0] + feed_dict[inp_b][0]) // self.assertAllClose(output_np[1], // feed_dict[inp_a][1] + feed_dict[inp_b][1]) // def testAssertShallowStructure(self): // inp_ab = ["a", "b"] //inp_abc = ["a", "b", "c"] //expected_message = ( // "The two structures don't have the same sequence length. Input " // "structure has length 2, while shallow structure has length 3.") // with self.assertRaisesRegexp(ValueError, expected_message): // nest.assert_shallow_structure(inp_abc, inp_ab) // inp_ab1 = [(1, 1), (2, 2)] // inp_ab2 = [[1, 1], [2, 2]] // expected_message = ( // "The two structures don't have the same sequence type. Input structure " // "has type <(type|class) 'tuple'>, while shallow structure has type " // "<(type|class) 'list'>.") // with self.assertRaisesRegexp(TypeError, expected_message): // nest.assert_shallow_structure(inp_ab2, inp_ab1) // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} // expected_message = ( // r"The two structures don't have the same keys. Input " // r"structure has keys \['c'\], while shallow structure has " // r"keys \['d'\].") // with self.assertRaisesRegexp(ValueError, expected_message): // nest.assert_shallow_structure(inp_ab2, inp_ab1) // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) // nest.assert_shallow_structure(inp_ab, inp_ba) // # This assertion is expected to pass: two namedtuples with the same //# name and field names are considered to be identical. //inp_shallow = NestTest.SameNameab(1, 2) // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) // def testFlattenUpTo(self): // # Shallow tree ends at scalar. // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] // shallow_tree = [[True, True], [False, True]] // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) //# Shallow tree ends at string. // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, // input_tree) // input_tree_flattened = nest.flatten(input_tree) // self.assertEqual(input_tree_flattened_as_shallow_tree, // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) // # Make sure dicts are correctly flattened, yielding values, not keys. //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, // input_tree) // self.assertEqual(input_tree_flattened_as_shallow_tree, // [1, { "c": 2}, 3, (4, 5)]) // # Namedtuples. // ab_tuple = NestTest.ABTuple // input_tree = ab_tuple(a =[0, 1], b = 2) // shallow_tree = ab_tuple(a= 0, b= 1) // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, // input_tree) // self.assertEqual(input_tree_flattened_as_shallow_tree, // [[0, 1], 2]) // # Nested dicts, OrderedDicts and namedtuples. // input_tree = collections.OrderedDict( // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) // shallow_tree = input_tree // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, // input_tree) // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, // input_tree) // self.assertEqual(input_tree_flattened_as_shallow_tree, // [ab_tuple(a =[0, { "b": 1}], b=2), // 3, // collections.OrderedDict([("f", 4)])]) // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, // input_tree) // self.assertEqual(input_tree_flattened_as_shallow_tree, // [ab_tuple(a =[0, {"b": 1}], b=2), // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) // ## Shallow non-list edge-case. // # Using iterable elements. // input_tree = ["input_tree"] //shallow_tree = "shallow_tree" // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [input_tree]) // self.assertEqual(flattened_shallow_tree, [shallow_tree]) // input_tree = ["input_tree_0", "input_tree_1"] //shallow_tree = "shallow_tree" // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [input_tree]) // self.assertEqual(flattened_shallow_tree, [shallow_tree]) // # Using non-iterable elements. //input_tree = [0] //shallow_tree = 9 // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [input_tree]) // self.assertEqual(flattened_shallow_tree, [shallow_tree]) // input_tree = [0, 1] //shallow_tree = 9 // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [input_tree]) // self.assertEqual(flattened_shallow_tree, [shallow_tree]) // ## Both non-list edge-case. //# Using iterable elements. //input_tree = "input_tree" // shallow_tree = "shallow_tree" // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [input_tree]) // self.assertEqual(flattened_shallow_tree, [shallow_tree]) // # Using non-iterable elements. //input_tree = 0 // shallow_tree = 0 // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_input_tree, [input_tree]) // self.assertEqual(flattened_shallow_tree, [shallow_tree]) // ## Input non-list edge-case. //# Using iterable elements. //input_tree = "input_tree" // shallow_tree = ["shallow_tree"] //expected_message = ("If shallow structure is a sequence, input must also " // "be a sequence. Input has type: <(type|class) 'str'>.") // with self.assertRaisesRegexp(TypeError, expected_message): // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_shallow_tree, shallow_tree) // input_tree = "input_tree" // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] //with self.assertRaisesRegexp(TypeError, expected_message): // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_shallow_tree, shallow_tree) //# Using non-iterable elements. // input_tree = 0 // shallow_tree = [9] //expected_message = ("If shallow structure is a sequence, input must also " // "be a sequence. Input has type: <(type|class) 'int'>.") // with self.assertRaisesRegexp(TypeError, expected_message): // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_shallow_tree, shallow_tree) // input_tree = 0 // shallow_tree = [9, 8] //with self.assertRaisesRegexp(TypeError, expected_message): // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) // self.assertEqual(flattened_shallow_tree, shallow_tree) // def testMapStructureUpTo(self) : // # Named tuples. // ab_tuple = collections.namedtuple("ab_tuple", "a, b") // op_tuple = collections.namedtuple("op_tuple", "add, mul") // inp_val = ab_tuple(a= 2, b= 3) // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) // out = nest.map_structure_up_to( // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) // self.assertEqual(out.a, 6) // self.assertEqual(out.b, 15) // # Lists. // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] // name_list = ["evens", ["odds", "primes"]] // out = nest.map_structure_up_to( // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), // name_list, data_list) // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) // # Dicts. // inp_val = dict(a= 2, b= 3) // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) // out = nest.map_structure_up_to( // inp_val, // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) // self.assertEqual(out["a"], 6) // self.assertEqual(out["b"], 15) // # Non-equal dicts. // inp_val = dict(a= 2, b= 3) // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) // with self.assertRaisesRegexp(ValueError, "same keys"): // nest.map_structure_up_to( // inp_val, // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) // # Dict+custom mapping. // inp_val = dict(a= 2, b= 3) // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) // out = nest.map_structure_up_to( // inp_val, // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) // self.assertEqual(out["a"], 6) // self.assertEqual(out["b"], 15) // # Non-equal dict/mapping. // inp_val = dict(a= 2, b= 3) // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) // with self.assertRaisesRegexp(ValueError, "same keys"): // nest.map_structure_up_to( // inp_val, // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) // def testGetTraverseShallowStructure(self): // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] // scalar_traverse_r = nest.get_traverse_shallow_structure( // lambda s: not isinstance(s, tuple), // scalar_traverse_input) // self.assertEqual(scalar_traverse_r, // [True, True, False, [True, True], {"a": False}, []]) // nest.assert_shallow_structure(scalar_traverse_r, // scalar_traverse_input) // structure_traverse_input = [(1, [2]), ([1], 2)] // structure_traverse_r = nest.get_traverse_shallow_structure( // lambda s: (True, False) if isinstance(s, tuple) else True, // structure_traverse_input) // self.assertEqual(structure_traverse_r, // [(True, False), ([True], False)]) // nest.assert_shallow_structure(structure_traverse_r, // structure_traverse_input) // with self.assertRaisesRegexp(TypeError, "returned structure"): // nest.get_traverse_shallow_structure(lambda _: [True], 0) // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): // nest.get_traverse_shallow_structure(lambda _: 1, [1]) // with self.assertRaisesRegexp( // TypeError, "didn't return a depth=1 structure of bools"): // nest.get_traverse_shallow_structure(lambda _: [1], [1]) // def testYieldFlatStringPaths(self): // for inputs_expected in ({"inputs": [], "expected": []}, // {"inputs": 3, "expected": [()]}, // {"inputs": [3], "expected": [(0,)]}, // {"inputs": {"a": 3}, "expected": [("a",)]}, // {"inputs": {"a": {"b": 4}}, // "expected": [("a", "b")]}, // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, // {"inputs": [{"a": [(23, 42)]}], // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, // {"inputs": [{"a": ([23], 42)}], // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, // "expected": [("a", "a"), ("c", 0, 0, 0)]}, // {"inputs": {"0": [{"1": 23}]}, // "expected": [("0", 0, "1")]}): // inputs = inputs_expected["inputs"] // expected = inputs_expected["expected"] // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) // def testFlattenWithStringPaths(self): // for inputs_expected in ( // {"inputs": [], "expected": []}, // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): // inputs = inputs_expected["inputs"] // expected = inputs_expected["expected"] // self.assertEqual( // nest.flatten_with_joined_string_paths(inputs, separator="/"), // expected) // # Need a separate test for namedtuple as we can't declare tuple definitions // # in the @parameterized arguments. // def testFlattenNamedTuple(self): // # pylint: disable=invalid-name // Foo = collections.namedtuple("Foo", ["a", "b"]) // Bar = collections.namedtuple("Bar", ["c", "d"]) // # pylint: enable=invalid-name // test_cases = [ // (Foo(a = 3, b = Bar(c = 23, d = 42)), // [("a", 3), ("b/c", 23), ("b/d", 42)]), // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), // (Bar(c = 42, d = 43), // [("c", 42), ("d", 43)]), // (Bar(c =[42], d = 43), // [("c/0", 42), ("d", 43)]), // ] // for inputs, expected in test_cases: // self.assertEqual( // list(nest.flatten_with_joined_string_paths(inputs)), expected) // @parameterized.named_parameters( // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, // {"a": ("a", 4), "b": ("b", 6)}), // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), // ("nested", // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, // {"a": [("a/0", 10), ("a/1", 12)], // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): // def format_sum(path, * values): // return (path, sum(values)) // result = nest.map_structure_with_paths(format_sum, s1, s2, // check_types=check_types) // self.assertEqual(expected, result) // @parameterized.named_parameters( // ("tuples", (1, 2), (3, 4, 5), ValueError), // ("dicts", {"a": 1}, {"b": 2}, ValueError), // ("mixed", (1, 2), [3, 4], TypeError), // ("nested", // {"a": [2, 3], "b": [1, 3]}, // {"b": [5, 6, 7], "a": [8, 9]}, // ValueError // )) // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): // with self.assertRaises(error_type): // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) //class NestBenchmark(test.Benchmark): // def run_and_report(self, s1, s2, name): // burn_iter, test_iter = 100, 30000 // for _ in xrange(burn_iter) : // nest.assert_same_structure(s1, s2) // t0 = time.time() // for _ in xrange(test_iter) : // nest.assert_same_structure(s1, s2) // t1 = time.time() // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, // name=name) // def benchmark_assert_structure(self): // s1 = (((1, 2), 3), 4, (5, 6)) // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) // self.run_and_report(s1, s2, "assert_same_structure_6_elem") // s1 = (((1, 2), 3), 4, (5, 6)) * 10 // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 // self.run_and_report(s1, s2, "assert_same_structure_60_elem") //if __name__ == "__main__": // test.main() } }