You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nest_test.py 37 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for utilities working with arbitrarily nested structures."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import collections
  20. import time
  21. from absl.testing import parameterized
  22. import numpy as np
  23. from six.moves import xrange # pylint: disable=redefined-builtin
  24. from tensorflow.python.framework import constant_op
  25. from tensorflow.python.framework import dtypes
  26. from tensorflow.python.framework import test_util
  27. from tensorflow.python.ops import array_ops
  28. from tensorflow.python.ops import math_ops
  29. from tensorflow.python.platform import test
  30. from tensorflow.python.util import nest
  31. try:
  32. import attr # pylint:disable=g-import-not-at-top
  33. except ImportError:
  34. attr = None
  35. class _CustomMapping(collections.Mapping):
  36. def __init__(self, *args, **kwargs):
  37. self._wrapped = dict(*args, **kwargs)
  38. def __getitem__(self, key):
  39. return self._wrapped[key]
  40. def __iter__(self):
  41. return iter(self._wrapped)
  42. def __len__(self):
  43. return len(self._wrapped)
  44. class NestTest(parameterized.TestCase, test.TestCase):
  45. PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
  46. if attr:
  47. class BadAttr(object):
  48. """Class that has a non-iterable __attrs_attrs__."""
  49. __attrs_attrs__ = None
  50. @attr.s
  51. class SampleAttr(object):
  52. field1 = attr.ib()
  53. field2 = attr.ib()
  54. @test_util.assert_no_new_pyobjects_executing_eagerly
  55. def testAttrsFlattenAndPack(self):
  56. if attr is None:
  57. self.skipTest("attr module is unavailable.")
  58. field_values = [1, 2]
  59. sample_attr = NestTest.SampleAttr(*field_values)
  60. self.assertFalse(nest._is_attrs(field_values))
  61. self.assertTrue(nest._is_attrs(sample_attr))
  62. flat = nest.flatten(sample_attr)
  63. self.assertEqual(field_values, flat)
  64. restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
  65. self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
  66. self.assertEqual(restructured_from_flat, sample_attr)
  67. # Check that flatten fails if attributes are not iterable
  68. with self.assertRaisesRegexp(TypeError, "object is not iterable"):
  69. flat = nest.flatten(NestTest.BadAttr())
  70. @test_util.assert_no_new_pyobjects_executing_eagerly
  71. def testFlattenAndPack(self):
  72. structure = ((3, 4), 5, (6, 7, (9, 10), 8))
  73. flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
  74. self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
  75. self.assertEqual(
  76. nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
  77. ("d", "e", ("f", "g"), "h")))
  78. structure = (NestTest.PointXY(x=4, y=2),
  79. ((NestTest.PointXY(x=1, y=0),),))
  80. flat = [4, 2, 1, 0]
  81. self.assertEqual(nest.flatten(structure), flat)
  82. restructured_from_flat = nest.pack_sequence_as(structure, flat)
  83. self.assertEqual(restructured_from_flat, structure)
  84. self.assertEqual(restructured_from_flat[0].x, 4)
  85. self.assertEqual(restructured_from_flat[0].y, 2)
  86. self.assertEqual(restructured_from_flat[1][0][0].x, 1)
  87. self.assertEqual(restructured_from_flat[1][0][0].y, 0)
  88. self.assertEqual([5], nest.flatten(5))
  89. self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
  90. self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
  91. self.assertEqual(
  92. np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
  93. with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
  94. nest.pack_sequence_as("scalar", [4, 5])
  95. with self.assertRaisesRegexp(TypeError, "flat_sequence"):
  96. nest.pack_sequence_as([4, 5], "bad_sequence")
  97. with self.assertRaises(ValueError):
  98. nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
  99. @parameterized.parameters({"mapping_type": collections.OrderedDict},
  100. {"mapping_type": _CustomMapping})
  101. @test_util.assert_no_new_pyobjects_executing_eagerly
  102. def testFlattenDictOrder(self, mapping_type):
  103. """`flatten` orders dicts by key, including OrderedDicts."""
  104. ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
  105. plain = {"d": 3, "b": 1, "a": 0, "c": 2}
  106. ordered_flat = nest.flatten(ordered)
  107. plain_flat = nest.flatten(plain)
  108. self.assertEqual([0, 1, 2, 3], ordered_flat)
  109. self.assertEqual([0, 1, 2, 3], plain_flat)
  110. @parameterized.parameters({"mapping_type": collections.OrderedDict},
  111. {"mapping_type": _CustomMapping})
  112. def testPackDictOrder(self, mapping_type):
  113. """Packing orders dicts by key, including OrderedDicts."""
  114. custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
  115. plain = {"d": 0, "b": 0, "a": 0, "c": 0}
  116. seq = [0, 1, 2, 3]
  117. custom_reconstruction = nest.pack_sequence_as(custom, seq)
  118. plain_reconstruction = nest.pack_sequence_as(plain, seq)
  119. self.assertIsInstance(custom_reconstruction, mapping_type)
  120. self.assertIsInstance(plain_reconstruction, dict)
  121. self.assertEqual(
  122. mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
  123. custom_reconstruction)
  124. self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
  125. Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
  126. @test_util.assert_no_new_pyobjects_executing_eagerly
  127. def testFlattenAndPack_withDicts(self):
  128. # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
  129. mess = [
  130. "z",
  131. NestTest.Abc(3, 4), {
  132. "d": _CustomMapping({
  133. 41: 4
  134. }),
  135. "c": [
  136. 1,
  137. collections.OrderedDict([
  138. ("b", 3),
  139. ("a", 2),
  140. ]),
  141. ],
  142. "b": 5
  143. }, 17
  144. ]
  145. flattened = nest.flatten(mess)
  146. self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
  147. structure_of_mess = [
  148. 14,
  149. NestTest.Abc("a", True),
  150. {
  151. "d": _CustomMapping({
  152. 41: 42
  153. }),
  154. "c": [
  155. 0,
  156. collections.OrderedDict([
  157. ("b", 9),
  158. ("a", 8),
  159. ]),
  160. ],
  161. "b": 3
  162. },
  163. "hi everybody",
  164. ]
  165. unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
  166. self.assertEqual(unflattened, mess)
  167. # Check also that the OrderedDict was created, with the correct key order.
  168. unflattened_ordered_dict = unflattened[2]["c"][1]
  169. self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
  170. self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
  171. unflattened_custom_mapping = unflattened[2]["d"]
  172. self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
  173. self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
  174. def testFlatten_numpyIsNotFlattened(self):
  175. structure = np.array([1, 2, 3])
  176. flattened = nest.flatten(structure)
  177. self.assertEqual(len(flattened), 1)
  178. def testFlatten_stringIsNotFlattened(self):
  179. structure = "lots of letters"
  180. flattened = nest.flatten(structure)
  181. self.assertEqual(len(flattened), 1)
  182. unflattened = nest.pack_sequence_as("goodbye", flattened)
  183. self.assertEqual(structure, unflattened)
  184. def testPackSequenceAs_notIterableError(self):
  185. with self.assertRaisesRegexp(TypeError,
  186. "flat_sequence must be a sequence"):
  187. nest.pack_sequence_as("hi", "bye")
  188. def testPackSequenceAs_wrongLengthsError(self):
  189. with self.assertRaisesRegexp(
  190. ValueError,
  191. "Structure had 2 elements, but flat_sequence had 3 elements."):
  192. nest.pack_sequence_as(["hello", "world"],
  193. ["and", "goodbye", "again"])
  194. @test_util.assert_no_new_pyobjects_executing_eagerly
  195. def testIsSequence(self):
  196. self.assertFalse(nest.is_sequence("1234"))
  197. self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
  198. self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
  199. self.assertTrue(nest.is_sequence([]))
  200. self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
  201. self.assertFalse(nest.is_sequence(set([1, 2])))
  202. ones = array_ops.ones([2, 3])
  203. self.assertFalse(nest.is_sequence(ones))
  204. self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
  205. self.assertFalse(nest.is_sequence(np.ones((4, 5))))
  206. @parameterized.parameters({"mapping_type": _CustomMapping},
  207. {"mapping_type": dict})
  208. def testFlattenDictItems(self, mapping_type):
  209. dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
  210. flat = {4: "a", 5: "b", 6: "c", 8: "d"}
  211. self.assertEqual(nest.flatten_dict_items(dictionary), flat)
  212. with self.assertRaises(TypeError):
  213. nest.flatten_dict_items(4)
  214. bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
  215. with self.assertRaisesRegexp(ValueError, "not unique"):
  216. nest.flatten_dict_items(bad_dictionary)
  217. another_bad_dictionary = mapping_type({
  218. (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
  219. })
  220. with self.assertRaisesRegexp(
  221. ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
  222. nest.flatten_dict_items(another_bad_dictionary)
  223. # pylint does not correctly recognize these as class names and
  224. # suggests to use variable style under_score naming.
  225. # pylint: disable=invalid-name
  226. Named0ab = collections.namedtuple("named_0", ("a", "b"))
  227. Named1ab = collections.namedtuple("named_1", ("a", "b"))
  228. SameNameab = collections.namedtuple("same_name", ("a", "b"))
  229. SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
  230. SameNamexy = collections.namedtuple("same_name", ("x", "y"))
  231. SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
  232. SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
  233. NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
  234. # pylint: enable=invalid-name
  235. class SameNamedType1(SameNameab):
  236. pass
  237. @test_util.assert_no_new_pyobjects_executing_eagerly
  238. def testAssertSameStructure(self):
  239. structure1 = (((1, 2), 3), 4, (5, 6))
  240. structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  241. structure_different_num_elements = ("spam", "eggs")
  242. structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
  243. nest.assert_same_structure(structure1, structure2)
  244. nest.assert_same_structure("abc", 1.0)
  245. nest.assert_same_structure("abc", np.array([0, 1]))
  246. nest.assert_same_structure("abc", constant_op.constant([0, 1]))
  247. with self.assertRaisesRegexp(
  248. ValueError,
  249. ("The two structures don't have the same nested structure\\.\n\n"
  250. "First structure:.*?\n\n"
  251. "Second structure:.*\n\n"
  252. "More specifically: Substructure "
  253. r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
  254. 'substructure "type=str str=spam" is not\n'
  255. "Entire first structure:\n"
  256. r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
  257. "Entire second structure:\n"
  258. r"\(\., \.\)")):
  259. nest.assert_same_structure(structure1, structure_different_num_elements)
  260. with self.assertRaisesRegexp(
  261. ValueError,
  262. ("The two structures don't have the same nested structure\\.\n\n"
  263. "First structure:.*?\n\n"
  264. "Second structure:.*\n\n"
  265. r'More specifically: Substructure "type=list str=\[0, 1\]" '
  266. r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
  267. "is not")):
  268. nest.assert_same_structure([0, 1], np.array([0, 1]))
  269. with self.assertRaisesRegexp(
  270. ValueError,
  271. ("The two structures don't have the same nested structure\\.\n\n"
  272. "First structure:.*?\n\n"
  273. "Second structure:.*\n\n"
  274. r'More specifically: Substructure "type=list str=\[0, 1\]" '
  275. 'is a sequence, while substructure "type=int str=0" '
  276. "is not")):
  277. nest.assert_same_structure(0, [0, 1])
  278. self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
  279. with self.assertRaisesRegexp(
  280. ValueError,
  281. ("don't have the same nested structure\\.\n\n"
  282. "First structure: .*?\n\nSecond structure: ")):
  283. nest.assert_same_structure(structure1, structure_different_nesting)
  284. self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
  285. NestTest.Named0ab("a", "b"))
  286. nest.assert_same_structure(NestTest.Named0ab(3, 4),
  287. NestTest.Named0ab("a", "b"))
  288. self.assertRaises(TypeError, nest.assert_same_structure,
  289. NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
  290. with self.assertRaisesRegexp(
  291. ValueError,
  292. ("don't have the same nested structure\\.\n\n"
  293. "First structure: .*?\n\nSecond structure: ")):
  294. nest.assert_same_structure(NestTest.Named0ab(3, 4),
  295. NestTest.Named0ab([3], 4))
  296. with self.assertRaisesRegexp(
  297. ValueError,
  298. ("don't have the same nested structure\\.\n\n"
  299. "First structure: .*?\n\nSecond structure: ")):
  300. nest.assert_same_structure([[3], 4], [3, [4]])
  301. structure1_list = [[[1, 2], 3], 4, [5, 6]]
  302. with self.assertRaisesRegexp(TypeError,
  303. "don't have the same sequence type"):
  304. nest.assert_same_structure(structure1, structure1_list)
  305. nest.assert_same_structure(structure1, structure2, check_types=False)
  306. nest.assert_same_structure(structure1, structure1_list, check_types=False)
  307. with self.assertRaisesRegexp(ValueError,
  308. "don't have the same set of keys"):
  309. nest.assert_same_structure({"a": 1}, {"b": 1})
  310. nest.assert_same_structure(NestTest.SameNameab(0, 1),
  311. NestTest.SameNameab2(2, 3))
  312. # This assertion is expected to pass: two namedtuples with the same
  313. # name and field names are considered to be identical.
  314. nest.assert_same_structure(
  315. NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
  316. NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
  317. expected_message = "The two structures don't have the same.*"
  318. with self.assertRaisesRegexp(ValueError, expected_message):
  319. nest.assert_same_structure(
  320. NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
  321. NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
  322. self.assertRaises(TypeError, nest.assert_same_structure,
  323. NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
  324. self.assertRaises(TypeError, nest.assert_same_structure,
  325. NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
  326. self.assertRaises(TypeError, nest.assert_same_structure,
  327. NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
  328. EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
  329. def testHeterogeneousComparison(self):
  330. nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
  331. nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
  332. @test_util.assert_no_new_pyobjects_executing_eagerly
  333. def testMapStructure(self):
  334. structure1 = (((1, 2), 3), 4, (5, 6))
  335. structure2 = (((7, 8), 9), 10, (11, 12))
  336. structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
  337. nest.assert_same_structure(structure1, structure1_plus1)
  338. self.assertAllEqual(
  339. [2, 3, 4, 5, 6, 7],
  340. nest.flatten(structure1_plus1))
  341. structure1_plus_structure2 = nest.map_structure(
  342. lambda x, y: x + y, structure1, structure2)
  343. self.assertEqual(
  344. (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
  345. structure1_plus_structure2)
  346. self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
  347. self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
  348. # Empty structures
  349. self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
  350. self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
  351. self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
  352. self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
  353. NestTest.EmptyNT()))
  354. # This is checking actual equality of types, empty list != empty tuple
  355. self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
  356. with self.assertRaisesRegexp(TypeError, "callable"):
  357. nest.map_structure("bad", structure1_plus1)
  358. with self.assertRaisesRegexp(ValueError, "at least one structure"):
  359. nest.map_structure(lambda x: x)
  360. with self.assertRaisesRegexp(ValueError, "same number of elements"):
  361. nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
  362. with self.assertRaisesRegexp(ValueError, "same nested structure"):
  363. nest.map_structure(lambda x, y: None, 3, (3,))
  364. with self.assertRaisesRegexp(TypeError, "same sequence type"):
  365. nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
  366. with self.assertRaisesRegexp(ValueError, "same nested structure"):
  367. nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
  368. structure1_list = [[[1, 2], 3], 4, [5, 6]]
  369. with self.assertRaisesRegexp(TypeError, "same sequence type"):
  370. nest.map_structure(lambda x, y: None, structure1, structure1_list)
  371. nest.map_structure(lambda x, y: None, structure1, structure1_list,
  372. check_types=False)
  373. with self.assertRaisesRegexp(ValueError, "same nested structure"):
  374. nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
  375. check_types=False)
  376. with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  377. nest.map_structure(lambda x: None, structure1, foo="a")
  378. with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  379. nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
  380. ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
  381. @test_util.assert_no_new_pyobjects_executing_eagerly
  382. def testMapStructureWithStrings(self):
  383. inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
  384. inp_b = NestTest.ABTuple(a=2, b=(1, 3))
  385. out = nest.map_structure(lambda string, repeats: string * repeats,
  386. inp_a,
  387. inp_b)
  388. self.assertEqual("foofoo", out.a)
  389. self.assertEqual("bar", out.b[0])
  390. self.assertEqual("bazbazbaz", out.b[1])
  391. nt = NestTest.ABTuple(a=("something", "something_else"),
  392. b="yet another thing")
  393. rev_nt = nest.map_structure(lambda x: x[::-1], nt)
  394. # Check the output is the correct structure, and all strings are reversed.
  395. nest.assert_same_structure(nt, rev_nt)
  396. self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
  397. self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
  398. self.assertEqual(nt.b[::-1], rev_nt.b)
  399. @test_util.run_deprecated_v1
  400. def testMapStructureOverPlaceholders(self):
  401. inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  402. array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  403. inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  404. array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  405. output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
  406. nest.assert_same_structure(output, inp_a)
  407. self.assertShapeEqual(np.zeros((3, 4)), output[0])
  408. self.assertShapeEqual(np.zeros((3, 7)), output[1])
  409. feed_dict = {
  410. inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
  411. inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
  412. }
  413. with self.cached_session() as sess:
  414. output_np = sess.run(output, feed_dict=feed_dict)
  415. self.assertAllClose(output_np[0],
  416. feed_dict[inp_a][0] + feed_dict[inp_b][0])
  417. self.assertAllClose(output_np[1],
  418. feed_dict[inp_a][1] + feed_dict[inp_b][1])
  419. def testAssertShallowStructure(self):
  420. inp_ab = ["a", "b"]
  421. inp_abc = ["a", "b", "c"]
  422. expected_message = (
  423. "The two structures don't have the same sequence length. Input "
  424. "structure has length 2, while shallow structure has length 3.")
  425. with self.assertRaisesRegexp(ValueError, expected_message):
  426. nest.assert_shallow_structure(inp_abc, inp_ab)
  427. inp_ab1 = [(1, 1), (2, 2)]
  428. inp_ab2 = [[1, 1], [2, 2]]
  429. expected_message = (
  430. "The two structures don't have the same sequence type. Input structure "
  431. "has type <(type|class) 'tuple'>, while shallow structure has type "
  432. "<(type|class) 'list'>.")
  433. with self.assertRaisesRegexp(TypeError, expected_message):
  434. nest.assert_shallow_structure(inp_ab2, inp_ab1)
  435. nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
  436. inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
  437. inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
  438. expected_message = (
  439. r"The two structures don't have the same keys. Input "
  440. r"structure has keys \['c'\], while shallow structure has "
  441. r"keys \['d'\].")
  442. with self.assertRaisesRegexp(ValueError, expected_message):
  443. nest.assert_shallow_structure(inp_ab2, inp_ab1)
  444. inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
  445. inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
  446. nest.assert_shallow_structure(inp_ab, inp_ba)
  447. # This assertion is expected to pass: two namedtuples with the same
  448. # name and field names are considered to be identical.
  449. inp_shallow = NestTest.SameNameab(1, 2)
  450. inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
  451. nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
  452. nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
  453. def testFlattenUpTo(self):
  454. # Shallow tree ends at scalar.
  455. input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  456. shallow_tree = [[True, True], [False, True]]
  457. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  458. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  459. self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
  460. self.assertEqual(flattened_shallow_tree, [True, True, False, True])
  461. # Shallow tree ends at string.
  462. input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
  463. shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
  464. input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  465. input_tree)
  466. input_tree_flattened = nest.flatten(input_tree)
  467. self.assertEqual(input_tree_flattened_as_shallow_tree,
  468. [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
  469. self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
  470. # Make sure dicts are correctly flattened, yielding values, not keys.
  471. input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
  472. shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
  473. input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  474. input_tree)
  475. self.assertEqual(input_tree_flattened_as_shallow_tree,
  476. [1, {"c": 2}, 3, (4, 5)])
  477. # Namedtuples.
  478. ab_tuple = NestTest.ABTuple
  479. input_tree = ab_tuple(a=[0, 1], b=2)
  480. shallow_tree = ab_tuple(a=0, b=1)
  481. input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  482. input_tree)
  483. self.assertEqual(input_tree_flattened_as_shallow_tree,
  484. [[0, 1], 2])
  485. # Nested dicts, OrderedDicts and namedtuples.
  486. input_tree = collections.OrderedDict(
  487. [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
  488. ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
  489. shallow_tree = input_tree
  490. input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  491. input_tree)
  492. self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
  493. shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
  494. input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  495. input_tree)
  496. self.assertEqual(input_tree_flattened_as_shallow_tree,
  497. [ab_tuple(a=[0, {"b": 1}], b=2),
  498. 3,
  499. collections.OrderedDict([("f", 4)])])
  500. shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
  501. input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  502. input_tree)
  503. self.assertEqual(input_tree_flattened_as_shallow_tree,
  504. [ab_tuple(a=[0, {"b": 1}], b=2),
  505. {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
  506. ## Shallow non-list edge-case.
  507. # Using iterable elements.
  508. input_tree = ["input_tree"]
  509. shallow_tree = "shallow_tree"
  510. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  511. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  512. self.assertEqual(flattened_input_tree, [input_tree])
  513. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  514. input_tree = ["input_tree_0", "input_tree_1"]
  515. shallow_tree = "shallow_tree"
  516. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  517. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  518. self.assertEqual(flattened_input_tree, [input_tree])
  519. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  520. # Using non-iterable elements.
  521. input_tree = [0]
  522. shallow_tree = 9
  523. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  524. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  525. self.assertEqual(flattened_input_tree, [input_tree])
  526. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  527. input_tree = [0, 1]
  528. shallow_tree = 9
  529. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  530. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  531. self.assertEqual(flattened_input_tree, [input_tree])
  532. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  533. ## Both non-list edge-case.
  534. # Using iterable elements.
  535. input_tree = "input_tree"
  536. shallow_tree = "shallow_tree"
  537. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  538. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  539. self.assertEqual(flattened_input_tree, [input_tree])
  540. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  541. # Using non-iterable elements.
  542. input_tree = 0
  543. shallow_tree = 0
  544. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  545. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  546. self.assertEqual(flattened_input_tree, [input_tree])
  547. self.assertEqual(flattened_shallow_tree, [shallow_tree])
  548. ## Input non-list edge-case.
  549. # Using iterable elements.
  550. input_tree = "input_tree"
  551. shallow_tree = ["shallow_tree"]
  552. expected_message = ("If shallow structure is a sequence, input must also "
  553. "be a sequence. Input has type: <(type|class) 'str'>.")
  554. with self.assertRaisesRegexp(TypeError, expected_message):
  555. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  556. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  557. self.assertEqual(flattened_shallow_tree, shallow_tree)
  558. input_tree = "input_tree"
  559. shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
  560. with self.assertRaisesRegexp(TypeError, expected_message):
  561. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  562. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  563. self.assertEqual(flattened_shallow_tree, shallow_tree)
  564. # Using non-iterable elements.
  565. input_tree = 0
  566. shallow_tree = [9]
  567. expected_message = ("If shallow structure is a sequence, input must also "
  568. "be a sequence. Input has type: <(type|class) 'int'>.")
  569. with self.assertRaisesRegexp(TypeError, expected_message):
  570. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  571. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  572. self.assertEqual(flattened_shallow_tree, shallow_tree)
  573. input_tree = 0
  574. shallow_tree = [9, 8]
  575. with self.assertRaisesRegexp(TypeError, expected_message):
  576. flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  577. flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  578. self.assertEqual(flattened_shallow_tree, shallow_tree)
  579. def testMapStructureUpTo(self):
  580. # Named tuples.
  581. ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  582. op_tuple = collections.namedtuple("op_tuple", "add, mul")
  583. inp_val = ab_tuple(a=2, b=3)
  584. inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
  585. out = nest.map_structure_up_to(
  586. inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
  587. self.assertEqual(out.a, 6)
  588. self.assertEqual(out.b, 15)
  589. # Lists.
  590. data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
  591. name_list = ["evens", ["odds", "primes"]]
  592. out = nest.map_structure_up_to(
  593. name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
  594. name_list, data_list)
  595. self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
  596. # Dicts.
  597. inp_val = dict(a=2, b=3)
  598. inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
  599. out = nest.map_structure_up_to(
  600. inp_val,
  601. lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  602. self.assertEqual(out["a"], 6)
  603. self.assertEqual(out["b"], 15)
  604. # Non-equal dicts.
  605. inp_val = dict(a=2, b=3)
  606. inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
  607. with self.assertRaisesRegexp(ValueError, "same keys"):
  608. nest.map_structure_up_to(
  609. inp_val,
  610. lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  611. # Dict+custom mapping.
  612. inp_val = dict(a=2, b=3)
  613. inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
  614. out = nest.map_structure_up_to(
  615. inp_val,
  616. lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  617. self.assertEqual(out["a"], 6)
  618. self.assertEqual(out["b"], 15)
  619. # Non-equal dict/mapping.
  620. inp_val = dict(a=2, b=3)
  621. inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
  622. with self.assertRaisesRegexp(ValueError, "same keys"):
  623. nest.map_structure_up_to(
  624. inp_val,
  625. lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  626. def testGetTraverseShallowStructure(self):
  627. scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
  628. scalar_traverse_r = nest.get_traverse_shallow_structure(
  629. lambda s: not isinstance(s, tuple),
  630. scalar_traverse_input)
  631. self.assertEqual(scalar_traverse_r,
  632. [True, True, False, [True, True], {"a": False}, []])
  633. nest.assert_shallow_structure(scalar_traverse_r,
  634. scalar_traverse_input)
  635. structure_traverse_input = [(1, [2]), ([1], 2)]
  636. structure_traverse_r = nest.get_traverse_shallow_structure(
  637. lambda s: (True, False) if isinstance(s, tuple) else True,
  638. structure_traverse_input)
  639. self.assertEqual(structure_traverse_r,
  640. [(True, False), ([True], False)])
  641. nest.assert_shallow_structure(structure_traverse_r,
  642. structure_traverse_input)
  643. with self.assertRaisesRegexp(TypeError, "returned structure"):
  644. nest.get_traverse_shallow_structure(lambda _: [True], 0)
  645. with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
  646. nest.get_traverse_shallow_structure(lambda _: 1, [1])
  647. with self.assertRaisesRegexp(
  648. TypeError, "didn't return a depth=1 structure of bools"):
  649. nest.get_traverse_shallow_structure(lambda _: [1], [1])
  650. def testYieldFlatStringPaths(self):
  651. for inputs_expected in ({"inputs": [], "expected": []},
  652. {"inputs": 3, "expected": [()]},
  653. {"inputs": [3], "expected": [(0,)]},
  654. {"inputs": {"a": 3}, "expected": [("a",)]},
  655. {"inputs": {"a": {"b": 4}},
  656. "expected": [("a", "b")]},
  657. {"inputs": [{"a": 2}], "expected": [(0, "a")]},
  658. {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
  659. {"inputs": [{"a": [(23, 42)]}],
  660. "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
  661. {"inputs": [{"a": ([23], 42)}],
  662. "expected": [(0, "a", 0, 0), (0, "a", 1)]},
  663. {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
  664. "expected": [("a", "a"), ("c", 0, 0, 0)]},
  665. {"inputs": {"0": [{"1": 23}]},
  666. "expected": [("0", 0, "1")]}):
  667. inputs = inputs_expected["inputs"]
  668. expected = inputs_expected["expected"]
  669. self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
  670. def testFlattenWithStringPaths(self):
  671. for inputs_expected in (
  672. {"inputs": [], "expected": []},
  673. {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
  674. {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
  675. inputs = inputs_expected["inputs"]
  676. expected = inputs_expected["expected"]
  677. self.assertEqual(
  678. nest.flatten_with_joined_string_paths(inputs, separator="/"),
  679. expected)
  680. # Need a separate test for namedtuple as we can't declare tuple definitions
  681. # in the @parameterized arguments.
  682. def testFlattenNamedTuple(self):
  683. # pylint: disable=invalid-name
  684. Foo = collections.namedtuple("Foo", ["a", "b"])
  685. Bar = collections.namedtuple("Bar", ["c", "d"])
  686. # pylint: enable=invalid-name
  687. test_cases = [
  688. (Foo(a=3, b=Bar(c=23, d=42)),
  689. [("a", 3), ("b/c", 23), ("b/d", 42)]),
  690. (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
  691. [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
  692. (Bar(c=42, d=43),
  693. [("c", 42), ("d", 43)]),
  694. (Bar(c=[42], d=43),
  695. [("c/0", 42), ("d", 43)]),
  696. ]
  697. for inputs, expected in test_cases:
  698. self.assertEqual(
  699. list(nest.flatten_with_joined_string_paths(inputs)), expected)
  700. @parameterized.named_parameters(
  701. ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
  702. ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
  703. {"a": ("a", 4), "b": ("b", 6)}),
  704. ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
  705. ("nested",
  706. {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
  707. {"a": [("a/0", 10), ("a/1", 12)],
  708. "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
  709. def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
  710. def format_sum(path, *values):
  711. return (path, sum(values))
  712. result = nest.map_structure_with_paths(format_sum, s1, s2,
  713. check_types=check_types)
  714. self.assertEqual(expected, result)
  715. @parameterized.named_parameters(
  716. ("tuples", (1, 2), (3, 4, 5), ValueError),
  717. ("dicts", {"a": 1}, {"b": 2}, ValueError),
  718. ("mixed", (1, 2), [3, 4], TypeError),
  719. ("nested",
  720. {"a": [2, 3], "b": [1, 3]},
  721. {"b": [5, 6, 7], "a": [8, 9]},
  722. ValueError
  723. ))
  724. def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
  725. with self.assertRaises(error_type):
  726. nest.map_structure_with_paths(lambda path, *s: 0, s1, s2)
  727. class NestBenchmark(test.Benchmark):
  728. def run_and_report(self, s1, s2, name):
  729. burn_iter, test_iter = 100, 30000
  730. for _ in xrange(burn_iter):
  731. nest.assert_same_structure(s1, s2)
  732. t0 = time.time()
  733. for _ in xrange(test_iter):
  734. nest.assert_same_structure(s1, s2)
  735. t1 = time.time()
  736. self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
  737. name=name)
  738. def benchmark_assert_structure(self):
  739. s1 = (((1, 2), 3), 4, (5, 6))
  740. s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  741. self.run_and_report(s1, s2, "assert_same_structure_6_elem")
  742. s1 = (((1, 2), 3), 4, (5, 6)) * 10
  743. s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
  744. self.run_and_report(s1, s2, "assert_same_structure_60_elem")
  745. if __name__ == "__main__":
  746. test.main()

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。