You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

NestTest.cs 45 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. using System.Collections;
  2. using System.Collections.Generic;
  3. using Colorful;
  4. using Microsoft.VisualStudio.TestTools.UnitTesting;
  5. using Newtonsoft.Json.Linq;
  6. using NumSharp;
  7. using Tensorflow;
  8. using Tensorflow.Util;
  9. namespace TensorFlowNET.UnitTest.nest_test
  10. {
  11. /// <summary>
  12. /// excerpt of tensorflow/python/framework/util/nest_test.py
  13. /// </summary>
  14. [TestClass]
  15. public class NestTest : PythonTest
  16. {
  17. //public class PointXY
  18. //{
  19. // public double x;
  20. // public double y;
  21. //}
  22. // if attr:
  23. // class BadAttr(object):
  24. // """Class that has a non-iterable __attrs_attrs__."""
  25. // __attrs_attrs__ = None
  26. // @attr.s
  27. // class SampleAttr(object):
  28. // field1 = attr.ib()
  29. // field2 = attr.ib()
  30. // @test_util.assert_no_new_pyobjects_executing_eagerly
  31. // def testAttrsFlattenAndPack(self) :
  32. // if attr is None:
  33. // self.skipTest("attr module is unavailable.")
  34. // field_values = [1, 2]
  35. // sample_attr = NestTest.SampleAttr(* field_values)
  36. // self.assertFalse(nest._is_attrs(field_values))
  37. // self.assertTrue(nest._is_attrs(sample_attr))
  38. // flat = nest.flatten(sample_attr)
  39. // self.assertEqual(field_values, flat)
  40. // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
  41. // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
  42. // self.assertEqual(restructured_from_flat, sample_attr)
  43. //# Check that flatten fails if attributes are not iterable
  44. // with self.assertRaisesRegexp(TypeError, "object is not iterable"):
  45. // flat = nest.flatten(NestTest.BadAttr())
  46. [TestMethod]
  47. public void testFlattenAndPack()
  48. {
  49. object structure = new object[] { new object[] { 3, 4 }, 5, new object[] { 6, 7, new object[] { 9, 10 }, 8 } };
  50. var flat = new List<object> { "a", "b", "c", "d", "e", "f", "g", "h" };
  51. self.assertEqual(nest.flatten(structure), new[] { 3, 4, 5, 6, 7, 9, 10, 8 });
  52. self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
  53. JArray.FromObject(new object[] { new object[] { "a", "b" }, "c", new object[] { "d", "e", new object[] { "f", "g" }, "h" } }).ToString());
  54. structure = new object[] { new Hashtable { ["x"] = 4, ["y"] = 2 }, new object[] { new object[] { new Hashtable { ["x"] = 1, ["y"] = 0 }, }, } };
  55. flat = new List<object> { 4, 2, 1, 0 };
  56. self.assertEqual(nest.flatten(structure), flat);
  57. var restructured_from_flat = nest.pack_sequence_as(structure, flat) as object[];
  58. //Console.WriteLine(JArray.FromObject(restructured_from_flat));
  59. self.assertEqual(restructured_from_flat, structure);
  60. self.assertEqual((restructured_from_flat[0] as Hashtable)["x"], 4);
  61. self.assertEqual((restructured_from_flat[0] as Hashtable)["y"], 2);
  62. self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["x"], 1);
  63. self.assertEqual((((restructured_from_flat[1] as object[])[0] as object[])[0] as Hashtable)["y"], 0);
  64. self.assertEqual(new List<object> { 5 }, nest.flatten(5));
  65. flat = nest.flatten(np.array(new[] { 5 }));
  66. self.assertEqual(new object[] { np.array(new int[] { 5 }) }, flat);
  67. self.assertEqual("a", nest.pack_sequence_as(5, new List<object> { "a" }));
  68. self.assertEqual(np.array(new[] { 5 }),
  69. nest.pack_sequence_as("scalar", new List<object> { np.array(new[] { 5 }) }));
  70. Assert.ThrowsException<ValueError>(() => nest.pack_sequence_as("scalar", new List<object>() { 4, 5 }));
  71. Assert.ThrowsException<ValueError>(() =>
  72. nest.pack_sequence_as(new object[] { 5, 6, new object[] { 7, 8 } }, new List<object> { "a", "b", "c" }));
  73. }
  74. // @parameterized.parameters({"mapping_type": collections.OrderedDict
  75. // },
  76. // {"mapping_type": _CustomMapping
  77. //})
  78. // @test_util.assert_no_new_pyobjects_executing_eagerly
  79. // def testFlattenDictOrder(self, mapping_type) :
  80. // """`flatten` orders dicts by key, including OrderedDicts."""
  81. // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
  82. // plain = {"d": 3, "b": 1, "a": 0, "c": 2}
  83. // ordered_flat = nest.flatten(ordered)
  84. // plain_flat = nest.flatten(plain)
  85. // self.assertEqual([0, 1, 2, 3], ordered_flat)
  86. // self.assertEqual([0, 1, 2, 3], plain_flat)
  87. // @parameterized.parameters({"mapping_type": collections.OrderedDict},
  88. // {"mapping_type": _CustomMapping})
  89. // def testPackDictOrder(self, mapping_type):
  90. // """Packing orders dicts by key, including OrderedDicts."""
  91. // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
  92. // plain = {"d": 0, "b": 0, "a": 0, "c": 0}
  93. // seq = [0, 1, 2, 3]
  94. //custom_reconstruction = nest.pack_sequence_as(custom, seq)
  95. //plain_reconstruction = nest.pack_sequence_as(plain, seq)
  96. // self.assertIsInstance(custom_reconstruction, mapping_type)
  97. // self.assertIsInstance(plain_reconstruction, dict)
  98. // self.assertEqual(
  99. // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
  100. // custom_reconstruction)
  101. // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
  102. // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
  103. // @test_util.assert_no_new_pyobjects_executing_eagerly
  104. // def testFlattenAndPack_withDicts(self) :
  105. // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
  106. // mess = [
  107. // "z",
  108. // NestTest.Abc(3, 4), {
  109. // "d": _CustomMapping({
  110. // 41: 4
  111. // }),
  112. // "c": [
  113. // 1,
  114. // collections.OrderedDict([
  115. // ("b", 3),
  116. // ("a", 2),
  117. // ]),
  118. // ],
  119. // "b": 5
  120. // }, 17
  121. // ]
  122. // flattened = nest.flatten(mess)
  123. // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
  124. // structure_of_mess = [
  125. // 14,
  126. // NestTest.Abc("a", True),
  127. // {
  128. // "d": _CustomMapping({
  129. // 41: 42
  130. // }),
  131. // "c": [
  132. // 0,
  133. // collections.OrderedDict([
  134. // ("b", 9),
  135. // ("a", 8),
  136. // ]),
  137. // ],
  138. // "b": 3
  139. // },
  140. // "hi everybody",
  141. // ]
  142. // unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
  143. // self.assertEqual(unflattened, mess)
  144. // # Check also that the OrderedDict was created, with the correct key order.
  145. //unflattened_ordered_dict = unflattened[2]["c"][1]
  146. // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
  147. // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
  148. // unflattened_custom_mapping = unflattened[2]["d"]
  149. // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
  150. // self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
  151. [TestMethod]
  152. public void testFlatten_numpyIsNotFlattened()
  153. {
  154. var structure = np.array(1, 2, 3);
  155. var flattened = nest.flatten(structure);
  156. self.assertEqual(len(flattened), 1);
  157. }
  158. [TestMethod]
  159. public void testFlatten_stringIsNotFlattened()
  160. {
  161. var structure = "lots of letters";
  162. var flattened = nest.flatten(structure);
  163. self.assertEqual(len(flattened), 1);
  164. var unflattened = nest.pack_sequence_as("goodbye", flattened);
  165. self.assertEqual(structure, unflattened);
  166. }
  167. // def testPackSequenceAs_notIterableError(self) :
  168. // with self.assertRaisesRegexp(TypeError,
  169. // "flat_sequence must be a sequence"):
  170. // nest.pack_sequence_as("hi", "bye")
  171. [TestMethod]
  172. public void testPackSequenceAs_wrongLengthsError()
  173. {
  174. Assert.ThrowsException<ValueError>(() =>
  175. {
  176. // with self.assertRaisesRegexp(
  177. // ValueError,
  178. // "Structure had 2 elements, but flat_sequence had 3 elements."):
  179. nest.pack_sequence_as(new object[] { "hello", "world" }, new object[] { "and", "goodbye", "again" });
  180. });
  181. }
  182. [TestMethod]
  183. public void testIsSequence()
  184. {
  185. self.assertFalse(nest.is_sequence("1234"));
  186. self.assertTrue(nest.is_sequence(new object[] { 1, 3, new object[] { 4, 5 } }));
  187. // TODO: ValueTuple<T,T>
  188. //self.assertTrue(nest.is_sequence(((7, 8), (5, 6))));
  189. self.assertTrue(nest.is_sequence(new object[] { }));
  190. self.assertTrue(nest.is_sequence(new Hashtable { ["a"] = 1, ["b"] = 2 }));
  191. self.assertFalse(nest.is_sequence(new HashSet<int> { 1, 2 }));
  192. var ones = array_ops.ones(new int[] { 2, 3 });
  193. self.assertFalse(nest.is_sequence(ones));
  194. self.assertFalse(nest.is_sequence(gen_math_ops.tanh(ones)));
  195. self.assertFalse(nest.is_sequence(np.ones(new int[] { 4, 5 })));
  196. }
  197. // @parameterized.parameters({"mapping_type": _CustomMapping},
  198. // {"mapping_type": dict})
  199. // def testFlattenDictItems(self, mapping_type):
  200. // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
  201. // flat = {4: "a", 5: "b", 6: "c", 8: "d"}
  202. // self.assertEqual(nest.flatten_dict_items(dictionary), flat)
  203. // with self.assertRaises(TypeError):
  204. // nest.flatten_dict_items(4)
  205. // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
  206. // with self.assertRaisesRegexp(ValueError, "not unique"):
  207. // nest.flatten_dict_items(bad_dictionary)
  208. // another_bad_dictionary = mapping_type({
  209. // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
  210. // })
  211. // with self.assertRaisesRegexp(
  212. // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
  213. // nest.flatten_dict_items(another_bad_dictionary)
  214. //# pylint does not correctly recognize these as class names and
  215. //# suggests to use variable style under_score naming.
  216. //# pylint: disable=invalid-name
  217. // Named0ab = collections.namedtuple("named_0", ("a", "b"))
  218. // Named1ab = collections.namedtuple("named_1", ("a", "b"))
  219. // SameNameab = collections.namedtuple("same_name", ("a", "b"))
  220. // SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
  221. // SameNamexy = collections.namedtuple("same_name", ("x", "y"))
  222. // SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
  223. // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
  224. // NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
  225. // # pylint: enable=invalid-name
  226. // class SameNamedType1(SameNameab):
  227. // pass
  228. // @test_util.assert_no_new_pyobjects_executing_eagerly
  229. // def testAssertSameStructure(self):
  230. // structure1 = (((1, 2), 3), 4, (5, 6))
  231. // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  232. // structure_different_num_elements = ("spam", "eggs")
  233. // structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
  234. // nest.assert_same_structure(structure1, structure2)
  235. // nest.assert_same_structure("abc", 1.0)
  236. // nest.assert_same_structure("abc", np.array([0, 1]))
  237. // nest.assert_same_structure("abc", constant_op.constant([0, 1]))
  238. // with self.assertRaisesRegexp(
  239. // ValueError,
  240. // ("The two structures don't have the same nested structure\\.\n\n"
  241. // "First structure:.*?\n\n"
  242. // "Second structure:.*\n\n"
  243. // "More specifically: Substructure "
  244. // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
  245. // 'substructure "type=str str=spam" is not\n'
  246. // "Entire first structure:\n"
  247. // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
  248. // "Entire second structure:\n"
  249. // r"\(\., \.\)")):
  250. // nest.assert_same_structure(structure1, structure_different_num_elements)
  251. // with self.assertRaisesRegexp(
  252. // ValueError,
  253. // ("The two structures don't have the same nested structure\\.\n\n"
  254. // "First structure:.*?\n\n"
  255. // "Second structure:.*\n\n"
  256. // r'More specifically: Substructure "type=list str=\[0, 1\]" '
  257. // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
  258. // "is not")):
  259. // nest.assert_same_structure([0, 1], np.array([0, 1]))
  260. // with self.assertRaisesRegexp(
  261. // ValueError,
  262. // ("The two structures don't have the same nested structure\\.\n\n"
  263. // "First structure:.*?\n\n"
  264. // "Second structure:.*\n\n"
  265. // r'More specifically: Substructure "type=list str=\[0, 1\]" '
  266. // 'is a sequence, while substructure "type=int str=0" '
  267. // "is not")):
  268. // nest.assert_same_structure(0, [0, 1])
  269. // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
  270. // with self.assertRaisesRegexp(
  271. // ValueError,
  272. // ("don't have the same nested structure\\.\n\n"
  273. // "First structure: .*?\n\nSecond structure: ")):
  274. // nest.assert_same_structure(structure1, structure_different_nesting)
  275. // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
  276. // NestTest.Named0ab("a", "b"))
  277. // nest.assert_same_structure(NestTest.Named0ab(3, 4),
  278. // NestTest.Named0ab("a", "b"))
  279. // self.assertRaises(TypeError, nest.assert_same_structure,
  280. // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
  281. // with self.assertRaisesRegexp(
  282. // ValueError,
  283. // ("don't have the same nested structure\\.\n\n"
  284. // "First structure: .*?\n\nSecond structure: ")):
  285. // nest.assert_same_structure(NestTest.Named0ab(3, 4),
  286. // NestTest.Named0ab([3], 4))
  287. // with self.assertRaisesRegexp(
  288. // ValueError,
  289. // ("don't have the same nested structure\\.\n\n"
  290. // "First structure: .*?\n\nSecond structure: ")):
  291. // nest.assert_same_structure([[3], 4], [3, [4]])
  292. // structure1_list = [[[1, 2], 3], 4, [5, 6]]
  293. // with self.assertRaisesRegexp(TypeError,
  294. // "don't have the same sequence type"):
  295. // nest.assert_same_structure(structure1, structure1_list)
  296. // nest.assert_same_structure(structure1, structure2, check_types= False)
  297. // nest.assert_same_structure(structure1, structure1_list, check_types=False)
  298. // with self.assertRaisesRegexp(ValueError,
  299. // "don't have the same set of keys"):
  300. // nest.assert_same_structure({"a": 1}, {"b": 1})
  301. // nest.assert_same_structure(NestTest.SameNameab(0, 1),
  302. // NestTest.SameNameab2(2, 3))
  303. // # This assertion is expected to pass: two namedtuples with the same
  304. // # name and field names are considered to be identical.
  305. // nest.assert_same_structure(
  306. // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
  307. // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
  308. // expected_message = "The two structures don't have the same.*"
  309. // with self.assertRaisesRegexp(ValueError, expected_message):
  310. // nest.assert_same_structure(
  311. // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
  312. // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
  313. // self.assertRaises(TypeError, nest.assert_same_structure,
  314. // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
  315. // self.assertRaises(TypeError, nest.assert_same_structure,
  316. // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
  317. // self.assertRaises(TypeError, nest.assert_same_structure,
  318. // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
  319. // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
  320. // def testHeterogeneousComparison(self):
  321. // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
  322. // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
  323. [TestMethod]
  324. public void testMapStructure()
  325. {
  326. var structure1 = new object[] { new object[] { new object[] { 1, 2 }, 3 }, 4, new object[] { 5, 6 } };
  327. var structure2 = new object[] { new object[] { new object[] { 7, 8 }, 9 }, 10, new object[] { 11, 12 } };
  328. var structure1_plus1 = nest.map_structure(x => (int)x + 1, structure1);
  329. var structure1_strings = nest.map_structure(x => $"{x}", structure1);
  330. var s = JArray.FromObject(structure1_plus1).ToString();
  331. Console.WriteLine(s);
  332. // nest.assert_same_structure(structure1, structure1_plus1)
  333. self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
  334. self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
  335. // structure1_plus_structure2 = nest.map_structure(
  336. // lambda x, y: x + y, structure1, structure2)
  337. // self.assertEqual(
  338. // (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
  339. // structure1_plus_structure2)
  340. // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
  341. // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
  342. // # Empty structures
  343. // self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
  344. // self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
  345. // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
  346. // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
  347. // NestTest.EmptyNT()))
  348. // # This is checking actual equality of types, empty list != empty tuple
  349. // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
  350. // with self.assertRaisesRegexp(TypeError, "callable"):
  351. // nest.map_structure("bad", structure1_plus1)
  352. // with self.assertRaisesRegexp(ValueError, "at least one structure"):
  353. // nest.map_structure(lambda x: x)
  354. // with self.assertRaisesRegexp(ValueError, "same number of elements"):
  355. // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
  356. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  357. // nest.map_structure(lambda x, y: None, 3, (3,))
  358. // with self.assertRaisesRegexp(TypeError, "same sequence type"):
  359. // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
  360. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  361. // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
  362. // structure1_list = [[[1, 2], 3], 4, [5, 6]]
  363. // with self.assertRaisesRegexp(TypeError, "same sequence type"):
  364. // nest.map_structure(lambda x, y: None, structure1, structure1_list)
  365. // nest.map_structure(lambda x, y: None, structure1, structure1_list,
  366. // check_types=False)
  367. // with self.assertRaisesRegexp(ValueError, "same nested structure"):
  368. // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
  369. // check_types=False)
  370. // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  371. // nest.map_structure(lambda x: None, structure1, foo="a")
  372. // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
  373. // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
  374. // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
  375. }
  376. // @test_util.assert_no_new_pyobjects_executing_eagerly
  377. // def testMapStructureWithStrings(self) :
  378. // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
  379. // inp_b = NestTest.ABTuple(a=2, b=(1, 3))
  380. // out = nest.map_structure(lambda string, repeats: string* repeats,
  381. // inp_a,
  382. // inp_b)
  383. // self.assertEqual("foofoo", out.a)
  384. // self.assertEqual("bar", out.b[0])
  385. // self.assertEqual("bazbazbaz", out.b[1])
  386. // nt = NestTest.ABTuple(a=("something", "something_else"),
  387. // b="yet another thing")
  388. // rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
  389. // # Check the output is the correct structure, and all strings are reversed.
  390. // nest.assert_same_structure(nt, rev_nt)
  391. // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
  392. // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
  393. // self.assertEqual(nt.b[::- 1], rev_nt.b)
  394. // @test_util.run_deprecated_v1
  395. // def testMapStructureOverPlaceholders(self) :
  396. // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  397. // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  398. // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
  399. // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
  400. // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
  401. // nest.assert_same_structure(output, inp_a)
  402. // self.assertShapeEqual(np.zeros((3, 4)), output[0])
  403. // self.assertShapeEqual(np.zeros((3, 7)), output[1])
  404. // feed_dict = {
  405. // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
  406. // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
  407. // }
  408. // with self.cached_session() as sess:
  409. // output_np = sess.run(output, feed_dict=feed_dict)
  410. // self.assertAllClose(output_np[0],
  411. // feed_dict[inp_a][0] + feed_dict[inp_b][0])
  412. // self.assertAllClose(output_np[1],
  413. // feed_dict[inp_a][1] + feed_dict[inp_b][1])
  414. // def testAssertShallowStructure(self):
  415. // inp_ab = ["a", "b"]
  416. //inp_abc = ["a", "b", "c"]
  417. //expected_message = (
  418. // "The two structures don't have the same sequence length. Input "
  419. // "structure has length 2, while shallow structure has length 3.")
  420. // with self.assertRaisesRegexp(ValueError, expected_message):
  421. // nest.assert_shallow_structure(inp_abc, inp_ab)
  422. // inp_ab1 = [(1, 1), (2, 2)]
  423. // inp_ab2 = [[1, 1], [2, 2]]
  424. // expected_message = (
  425. // "The two structures don't have the same sequence type. Input structure "
  426. // "has type <(type|class) 'tuple'>, while shallow structure has type "
  427. // "<(type|class) 'list'>.")
  428. // with self.assertRaisesRegexp(TypeError, expected_message):
  429. // nest.assert_shallow_structure(inp_ab2, inp_ab1)
  430. // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)
  431. // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
  432. // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
  433. // expected_message = (
  434. // r"The two structures don't have the same keys. Input "
  435. // r"structure has keys \['c'\], while shallow structure has "
  436. // r"keys \['d'\].")
  437. // with self.assertRaisesRegexp(ValueError, expected_message):
  438. // nest.assert_shallow_structure(inp_ab2, inp_ab1)
  439. // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
  440. // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
  441. // nest.assert_shallow_structure(inp_ab, inp_ba)
  442. // # This assertion is expected to pass: two namedtuples with the same
  443. //# name and field names are considered to be identical.
  444. //inp_shallow = NestTest.SameNameab(1, 2)
  445. // inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
  446. // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
  447. // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
  448. // def testFlattenUpTo(self):
  449. // # Shallow tree ends at scalar.
  450. // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  451. // shallow_tree = [[True, True], [False, True]]
  452. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  453. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  454. // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
  455. // self.assertEqual(flattened_shallow_tree, [True, True, False, True])
  456. //# Shallow tree ends at string.
  457. // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
  458. // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
  459. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  460. // input_tree)
  461. // input_tree_flattened = nest.flatten(input_tree)
  462. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  463. // [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
  464. // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
  465. // # Make sure dicts are correctly flattened, yielding values, not keys.
  466. //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
  467. // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
  468. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  469. // input_tree)
  470. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  471. // [1, { "c": 2}, 3, (4, 5)])
  472. // # Namedtuples.
  473. // ab_tuple = NestTest.ABTuple
  474. // input_tree = ab_tuple(a =[0, 1], b = 2)
  475. // shallow_tree = ab_tuple(a= 0, b= 1)
  476. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  477. // input_tree)
  478. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  479. // [[0, 1], 2])
  480. // # Nested dicts, OrderedDicts and namedtuples.
  481. // input_tree = collections.OrderedDict(
  482. // [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
  483. // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
  484. // shallow_tree = input_tree
  485. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  486. // input_tree)
  487. // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
  488. // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
  489. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  490. // input_tree)
  491. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  492. // [ab_tuple(a =[0, { "b": 1}], b=2),
  493. // 3,
  494. // collections.OrderedDict([("f", 4)])])
  495. // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
  496. // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
  497. // input_tree)
  498. // self.assertEqual(input_tree_flattened_as_shallow_tree,
  499. // [ab_tuple(a =[0, {"b": 1}], b=2),
  500. // {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
  501. // ## Shallow non-list edge-case.
  502. // # Using iterable elements.
  503. // input_tree = ["input_tree"]
  504. //shallow_tree = "shallow_tree"
  505. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  506. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  507. // self.assertEqual(flattened_input_tree, [input_tree])
  508. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  509. // input_tree = ["input_tree_0", "input_tree_1"]
  510. //shallow_tree = "shallow_tree"
  511. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  512. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  513. // self.assertEqual(flattened_input_tree, [input_tree])
  514. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  515. // # Using non-iterable elements.
  516. //input_tree = [0]
  517. //shallow_tree = 9
  518. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  519. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  520. // self.assertEqual(flattened_input_tree, [input_tree])
  521. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  522. // input_tree = [0, 1]
  523. //shallow_tree = 9
  524. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  525. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  526. // self.assertEqual(flattened_input_tree, [input_tree])
  527. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  528. // ## Both non-list edge-case.
  529. //# Using iterable elements.
  530. //input_tree = "input_tree"
  531. // shallow_tree = "shallow_tree"
  532. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  533. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  534. // self.assertEqual(flattened_input_tree, [input_tree])
  535. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  536. // # Using non-iterable elements.
  537. //input_tree = 0
  538. // shallow_tree = 0
  539. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  540. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  541. // self.assertEqual(flattened_input_tree, [input_tree])
  542. // self.assertEqual(flattened_shallow_tree, [shallow_tree])
  543. // ## Input non-list edge-case.
  544. //# Using iterable elements.
  545. //input_tree = "input_tree"
  546. // shallow_tree = ["shallow_tree"]
  547. //expected_message = ("If shallow structure is a sequence, input must also "
  548. // "be a sequence. Input has type: <(type|class) 'str'>.")
  549. // with self.assertRaisesRegexp(TypeError, expected_message):
  550. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  551. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  552. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  553. // input_tree = "input_tree"
  554. // shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
  555. //with self.assertRaisesRegexp(TypeError, expected_message):
  556. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  557. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  558. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  559. //# Using non-iterable elements.
  560. // input_tree = 0
  561. // shallow_tree = [9]
  562. //expected_message = ("If shallow structure is a sequence, input must also "
  563. // "be a sequence. Input has type: <(type|class) 'int'>.")
  564. // with self.assertRaisesRegexp(TypeError, expected_message):
  565. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  566. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  567. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  568. // input_tree = 0
  569. // shallow_tree = [9, 8]
  570. //with self.assertRaisesRegexp(TypeError, expected_message):
  571. // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
  572. // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
  573. // self.assertEqual(flattened_shallow_tree, shallow_tree)
  574. // def testMapStructureUpTo(self) :
  575. // # Named tuples.
  576. // ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  577. // op_tuple = collections.namedtuple("op_tuple", "add, mul")
  578. // inp_val = ab_tuple(a= 2, b= 3)
  579. // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
  580. // out = nest.map_structure_up_to(
  581. // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
  582. // self.assertEqual(out.a, 6)
  583. // self.assertEqual(out.b, 15)
  584. // # Lists.
  585. // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
  586. // name_list = ["evens", ["odds", "primes"]]
  587. // out = nest.map_structure_up_to(
  588. // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
  589. // name_list, data_list)
  590. // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
  591. // # Dicts.
  592. // inp_val = dict(a= 2, b= 3)
  593. // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
  594. // out = nest.map_structure_up_to(
  595. // inp_val,
  596. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  597. // self.assertEqual(out["a"], 6)
  598. // self.assertEqual(out["b"], 15)
  599. // # Non-equal dicts.
  600. // inp_val = dict(a= 2, b= 3)
  601. // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
  602. // with self.assertRaisesRegexp(ValueError, "same keys"):
  603. // nest.map_structure_up_to(
  604. // inp_val,
  605. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  606. // # Dict+custom mapping.
  607. // inp_val = dict(a= 2, b= 3)
  608. // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
  609. // out = nest.map_structure_up_to(
  610. // inp_val,
  611. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  612. // self.assertEqual(out["a"], 6)
  613. // self.assertEqual(out["b"], 15)
  614. // # Non-equal dict/mapping.
  615. // inp_val = dict(a= 2, b= 3)
  616. // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
  617. // with self.assertRaisesRegexp(ValueError, "same keys"):
  618. // nest.map_structure_up_to(
  619. // inp_val,
  620. // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
  621. // def testGetTraverseShallowStructure(self):
  622. // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
  623. // scalar_traverse_r = nest.get_traverse_shallow_structure(
  624. // lambda s: not isinstance(s, tuple),
  625. // scalar_traverse_input)
  626. // self.assertEqual(scalar_traverse_r,
  627. // [True, True, False, [True, True], {"a": False}, []])
  628. // nest.assert_shallow_structure(scalar_traverse_r,
  629. // scalar_traverse_input)
  630. // structure_traverse_input = [(1, [2]), ([1], 2)]
  631. // structure_traverse_r = nest.get_traverse_shallow_structure(
  632. // lambda s: (True, False) if isinstance(s, tuple) else True,
  633. // structure_traverse_input)
  634. // self.assertEqual(structure_traverse_r,
  635. // [(True, False), ([True], False)])
  636. // nest.assert_shallow_structure(structure_traverse_r,
  637. // structure_traverse_input)
  638. // with self.assertRaisesRegexp(TypeError, "returned structure"):
  639. // nest.get_traverse_shallow_structure(lambda _: [True], 0)
  640. // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
  641. // nest.get_traverse_shallow_structure(lambda _: 1, [1])
  642. // with self.assertRaisesRegexp(
  643. // TypeError, "didn't return a depth=1 structure of bools"):
  644. // nest.get_traverse_shallow_structure(lambda _: [1], [1])
  645. // def testYieldFlatStringPaths(self):
  646. // for inputs_expected in ({"inputs": [], "expected": []},
  647. // {"inputs": 3, "expected": [()]},
  648. // {"inputs": [3], "expected": [(0,)]},
  649. // {"inputs": {"a": 3}, "expected": [("a",)]},
  650. // {"inputs": {"a": {"b": 4}},
  651. // "expected": [("a", "b")]},
  652. // {"inputs": [{"a": 2}], "expected": [(0, "a")]},
  653. // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
  654. // {"inputs": [{"a": [(23, 42)]}],
  655. // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
  656. // {"inputs": [{"a": ([23], 42)}],
  657. // "expected": [(0, "a", 0, 0), (0, "a", 1)]},
  658. // {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
  659. // "expected": [("a", "a"), ("c", 0, 0, 0)]},
  660. // {"inputs": {"0": [{"1": 23}]},
  661. // "expected": [("0", 0, "1")]}):
  662. // inputs = inputs_expected["inputs"]
  663. // expected = inputs_expected["expected"]
  664. // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
  665. // def testFlattenWithStringPaths(self):
  666. // for inputs_expected in (
  667. // {"inputs": [], "expected": []},
  668. // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
  669. // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
  670. // inputs = inputs_expected["inputs"]
  671. // expected = inputs_expected["expected"]
  672. // self.assertEqual(
  673. // nest.flatten_with_joined_string_paths(inputs, separator="/"),
  674. // expected)
  675. // # Need a separate test for namedtuple as we can't declare tuple definitions
  676. // # in the @parameterized arguments.
  677. // def testFlattenNamedTuple(self):
  678. // # pylint: disable=invalid-name
  679. // Foo = collections.namedtuple("Foo", ["a", "b"])
  680. // Bar = collections.namedtuple("Bar", ["c", "d"])
  681. // # pylint: enable=invalid-name
  682. // test_cases = [
  683. // (Foo(a = 3, b = Bar(c = 23, d = 42)),
  684. // [("a", 3), ("b/c", 23), ("b/d", 42)]),
  685. // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
  686. // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
  687. // (Bar(c = 42, d = 43),
  688. // [("c", 42), ("d", 43)]),
  689. // (Bar(c =[42], d = 43),
  690. // [("c/0", 42), ("d", 43)]),
  691. // ]
  692. // for inputs, expected in test_cases:
  693. // self.assertEqual(
  694. // list(nest.flatten_with_joined_string_paths(inputs)), expected)
  695. // @parameterized.named_parameters(
  696. // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
  697. // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
  698. // {"a": ("a", 4), "b": ("b", 6)}),
  699. // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
  700. // ("nested",
  701. // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
  702. // {"a": [("a/0", 10), ("a/1", 12)],
  703. // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
  704. // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
  705. // def format_sum(path, * values):
  706. // return (path, sum(values))
  707. // result = nest.map_structure_with_paths(format_sum, s1, s2,
  708. // check_types=check_types)
  709. // self.assertEqual(expected, result)
  710. // @parameterized.named_parameters(
  711. // ("tuples", (1, 2), (3, 4, 5), ValueError),
  712. // ("dicts", {"a": 1}, {"b": 2}, ValueError),
  713. // ("mixed", (1, 2), [3, 4], TypeError),
  714. // ("nested",
  715. // {"a": [2, 3], "b": [1, 3]},
  716. // {"b": [5, 6, 7], "a": [8, 9]},
  717. // ValueError
  718. // ))
  719. // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
  720. // with self.assertRaises(error_type):
  721. // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)
  722. //class NestBenchmark(test.Benchmark):
  723. // def run_and_report(self, s1, s2, name):
  724. // burn_iter, test_iter = 100, 30000
  725. // for _ in xrange(burn_iter) :
  726. // nest.assert_same_structure(s1, s2)
  727. // t0 = time.time()
  728. // for _ in xrange(test_iter) :
  729. // nest.assert_same_structure(s1, s2)
  730. // t1 = time.time()
  731. // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
  732. // name=name)
  733. // def benchmark_assert_structure(self):
  734. // s1 = (((1, 2), 3), 4, (5, 6))
  735. // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  736. // self.run_and_report(s1, s2, "assert_same_structure_6_elem")
  737. // s1 = (((1, 2), 3), 4, (5, 6)) * 10
  738. // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
  739. // self.run_and_report(s1, s2, "assert_same_structure_60_elem")
  740. //if __name__ == "__main__":
  741. // test.main()
  742. }
  743. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。