You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NestTest.cs 46 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. using System.Collections;
  2. using System.Collections.Generic;
  3. using Colorful;
  4. using Microsoft.VisualStudio.TestTools.UnitTesting;
  5. using Newtonsoft.Json.Linq;
  6. using NumSharp;
  7. using Tensorflow;
  8. using Tensorflow.Util;
  9. using static Tensorflow.Python;
  10. namespace TensorFlowNET.UnitTest.nest_test
  11. {
  12. /// <summary>
  13. /// excerpt of tensorflow/python/framework/util/nest_test.py
  14. /// </summary>
  15. [TestClass]
  16. public class NestTest : PythonTest
  17. {
  18. [TestInitialize]
  19. public void TestInitialize()
  20. {
  21. tf.Graph().as_default();
  22. }
  23. //public class PointXY
  24. //{
  25. // public double x;
  26. // public double y;
  27. //}
  28. // if attr:
  29. // class BadAttr(object):
  30. // """Class that has a non-iterable __attrs_attrs__."""
  31. // __attrs_attrs__ = None
  32. // @attr.s
  33. // class SampleAttr(object):
  34. // field1 = attr.ib()
  35. // field2 = attr.ib()
  36. // @test_util.assert_no_new_pyobjects_executing_eagerly
  37. // def testAttrsFlattenAndPack(self) :
  38. // if attr is None:
  39. // self.skipTest("attr module is unavailable.")
  40. // field_values = [1, 2]
  41. // sample_attr = NestTest.SampleAttr(* field_values)
  42. // self.assertFalse(nest._is_attrs(field_values))
  43. // self.assertTrue(nest._is_attrs(sample_attr))
  44. // flat = nest.flatten(sample_attr)
  45. // self.assertEqual(field_values, flat)
  46. // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
  47. // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
  48. // self.assertEqual(restructured_from_flat, sample_attr)
  49. //# Check that flatten fails if attributes are not iterable
  50. // with self.assertRaisesRegexp(TypeError, "object is not iterable"):
  51. // flat = nest.flatten(NestTest.BadAttr())
  52. [TestMethod]
  53. public void testFlattenAndPack()
  54. {
  55. object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } };
  56. var flat = new List<object> { "a", "b", "c", "d", "e", "f", "g", "h" };
  57. self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 });
  58. self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
  59. JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString());
  60. structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } };
  61. flat = new List<object> { 4, 2, 1, 0 };
  62. self.assertEqual(nest.flatten(structure), flat);
  63. var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[];
  64. //Console.WriteLine(JArray.FromObject(restructured_from_flat));
  65. self.assertEqual(restructured_from_flat, structure);
  66. self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4);
  67. self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2);
  68. self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1);
  69. self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);
  70. self.assertEqual(new List<object> { 5 }, nest.flatten(5));
  71. flat = nest.flatten(np.array(new[] { 5 }));
  72. self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat);
  73. self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" }));
  74. self.assertEqual(np.array(new[] { 5 }),
  75. nest.pack_sequence_as("scalar", new List<object> { np.array(new[] { 5 }) }));
  76. Assert.ThrowsException<ValueError>(() => nest.pack_sequence_as("scalar", new List<object>() { 4, 5 }));
  77. Assert.ThrowsException<ValueError>(() =>
  78. nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List<object> { "a", "b", "c" }));
  79. }
  80. // @parameterized.parameters({"mapping_type": collections.OrderedDict
  81. // },
  82. // {"mapping_type": _CustomMapping
  83. //})
  84. // @test_util.assert_no_new_pyobjects_executing_eagerly
  85. // def testFlattenDictOrder(self, mapping_type) :
  86. // """`flatten` orders dicts by key, including OrderedDicts."""
  87. // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
  88. // plain = {"d": 3, "b": 1, "a": 0, "c": 2}
  89. // ordered_flat = nest.flatten(ordered)
  90. // plain_flat = nest.flatten(plain)
  91. // self.assertEqual([0, 1, 2, 3], ordered_flat)
  92. // self.assertEqual([0, 1, 2, 3], plain_flat)
  93. // @parameterized.parameters({"mapping_type": collections.OrderedDict},
  94. // {"mapping_type": _CustomMapping})
  95. // def testPackDictOrder(self, mapping_type):
  96. // """Packing orders dicts by key, including OrderedDicts."""
  97. // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
  98. // plain = {"d": 0, "b": 0, "a": 0, "c": 0}
  99. // seq = [0, 1, 2, 3]
  100. //custom_reconstruction = nest.pack_sequence_as(custom, seq)
  101. //plain_reconstruction = nest.pack_sequence_as(plain, seq)
  102. // self.assertIsInstance(custom_reconstruction, mapping_type)
  103. // self.assertIsInstance(plain_reconstruction, dict)
  104. // self.assertEqual(
  105. // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
  106. // custom_reconstruction)
  107. // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
  108. // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
  109. // @test_util.assert_no_new_pyobjects_executing_eagerly
  110. // def testFlattenAndPack_withDicts(self) :
  111. // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
  112. // mess = [
  113. // "z",
  114. // NestTest.Abc(3, 4), {
  115. // "d": _CustomMapping({
  116. // 41: 4
  117. // }),
  118. // "c": [
  119. // 1,
  120. // collections.OrderedDict([
  121. // ("b", 3),
  122. // ("a", 2),
  123. // ]),
  124. // ],
  125. // "b": 5
  126. // }, 17
  127. // ]
  128. // flattened = nest.flatten(mess)
  129. // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
  130. // structure_of_mess = [
  131. // 14,
  132. // NestTest.Abc("a", True),
  133. // {
  134. // "d": _CustomMapping({
  135. // 41: 42
  136. // }),
  137. // "c": [
  138. // 0,
  139. // collections.OrderedDict([
  140. // ("b", 9),
  141. // ("a", 8),
  142. // ]),
  143. // ],
  144. // "b": 3
  145. // },
  146. // "hi everybody",
  147. // ]
  148. // unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
  149. // self.assertEqual(unflattened, mess)
  150. // # Check also that the OrderedDict was created, with the correct key order.
  151. //unflattened_ordered_dict = unflattened[2]["c"][1]
  152. // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
  153. // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
  154. // unflattened_custom_mapping = unflattened[2]["d"]
  155. // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
  156. // self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
  157. [TestMethod]
  158. public void testFlatten_numpyIsNotFlattened()
  159. {
  160. var structure = np.array(1, 2, 3);
  161. var flattened = nest.flatten(structure);
  162. self.assertEqual(len(flattened), 1);
  163. }
  164. [TestMethod]
  165. public void testFlatten_stringIsNotFlattened()
  166. {
  167. var structure = "lots of letters";
  168. var flattened = nest.flatten(structure);
  169. self.assertEqual(len(flattened), 1);
  170. var unflattened = nest.pack_sequence_as("goodbye", flattened);
  171. self.assertEqual(structure, unflattened);
  172. }
  173. // def testPackSequenceAs_notIterableError(self) :
  174. // with self.assertRaisesRegexp(TypeError,
  175. // "flat_sequence must be a sequence"):
  176. // nest.pack_sequence_as("hi", "bye")
  177. [TestMethod]
  178. public void testPackSequenceAs_wrongLengthsError()
  179. {
  180. Assert.ThrowsException<ValueError>(() =>
  181. {
  182. // with self.assertRaisesRegexp(
  183. // ValueError,
  184. // "Structure had 2 elements, but flat_sequence had 3 elements."):
  185. nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" });
  186. });
  187. }
  188. [TestMethod]
  189. public void testIsSequence()
  190. {
  191. self.assertFalse(nest.is_sequence("1234"));
  192. self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } }));
  193. // TODO: ValueTuple<T,T>
  194. //self.assertTrue(nest.is_sequence(((7, 8), (5, 6))));
  195. self.assertTrue(nest.is_sequence(new object[] { }));
  196. self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 }));
  197. self.assertFalse(nest.is_sequence(new HashSet<int> { 1, 2 }));
  198. var ones = array_ops.ones(new int[] { 2, 3 });
  199. self.assertFalse(nest.is_sequence(ones));
  200. self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones)));
  201. self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 })));
  202. }
  203. // @parameterized.parameters({"mapping_type": _CustomMapping},
  204. // {"mapping_type": dict})
  205. // def testFlattenDictItems(self, mapping_type):
  206. // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
  207. // flat = {4: "a", 5: "b", 6: "c", 8: "d"}
  208. // self.assertEqual(nest.flatten_dict_items(dictionary), flat)
  209. // with self.assertRaises(TypeError):
  210. // nest.flatten_dict_items(4)
  211. // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
  212. // with self.assertRaisesRegexp(ValueError, "not unique"):
  213. // nest.flatten_dict_items(bad_dictionary)
  214. // another_bad_dictionary = mapping_type({
  215. // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
  216. // })
  217. // with self.assertRaisesRegexp(
  218. // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
  219. // nest.flatten_dict_items(another_bad_dictionary)
  220. //# pylint does not correctly recognize these as class names and
  221. //# suggests to use variable style under_score naming.
  222. //# pylint: disable=invalid-name
  223. // Named0ab = collections.namedtuple("named_0", ("a", "b"))
  224. // Named1ab = collections.namedtuple("named_1", ("a", "b"))
  225. // SameNameab = collections.namedtuple("same_name", ("a", "b"))
  226. // SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
  227. // SameNamexy = collections.namedtuple("same_name", ("x", "y"))
  228. // SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
  229. // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
  230. // NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
  231. // # pylint: enable=invalid-name
  232. // class SameNamedType1(SameNameab):
  233. // pass
  234. // @test_util.assert_no_new_pyobjects_executing_eagerly
  235. // def testAssertSameStructure(self):
  236. // structure1 = (((1, 2), 3), 4, (5, 6))
  237. // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  238. // structure_different_num_elements = ("spam", "eggs")
  239. // structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
  240. // nest.assert_same_structure(structure1, structure2)
  241. // nest.assert_same_structure("abc", 1.0)
  242. // nest.assert_same_structure("abc", np.array([0, 1]))
  243. // nest.assert_same_structure("abc", constant_op.constant([0, 1]))
  244. // with self.assertRaisesRegexp(
  245. // ValueError,
  246. // ("The two structures don't have the same nested structure\\.\n\n"
  247. // "First structure:.*?\n\n"
  248. // "Second structure:.*\n\n"
  249. // "More specifically: Substructure "
  250. // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
  251. // 'substructure "type=str str=spam" is not\n'
  252. // "Entire first structure:\n"
  253. // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
  254. // "Entire second structure:\n"
  255. // r"\(\., \.\)")):
  256. // nest.assert_same_structure(structure1, structure_different_num_elements)
  257. // with self.assertRaisesRegexp(
  258. // ValueError,
  259. // ("The two structures don't have the same nested structure\\.\n\n"
  260. // "First structure:.*?\n\n"
  261. // "Second structure:.*\n\n"
  262. // r'More specifically: Substructure "type=list str=\[0, 1\]" '
  263. // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
  264. // "is not")):
  265. // nest.assert_same_structure([0, 1], np.array([0, 1]))
  266. // with self.assertRaisesRegexp(
  267. // ValueError,
  268. // ("The two structures don't have the same nested structure\\.\n\n"
  269. // "First structure:.*?\n\n"
  270. // "Second structure:.*\n\n"
  271. // r'More specifically: Substructure "type=list str=\[0, 1\]" '
  272. // 'is a sequence, while substructure "type=int str=0" '
  273. // "is not")):
  274. // nest.assert_same_structure(0, [0, 1])
  275. // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
  276. // with self.assertRaisesRegexp(
  277. // ValueError,
  278. // ("don't have the same nested structure\\.\n\n"
  279. // "First structure: .*?\n\nSecond structure: ")):
  280. // nest.assert_same_structure(structure1, structure_different_nesting)
  281. // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
  282. // NestTest.Named0ab("a", "b"))
  283. // nest.assert_same_structure(NestTest.Named0ab(3, 4),
  284. // NestTest.Named0ab("a", "b"))
  285. // self.assertRaises(TypeError, nest.assert_same_structure,
  286. // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
  287. // with self.assertRaisesRegexp(
  288. // ValueError,
  289. // ("don't have the same nested structure\\.\n\n"
  290. // "First structure: .*?\n\nSecond structure: ")):
  291. // nest.assert_same_structure(NestTest.Named0ab(3, 4),
  292. // NestTest.Named0ab([3], 4))
  293. // with self.assertRaisesRegexp(
  294. // ValueError,
  295. // ("don't have the same nested structure\\.\n\n"
  296. // "First structure: .*?\n\nSecond structure: ")):
  297. // nest.assert_same_structure([[3], 4], [3, [4]])
  298. // structure1_list = [[[1, 2], 3], 4, [5, 6]]
  299. // with self.assertRaisesRegexp(TypeError,
  300. // "don't have the same sequence type"):
  301. // nest.assert_same_structure(structure1, structure1_list)
  302. // nest.assert_same_structure(structure1, structure2, check_types= False)
  303. // nest.assert_same_structure(structure1, structure1_list, check_types=False)
  304. // with self.assertRaisesRegexp(ValueError,
  305. // "don't have the same set of keys"):
  306. // nest.assert_same_structure({"a": 1}, {"b": 1})
  307. // nest.assert_same_structure(NestTest.SameNameab(0, 1),
  308. // NestTest.SameNameab2(2, 3))
  309. // # This assertion is expected to pass: two namedtuples with the same
  310. // # name and field names are considered to be identical.
  311. // nest.assert_same_structure(
  312. // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
  313. // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
  314. // expected_message = "The two structures don't have the same.*"
  315. // with self.assertRaisesRegexp(ValueError, expected_message):
  316. // nest.assert_same_structure(
  317. // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
  318. // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
  319. // self.assertRaises(TypeError, nest.assert_same_structure,
  320. // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
  321. // self.assertRaises(TypeError, nest.assert_same_structure,
  322. // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
  323. // self.assertRaises(TypeError, nest.assert_same_structure,
  324. // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
  325. // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
  326. // def testHeterogeneousComparison(self):
  327. // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
  328. // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
  329. [TestMethod]
  330. public void testMapStructure()
  331. {
  332. var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } };
  333. var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } };
  334. var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1);
  335. var structure1_strings = nest.map_structure(x => $"{x}", structure1);
  336. var s = JArray.FromObject(structure1_plus1).ToString();
  337. Console.WriteLine(s);
  338. // nest.assert_same_structure(structure1, structure1_plus1)
  339. self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
  340. self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
  341. var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2);
  342. self.assertEqual(
  343. new object[] { new object[] { new object[] { 1 + 7, 2 + 8}, 3 + 9}, 4 + 10, new object[] { 5 + 11, 6 + 12}},
  344. structure1_plus_structure2);
  345. // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
  346. // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
  347. // # Empty structures
  348. // self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
  349. // self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
  350. // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
  351. // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
  352. // NestTest.EmptyNT()))
  353. // # This is checking actual equality of types, empty list != empty tuple
  354. // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
  355. // with self.assertRaisesRegexp(TypeError, "callable"):
  356. // nest.map_structure("bad", structure1_plus1)
  357. // with self.assertRaisesRegexp(ValueError, "at least one structure"):
  358. // nest.map_structure(lambda x: x)
  359. // with self.assertRaisesRegexp(ValueError, "same number of elements"):
  360. // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
  361. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  362. // nest.map_structure(lambda x, y: None, 3, (3,))
  363. // with self.assertRaisesRegexp(TypeError, "same sequence type"):
  364. // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
  365. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  366. // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
  367. // structure1_list = [[[1, 2], 3], 4, [5, 6]]
  368. // with self.assertRaisesRegexp(TypeError, "same sequence type"):
  369. // nest.map_structure(lambda x, y: None, structure1, structure1_list)
  370. // nest.map_structure(lambda x, y: None, structure1, structure1_list,
  371. // check_types=False)
  372. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  373. // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
  374. // check_types=False)
  375. // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  376. // nest.map_structure(lambda x: None, structure1, foo="a")
  377. // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  378. // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
  379. // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
  380. }
  381. // @test_util.assert_no_new_pyobjects_executing_eagerly
  382. // def testMapStructureWithStrings(self) :
  383. // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
  384. // inp_b = NestTest.ABTuple(a=2, b=(1, 3))
  385. // out = nest.map_structure(lambda string, repeats: string* repeats,
  386. // inp_a,
  387. // inp_b)
  388. // self.assertEqual("foofoo", out.a)
  389. // self.assertEqual("bar", out.b[0])
  390. // self.assertEqual("bazbazbaz", out.b[1])
  391. // nt = NestTest.ABTuple(a=("something", "something_else"),
  392. // b="yet another thing")
  393. // rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
  394. // # Check the output is the correct structure, and all strings are reversed.
  395. // nest.assert_same_structure(nt, rev_nt)
  396. // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
  397. // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
  398. // self.assertEqual(nt.b[::- 1], rev_nt.b)
  399. // @test_util.run_deprecated_v1
  400. // def testMapStructureOverPlaceholders(self) :
  401. // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  402. // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  403. // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  404. // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  405. // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
  406. // nest.assert_same_structure(output, inp_a)
  407. // self.assertShapeEqual(np.zeros((3, 4)), output[0])
  408. // self.assertShapeEqual(np.zeros((3, 7)), output[1])
  409. // feed_dict = {
  410. // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
  411. // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
  412. // }
  413. // with self.cached_session() as sess:
  414. // output_np = sess.run(output, feed_dict=feed_dict)
  415. // self.assertAllClose(output_np[0],
  416. // feed_dict[inp_a][0] + feed_dict[inp_b][0])
  417. // self.assertAllClose(output_np[1],
  418. // feed_dict[inp_a][1] + feed_dict[inp_b][1])
  419. // def testAssertShallowStructure(self):
  420. // inp_ab = ["a", "b"]
  421. //inp_abc = ["a", "b", "c"]
  422. //expected_message = (
  423. // "The two structures don't have the same sequence length. Input "
  424. // "structure has length 2, while shallow structure has length 3.")
  425. // with self.assertRaisesRegexp(ValueError, expected_message):
  426. // nest.assert_shallow_structure(inp_abc, inp_ab)
  427. // inp_ab1 = [(1, 1), (2, 2)]
  428. // inp_ab2 = [[1, 1], [2, 2]]
  429. // expected_message = (
  430. // "The two structures don't have the same sequence type. Input structure "
  431. // "has type <(type|class) 'tuple'>, while shallow structure has type "
  432. // "<(type|class) 'list'>.")
  433. // with self.assertRaisesRegexp(TypeError, expected_message):
  434. // nest.assert_shallow_structure(inp_ab2, inp_ab1)
  435. // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)
  436. // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
  437. // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
  438. // expected_message = (
  439. // r"The two structures don't have the same keys. Input "
  440. // r"structure has keys \['c'\], while shallow structure has "
  441. // r"keys \['d'\].")
  442. // with self.assertRaisesRegexp(ValueError, expected_message):
  443. // nest.assert_shallow_structure(inp_ab2, inp_ab1)
  444. // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
  445. // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
  446. // nest.assert_shallow_structure(inp_ab, inp_ba)
  447. // # This assertion is expected to pass: two namedtuples with the same
  448. //# name and field names are considered to be identical.
  449. //inp_shallow = NestTest.SameNameab(1, 2)
  450. // inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
  451. // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
  452. // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
  453. // def testFlattenUpTo(self):
  454. // # Shallow tree ends at scalar.
  455. // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  456. // shallow_tree = [[True, True], [False, True]]
  457. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  458. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  459. // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
  460. // self.assertEqual(flattened_shallow_tree, [True, True, False, True])
  461. //# Shallow tree ends at string.
  462. // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
  463. // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
  464. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  465. // input_tree)
  466. // input_tree_flattened = nest.flatten(input_tree)
  467. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  468. // [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
  469. // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
  470. // # Make sure dicts are correctly flattened, yielding values, not keys.
  471. //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
  472. // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
  473. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  474. // input_tree)
  475. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  476. // [1, { "c": 2}, 3, (4, 5)])
  477. // # Namedtuples.
  478. // ab_tuple = NestTest.ABTuple
  479. // input_tree = ab_tuple(a =[0, 1], b = 2)
  480. // shallow_tree = ab_tuple(a= 0, b= 1)
  481. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  482. // input_tree)
  483. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  484. // [[0, 1], 2])
  485. // # Nested dicts, OrderedDicts and namedtuples.
  486. // input_tree = collections.OrderedDict(
  487. // [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
  488. // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
  489. // shallow_tree = input_tree
  490. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  491. // input_tree)
  492. // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
  493. // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
  494. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  495. // input_tree)
  496. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  497. // [ab_tuple(a =[0, { "b": 1}], b=2),
  498. // 3,
  499. // collections.OrderedDict([("f", 4)])])
  500. // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
  501. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  502. // input_tree)
  503. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  504. // [ab_tuple(a =[0, {"b": 1}], b=2),
  505. // {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
  506. // ## Shallow non-list edge-case.
  507. // # Using iterable elements.
  508. // input_tree = ["input_tree"]
  509. //shallow_tree = "shallow_tree"
  510. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  511. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  512. // self.assertEqual(flattened_input_tree, [input_tree])
  513. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  514. // input_tree = ["input_tree_0", "input_tree_1"]
  515. //shallow_tree = "shallow_tree"
  516. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  517. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  518. // self.assertEqual(flattened_input_tree, [input_tree])
  519. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  520. // # Using non-iterable elements.
  521. //input_tree = [0]
  522. //shallow_tree = 9
  523. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  524. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  525. // self.assertEqual(flattened_input_tree, [input_tree])
  526. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  527. // input_tree = [0, 1]
  528. //shallow_tree = 9
  529. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  530. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  531. // self.assertEqual(flattened_input_tree, [input_tree])
  532. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  533. // ## Both non-list edge-case.
  534. //# Using iterable elements.
  535. //input_tree = "input_tree"
  536. // shallow_tree = "shallow_tree"
  537. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  538. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  539. // self.assertEqual(flattened_input_tree, [input_tree])
  540. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  541. // # Using non-iterable elements.
  542. //input_tree = 0
  543. // shallow_tree = 0
  544. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  545. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  546. // self.assertEqual(flattened_input_tree, [input_tree])
  547. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  548. // ## Input non-list edge-case.
  549. //# Using iterable elements.
  550. //input_tree = "input_tree"
  551. // shallow_tree = ["shallow_tree"]
  552. //expected_message = ("If shallow structure is a sequence, input must also "
  553. // "be a sequence. Input has type: <(type|class) 'str'>.")
  554. // with self.assertRaisesRegexp(TypeError, expected_message):
  555. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  556. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  557. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  558. // input_tree = "input_tree"
  559. // shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
  560. //with self.assertRaisesRegexp(TypeError, expected_message):
  561. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  562. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  563. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  564. //# Using non-iterable elements.
  565. // input_tree = 0
  566. // shallow_tree = [9]
  567. //expected_message = ("If shallow structure is a sequence, input must also "
  568. // "be a sequence. Input has type: <(type|class) 'int'>.")
  569. // with self.assertRaisesRegexp(TypeError, expected_message):
  570. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  571. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  572. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  573. // input_tree = 0
  574. // shallow_tree = [9, 8]
  575. //with self.assertRaisesRegexp(TypeError, expected_message):
  576. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  577. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  578. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  579. // def testMapStructureUpTo(self) :
  580. // # Named tuples.
  581. // ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  582. // op_tuple = collections.namedtuple("op_tuple", "add, mul")
  583. // inp_val = ab_tuple(a= 2, b= 3)
  584. // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
  585. // out = nest.map_structure_up_to(
  586. // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
  587. // self.assertEqual(out.a, 6)
  588. // self.assertEqual(out.b, 15)
  589. // # Lists.
  590. // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
  591. // name_list = ["evens", ["odds", "primes"]]
  592. // out = nest.map_structure_up_to(
  593. // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
  594. // name_list, data_list)
  595. // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
  596. // # Dicts.
  597. // inp_val = dict(a= 2, b= 3)
  598. // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
  599. // out = nest.map_structure_up_to(
  600. // inp_val,
  601. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  602. // self.assertEqual(out["a"], 6)
  603. // self.assertEqual(out["b"], 15)
  604. // # Non-equal dicts.
  605. // inp_val = dict(a= 2, b= 3)
  606. // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
  607. // with self.assertRaisesRegexp(ValueError, "same keys"):
  608. // nest.map_structure_up_to(
  609. // inp_val,
  610. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  611. // # Dict+custom mapping.
  612. // inp_val = dict(a= 2, b= 3)
  613. // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
  614. // out = nest.map_structure_up_to(
  615. // inp_val,
  616. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  617. // self.assertEqual(out["a"], 6)
  618. // self.assertEqual(out["b"], 15)
  619. // # Non-equal dict/mapping.
  620. // inp_val = dict(a= 2, b= 3)
  621. // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
  622. // with self.assertRaisesRegexp(ValueError, "same keys"):
  623. // nest.map_structure_up_to(
  624. // inp_val,
  625. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  626. // def testGetTraverseShallowStructure(self):
  627. // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
  628. // scalar_traverse_r = nest.get_traverse_shallow_structure(
  629. // lambda s: not isinstance(s, tuple),
  630. // scalar_traverse_input)
  631. // self.assertEqual(scalar_traverse_r,
  632. // [True, True, False, [True, True], {"a": False}, []])
  633. // nest.assert_shallow_structure(scalar_traverse_r,
  634. // scalar_traverse_input)
  635. // structure_traverse_input = [(1, [2]), ([1], 2)]
  636. // structure_traverse_r = nest.get_traverse_shallow_structure(
  637. // lambda s: (True, False) if isinstance(s, tuple) else True,
  638. // structure_traverse_input)
  639. // self.assertEqual(structure_traverse_r,
  640. // [(True, False), ([True], False)])
  641. // nest.assert_shallow_structure(structure_traverse_r,
  642. // structure_traverse_input)
  643. // with self.assertRaisesRegexp(TypeError, "returned structure"):
  644. // nest.get_traverse_shallow_structure(lambda _: [True], 0)
  645. // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
  646. // nest.get_traverse_shallow_structure(lambda _: 1, [1])
  647. // with self.assertRaisesRegexp(
  648. // TypeError, "didn't return a depth=1 structure of bools"):
  649. // nest.get_traverse_shallow_structure(lambda _: [1], [1])
  650. // def testYieldFlatStringPaths(self):
  651. // for inputs_expected in ({"inputs": [], "expected": []},
  652. // {"inputs": 3, "expected": [()]},
  653. // {"inputs": [3], "expected": [(0,)]},
  654. // {"inputs": {"a": 3}, "expected": [("a",)]},
  655. // {"inputs": {"a": {"b": 4}},
  656. // "expected": [("a", "b")]},
  657. // {"inputs": [{"a": 2}], "expected": [(0, "a")]},
  658. // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
  659. // {"inputs": [{"a": [(23, 42)]}],
  660. // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
  661. // {"inputs": [{"a": ([23], 42)}],
  662. // "expected": [(0, "a", 0, 0), (0, "a", 1)]},
  663. // {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
  664. // "expected": [("a", "a"), ("c", 0, 0, 0)]},
  665. // {"inputs": {"0": [{"1": 23}]},
  666. // "expected": [("0", 0, "1")]}):
  667. // inputs = inputs_expected["inputs"]
  668. // expected = inputs_expected["expected"]
  669. // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
  670. // def testFlattenWithStringPaths(self):
  671. // for inputs_expected in (
  672. // {"inputs": [], "expected": []},
  673. // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
  674. // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
  675. // inputs = inputs_expected["inputs"]
  676. // expected = inputs_expected["expected"]
  677. // self.assertEqual(
  678. // nest.flatten_with_joined_string_paths(inputs, separator="/"),
  679. // expected)
  680. // # Need a separate test for namedtuple as we can't declare tuple definitions
  681. // # in the @parameterized arguments.
  682. // def testFlattenNamedTuple(self):
  683. // # pylint: disable=invalid-name
  684. // Foo = collections.namedtuple("Foo", ["a", "b"])
  685. // Bar = collections.namedtuple("Bar", ["c", "d"])
  686. // # pylint: enable=invalid-name
  687. // test_cases = [
  688. // (Foo(a = 3, b = Bar(c = 23, d = 42)),
  689. // [("a", 3), ("b/c", 23), ("b/d", 42)]),
  690. // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
  691. // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
  692. // (Bar(c = 42, d = 43),
  693. // [("c", 42), ("d", 43)]),
  694. // (Bar(c =[42], d = 43),
  695. // [("c/0", 42), ("d", 43)]),
  696. // ]
  697. // for inputs, expected in test_cases:
  698. // self.assertEqual(
  699. // list(nest.flatten_with_joined_string_paths(inputs)), expected)
  700. // @parameterized.named_parameters(
  701. // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
  702. // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
  703. // {"a": ("a", 4), "b": ("b", 6)}),
  704. // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
  705. // ("nested",
  706. // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
  707. // {"a": [("a/0", 10), ("a/1", 12)],
  708. // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
  709. // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
  710. // def format_sum(path, * values):
  711. // return (path, sum(values))
  712. // result = nest.map_structure_with_paths(format_sum, s1, s2,
  713. // check_types=check_types)
  714. // self.assertEqual(expected, result)
  715. // @parameterized.named_parameters(
  716. // ("tuples", (1, 2), (3, 4, 5), ValueError),
  717. // ("dicts", {"a": 1}, {"b": 2}, ValueError),
  718. // ("mixed", (1, 2), [3, 4], TypeError),
  719. // ("nested",
  720. // {"a": [2, 3], "b": [1, 3]},
  721. // {"b": [5, 6, 7], "a": [8, 9]},
  722. // ValueError
  723. // ))
  724. // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
  725. // with self.assertRaises(error_type):
  726. // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)
  727. //class NestBenchmark(test.Benchmark):
  728. // def run_and_report(self, s1, s2, name):
  729. // burn_iter, test_iter = 100, 30000
  730. // for _ in xrange(burn_iter) :
  731. // nest.assert_same_structure(s1, s2)
  732. // t0 = time.time()
  733. // for _ in xrange(test_iter) :
  734. // nest.assert_same_structure(s1, s2)
  735. // t1 = time.time()
  736. // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
  737. // name=name)
  738. // def benchmark_assert_structure(self):
  739. // s1 = (((1, 2), 3), 4, (5, 6))
  740. // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  741. // self.run_and_report(s1, s2, "assert_same_structure_6_elem")
  742. // s1 = (((1, 2), 3), 4, (5, 6)) * 10
  743. // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
  744. // self.run_and_report(s1, s2, "assert_same_structure_60_elem")
  745. //if __name__ == "__main__":
  746. // test.main()
  747. }
  748. }