You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NestTest.cs 45 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using Microsoft.VisualStudio.TestTools.UnitTesting;
  5. using Newtonsoft.Json.Linq;
  6. using NumSharp;
  7. using Tensorflow;
  8. using Tensorflow.Util;
  9. using static Tensorflow.Binding;
  10. namespace TensorFlowNET.UnitTest.nest_test
  11. {
  12. /// <summary>
  13. /// excerpt of tensorflow/python/framework/util/nest_test.py
  14. /// </summary>
  15. [TestClass]
  16. public class NestTest : PythonTest
  17. {
  18. [TestInitialize]
  19. public void TestInitialize()
  20. {
  21. tf.Graph().as_default();
  22. }
  23. //public class PointXY
  24. //{
  25. // public double x;
  26. // public double y;
  27. //}
  28. // if attr:
  29. // class BadAttr(object):
  30. // """Class that has a non-iterable __attrs_attrs__."""
  31. // __attrs_attrs__ = None
  32. // @attr.s
  33. // class SampleAttr(object):
  34. // field1 = attr.ib()
  35. // field2 = attr.ib()
  36. // @test_util.assert_no_new_pyobjects_executing_eagerly
  37. // def testAttrsFlattenAndPack(self) :
  38. // if attr is None:
  39. // self.skipTest("attr module is unavailable.")
  40. // field_values = [1, 2]
  41. // sample_attr = NestTest.SampleAttr(* field_values)
  42. // self.assertFalse(nest._is_attrs(field_values))
  43. // self.assertTrue(nest._is_attrs(sample_attr))
  44. // flat = nest.flatten(sample_attr)
  45. // self.assertEqual(field_values, flat)
  46. // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
  47. // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
  48. // self.assertEqual(restructured_from_flat, sample_attr)
  49. //# Check that flatten fails if attributes are not iterable
  50. // with self.assertRaisesRegexp(TypeError, "object is not iterable"):
  51. // flat = nest.flatten(NestTest.BadAttr())
  52. [Ignore]
  53. [TestMethod]
  54. public void testFlattenAndPack()
  55. {
  56. object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } };
  57. var flat = new List<object> { "a", "b", "c", "d", "e", "f", "g", "h" };
  58. self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 });
  59. self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
  60. JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString());
  61. structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } };
  62. flat = new List<object> { 4, 2, 1, 0 };
  63. self.assertEqual(nest.flatten(structure), flat);
  64. var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[];
  65. //Console.WriteLine(JArray.FromObject(restructured_from_flat));
  66. self.assertEqual(restructured_from_flat, structure);
  67. self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4);
  68. self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2);
  69. self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1);
  70. self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);
  71. self.assertEqual(new List<object> { 5 }, nest.flatten(5));
  72. var flat1 = nest.flatten(np.array(new[] { 5 }));
  73. self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat1);
  74. self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" }));
  75. self.assertEqual(np.array(new[] { 5 }),
  76. nest.pack_sequence_as("scalar", new List<object> { np.array(new[] { 5 }) }));
  77. Assert.ThrowsException<ValueError>(() => nest.pack_sequence_as("scalar", new List<object>() { 4, 5 }));
  78. Assert.ThrowsException<ValueError>(() =>
  79. nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List<object> { "a", "b", "c" }));
  80. }
  81. // @parameterized.parameters({"mapping_type": collections.OrderedDict
  82. // },
  83. // {"mapping_type": _CustomMapping
  84. //})
  85. // @test_util.assert_no_new_pyobjects_executing_eagerly
  86. // def testFlattenDictOrder(self, mapping_type) :
  87. // """`flatten` orders dicts by key, including OrderedDicts."""
  88. // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
  89. // plain = {"d": 3, "b": 1, "a": 0, "c": 2}
  90. // ordered_flat = nest.flatten(ordered)
  91. // plain_flat = nest.flatten(plain)
  92. // self.assertEqual([0, 1, 2, 3], ordered_flat)
  93. // self.assertEqual([0, 1, 2, 3], plain_flat)
  94. // @parameterized.parameters({"mapping_type": collections.OrderedDict},
  95. // {"mapping_type": _CustomMapping})
  96. // def testPackDictOrder(self, mapping_type):
  97. // """Packing orders dicts by key, including OrderedDicts."""
  98. // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
  99. // plain = {"d": 0, "b": 0, "a": 0, "c": 0}
  100. // seq = [0, 1, 2, 3]
  101. //custom_reconstruction = nest.pack_sequence_as(custom, seq)
  102. //plain_reconstruction = nest.pack_sequence_as(plain, seq)
  103. // self.assertIsInstance(custom_reconstruction, mapping_type)
  104. // self.assertIsInstance(plain_reconstruction, dict)
  105. // self.assertEqual(
  106. // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
  107. // custom_reconstruction)
  108. // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
  109. // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
  110. // @test_util.assert_no_new_pyobjects_executing_eagerly
  111. // def testFlattenAndPack_withDicts(self) :
  112. // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
  113. // mess = [
  114. // "z",
  115. // NestTest.Abc(3, 4), {
  116. // "d": _CustomMapping({
  117. // 41: 4
  118. // }),
  119. // "c": [
  120. // 1,
  121. // collections.OrderedDict([
  122. // ("b", 3),
  123. // ("a", 2),
  124. // ]),
  125. // ],
  126. // "b": 5
  127. // }, 17
  128. // ]
  129. // flattened = nest.flatten(mess)
  130. // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
  131. // structure_of_mess = [
  132. // 14,
  133. // NestTest.Abc("a", True),
  134. // {
  135. // "d": _CustomMapping({
  136. // 41: 42
  137. // }),
  138. // "c": [
  139. // 0,
  140. // collections.OrderedDict([
  141. // ("b", 9),
  142. // ("a", 8),
  143. // ]),
  144. // ],
  145. // "b": 3
  146. // },
  147. // "hi everybody",
  148. // ]
  149. // unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
  150. // self.assertEqual(unflattened, mess)
  151. // # Check also that the OrderedDict was created, with the correct key order.
  152. //unflattened_ordered_dict = unflattened[2]["c"][1]
  153. // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
  154. // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
  155. // unflattened_custom_mapping = unflattened[2]["d"]
  156. // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
  157. // self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
  158. [TestMethod]
  159. public void testFlatten_numpyIsNotFlattened()
  160. {
  161. var structure = np.array(1, 2, 3);
  162. var flattened = nest.flatten(structure);
  163. self.assertEqual(len(flattened), 1);
  164. }
  165. [TestMethod]
  166. public void testFlatten_stringIsNotFlattened()
  167. {
  168. var structure = "lots of letters";
  169. var flattened = nest.flatten(structure);
  170. self.assertEqual(len(flattened), 1);
  171. var unflattened = nest.pack_sequence_as("goodbye", flattened);
  172. self.assertEqual(structure, unflattened);
  173. }
  174. // def testPackSequenceAs_notIterableError(self) :
  175. // with self.assertRaisesRegexp(TypeError,
  176. // "flat_sequence must be a sequence"):
  177. // nest.pack_sequence_as("hi", "bye")
  178. [TestMethod]
  179. public void testPackSequenceAs_wrongLengthsError()
  180. {
  181. Assert.ThrowsException<ValueError>(() =>
  182. {
  183. // with self.assertRaisesRegexp(
  184. // ValueError,
  185. // "Structure had 2 elements, but flat_sequence had 3 elements."):
  186. nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" });
  187. });
  188. }
  189. [Ignore]
  190. [TestMethod]
  191. public void testIsSequence()
  192. {
  193. self.assertFalse(nest.is_sequence("1234"));
  194. self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } }));
  195. // TODO: ValueTuple<T,T>
  196. //self.assertTrue(nest.is_sequence(((7, 8), (5, 6))));
  197. self.assertTrue(nest.is_sequence(new object[] { }));
  198. self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 }));
  199. self.assertFalse(nest.is_sequence(new HashSet<int> { 1, 2 }));
  200. var ones = array_ops.ones(new int[] { 2, 3 });
  201. self.assertFalse(nest.is_sequence(ones));
  202. self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones)));
  203. self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 })));
  204. }
  205. // @parameterized.parameters({"mapping_type": _CustomMapping},
  206. // {"mapping_type": dict})
  207. // def testFlattenDictItems(self, mapping_type):
  208. // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
  209. // flat = {4: "a", 5: "b", 6: "c", 8: "d"}
  210. // self.assertEqual(nest.flatten_dict_items(dictionary), flat)
  211. // with self.assertRaises(TypeError):
  212. // nest.flatten_dict_items(4)
  213. // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
  214. // with self.assertRaisesRegexp(ValueError, "not unique"):
  215. // nest.flatten_dict_items(bad_dictionary)
  216. // another_bad_dictionary = mapping_type({
  217. // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
  218. // })
  219. // with self.assertRaisesRegexp(
  220. // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
  221. // nest.flatten_dict_items(another_bad_dictionary)
  222. //# pylint does not correctly recognize these as class names and
  223. //# suggests to use variable style under_score naming.
  224. //# pylint: disable=invalid-name
  225. // Named0ab = collections.namedtuple("named_0", ("a", "b"))
  226. // Named1ab = collections.namedtuple("named_1", ("a", "b"))
  227. // SameNameab = collections.namedtuple("same_name", ("a", "b"))
  228. // SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
  229. // SameNamexy = collections.namedtuple("same_name", ("x", "y"))
  230. // SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
  231. // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
  232. // NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
  233. // # pylint: enable=invalid-name
  234. // class SameNamedType1(SameNameab):
  235. // pass
  236. // @test_util.assert_no_new_pyobjects_executing_eagerly
  237. // def testAssertSameStructure(self):
  238. // structure1 = (((1, 2), 3), 4, (5, 6))
  239. // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  240. // structure_different_num_elements = ("spam", "eggs")
  241. // structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
  242. // nest.assert_same_structure(structure1, structure2)
  243. // nest.assert_same_structure("abc", 1.0)
  244. // nest.assert_same_structure("abc", np.array([0, 1]))
  245. // nest.assert_same_structure("abc", constant_op.constant([0, 1]))
  246. // with self.assertRaisesRegexp(
  247. // ValueError,
  248. // ("The two structures don't have the same nested structure\\.\n\n"
  249. // "First structure:.*?\n\n"
  250. // "Second structure:.*\n\n"
  251. // "More specifically: Substructure "
  252. // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
  253. // 'substructure "type=str str=spam" is not\n'
  254. // "Entire first structure:\n"
  255. // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
  256. // "Entire second structure:\n"
  257. // r"\(\., \.\)")):
  258. // nest.assert_same_structure(structure1, structure_different_num_elements)
  259. // with self.assertRaisesRegexp(
  260. // ValueError,
  261. // ("The two structures don't have the same nested structure\\.\n\n"
  262. // "First structure:.*?\n\n"
  263. // "Second structure:.*\n\n"
  264. // r'More specifically: Substructure "type=list str=\[0, 1\]" '
  265. // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
  266. // "is not")):
  267. // nest.assert_same_structure([0, 1], np.array([0, 1]))
  268. // with self.assertRaisesRegexp(
  269. // ValueError,
  270. // ("The two structures don't have the same nested structure\\.\n\n"
  271. // "First structure:.*?\n\n"
  272. // "Second structure:.*\n\n"
  273. // r'More specifically: Substructure "type=list str=\[0, 1\]" '
  274. // 'is a sequence, while substructure "type=int str=0" '
  275. // "is not")):
  276. // nest.assert_same_structure(0, [0, 1])
  277. // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
  278. // with self.assertRaisesRegexp(
  279. // ValueError,
  280. // ("don't have the same nested structure\\.\n\n"
  281. // "First structure: .*?\n\nSecond structure: ")):
  282. // nest.assert_same_structure(structure1, structure_different_nesting)
  283. // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
  284. // NestTest.Named0ab("a", "b"))
  285. // nest.assert_same_structure(NestTest.Named0ab(3, 4),
  286. // NestTest.Named0ab("a", "b"))
  287. // self.assertRaises(TypeError, nest.assert_same_structure,
  288. // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
  289. // with self.assertRaisesRegexp(
  290. // ValueError,
  291. // ("don't have the same nested structure\\.\n\n"
  292. // "First structure: .*?\n\nSecond structure: ")):
  293. // nest.assert_same_structure(NestTest.Named0ab(3, 4),
  294. // NestTest.Named0ab([3], 4))
  295. // with self.assertRaisesRegexp(
  296. // ValueError,
  297. // ("don't have the same nested structure\\.\n\n"
  298. // "First structure: .*?\n\nSecond structure: ")):
  299. // nest.assert_same_structure([[3], 4], [3, [4]])
  300. // structure1_list = [[[1, 2], 3], 4, [5, 6]]
  301. // with self.assertRaisesRegexp(TypeError,
  302. // "don't have the same sequence type"):
  303. // nest.assert_same_structure(structure1, structure1_list)
  304. // nest.assert_same_structure(structure1, structure2, check_types= False)
  305. // nest.assert_same_structure(structure1, structure1_list, check_types=False)
  306. // with self.assertRaisesRegexp(ValueError,
  307. // "don't have the same set of keys"):
  308. // nest.assert_same_structure({"a": 1}, {"b": 1})
  309. // nest.assert_same_structure(NestTest.SameNameab(0, 1),
  310. // NestTest.SameNameab2(2, 3))
  311. // # This assertion is expected to pass: two namedtuples with the same
  312. // # name and field names are considered to be identical.
  313. // nest.assert_same_structure(
  314. // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
  315. // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
  316. // expected_message = "The two structures don't have the same.*"
  317. // with self.assertRaisesRegexp(ValueError, expected_message):
  318. // nest.assert_same_structure(
  319. // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
  320. // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
  321. // self.assertRaises(TypeError, nest.assert_same_structure,
  322. // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
  323. // self.assertRaises(TypeError, nest.assert_same_structure,
  324. // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
  325. // self.assertRaises(TypeError, nest.assert_same_structure,
  326. // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
  327. // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
  328. // def testHeterogeneousComparison(self):
  329. // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
  330. // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
  331. [Ignore]
  332. [TestMethod]
  333. public void testMapStructure()
  334. {
  335. var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } };
  336. var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } };
  337. var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1);
  338. var structure1_strings = nest.map_structure(x => $"{x}", structure1);
  339. var s = JArray.FromObject(structure1_plus1).ToString();
  340. Console.WriteLine(s);
  341. // nest.assert_same_structure(structure1, structure1_plus1)
  342. self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
  343. self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
  344. var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2);
  345. self.assertEqual(
  346. new object[] { new object[] { new object[] { 1 + 7, 2 + 8}, 3 + 9}, 4 + 10, new object[] { 5 + 11, 6 + 12}},
  347. structure1_plus_structure2);
  348. // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
  349. // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
  350. // # Empty structures
  351. // self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
  352. // self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
  353. // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
  354. // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
  355. // NestTest.EmptyNT()))
  356. // # This is checking actual equality of types, empty list != empty tuple
  357. // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
  358. // with self.assertRaisesRegexp(TypeError, "callable"):
  359. // nest.map_structure("bad", structure1_plus1)
  360. // with self.assertRaisesRegexp(ValueError, "at least one structure"):
  361. // nest.map_structure(lambda x: x)
  362. // with self.assertRaisesRegexp(ValueError, "same number of elements"):
  363. // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
  364. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  365. // nest.map_structure(lambda x, y: None, 3, (3,))
  366. // with self.assertRaisesRegexp(TypeError, "same sequence type"):
  367. // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
  368. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  369. // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
  370. // structure1_list = [[[1, 2], 3], 4, [5, 6]]
  371. // with self.assertRaisesRegexp(TypeError, "same sequence type"):
  372. // nest.map_structure(lambda x, y: None, structure1, structure1_list)
  373. // nest.map_structure(lambda x, y: None, structure1, structure1_list,
  374. // check_types=False)
  375. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  376. // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
  377. // check_types=False)
  378. // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  379. // nest.map_structure(lambda x: None, structure1, foo="a")
  380. // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  381. // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
  382. // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
  383. }
  384. // @test_util.assert_no_new_pyobjects_executing_eagerly
  385. // def testMapStructureWithStrings(self) :
  386. // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
  387. // inp_b = NestTest.ABTuple(a=2, b=(1, 3))
  388. // out = nest.map_structure(lambda string, repeats: string* repeats,
  389. // inp_a,
  390. // inp_b)
  391. // self.assertEqual("foofoo", out.a)
  392. // self.assertEqual("bar", out.b[0])
  393. // self.assertEqual("bazbazbaz", out.b[1])
  394. // nt = NestTest.ABTuple(a=("something", "something_else"),
  395. // b="yet another thing")
  396. // rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
  397. // # Check the output is the correct structure, and all strings are reversed.
  398. // nest.assert_same_structure(nt, rev_nt)
  399. // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
  400. // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
  401. // self.assertEqual(nt.b[::- 1], rev_nt.b)
  402. // @test_util.run_deprecated_v1
  403. // def testMapStructureOverPlaceholders(self) :
  404. // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  405. // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  406. // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  407. // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  408. // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
  409. // nest.assert_same_structure(output, inp_a)
  410. // self.assertShapeEqual(np.zeros((3, 4)), output[0])
  411. // self.assertShapeEqual(np.zeros((3, 7)), output[1])
  412. // feed_dict = {
  413. // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
  414. // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
  415. // }
  416. // with self.cached_session() as sess:
  417. // output_np = sess.run(output, feed_dict=feed_dict)
  418. // self.assertAllClose(output_np[0],
  419. // feed_dict[inp_a][0] + feed_dict[inp_b][0])
  420. // self.assertAllClose(output_np[1],
  421. // feed_dict[inp_a][1] + feed_dict[inp_b][1])
  422. // def testAssertShallowStructure(self):
  423. // inp_ab = ["a", "b"]
  424. //inp_abc = ["a", "b", "c"]
  425. //expected_message = (
  426. // "The two structures don't have the same sequence length. Input "
  427. // "structure has length 2, while shallow structure has length 3.")
  428. // with self.assertRaisesRegexp(ValueError, expected_message):
  429. // nest.assert_shallow_structure(inp_abc, inp_ab)
  430. // inp_ab1 = [(1, 1), (2, 2)]
  431. // inp_ab2 = [[1, 1], [2, 2]]
  432. // expected_message = (
  433. // "The two structures don't have the same sequence type. Input structure "
  434. // "has type <(type|class) 'tuple'>, while shallow structure has type "
  435. // "<(type|class) 'list'>.")
  436. // with self.assertRaisesRegexp(TypeError, expected_message):
  437. // nest.assert_shallow_structure(inp_ab2, inp_ab1)
  438. // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)
  439. // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
  440. // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
  441. // expected_message = (
  442. // r"The two structures don't have the same keys. Input "
  443. // r"structure has keys \['c'\], while shallow structure has "
  444. // r"keys \['d'\].")
  445. // with self.assertRaisesRegexp(ValueError, expected_message):
  446. // nest.assert_shallow_structure(inp_ab2, inp_ab1)
  447. // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
  448. // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
  449. // nest.assert_shallow_structure(inp_ab, inp_ba)
  450. // # This assertion is expected to pass: two namedtuples with the same
  451. //# name and field names are considered to be identical.
  452. //inp_shallow = NestTest.SameNameab(1, 2)
  453. // inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
  454. // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
  455. // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
  456. // def testFlattenUpTo(self):
  457. // # Shallow tree ends at scalar.
  458. // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  459. // shallow_tree = [[True, True], [False, True]]
  460. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  461. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  462. // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
  463. // self.assertEqual(flattened_shallow_tree, [True, True, False, True])
  464. //# Shallow tree ends at string.
  465. // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
  466. // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
  467. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  468. // input_tree)
  469. // input_tree_flattened = nest.flatten(input_tree)
  470. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  471. // [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
  472. // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
  473. // # Make sure dicts are correctly flattened, yielding values, not keys.
  474. //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
  475. // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
  476. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  477. // input_tree)
  478. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  479. // [1, { "c": 2}, 3, (4, 5)])
  480. // # Namedtuples.
  481. // ab_tuple = NestTest.ABTuple
  482. // input_tree = ab_tuple(a =[0, 1], b = 2)
  483. // shallow_tree = ab_tuple(a= 0, b= 1)
  484. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  485. // input_tree)
  486. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  487. // [[0, 1], 2])
  488. // # Nested dicts, OrderedDicts and namedtuples.
  489. // input_tree = collections.OrderedDict(
  490. // [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
  491. // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
  492. // shallow_tree = input_tree
  493. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  494. // input_tree)
  495. // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
  496. // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
  497. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  498. // input_tree)
  499. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  500. // [ab_tuple(a =[0, { "b": 1}], b=2),
  501. // 3,
  502. // collections.OrderedDict([("f", 4)])])
  503. // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
  504. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  505. // input_tree)
  506. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  507. // [ab_tuple(a =[0, {"b": 1}], b=2),
  508. // {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
  509. // ## Shallow non-list edge-case.
  510. // # Using iterable elements.
  511. // input_tree = ["input_tree"]
  512. //shallow_tree = "shallow_tree"
  513. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  514. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  515. // self.assertEqual(flattened_input_tree, [input_tree])
  516. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  517. // input_tree = ["input_tree_0", "input_tree_1"]
  518. //shallow_tree = "shallow_tree"
  519. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  520. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  521. // self.assertEqual(flattened_input_tree, [input_tree])
  522. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  523. // # Using non-iterable elements.
  524. //input_tree = [0]
  525. //shallow_tree = 9
  526. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  527. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  528. // self.assertEqual(flattened_input_tree, [input_tree])
  529. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  530. // input_tree = [0, 1]
  531. //shallow_tree = 9
  532. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  533. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  534. // self.assertEqual(flattened_input_tree, [input_tree])
  535. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  536. // ## Both non-list edge-case.
  537. //# Using iterable elements.
  538. //input_tree = "input_tree"
  539. // shallow_tree = "shallow_tree"
  540. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  541. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  542. // self.assertEqual(flattened_input_tree, [input_tree])
  543. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  544. // # Using non-iterable elements.
  545. //input_tree = 0
  546. // shallow_tree = 0
  547. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  548. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  549. // self.assertEqual(flattened_input_tree, [input_tree])
  550. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  551. // ## Input non-list edge-case.
  552. //# Using iterable elements.
  553. //input_tree = "input_tree"
  554. // shallow_tree = ["shallow_tree"]
  555. //expected_message = ("If shallow structure is a sequence, input must also "
  556. // "be a sequence. Input has type: <(type|class) 'str'>.")
  557. // with self.assertRaisesRegexp(TypeError, expected_message):
  558. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  559. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  560. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  561. // input_tree = "input_tree"
  562. // shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
  563. //with self.assertRaisesRegexp(TypeError, expected_message):
  564. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  565. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  566. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  567. //# Using non-iterable elements.
  568. // input_tree = 0
  569. // shallow_tree = [9]
  570. //expected_message = ("If shallow structure is a sequence, input must also "
  571. // "be a sequence. Input has type: <(type|class) 'int'>.")
  572. // with self.assertRaisesRegexp(TypeError, expected_message):
  573. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  574. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  575. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  576. // input_tree = 0
  577. // shallow_tree = [9, 8]
  578. //with self.assertRaisesRegexp(TypeError, expected_message):
  579. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  580. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  581. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  582. // def testMapStructureUpTo(self) :
  583. // # Named tuples.
  584. // ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  585. // op_tuple = collections.namedtuple("op_tuple", "add, mul")
  586. // inp_val = ab_tuple(a= 2, b= 3)
  587. // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
  588. // out = nest.map_structure_up_to(
  589. // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
  590. // self.assertEqual(out.a, 6)
  591. // self.assertEqual(out.b, 15)
  592. // # Lists.
  593. // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
  594. // name_list = ["evens", ["odds", "primes"]]
  595. // out = nest.map_structure_up_to(
  596. // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
  597. // name_list, data_list)
  598. // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
  599. // # Dicts.
  600. // inp_val = dict(a= 2, b= 3)
  601. // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
  602. // out = nest.map_structure_up_to(
  603. // inp_val,
  604. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  605. // self.assertEqual(out["a"], 6)
  606. // self.assertEqual(out["b"], 15)
  607. // # Non-equal dicts.
  608. // inp_val = dict(a= 2, b= 3)
  609. // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
  610. // with self.assertRaisesRegexp(ValueError, "same keys"):
  611. // nest.map_structure_up_to(
  612. // inp_val,
  613. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  614. // # Dict+custom mapping.
  615. // inp_val = dict(a= 2, b= 3)
  616. // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
  617. // out = nest.map_structure_up_to(
  618. // inp_val,
  619. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  620. // self.assertEqual(out["a"], 6)
  621. // self.assertEqual(out["b"], 15)
  622. // # Non-equal dict/mapping.
  623. // inp_val = dict(a= 2, b= 3)
  624. // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
  625. // with self.assertRaisesRegexp(ValueError, "same keys"):
  626. // nest.map_structure_up_to(
  627. // inp_val,
  628. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  629. // def testGetTraverseShallowStructure(self):
  630. // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
  631. // scalar_traverse_r = nest.get_traverse_shallow_structure(
  632. // lambda s: not isinstance(s, tuple),
  633. // scalar_traverse_input)
  634. // self.assertEqual(scalar_traverse_r,
  635. // [True, True, False, [True, True], {"a": False}, []])
  636. // nest.assert_shallow_structure(scalar_traverse_r,
  637. // scalar_traverse_input)
  638. // structure_traverse_input = [(1, [2]), ([1], 2)]
  639. // structure_traverse_r = nest.get_traverse_shallow_structure(
  640. // lambda s: (True, False) if isinstance(s, tuple) else True,
  641. // structure_traverse_input)
  642. // self.assertEqual(structure_traverse_r,
  643. // [(True, False), ([True], False)])
  644. // nest.assert_shallow_structure(structure_traverse_r,
  645. // structure_traverse_input)
  646. // with self.assertRaisesRegexp(TypeError, "returned structure"):
  647. // nest.get_traverse_shallow_structure(lambda _: [True], 0)
  648. // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
  649. // nest.get_traverse_shallow_structure(lambda _: 1, [1])
  650. // with self.assertRaisesRegexp(
  651. // TypeError, "didn't return a depth=1 structure of bools"):
  652. // nest.get_traverse_shallow_structure(lambda _: [1], [1])
  653. // def testYieldFlatStringPaths(self):
  654. // for inputs_expected in ({"inputs": [], "expected": []},
  655. // {"inputs": 3, "expected": [()]},
  656. // {"inputs": [3], "expected": [(0,)]},
  657. // {"inputs": {"a": 3}, "expected": [("a",)]},
  658. // {"inputs": {"a": {"b": 4}},
  659. // "expected": [("a", "b")]},
  660. // {"inputs": [{"a": 2}], "expected": [(0, "a")]},
  661. // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
  662. // {"inputs": [{"a": [(23, 42)]}],
  663. // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
  664. // {"inputs": [{"a": ([23], 42)}],
  665. // "expected": [(0, "a", 0, 0), (0, "a", 1)]},
  666. // {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
  667. // "expected": [("a", "a"), ("c", 0, 0, 0)]},
  668. // {"inputs": {"0": [{"1": 23}]},
  669. // "expected": [("0", 0, "1")]}):
  670. // inputs = inputs_expected["inputs"]
  671. // expected = inputs_expected["expected"]
  672. // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
  673. // def testFlattenWithStringPaths(self):
  674. // for inputs_expected in (
  675. // {"inputs": [], "expected": []},
  676. // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
  677. // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
  678. // inputs = inputs_expected["inputs"]
  679. // expected = inputs_expected["expected"]
  680. // self.assertEqual(
  681. // nest.flatten_with_joined_string_paths(inputs, separator="/"),
  682. // expected)
  683. // # Need a separate test for namedtuple as we can't declare tuple definitions
  684. // # in the @parameterized arguments.
  685. // def testFlattenNamedTuple(self):
  686. // # pylint: disable=invalid-name
  687. // Foo = collections.namedtuple("Foo", ["a", "b"])
  688. // Bar = collections.namedtuple("Bar", ["c", "d"])
  689. // # pylint: enable=invalid-name
  690. // test_cases = [
  691. // (Foo(a = 3, b = Bar(c = 23, d = 42)),
  692. // [("a", 3), ("b/c", 23), ("b/d", 42)]),
  693. // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
  694. // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
  695. // (Bar(c = 42, d = 43),
  696. // [("c", 42), ("d", 43)]),
  697. // (Bar(c =[42], d = 43),
  698. // [("c/0", 42), ("d", 43)]),
  699. // ]
  700. // for inputs, expected in test_cases:
  701. // self.assertEqual(
  702. // list(nest.flatten_with_joined_string_paths(inputs)), expected)
  703. // @parameterized.named_parameters(
  704. // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
  705. // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
  706. // {"a": ("a", 4), "b": ("b", 6)}),
  707. // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
  708. // ("nested",
  709. // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
  710. // {"a": [("a/0", 10), ("a/1", 12)],
  711. // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
  712. // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
  713. // def format_sum(path, * values):
  714. // return (path, sum(values))
  715. // result = nest.map_structure_with_paths(format_sum, s1, s2,
  716. // check_types=check_types)
  717. // self.assertEqual(expected, result)
  718. // @parameterized.named_parameters(
  719. // ("tuples", (1, 2), (3, 4, 5), ValueError),
  720. // ("dicts", {"a": 1}, {"b": 2}, ValueError),
  721. // ("mixed", (1, 2), [3, 4], TypeError),
  722. // ("nested",
  723. // {"a": [2, 3], "b": [1, 3]},
  724. // {"b": [5, 6, 7], "a": [8, 9]},
  725. // ValueError
  726. // ))
  727. // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
  728. // with self.assertRaises(error_type):
  729. // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)
  730. //class NestBenchmark(test.Benchmark):
  731. // def run_and_report(self, s1, s2, name):
  732. // burn_iter, test_iter = 100, 30000
  733. // for _ in xrange(burn_iter) :
  734. // nest.assert_same_structure(s1, s2)
  735. // t0 = time.time()
  736. // for _ in xrange(test_iter) :
  737. // nest.assert_same_structure(s1, s2)
  738. // t1 = time.time()
  739. // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
  740. // name=name)
  741. // def benchmark_assert_structure(self):
  742. // s1 = (((1, 2), 3), 4, (5, 6))
  743. // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  744. // self.run_and_report(s1, s2, "assert_same_structure_6_elem")
  745. // s1 = (((1, 2), 3), 4, (5, 6)) * 10
  746. // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
  747. // self.run_and_report(s1, s2, "assert_same_structure_60_elem")
  748. //if __name__ == "__main__":
  749. // test.main()
  750. }
  751. }