diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs
new file mode 100644
index 00000000..2d43b976
--- /dev/null
+++ b/src/TensorFlowNET.Core/Util/nest.py.cs
@@ -0,0 +1,871 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using NumSharp;
+
+namespace Tensorflow.Util
+{
+ //Functions for working with arbitrarily nested sequences of elements.
+
+ //This module can perform operations on nested structures. A nested structure is a
+ //Python sequence, tuple (including `namedtuple`), or dict that can contain
+ //further sequences, tuples, and dicts.
+
+ //The utilities here assume (and do not check) that the nested structures form a
+ //'tree', i.e., no references in the structure of the input of these functions
+ //should be recursive.
+
+ //Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
+ // (np.array([3, 4]), tf.constant([3, 4])))`
+ //
+
+ public static class nest
+ {
+
+
+ //def _get_attrs_values(obj):
+ // """Returns the list of values from an attrs instance."""
+ // attrs = getattr(obj.__class__, "__attrs_attrs__")
+ // return [getattr(obj, a.name) for a in attrs]
+
+ ///
+ /// Returns a sorted list of the dict keys, with error if keys not sortable.
+ ///
+ private static IEnumerable _sorted(IDictionary dict_)
+ {
+ return dict_.Keys.OfType().OrderBy(x => x);
+ }
+
+
+ //def _is_namedtuple(instance, strict=False):
+ // """Returns True iff `instance` is a `namedtuple`.
+
+ // Args:
+ // instance: An instance of a Python object.
+ // strict: If True, `instance` is considered to be a `namedtuple` only if
+ // it is a "plain" namedtuple. For instance, a class inheriting
+ // from a `namedtuple` will be considered to be a `namedtuple`
+ // iff `strict=False`.
+
+ // Returns:
+ // True if `instance` is a `namedtuple`.
+ // """
+ // return _pywrap_tensorflow.IsNamedtuple(instance, strict)
+
+
+ //# See the swig file (util.i) for documentation.
+ //_is_mapping = _pywrap_tensorflow.IsMapping
+ //_is_attrs = _pywrap_tensorflow.IsAttrs
+
+ ///
+ /// Converts the sequence `args` to the same type as `instance`.
+ ///
+ /// an instance of `tuple`, `list`, `namedtuple`, `dict`, or
+ /// `collections.OrderedDict`.
+ /// elements to be converted to the `instance` type.
+ /// `args` with the type of `instance`.
+ private static object _sequence_like(object instance, IEnumerable args)
+ {
+ if (is_mapping(instance))
+ {
+ //# Pack dictionaries in a deterministic order by sorting the keys.
+ //# Notice this means that we ignore the original order of `OrderedDict`
+ //# instances. This is intentional, to avoid potential bugs caused by mixing
+ //# ordered and plain dicts (e.g., flattening a dict but using a
+ //# corresponding `OrderedDict` to pack it back).
+ // result = dict(zip(_sorted(instance), args))
+ // return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
+ }
+ //else if( _is_namedtuple(instance) || _is_attrs(instance))
+ // return type(instance)(*args)
+ else
+ {
+ // Not a namedtuple
+ switch (instance)
+ {
+ case object[] array:
+ var result_array = new object[args.Count()];
+ int i = 0;
+ foreach (var x in args)
+ {
+ result_array[i] = x;
+ i++;
+ }
+ return result_array;
+ case List list:
+ return new List(args);
+ default:
+ throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
+ }
+ }
+ throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
+ }
+
+ ///
+ /// Yields the next value from the given iterable.
+ ///
+ private static IEnumerable _yield_value(object iterable)
+ {
+ if (is_mapping(iterable))
+ {
+ var dict = iterable as IDictionary;
+ //# Iterate through dictionaries in a deterministic order by sorting the
+ //# keys. Notice this means that we ignore the original order of `OrderedDict`
+ //# instances. This is intentional, to avoid potential bugs caused by mixing
+ //# ordered and plain dicts (e.g., flattening a dict but using a
+ //# corresponding `OrderedDict` to pack it back).
+ foreach (var key in _sorted(dict))
+ yield return dict[key];
+ }
+ //else if (_is_attrs(iterable))
+ //{
+ // // for value in _get_attrs_values(iterable):
+ // // yield value
+ //}
+ else if (iterable is IEnumerable)
+ {
+ var enumerable = iterable as IEnumerable;
+ foreach (var value in enumerable)
+ yield return value;
+ }
+ else
+ {
+ throw new TypeError("Unexpected iterable type: " + iterable.GetType());
+ //var jobj = JObject.FromObject(iterable);
+ //foreach (var key in _sorted())
+ // yield return jobj[key];
+ }
+ }
+
+ //# See the swig file (util.i) for documentation.
+ public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string);
+
+ public static bool is_mapping(object arg) => arg is IDictionary;
+
+ //# See the swig file (util.i) for documentation.
+ //flatten = _pywrap_tensorflow.Flatten
+
+ public static List flatten(object structure)
+ {
+ var list = new List();
+ _flatten_recursive(structure, list);
+ return list;
+ }
+
+ private static void _flatten_recursive(object obj, List list)
+ {
+ if (obj is string)
+ {
+ list.Add(obj);
+ return;
+ }
+ if (obj is IDictionary)
+ {
+ var dict = obj as IDictionary;
+ foreach (var key in _sorted(dict))
+ _flatten_recursive(dict[key], list);
+ return;
+ }
+ if (obj is NDArray)
+ {
+ list.Add(obj);
+ return;
+ }
+ if (obj is IEnumerable)
+ {
+ var structure = obj as IEnumerable;
+ foreach (var child in structure)
+ _flatten_recursive(child, list);
+ return;
+ }
+ list.Add(obj);
+ }
+
+
+ //# See the swig file (util.i) for documentation.
+ //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
+
+
+ //class _DotString(object):
+
+ // def __str__(self):
+ // return "."
+
+ // def __repr__(self):
+ // return "."
+
+
+ //_DOT = _DotString()
+
+
+ //def assert_same_structure(nest1, nest2, check_types=True):
+ // """Asserts that two structures are nested in the same way.
+
+ // Note that namedtuples with identical name and fields are always considered
+ // to have the same shallow structure (even with `check_types=True`).
+ // For intance, this code will print `True`:
+
+ // ```python
+ // def nt(a, b):
+ // return collections.namedtuple('foo', 'a b')(a, b)
+ // print(assert_same_structure(nt(0, 1), nt(2, 3)))
+ // ```
+
+ // Args:
+ // nest1: an arbitrarily nested structure.
+ // nest2: an arbitrarily nested structure.
+ // check_types: if `True` (default) types of sequences are checked as well,
+ // including the keys of dictionaries. If set to `False`, for example a
+ // list and a tuple of objects will look the same if they have the same
+ // size. Note that namedtuples with identical name and fields are always
+ // considered to have the same shallow structure. Two types will also be
+ // considered the same if they are both list subtypes (which allows "list"
+ // and "_ListWrapper" from checkpointable dependency tracking to compare
+ // equal).
+
+ // Raises:
+ // ValueError: If the two structures do not have the same number of elements or
+ // if the two structures are not nested in the same way.
+ // TypeError: If the two structures differ in the type of sequence in any of
+ // their substructures. Only possible if `check_types` is `True`.
+ // """
+ // try:
+ // _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
+ // except (ValueError, TypeError) as e:
+ // str1 = str(map_structure(lambda _: _DOT, nest1))
+ // str2 = str(map_structure(lambda _: _DOT, nest2))
+ // raise type(e)("%s\n"
+ // "Entire first structure:\n%s\n"
+ // "Entire second structure:\n%s"
+ // % (str(e), str1, str2))
+
+
+ //def flatten_dict_items(dictionary):
+ // """Returns a dictionary with flattened keys and values.
+
+ // This function flattens the keys and values of a dictionary, which can be
+ // arbitrarily nested structures, and returns the flattened version of such
+ // structures:
+
+ // ```python
+ // example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
+ // result = {4: "a", 5: "b", 6: "c", 8: "d"}
+ // flatten_dict_items(example_dictionary) == result
+ // ```
+
+ // The input dictionary must satisfy two properties:
+
+ // 1. Its keys and values should have the same exact nested structure.
+ // 2. The set of all flattened keys of the dictionary must not contain repeated
+ // keys.
+
+ // Args:
+ // dictionary: the dictionary to zip
+
+ // Returns:
+ // The zipped dictionary.
+
+ // Raises:
+ // TypeError: If the input is not a dictionary.
+ // ValueError: If any key and value have not the same structure, or if keys are
+ // not unique.
+ // """
+ // if not isinstance(dictionary, (dict, _collections.Mapping)):
+ // raise TypeError("input must be a dictionary")
+ // flat_dictionary = {}
+ // for i, v in _six.iteritems(dictionary):
+ // if not is_sequence(i):
+ // if i in flat_dictionary:
+ // raise ValueError(
+ // "Could not flatten dictionary: key %s is not unique." % i)
+ // flat_dictionary[i] = v
+ // else:
+ // flat_i = flatten(i)
+ // flat_v = flatten(v)
+ // if len(flat_i) != len(flat_v):
+ // raise ValueError(
+ // "Could not flatten dictionary. Key had %d elements, but value had "
+ // "%d elements. Key: %s, value: %s."
+ // % (len(flat_i), len(flat_v), flat_i, flat_v))
+ // for new_i, new_v in zip(flat_i, flat_v):
+ // if new_i in flat_dictionary:
+ // raise ValueError(
+ // "Could not flatten dictionary: key %s is not unique."
+ // % (new_i))
+ // flat_dictionary[new_i] = new_v
+ // return flat_dictionary
+
+ ///
+ /// Helper function for pack_sequence_as.
+ ///
+ /// Substructure (list / tuple / dict) to mimic.
+ /// Flattened values to output substructure for.
+ /// Index at which to start reading from flat.
+ ///
+ /// The tuple(new_index, child), where:
+ /// * new_index - the updated index into `flat` having processed `structure`.
+ /// * packed - the subset of `flat` corresponding to `structure`,
+ /// having started at `index`, and packed into the same nested
+ /// format.
+ private static (int new_index, List child) _packed_nest_with_indices(object structure, List flat,
+ int index)
+ {
+ var packed = new List();
+ foreach (var s in _yield_value(structure))
+ {
+ if (is_sequence(s))
+ {
+ var (new_index, child) = _packed_nest_with_indices(s, flat, index);
+ packed.Add(_sequence_like(s, child));
+ index = new_index;
+ }
+ else
+ {
+ packed.Add(flat[index]);
+ index += 1;
+ }
+ }
+ return (index, packed);
+ }
+
+ private static int len(IEnumerable x) => x.Count();
+
+ ///
+ /// Returns a given flattened sequence packed into a given structure.
+ /// If `structure` is a scalar, `flat_sequence` must be a single-element list;
+ /// in this case the return value is `flat_sequence[0]`.
+ ///
+ /// If `structure` is or contains a dict instance, the keys will be sorted to
+ /// pack the flat sequence in deterministic order. This is true also for
+ /// `OrderedDict` instances: their sequence order is ignored, the sorting order of
+ /// keys is used instead. The same convention is followed in `flatten`.
+ /// This correctly repacks dicts and `OrderedDict`s after they have been
+ /// flattened, and also allows flattening an `OrderedDict` and then repacking it
+ /// back using a corresponding plain dict, or vice-versa.
+ /// Dictionaries with non-sortable keys cannot be flattened.
+ ///
+ ///
+ /// Nested structure, whose structure is given by nested lists,
+ /// tuples, and dicts. Note: numpy arrays and strings are considered
+ /// scalars.
+ ///
+ /// flat sequence to pack.
+ /// `flat_sequence` converted to have the same recursive structure as
+ /// `structure`.
+ ///
+ public static object pack_sequence_as(object structure, List flat_sequence)
+ {
+ if (flat_sequence == null)
+ throw new ArgumentException("flat_sequence must not be null");
+ // if not is_sequence(flat_sequence):
+ // raise TypeError("flat_sequence must be a sequence")
+
+ if (!is_sequence(structure))
+ {
+ if (len(flat_sequence) != 1)
+ throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1");
+ return flat_sequence.FirstOrDefault();
+ }
+ int final_index = 0;
+ List packed = null;
+ try
+ {
+ (final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0);
+ if (final_index < len(flat_sequence))
+ throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }");
+ }
+ catch (IndexOutOfRangeException)
+ {
+ var flat_structure = flatten(structure);
+ if (len(flat_structure) != len(flat_sequence))
+ {
+ throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
+ $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}");
+ }
+ return _sequence_like(structure, packed);
+ }
+ return packed;
+ }
+
+ ///
+ /// Applies `func` to each entry in `structure` and returns a new structure.
+ ///
+ /// Applies `func(x[0], x[1], ...)` where x[i] is an entry in
+ /// `structure[i]`. All structures in `structure` must have the same arity,
+ /// and the return value will contain the results in the same structure.
+ ///
+ ///
+ ///
+ /// A callable that accepts as many arguments as there are structures.
+ /// scalar, or tuple or list of constructed scalars and/or other
+ /// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.
+ /// If set to
+ /// `True` (default) the types of iterables within the structures have to be
+ /// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
+ /// exception). To allow this set this argument to `False`.
+ /// Note that namedtuples with identical name and fields are always
+ /// considered to have the same shallow structure.
+ ///
+ /// A new structure with the same arity as `structure`, whose values correspond
+ /// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
+ /// location in `structure[i]`. If there are different sequence types and
+ /// `check_types` is `False` the sequence types of the first structure will be
+ /// used.
+ ///
+ public static IEnumerable map_structure(Func func, IEnumerable structure, bool check_types = false)
+ {
+
+ // for other in structure[1:]:
+ // assert_same_structure(structure[0], other, check_types=check_types)
+
+ // flat_structure = [flatten(s) for s in structure]
+ // entries = zip(*flat_structure)
+
+ // return pack_sequence_as(
+ // structure[0], [func(*x) for x in entries])
+ return null;
+ }
+
+ //def map_structure_with_paths(func, *structure, **kwargs):
+ // """Applies `func` to each entry in `structure` and returns a new structure.
+
+ // Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
+ // `structure[i]` and `path` is the common path to x[i] in the structures. All
+ // structures in `structure` must have the same arity, and the return value will
+ // contain the results in the same structure. Special kwarg `check_types`
+ // determines whether the types of iterables within the structure must be the
+ // same-- see **kwargs definition below.
+
+ // Args:
+ // func: A callable with the signature func(path, *values, **kwargs) that is
+ // evaluated on the leaves of the structure.
+ // *structure: A variable number of compatible structures to process.
+ // **kwargs: Optional kwargs to be passed through to func. Special kwarg
+ // `check_types` is not passed to func, but instead determines whether the
+ // types of iterables within the structures have to be same (e.g.,
+ // `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
+ // default, the types must match. To allow iteration over structures of
+ // different types (but common arity), set this kwarg to `False`.
+
+ // Returns:
+ // A structure of the same form as the input structures whose leaves are the
+ // result of evaluating func on corresponding leaves of the input structures.
+
+ // Raises:
+ // TypeError: If `func` is not callable or if the structures do not match
+ // each other by depth tree.
+ // TypeError: If `check_types` is not `False` and the two structures differ in
+ // the type of sequence in any of their substructures.
+ // ValueError: If no structures are provided.
+ // """
+ // if not callable(func):
+ // raise TypeError("func must be callable, got: %s" % func)
+ // if not structure:
+ // raise ValueError("Must provide at least one structure")
+
+ // check_types = kwargs.pop("check_types", True)
+ // for other in structure[1:]:
+ // assert_same_structure(structure[0], other, check_types=check_types)
+
+ //# First set paths_and_values to:
+ //# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]]
+ // paths_and_values = [flatten_with_joined_string_paths(s) for s in structure]
+
+ //# Now zip(*paths_and_values) would be:
+ //# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))]
+ //# so grouped_by_path is set to:
+ //# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]]
+ //# Note that p1i, ... pmi must all be equal since the structures are the same.
+ // grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)]
+
+ // return pack_sequence_as(structure[0], [
+ // func(paths[0], *values, **kwargs) for paths, values in grouped_by_path])
+
+
+ //def _yield_flat_up_to(shallow_tree, input_tree):
+ // """Yields elements `input_tree` partially flattened up to `shallow_tree`."""
+ // if is_sequence(shallow_tree):
+ // for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
+ // _yield_value(input_tree)):
+ // for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
+ // yield input_leaf
+ // else:
+ // yield input_tree
+
+
+ //def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
+ // """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
+
+ // That is, this function tests if the `input_tree` structure can be created from
+ // the `shallow_tree` structure by replacing its leaf nodes with deeper
+ // tree structures.
+
+ // Examples:
+
+ // The following code will raise an exception:
+ // ```python
+ // shallow_tree = ["a", "b"]
+ // input_tree = ["c", ["d", "e"], "f"]
+ // assert_shallow_structure(shallow_tree, input_tree)
+ // ```
+
+ // The following code will not raise an exception:
+ // ```python
+ // shallow_tree = ["a", "b"]
+ // input_tree = ["c", ["d", "e"]]
+ // assert_shallow_structure(shallow_tree, input_tree)
+ // ```
+
+ // Args:
+ // shallow_tree: an arbitrarily nested structure.
+ // input_tree: an arbitrarily nested structure.
+ // check_types: if `True` (default) the sequence types of `shallow_tree` and
+ // `input_tree` have to be the same. Note that even with check_types==True,
+ // this function will consider two different namedtuple classes with the same
+ // name and _fields attribute to be the same class.
+
+ // Raises:
+ // TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
+ // TypeError: If the sequence types of `shallow_tree` are different from
+ // `input_tree`. Only raised if `check_types` is `True`.
+ // ValueError: If the sequence lengths of `shallow_tree` are different from
+ // `input_tree`.
+ // """
+ // if is_sequence(shallow_tree):
+ // if not is_sequence(input_tree):
+ // raise TypeError(
+ // "If shallow structure is a sequence, input must also be a sequence. "
+ // "Input has type: %s." % type(input_tree))
+
+ // if check_types and not isinstance(input_tree, type(shallow_tree)):
+ //# Duck-typing means that nest should be fine with two different
+ //# namedtuples with identical name and fields.
+ // shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
+ // input_is_namedtuple = _is_namedtuple(input_tree, False)
+ // if shallow_is_namedtuple and input_is_namedtuple:
+ // if not _same_namedtuples(shallow_tree, input_tree):
+ // raise TypeError(
+ // "The two namedtuples don't have the same sequence type. Input "
+ // "structure has type %s, while shallow structure has type %s."
+ // % (type(input_tree), type(shallow_tree)))
+ // elif not (isinstance(shallow_tree, _collections.Mapping)
+ // and isinstance(input_tree, _collections.Mapping)):
+ // raise TypeError(
+ // "The two structures don't have the same sequence type. Input "
+ // "structure has type %s, while shallow structure has type %s."
+ // % (type(input_tree), type(shallow_tree)))
+
+ // if len(input_tree) != len(shallow_tree):
+ // raise ValueError(
+ // "The two structures don't have the same sequence length. Input "
+ // "structure has length %s, while shallow structure has length %s."
+ // % (len(input_tree), len(shallow_tree)))
+
+ // if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)):
+ // if set(input_tree) != set(shallow_tree):
+ // raise ValueError(
+ // "The two structures don't have the same keys. Input "
+ // "structure has keys %s, while shallow structure has keys %s." %
+ // (list(_six.iterkeys(input_tree)),
+ // list(_six.iterkeys(shallow_tree))))
+
+ // input_tree = list(sorted(_six.iteritems(input_tree)))
+ // shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
+
+ // for shallow_branch, input_branch in zip(shallow_tree, input_tree):
+ // assert_shallow_structure(shallow_branch, input_branch,
+ // check_types=check_types)
+
+
+ //def flatten_up_to(shallow_tree, input_tree):
+ // """Flattens `input_tree` up to `shallow_tree`.
+
+ // Any further depth in structure in `input_tree` is retained as elements in the
+ // partially flatten output.
+
+ // If `shallow_tree` and `input_tree` are not sequences, this returns a
+ // single-element list: `[input_tree]`.
+
+ // Use Case:
+
+ // Sometimes we may wish to partially flatten a nested sequence, retaining some
+ // of the nested structure. We achieve this by specifying a shallow structure,
+ // `shallow_tree`, we wish to flatten up to.
+
+ // The input, `input_tree`, can be thought of as having the same structure as
+ // `shallow_tree`, but with leaf nodes that are themselves tree structures.
+
+ // Examples:
+
+ // ```python
+ // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
+ // shallow_tree = [[True, True], [False, True]]
+
+ // flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
+
+ //# Output is:
+ //# [[2, 2], [3, 3], [4, 9], [5, 5]]
+ //# [True, True, False, True]
+ // ```
+
+ // ```python
+ // input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
+ // shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
+
+ // input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
+ // input_tree_flattened = flatten(input_tree)
+
+ //# Output is:
+ //# [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
+ //# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
+ // ```
+
+ // Non-Sequence Edge Cases:
+
+ // ```python
+ // flatten_up_to(0, 0) # Output: [0]
+ // flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
+ // flatten_up_to([0, 1, 2], 0) # Output: TypeError
+ // flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
+ // ```
+
+ // Args:
+ // shallow_tree: a possibly pruned structure of input_tree.
+ // input_tree: an arbitrarily nested structure or a scalar object.
+ // Note, numpy arrays are considered scalars.
+
+ // Returns:
+ // A Python list, the partially flattened version of `input_tree` according to
+ // the structure of `shallow_tree`.
+
+ // Raises:
+ // TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
+ // TypeError: If the sequence types of `shallow_tree` are different from
+ // `input_tree`.
+ // ValueError: If the sequence lengths of `shallow_tree` are different from
+ // `input_tree`.
+ // """
+ // assert_shallow_structure(shallow_tree, input_tree)
+ // return list(_yield_flat_up_to(shallow_tree, input_tree))
+
+
+ //def map_structure_up_to(shallow_tree, func, *inputs):
+ // """Applies a function or op to a number of partially flattened inputs.
+
+ // The `inputs` are flattened up to `shallow_tree` before being mapped.
+
+ // Use Case:
+
+ // Sometimes we wish to apply a function to a partially flattened
+ // sequence (for example when the function itself takes sequence inputs). We
+ // achieve this by specifying a shallow structure, `shallow_tree` we wish to
+ // flatten up to.
+
+ // The `inputs`, can be thought of as having the same structure as
+ // `shallow_tree`, but with leaf nodes that are themselves tree structures.
+
+ // This function therefore will return something with the same base structure as
+ // `shallow_tree`.
+
+ // Examples:
+
+ // ```python
+ // ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+ // op_tuple = collections.namedtuple("op_tuple", "add, mul")
+ // inp_val = ab_tuple(a=2, b=3)
+ // inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
+ // out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
+ // inp_val, inp_ops)
+
+ //# Output is: ab_tuple(a=6, b=15)
+ // ```
+
+ // ```python
+ // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
+ // name_list = ['evens', ['odds', 'primes']]
+ // out = map_structure_up_to(
+ // name_list,
+ // lambda name, sec: "first_{}_{}".format(len(sec), name),
+ // name_list, data_list)
+
+ //# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
+ // ```
+
+ // Args:
+ // shallow_tree: a shallow tree, common to all the inputs.
+ // func: callable which will be applied to each input individually.
+ // *inputs: arbitrarily nested combination of objects that are compatible with
+ // shallow_tree. The function `func` is applied to corresponding
+ // partially flattened elements of each input, so the function must support
+ // arity of `len(inputs)`.
+
+ // Raises:
+ // TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
+ // TypeError: If the sequence types of `shallow_tree` are different from
+ // `input_tree`.
+ // ValueError: If the sequence lengths of `shallow_tree` are different from
+ // `input_tree`.
+
+ // Returns:
+ // result of repeatedly applying `func`, with same structure as
+ // `shallow_tree`.
+ // """
+ // if not inputs:
+ // raise ValueError("Cannot map over no sequences")
+ // for input_tree in inputs:
+ // assert_shallow_structure(shallow_tree, input_tree)
+
+ //# Flatten each input separately, apply the function to corresponding elements,
+ //# then repack based on the structure of the first input.
+ // all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
+ // for input_tree in inputs]
+ // results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
+ // return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
+
+
+ //def get_traverse_shallow_structure(traverse_fn, structure):
+ // """Generates a shallow structure from a `traverse_fn` and `structure`.
+
+ // `traverse_fn` must accept any possible subtree of `structure` and return
+ // a depth=1 structure containing `True` or `False` values, describing which
+ // of the top-level subtrees may be traversed. It may also
+ // return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
+
+ // Examples are available in the unit tests (nest_test.py).
+
+ // Args:
+ // traverse_fn: Function taking a substructure and returning either a scalar
+ // `bool` (whether to traverse that substructure or not) or a depth=1
+ // shallow structure of the same type, describing which parts of the
+ // substructure to traverse.
+ // structure: The structure to traverse.
+
+ // Returns:
+ // A shallow structure containing python bools, which can be passed to
+ // `map_structure_up_to` and `flatten_up_to`.
+
+ // Raises:
+ // TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
+ // or a structure with depth higher than 1 for a sequence input,
+ // or if any leaf values in the returned structure or scalar are not type
+ // `bool`.
+ // """
+ // to_traverse = traverse_fn(structure)
+ // if not is_sequence(structure):
+ // if not isinstance(to_traverse, bool):
+ // raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
+ // % (to_traverse, structure))
+ // return to_traverse
+ // level_traverse = []
+ // if isinstance(to_traverse, bool):
+ // if not to_traverse:
+ //# Do not traverse this substructure at all. Exit early.
+ // return False
+ // else:
+ //# Traverse the entire substructure.
+ // for branch in _yield_value(structure):
+ // level_traverse.append(
+ // get_traverse_shallow_structure(traverse_fn, branch))
+ // elif not is_sequence(to_traverse):
+ // raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
+ // % (to_traverse, structure))
+ // else:
+ //# Traverse some subset of this substructure.
+ // assert_shallow_structure(to_traverse, structure)
+ // for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
+ // if not isinstance(t, bool):
+ // raise TypeError(
+ // "traverse_fn didn't return a depth=1 structure of bools. saw: %s "
+ // " for structure: %s" % (to_traverse, structure))
+ // if t:
+ // level_traverse.append(
+ // get_traverse_shallow_structure(traverse_fn, branch))
+ // else:
+ // level_traverse.append(False)
+ // return _sequence_like(structure, level_traverse)
+
+
+ //def yield_flat_paths(nest):
+ // """Yields paths for some nested structure.
+
+ // Paths are lists of objects which can be str-converted, which may include
+ // integers or other types which are used as indices in a dict.
+
+ // The flat list will be in the corresponding order as if you called
+ // `snt.nest.flatten` on the structure. This is handy for naming Tensors such
+ // the TF scope structure matches the tuple structure.
+
+ // E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
+
+ // ```shell
+ // >>> nest.flatten(value)
+ // [3, 23, 42]
+ // >>> list(nest.yield_flat_paths(value))
+ // [('a',), ('b', 'c'), ('b', 'd')]
+ // ```
+
+ // ```shell
+ // >>> list(nest.yield_flat_paths({'a': [3]}))
+ // [('a', 0)]
+ // >>> list(nest.yield_flat_paths({'a': 3}))
+ // [('a',)]
+ // ```
+
+ // Args:
+ // nest: the value to produce a flattened paths list for.
+
+ // Yields:
+ // Tuples containing index or key values which form the path to a specific
+ // leaf value in the nested structure.
+ // """
+
+ //# The _maybe_add_final_path_element function is used below in order to avoid
+ //# adding trailing slashes when the sub-element recursed into is a leaf.
+ // if isinstance(nest, (dict, _collections.Mapping)):
+ // for key in _sorted(nest):
+ // value = nest[key]
+ // for sub_path in yield_flat_paths(value):
+ // yield (key,) + sub_path
+ // elif _is_namedtuple(nest):
+ // for key in nest._fields:
+ // value = getattr(nest, key)
+ // for sub_path in yield_flat_paths(value):
+ // yield (key,) + sub_path
+ // elif isinstance(nest, _six.string_types):
+ // yield ()
+ // elif isinstance(nest, _collections.Sequence):
+ // for idx, value in enumerate(nest):
+ // for sub_path in yield_flat_paths(value):
+ // yield (idx,) + sub_path
+ // else:
+ // yield ()
+
+
+ //def flatten_with_joined_string_paths(structure, separator="/"):
+ // """Returns a list of (string path, data element) tuples.
+
+ // The order of tuples produced matches that of `nest.flatten`. This allows you
+ // to flatten a nested structure while keeping information about where in the
+ // structure each data element was located. See `nest.yield_flat_paths`
+ // for more information.
+
+ // Args:
+ // structure: the nested structure to flatten.
+ // separator: string to separate levels of hierarchy in the results, defaults
+ // to '/'.
+
+ // Returns:
+ // A list of (string, data element) tuples.
+ // """
+ // flat_paths = yield_flat_paths(structure)
+ // def stringify_and_join(path_elements):
+ // return separator.join(str(path_element) for path_element in path_elements)
+ // flat_string_paths = [stringify_and_join(path) for path in flat_paths]
+ // return list(zip(flat_string_paths, flatten(structure)))
+
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs
new file mode 100644
index 00000000..7b0d61ba
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs
@@ -0,0 +1,852 @@
+using System.Collections;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Newtonsoft.Json.Linq;
+using Tensorflow;
+using Tensorflow.Util;
+
+namespace TensorFlowNET.UnitTest.control_flow_ops_test
+{
+ ///
+ /// excerpt of tensorflow/python/framework/util/nest_test.py
+ ///
+ [TestClass]
+ public class NestTest : PythonTest
+ {
+
+ public class PointXY
+ {
+ public double x;
+ public double y;
+ }
+
+ // if attr:
+ // class BadAttr(object):
+ // """Class that has a non-iterable __attrs_attrs__."""
+ // __attrs_attrs__ = None
+
+ // @attr.s
+ // class SampleAttr(object):
+ // field1 = attr.ib()
+ // field2 = attr.ib()
+
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testAttrsFlattenAndPack(self) :
+ // if attr is None:
+ // self.skipTest("attr module is unavailable.")
+
+ // field_values = [1, 2]
+ // sample_attr = NestTest.SampleAttr(* field_values)
+ // self.assertFalse(nest._is_attrs(field_values))
+ // self.assertTrue(nest._is_attrs(sample_attr))
+ // flat = nest.flatten(sample_attr)
+ // self.assertEqual(field_values, flat)
+ // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
+ // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
+ // self.assertEqual(restructured_from_flat, sample_attr)
+
+ //# Check that flatten fails if attributes are not iterable
+ // with self.assertRaisesRegexp(TypeError, "object is not iterable"):
+ // flat = nest.flatten(NestTest.BadAttr())
+
+ [TestMethod]
+ public void testFlattenAndPack()
+ {
+ object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}};
+ var flat = new List {"a", "b", "c", "d", "e", "f", "g", "h"};
+
+ self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8});
+ self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
+ JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString());
+ structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }};
+ flat = new List { 4, 2, 1, 0};
+ self.assertEqual(nest.flatten(structure), flat);
+ // restructured_from_flat = nest.pack_sequence_as(structure, flat)
+ // self.assertEqual(restructured_from_flat, structure)
+ // self.assertEqual(restructured_from_flat[0].x, 4)
+ // self.assertEqual(restructured_from_flat[0].y, 2)
+ // self.assertEqual(restructured_from_flat[1][0][0].x, 1)
+ // self.assertEqual(restructured_from_flat[1][0][0].y, 0)
+
+ // self.assertEqual([5], nest.flatten(5))
+ // self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
+
+ // self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
+ // self.assertEqual(
+ // np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
+
+ // with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
+ // nest.pack_sequence_as("scalar", [4, 5])
+
+ // with self.assertRaisesRegexp(TypeError, "flat_sequence"):
+ // nest.pack_sequence_as([4, 5], "bad_sequence")
+
+ // with self.assertRaises(ValueError):
+ // nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+
+ }
+
+ // @parameterized.parameters({"mapping_type": collections.OrderedDict
+ // },
+ // {"mapping_type": _CustomMapping
+ //})
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testFlattenDictOrder(self, mapping_type) :
+ // """`flatten` orders dicts by key, including OrderedDicts."""
+ // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
+ // plain = {"d": 3, "b": 1, "a": 0, "c": 2}
+ // ordered_flat = nest.flatten(ordered)
+ // plain_flat = nest.flatten(plain)
+ // self.assertEqual([0, 1, 2, 3], ordered_flat)
+ // self.assertEqual([0, 1, 2, 3], plain_flat)
+
+ // @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ // {"mapping_type": _CustomMapping})
+ // def testPackDictOrder(self, mapping_type):
+ // """Packing orders dicts by key, including OrderedDicts."""
+ // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
+ // plain = {"d": 0, "b": 0, "a": 0, "c": 0}
+ // seq = [0, 1, 2, 3]
+ //custom_reconstruction = nest.pack_sequence_as(custom, seq)
+ //plain_reconstruction = nest.pack_sequence_as(plain, seq)
+ // self.assertIsInstance(custom_reconstruction, mapping_type)
+ // self.assertIsInstance(plain_reconstruction, dict)
+ // self.assertEqual(
+ // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
+ // custom_reconstruction)
+ // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
+
+ // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
+
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testFlattenAndPack_withDicts(self) :
+ // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
+ // mess = [
+ // "z",
+ // NestTest.Abc(3, 4), {
+ // "d": _CustomMapping({
+ // 41: 4
+ // }),
+ // "c": [
+ // 1,
+ // collections.OrderedDict([
+ // ("b", 3),
+ // ("a", 2),
+ // ]),
+ // ],
+ // "b": 5
+ // }, 17
+ // ]
+
+ // flattened = nest.flatten(mess)
+ // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
+
+ // structure_of_mess = [
+ // 14,
+ // NestTest.Abc("a", True),
+ // {
+ // "d": _CustomMapping({
+ // 41: 42
+ // }),
+ // "c": [
+ // 0,
+ // collections.OrderedDict([
+ // ("b", 9),
+ // ("a", 8),
+ // ]),
+ // ],
+ // "b": 3
+ // },
+ // "hi everybody",
+ // ]
+
+ // unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
+ // self.assertEqual(unflattened, mess)
+
+ // # Check also that the OrderedDict was created, with the correct key order.
+ //unflattened_ordered_dict = unflattened[2]["c"][1]
+ // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
+ // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
+
+ // unflattened_custom_mapping = unflattened[2]["d"]
+ // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
+ // self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
+
+ // def testFlatten_numpyIsNotFlattened(self):
+ // structure = np.array([1, 2, 3])
+ // flattened = nest.flatten(structure)
+ // self.assertEqual(len(flattened), 1)
+
+ // def testFlatten_stringIsNotFlattened(self):
+ // structure = "lots of letters"
+ // flattened = nest.flatten(structure)
+ // self.assertEqual(len(flattened), 1)
+ // unflattened = nest.pack_sequence_as("goodbye", flattened)
+ // self.assertEqual(structure, unflattened)
+
+ // def testPackSequenceAs_notIterableError(self) :
+ // with self.assertRaisesRegexp(TypeError,
+ // "flat_sequence must be a sequence"):
+ // nest.pack_sequence_as("hi", "bye")
+
+ // def testPackSequenceAs_wrongLengthsError(self):
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // "Structure had 2 elements, but flat_sequence had 3 elements."):
+ // nest.pack_sequence_as(["hello", "world"],
+ // ["and", "goodbye", "again"])
+
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testIsSequence(self):
+ // self.assertFalse(nest.is_sequence("1234"))
+ // self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
+ // self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
+ // self.assertTrue(nest.is_sequence([]))
+ // self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
+ // self.assertFalse(nest.is_sequence(set([1, 2])))
+ // ones = array_ops.ones([2, 3])
+ // self.assertFalse(nest.is_sequence(ones))
+ // self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
+ // self.assertFalse(nest.is_sequence(np.ones((4, 5))))
+
+ // @parameterized.parameters({"mapping_type": _CustomMapping},
+ // {"mapping_type": dict})
+ // def testFlattenDictItems(self, mapping_type):
+ // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
+ // flat = {4: "a", 5: "b", 6: "c", 8: "d"}
+ // self.assertEqual(nest.flatten_dict_items(dictionary), flat)
+
+ // with self.assertRaises(TypeError):
+ // nest.flatten_dict_items(4)
+
+ // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
+ // with self.assertRaisesRegexp(ValueError, "not unique"):
+ // nest.flatten_dict_items(bad_dictionary)
+
+ // another_bad_dictionary = mapping_type({
+ // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
+ // })
+ // with self.assertRaisesRegexp(
+ // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
+ // nest.flatten_dict_items(another_bad_dictionary)
+
+ //# pylint does not correctly recognize these as class names and
+ //# suggests to use variable style under_score naming.
+ //# pylint: disable=invalid-name
+ // Named0ab = collections.namedtuple("named_0", ("a", "b"))
+ // Named1ab = collections.namedtuple("named_1", ("a", "b"))
+ // SameNameab = collections.namedtuple("same_name", ("a", "b"))
+ // SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
+ // SameNamexy = collections.namedtuple("same_name", ("x", "y"))
+ // SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
+ // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
+ // NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
+ // # pylint: enable=invalid-name
+
+ // class SameNamedType1(SameNameab):
+ // pass
+
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testAssertSameStructure(self):
+ // structure1 = (((1, 2), 3), 4, (5, 6))
+ // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+ // structure_different_num_elements = ("spam", "eggs")
+ // structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
+ // nest.assert_same_structure(structure1, structure2)
+ // nest.assert_same_structure("abc", 1.0)
+ // nest.assert_same_structure("abc", np.array([0, 1]))
+ // nest.assert_same_structure("abc", constant_op.constant([0, 1]))
+
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // ("The two structures don't have the same nested structure\\.\n\n"
+ // "First structure:.*?\n\n"
+ // "Second structure:.*\n\n"
+ // "More specifically: Substructure "
+ // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
+ // 'substructure "type=str str=spam" is not\n'
+ // "Entire first structure:\n"
+ // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
+ // "Entire second structure:\n"
+ // r"\(\., \.\)")):
+ // nest.assert_same_structure(structure1, structure_different_num_elements)
+
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // ("The two structures don't have the same nested structure\\.\n\n"
+ // "First structure:.*?\n\n"
+ // "Second structure:.*\n\n"
+ // r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
+ // "is not")):
+ // nest.assert_same_structure([0, 1], np.array([0, 1]))
+
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // ("The two structures don't have the same nested structure\\.\n\n"
+ // "First structure:.*?\n\n"
+ // "Second structure:.*\n\n"
+ // r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ // 'is a sequence, while substructure "type=int str=0" '
+ // "is not")):
+ // nest.assert_same_structure(0, [0, 1])
+
+ // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
+
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // ("don't have the same nested structure\\.\n\n"
+ // "First structure: .*?\n\nSecond structure: ")):
+ // nest.assert_same_structure(structure1, structure_different_nesting)
+
+ // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
+ // NestTest.Named0ab("a", "b"))
+
+ // nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ // NestTest.Named0ab("a", "b"))
+
+ // self.assertRaises(TypeError, nest.assert_same_structure,
+ // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
+
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // ("don't have the same nested structure\\.\n\n"
+ // "First structure: .*?\n\nSecond structure: ")):
+ // nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ // NestTest.Named0ab([3], 4))
+
+ // with self.assertRaisesRegexp(
+ // ValueError,
+ // ("don't have the same nested structure\\.\n\n"
+ // "First structure: .*?\n\nSecond structure: ")):
+ // nest.assert_same_structure([[3], 4], [3, [4]])
+
+ // structure1_list = [[[1, 2], 3], 4, [5, 6]]
+ // with self.assertRaisesRegexp(TypeError,
+ // "don't have the same sequence type"):
+ // nest.assert_same_structure(structure1, structure1_list)
+ // nest.assert_same_structure(structure1, structure2, check_types= False)
+ // nest.assert_same_structure(structure1, structure1_list, check_types=False)
+
+ // with self.assertRaisesRegexp(ValueError,
+ // "don't have the same set of keys"):
+ // nest.assert_same_structure({"a": 1}, {"b": 1})
+
+ // nest.assert_same_structure(NestTest.SameNameab(0, 1),
+ // NestTest.SameNameab2(2, 3))
+
+ // # This assertion is expected to pass: two namedtuples with the same
+ // # name and field names are considered to be identical.
+ // nest.assert_same_structure(
+ // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
+ // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
+
+ // expected_message = "The two structures don't have the same.*"
+ // with self.assertRaisesRegexp(ValueError, expected_message):
+ // nest.assert_same_structure(
+ // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
+ // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
+
+ // self.assertRaises(TypeError, nest.assert_same_structure,
+ // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
+
+ // self.assertRaises(TypeError, nest.assert_same_structure,
+ // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
+
+ // self.assertRaises(TypeError, nest.assert_same_structure,
+ // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
+
+ // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
+
+ // def testHeterogeneousComparison(self):
+ // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
+ // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
+
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testMapStructure(self) :
+ // structure1 = (((1, 2), 3), 4, (5, 6))
+ // structure2 = (((7, 8), 9), 10, (11, 12))
+ // structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
+ // nest.assert_same_structure(structure1, structure1_plus1)
+ // self.assertAllEqual(
+ // [2, 3, 4, 5, 6, 7],
+ // nest.flatten(structure1_plus1))
+ // structure1_plus_structure2 = nest.map_structure(
+ // lambda x, y: x + y, structure1, structure2)
+ // self.assertEqual(
+ // (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
+ // structure1_plus_structure2)
+
+ // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
+
+ // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
+
+ // # Empty structures
+ // self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
+ // self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
+ // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
+ // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
+ // NestTest.EmptyNT()))
+
+ // # This is checking actual equality of types, empty list != empty tuple
+ // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
+
+ // with self.assertRaisesRegexp(TypeError, "callable"):
+ // nest.map_structure("bad", structure1_plus1)
+
+ // with self.assertRaisesRegexp(ValueError, "at least one structure"):
+ // nest.map_structure(lambda x: x)
+
+ // with self.assertRaisesRegexp(ValueError, "same number of elements"):
+ // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
+
+ // with self.assertRaisesRegexp(ValueError, "same nested structure"):
+ // nest.map_structure(lambda x, y: None, 3, (3,))
+
+ // with self.assertRaisesRegexp(TypeError, "same sequence type"):
+ // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
+
+ // with self.assertRaisesRegexp(ValueError, "same nested structure"):
+ // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
+
+ // structure1_list = [[[1, 2], 3], 4, [5, 6]]
+ // with self.assertRaisesRegexp(TypeError, "same sequence type"):
+ // nest.map_structure(lambda x, y: None, structure1, structure1_list)
+
+ // nest.map_structure(lambda x, y: None, structure1, structure1_list,
+ // check_types=False)
+
+ // with self.assertRaisesRegexp(ValueError, "same nested structure"):
+ // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
+ // check_types=False)
+
+ // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
+ // nest.map_structure(lambda x: None, structure1, foo="a")
+
+ // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
+ // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
+
+ // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
+
+ // @test_util.assert_no_new_pyobjects_executing_eagerly
+ // def testMapStructureWithStrings(self) :
+ // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
+ // inp_b = NestTest.ABTuple(a=2, b=(1, 3))
+ // out = nest.map_structure(lambda string, repeats: string* repeats,
+ // inp_a,
+ // inp_b)
+ // self.assertEqual("foofoo", out.a)
+ // self.assertEqual("bar", out.b[0])
+ // self.assertEqual("bazbazbaz", out.b[1])
+
+ // nt = NestTest.ABTuple(a=("something", "something_else"),
+ // b="yet another thing")
+ // rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
+ // # Check the output is the correct structure, and all strings are reversed.
+ // nest.assert_same_structure(nt, rev_nt)
+ // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
+ // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
+ // self.assertEqual(nt.b[::- 1], rev_nt.b)
+
+ // @test_util.run_deprecated_v1
+ // def testMapStructureOverPlaceholders(self) :
+ // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
+ // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
+ // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
+ // array_ops.placeholder(dtypes.float32, shape=[3, 7]))
+
+ // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
+
+ // nest.assert_same_structure(output, inp_a)
+ // self.assertShapeEqual(np.zeros((3, 4)), output[0])
+ // self.assertShapeEqual(np.zeros((3, 7)), output[1])
+
+ // feed_dict = {
+ // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
+ // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
+ // }
+
+ // with self.cached_session() as sess:
+ // output_np = sess.run(output, feed_dict=feed_dict)
+ // self.assertAllClose(output_np[0],
+ // feed_dict[inp_a][0] + feed_dict[inp_b][0])
+ // self.assertAllClose(output_np[1],
+ // feed_dict[inp_a][1] + feed_dict[inp_b][1])
+
+ // def testAssertShallowStructure(self):
+ // inp_ab = ["a", "b"]
+ //inp_abc = ["a", "b", "c"]
+ //expected_message = (
+ // "The two structures don't have the same sequence length. Input "
+ // "structure has length 2, while shallow structure has length 3.")
+ // with self.assertRaisesRegexp(ValueError, expected_message):
+ // nest.assert_shallow_structure(inp_abc, inp_ab)
+
+ // inp_ab1 = [(1, 1), (2, 2)]
+ // inp_ab2 = [[1, 1], [2, 2]]
+ // expected_message = (
+ // "The two structures don't have the same sequence type. Input structure "
+ // "has type <(type|class) 'tuple'>, while shallow structure has type "
+ // "<(type|class) 'list'>.")
+ // with self.assertRaisesRegexp(TypeError, expected_message):
+ // nest.assert_shallow_structure(inp_ab2, inp_ab1)
+ // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)
+
+ // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
+ // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
+ // expected_message = (
+ // r"The two structures don't have the same keys. Input "
+ // r"structure has keys \['c'\], while shallow structure has "
+ // r"keys \['d'\].")
+
+ // with self.assertRaisesRegexp(ValueError, expected_message):
+ // nest.assert_shallow_structure(inp_ab2, inp_ab1)
+
+ // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
+ // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
+ // nest.assert_shallow_structure(inp_ab, inp_ba)
+
+ // # This assertion is expected to pass: two namedtuples with the same
+ //# name and field names are considered to be identical.
+ //inp_shallow = NestTest.SameNameab(1, 2)
+ // inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
+ // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
+ // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
+
+ // def testFlattenUpTo(self):
+ // # Shallow tree ends at scalar.
+ // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
+ // shallow_tree = [[True, True], [False, True]]
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
+ // self.assertEqual(flattened_shallow_tree, [True, True, False, True])
+
+ //# Shallow tree ends at string.
+ // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
+ // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
+ // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ // input_tree)
+ // input_tree_flattened = nest.flatten(input_tree)
+ // self.assertEqual(input_tree_flattened_as_shallow_tree,
+ // [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
+ // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
+
+ // # Make sure dicts are correctly flattened, yielding values, not keys.
+ //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
+ // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
+ // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ // input_tree)
+ // self.assertEqual(input_tree_flattened_as_shallow_tree,
+ // [1, { "c": 2}, 3, (4, 5)])
+
+ // # Namedtuples.
+ // ab_tuple = NestTest.ABTuple
+ // input_tree = ab_tuple(a =[0, 1], b = 2)
+ // shallow_tree = ab_tuple(a= 0, b= 1)
+ // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ // input_tree)
+ // self.assertEqual(input_tree_flattened_as_shallow_tree,
+ // [[0, 1], 2])
+
+ // # Nested dicts, OrderedDicts and namedtuples.
+ // input_tree = collections.OrderedDict(
+ // [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
+ // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
+ // shallow_tree = input_tree
+ // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ // input_tree)
+ // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
+ // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
+ // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ // input_tree)
+ // self.assertEqual(input_tree_flattened_as_shallow_tree,
+ // [ab_tuple(a =[0, { "b": 1}], b=2),
+ // 3,
+ // collections.OrderedDict([("f", 4)])])
+ // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
+ // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ // input_tree)
+ // self.assertEqual(input_tree_flattened_as_shallow_tree,
+ // [ab_tuple(a =[0, {"b": 1}], b=2),
+ // {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
+
+ // ## Shallow non-list edge-case.
+ // # Using iterable elements.
+ // input_tree = ["input_tree"]
+ //shallow_tree = "shallow_tree"
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [input_tree])
+ // self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ // input_tree = ["input_tree_0", "input_tree_1"]
+ //shallow_tree = "shallow_tree"
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [input_tree])
+ // self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ // # Using non-iterable elements.
+ //input_tree = [0]
+ //shallow_tree = 9
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [input_tree])
+ // self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ // input_tree = [0, 1]
+ //shallow_tree = 9
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [input_tree])
+ // self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ // ## Both non-list edge-case.
+ //# Using iterable elements.
+ //input_tree = "input_tree"
+ // shallow_tree = "shallow_tree"
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [input_tree])
+ // self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ // # Using non-iterable elements.
+ //input_tree = 0
+ // shallow_tree = 0
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_input_tree, [input_tree])
+ // self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ // ## Input non-list edge-case.
+ //# Using iterable elements.
+ //input_tree = "input_tree"
+ // shallow_tree = ["shallow_tree"]
+ //expected_message = ("If shallow structure is a sequence, input must also "
+ // "be a sequence. Input has type: <(type|class) 'str'>.")
+ // with self.assertRaisesRegexp(TypeError, expected_message):
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ // input_tree = "input_tree"
+ // shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
+ //with self.assertRaisesRegexp(TypeError, expected_message):
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ //# Using non-iterable elements.
+ // input_tree = 0
+ // shallow_tree = [9]
+ //expected_message = ("If shallow structure is a sequence, input must also "
+ // "be a sequence. Input has type: <(type|class) 'int'>.")
+ // with self.assertRaisesRegexp(TypeError, expected_message):
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ // input_tree = 0
+ // shallow_tree = [9, 8]
+ //with self.assertRaisesRegexp(TypeError, expected_message):
+ // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ // self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ // def testMapStructureUpTo(self) :
+ // # Named tuples.
+ // ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+ // op_tuple = collections.namedtuple("op_tuple", "add, mul")
+ // inp_val = ab_tuple(a= 2, b= 3)
+ // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
+ // out = nest.map_structure_up_to(
+ // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
+ // self.assertEqual(out.a, 6)
+ // self.assertEqual(out.b, 15)
+
+ // # Lists.
+ // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
+ // name_list = ["evens", ["odds", "primes"]]
+ // out = nest.map_structure_up_to(
+ // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
+ // name_list, data_list)
+ // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
+
+ // # Dicts.
+ // inp_val = dict(a= 2, b= 3)
+ // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
+ // out = nest.map_structure_up_to(
+ // inp_val,
+ // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+ // self.assertEqual(out["a"], 6)
+ // self.assertEqual(out["b"], 15)
+
+ // # Non-equal dicts.
+ // inp_val = dict(a= 2, b= 3)
+ // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
+ // with self.assertRaisesRegexp(ValueError, "same keys"):
+ // nest.map_structure_up_to(
+ // inp_val,
+ // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+
+ // # Dict+custom mapping.
+ // inp_val = dict(a= 2, b= 3)
+ // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
+ // out = nest.map_structure_up_to(
+ // inp_val,
+ // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+ // self.assertEqual(out["a"], 6)
+ // self.assertEqual(out["b"], 15)
+
+ // # Non-equal dict/mapping.
+ // inp_val = dict(a= 2, b= 3)
+ // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
+ // with self.assertRaisesRegexp(ValueError, "same keys"):
+ // nest.map_structure_up_to(
+ // inp_val,
+ // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+
+ // def testGetTraverseShallowStructure(self):
+ // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
+ // scalar_traverse_r = nest.get_traverse_shallow_structure(
+ // lambda s: not isinstance(s, tuple),
+ // scalar_traverse_input)
+ // self.assertEqual(scalar_traverse_r,
+ // [True, True, False, [True, True], {"a": False}, []])
+ // nest.assert_shallow_structure(scalar_traverse_r,
+ // scalar_traverse_input)
+
+ // structure_traverse_input = [(1, [2]), ([1], 2)]
+ // structure_traverse_r = nest.get_traverse_shallow_structure(
+ // lambda s: (True, False) if isinstance(s, tuple) else True,
+ // structure_traverse_input)
+ // self.assertEqual(structure_traverse_r,
+ // [(True, False), ([True], False)])
+ // nest.assert_shallow_structure(structure_traverse_r,
+ // structure_traverse_input)
+
+ // with self.assertRaisesRegexp(TypeError, "returned structure"):
+ // nest.get_traverse_shallow_structure(lambda _: [True], 0)
+
+ // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
+ // nest.get_traverse_shallow_structure(lambda _: 1, [1])
+
+ // with self.assertRaisesRegexp(
+ // TypeError, "didn't return a depth=1 structure of bools"):
+ // nest.get_traverse_shallow_structure(lambda _: [1], [1])
+
+ // def testYieldFlatStringPaths(self):
+ // for inputs_expected in ({"inputs": [], "expected": []},
+ // {"inputs": 3, "expected": [()]},
+ // {"inputs": [3], "expected": [(0,)]},
+ // {"inputs": {"a": 3}, "expected": [("a",)]},
+ // {"inputs": {"a": {"b": 4}},
+ // "expected": [("a", "b")]},
+ // {"inputs": [{"a": 2}], "expected": [(0, "a")]},
+ // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
+ // {"inputs": [{"a": [(23, 42)]}],
+ // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
+ // {"inputs": [{"a": ([23], 42)}],
+ // "expected": [(0, "a", 0, 0), (0, "a", 1)]},
+ // {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
+ // "expected": [("a", "a"), ("c", 0, 0, 0)]},
+ // {"inputs": {"0": [{"1": 23}]},
+ // "expected": [("0", 0, "1")]}):
+ // inputs = inputs_expected["inputs"]
+ // expected = inputs_expected["expected"]
+ // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
+
+ // def testFlattenWithStringPaths(self):
+ // for inputs_expected in (
+ // {"inputs": [], "expected": []},
+ // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
+ // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
+ // inputs = inputs_expected["inputs"]
+ // expected = inputs_expected["expected"]
+ // self.assertEqual(
+ // nest.flatten_with_joined_string_paths(inputs, separator="/"),
+ // expected)
+
+ // # Need a separate test for namedtuple as we can't declare tuple definitions
+ // # in the @parameterized arguments.
+ // def testFlattenNamedTuple(self):
+ // # pylint: disable=invalid-name
+ // Foo = collections.namedtuple("Foo", ["a", "b"])
+ // Bar = collections.namedtuple("Bar", ["c", "d"])
+ // # pylint: enable=invalid-name
+ // test_cases = [
+ // (Foo(a = 3, b = Bar(c = 23, d = 42)),
+ // [("a", 3), ("b/c", 23), ("b/d", 42)]),
+ // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
+ // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
+ // (Bar(c = 42, d = 43),
+ // [("c", 42), ("d", 43)]),
+ // (Bar(c =[42], d = 43),
+ // [("c/0", 42), ("d", 43)]),
+ // ]
+ // for inputs, expected in test_cases:
+ // self.assertEqual(
+ // list(nest.flatten_with_joined_string_paths(inputs)), expected)
+
+ // @parameterized.named_parameters(
+ // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
+ // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
+ // {"a": ("a", 4), "b": ("b", 6)}),
+ // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
+ // ("nested",
+ // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
+ // {"a": [("a/0", 10), ("a/1", 12)],
+ // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
+ // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
+ // def format_sum(path, * values):
+ // return (path, sum(values))
+ // result = nest.map_structure_with_paths(format_sum, s1, s2,
+ // check_types=check_types)
+ // self.assertEqual(expected, result)
+
+ // @parameterized.named_parameters(
+ // ("tuples", (1, 2), (3, 4, 5), ValueError),
+ // ("dicts", {"a": 1}, {"b": 2}, ValueError),
+ // ("mixed", (1, 2), [3, 4], TypeError),
+ // ("nested",
+ // {"a": [2, 3], "b": [1, 3]},
+ // {"b": [5, 6, 7], "a": [8, 9]},
+ // ValueError
+ // ))
+ // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
+ // with self.assertRaises(error_type):
+ // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)
+
+
+ //class NestBenchmark(test.Benchmark):
+
+ // def run_and_report(self, s1, s2, name):
+ // burn_iter, test_iter = 100, 30000
+
+ // for _ in xrange(burn_iter) :
+ // nest.assert_same_structure(s1, s2)
+
+ // t0 = time.time()
+ // for _ in xrange(test_iter) :
+ // nest.assert_same_structure(s1, s2)
+ // t1 = time.time()
+
+ // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
+ // name=name)
+
+ // def benchmark_assert_structure(self):
+ // s1 = (((1, 2), 3), 4, (5, 6))
+ // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+ // self.run_and_report(s1, s2, "assert_same_structure_6_elem")
+
+ // s1 = (((1, 2), 3), 4, (5, 6)) * 10
+ // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
+ // self.run_and_report(s1, s2, "assert_same_structure_60_elem")
+
+
+ //if __name__ == "__main__":
+ // test.main()
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/nest_test/nest_test.py b/test/TensorFlowNET.UnitTest/nest_test/nest_test.py
new file mode 100644
index 00000000..d0d0c5f7
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/nest_test/nest_test.py
@@ -0,0 +1,883 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utilities working with arbitrarily nested structures."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import time
+
+from absl.testing import parameterized
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import nest
+
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
+
+class _CustomMapping(collections.Mapping):
+
+ def __init__(self, *args, **kwargs):
+ self._wrapped = dict(*args, **kwargs)
+
+ def __getitem__(self, key):
+ return self._wrapped[key]
+
+ def __iter__(self):
+ return iter(self._wrapped)
+
+ def __len__(self):
+ return len(self._wrapped)
+
+
+class NestTest(parameterized.TestCase, test.TestCase):
+
+ PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+
+ if attr:
+ class BadAttr(object):
+ """Class that has a non-iterable __attrs_attrs__."""
+ __attrs_attrs__ = None
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAttrsFlattenAndPack(self):
+ if attr is None:
+ self.skipTest("attr module is unavailable.")
+
+ field_values = [1, 2]
+ sample_attr = NestTest.SampleAttr(*field_values)
+ self.assertFalse(nest._is_attrs(field_values))
+ self.assertTrue(nest._is_attrs(sample_attr))
+ flat = nest.flatten(sample_attr)
+ self.assertEqual(field_values, flat)
+ restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
+ self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
+ self.assertEqual(restructured_from_flat, sample_attr)
+
+ # Check that flatten fails if attributes are not iterable
+ with self.assertRaisesRegexp(TypeError, "object is not iterable"):
+ flat = nest.flatten(NestTest.BadAttr())
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testFlattenAndPack(self):
+ structure = ((3, 4), 5, (6, 7, (9, 10), 8))
+ flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
+ self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
+ self.assertEqual(
+ nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
+ ("d", "e", ("f", "g"), "h")))
+ structure = (NestTest.PointXY(x=4, y=2),
+ ((NestTest.PointXY(x=1, y=0),),))
+ flat = [4, 2, 1, 0]
+ self.assertEqual(nest.flatten(structure), flat)
+ restructured_from_flat = nest.pack_sequence_as(structure, flat)
+ self.assertEqual(restructured_from_flat, structure)
+ self.assertEqual(restructured_from_flat[0].x, 4)
+ self.assertEqual(restructured_from_flat[0].y, 2)
+ self.assertEqual(restructured_from_flat[1][0][0].x, 1)
+ self.assertEqual(restructured_from_flat[1][0][0].y, 0)
+
+ self.assertEqual([5], nest.flatten(5))
+ self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
+
+ self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
+ self.assertEqual(
+ np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
+
+ with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
+ nest.pack_sequence_as("scalar", [4, 5])
+
+ with self.assertRaisesRegexp(TypeError, "flat_sequence"):
+ nest.pack_sequence_as([4, 5], "bad_sequence")
+
+ with self.assertRaises(ValueError):
+ nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testFlattenDictOrder(self, mapping_type):
+ """`flatten` orders dicts by key, including OrderedDicts."""
+ ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
+ plain = {"d": 3, "b": 1, "a": 0, "c": 2}
+ ordered_flat = nest.flatten(ordered)
+ plain_flat = nest.flatten(plain)
+ self.assertEqual([0, 1, 2, 3], ordered_flat)
+ self.assertEqual([0, 1, 2, 3], plain_flat)
+
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
+ def testPackDictOrder(self, mapping_type):
+ """Packing orders dicts by key, including OrderedDicts."""
+ custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
+ plain = {"d": 0, "b": 0, "a": 0, "c": 0}
+ seq = [0, 1, 2, 3]
+ custom_reconstruction = nest.pack_sequence_as(custom, seq)
+ plain_reconstruction = nest.pack_sequence_as(plain, seq)
+ self.assertIsInstance(custom_reconstruction, mapping_type)
+ self.assertIsInstance(plain_reconstruction, dict)
+ self.assertEqual(
+ mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
+ custom_reconstruction)
+ self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
+
+ Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testFlattenAndPack_withDicts(self):
+ # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
+ mess = [
+ "z",
+ NestTest.Abc(3, 4), {
+ "d": _CustomMapping({
+ 41: 4
+ }),
+ "c": [
+ 1,
+ collections.OrderedDict([
+ ("b", 3),
+ ("a", 2),
+ ]),
+ ],
+ "b": 5
+ }, 17
+ ]
+
+ flattened = nest.flatten(mess)
+ self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
+
+ structure_of_mess = [
+ 14,
+ NestTest.Abc("a", True),
+ {
+ "d": _CustomMapping({
+ 41: 42
+ }),
+ "c": [
+ 0,
+ collections.OrderedDict([
+ ("b", 9),
+ ("a", 8),
+ ]),
+ ],
+ "b": 3
+ },
+ "hi everybody",
+ ]
+
+ unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
+ self.assertEqual(unflattened, mess)
+
+ # Check also that the OrderedDict was created, with the correct key order.
+ unflattened_ordered_dict = unflattened[2]["c"][1]
+ self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
+ self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
+
+ unflattened_custom_mapping = unflattened[2]["d"]
+ self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
+ self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
+
+ def testFlatten_numpyIsNotFlattened(self):
+ structure = np.array([1, 2, 3])
+ flattened = nest.flatten(structure)
+ self.assertEqual(len(flattened), 1)
+
+ def testFlatten_stringIsNotFlattened(self):
+ structure = "lots of letters"
+ flattened = nest.flatten(structure)
+ self.assertEqual(len(flattened), 1)
+ unflattened = nest.pack_sequence_as("goodbye", flattened)
+ self.assertEqual(structure, unflattened)
+
+ def testPackSequenceAs_notIterableError(self):
+ with self.assertRaisesRegexp(TypeError,
+ "flat_sequence must be a sequence"):
+ nest.pack_sequence_as("hi", "bye")
+
+ def testPackSequenceAs_wrongLengthsError(self):
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Structure had 2 elements, but flat_sequence had 3 elements."):
+ nest.pack_sequence_as(["hello", "world"],
+ ["and", "goodbye", "again"])
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testIsSequence(self):
+ self.assertFalse(nest.is_sequence("1234"))
+ self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
+ self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
+ self.assertTrue(nest.is_sequence([]))
+ self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
+ self.assertFalse(nest.is_sequence(set([1, 2])))
+ ones = array_ops.ones([2, 3])
+ self.assertFalse(nest.is_sequence(ones))
+ self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
+ self.assertFalse(nest.is_sequence(np.ones((4, 5))))
+
+ @parameterized.parameters({"mapping_type": _CustomMapping},
+ {"mapping_type": dict})
+ def testFlattenDictItems(self, mapping_type):
+ dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
+ flat = {4: "a", 5: "b", 6: "c", 8: "d"}
+ self.assertEqual(nest.flatten_dict_items(dictionary), flat)
+
+ with self.assertRaises(TypeError):
+ nest.flatten_dict_items(4)
+
+ bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
+ with self.assertRaisesRegexp(ValueError, "not unique"):
+ nest.flatten_dict_items(bad_dictionary)
+
+ another_bad_dictionary = mapping_type({
+ (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
+ })
+ with self.assertRaisesRegexp(
+ ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
+ nest.flatten_dict_items(another_bad_dictionary)
+
+ # pylint does not correctly recognize these as class names and
+ # suggests to use variable style under_score naming.
+ # pylint: disable=invalid-name
+ Named0ab = collections.namedtuple("named_0", ("a", "b"))
+ Named1ab = collections.namedtuple("named_1", ("a", "b"))
+ SameNameab = collections.namedtuple("same_name", ("a", "b"))
+ SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
+ SameNamexy = collections.namedtuple("same_name", ("x", "y"))
+ SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
+ SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
+ NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
+ # pylint: enable=invalid-name
+
+ class SameNamedType1(SameNameab):
+ pass
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAssertSameStructure(self):
+ structure1 = (((1, 2), 3), 4, (5, 6))
+ structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+ structure_different_num_elements = ("spam", "eggs")
+ structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
+ nest.assert_same_structure(structure1, structure2)
+ nest.assert_same_structure("abc", 1.0)
+ nest.assert_same_structure("abc", np.array([0, 1]))
+ nest.assert_same_structure("abc", constant_op.constant([0, 1]))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ "More specifically: Substructure "
+ r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
+ 'substructure "type=str str=spam" is not\n'
+ "Entire first structure:\n"
+ r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
+ "Entire second structure:\n"
+ r"\(\., \.\)")):
+ nest.assert_same_structure(structure1, structure_different_num_elements)
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
+ "is not")):
+ nest.assert_same_structure([0, 1], np.array([0, 1]))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("The two structures don't have the same nested structure\\.\n\n"
+ "First structure:.*?\n\n"
+ "Second structure:.*\n\n"
+ r'More specifically: Substructure "type=list str=\[0, 1\]" '
+ 'is a sequence, while substructure "type=int str=0" '
+ "is not")):
+ nest.assert_same_structure(0, [0, 1])
+
+ self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("don't have the same nested structure\\.\n\n"
+ "First structure: .*?\n\nSecond structure: ")):
+ nest.assert_same_structure(structure1, structure_different_nesting)
+
+ self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
+ NestTest.Named0ab("a", "b"))
+
+ nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ NestTest.Named0ab("a", "b"))
+
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("don't have the same nested structure\\.\n\n"
+ "First structure: .*?\n\nSecond structure: ")):
+ nest.assert_same_structure(NestTest.Named0ab(3, 4),
+ NestTest.Named0ab([3], 4))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ ("don't have the same nested structure\\.\n\n"
+ "First structure: .*?\n\nSecond structure: ")):
+ nest.assert_same_structure([[3], 4], [3, [4]])
+
+ structure1_list = [[[1, 2], 3], 4, [5, 6]]
+ with self.assertRaisesRegexp(TypeError,
+ "don't have the same sequence type"):
+ nest.assert_same_structure(structure1, structure1_list)
+ nest.assert_same_structure(structure1, structure2, check_types=False)
+ nest.assert_same_structure(structure1, structure1_list, check_types=False)
+
+ with self.assertRaisesRegexp(ValueError,
+ "don't have the same set of keys"):
+ nest.assert_same_structure({"a": 1}, {"b": 1})
+
+ nest.assert_same_structure(NestTest.SameNameab(0, 1),
+ NestTest.SameNameab2(2, 3))
+
+ # This assertion is expected to pass: two namedtuples with the same
+ # name and field names are considered to be identical.
+ nest.assert_same_structure(
+ NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
+ NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
+
+ expected_message = "The two structures don't have the same.*"
+ with self.assertRaisesRegexp(ValueError, expected_message):
+ nest.assert_same_structure(
+ NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
+ NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
+
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
+
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
+
+ self.assertRaises(TypeError, nest.assert_same_structure,
+ NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
+
+ EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
+
+ def testHeterogeneousComparison(self):
+ nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
+ nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testMapStructure(self):
+ structure1 = (((1, 2), 3), 4, (5, 6))
+ structure2 = (((7, 8), 9), 10, (11, 12))
+ structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
+ nest.assert_same_structure(structure1, structure1_plus1)
+ self.assertAllEqual(
+ [2, 3, 4, 5, 6, 7],
+ nest.flatten(structure1_plus1))
+ structure1_plus_structure2 = nest.map_structure(
+ lambda x, y: x + y, structure1, structure2)
+ self.assertEqual(
+ (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
+ structure1_plus_structure2)
+
+ self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
+
+ self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
+
+ # Empty structures
+ self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
+ self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
+ self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
+ self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
+ NestTest.EmptyNT()))
+
+ # This is checking actual equality of types, empty list != empty tuple
+ self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
+
+ with self.assertRaisesRegexp(TypeError, "callable"):
+ nest.map_structure("bad", structure1_plus1)
+
+ with self.assertRaisesRegexp(ValueError, "at least one structure"):
+ nest.map_structure(lambda x: x)
+
+ with self.assertRaisesRegexp(ValueError, "same number of elements"):
+ nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
+
+ with self.assertRaisesRegexp(ValueError, "same nested structure"):
+ nest.map_structure(lambda x, y: None, 3, (3,))
+
+ with self.assertRaisesRegexp(TypeError, "same sequence type"):
+ nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
+
+ with self.assertRaisesRegexp(ValueError, "same nested structure"):
+ nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
+
+ structure1_list = [[[1, 2], 3], 4, [5, 6]]
+ with self.assertRaisesRegexp(TypeError, "same sequence type"):
+ nest.map_structure(lambda x, y: None, structure1, structure1_list)
+
+ nest.map_structure(lambda x, y: None, structure1, structure1_list,
+ check_types=False)
+
+ with self.assertRaisesRegexp(ValueError, "same nested structure"):
+ nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
+ check_types=False)
+
+ with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
+ nest.map_structure(lambda x: None, structure1, foo="a")
+
+ with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
+ nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
+
+ ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testMapStructureWithStrings(self):
+ inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
+ inp_b = NestTest.ABTuple(a=2, b=(1, 3))
+ out = nest.map_structure(lambda string, repeats: string * repeats,
+ inp_a,
+ inp_b)
+ self.assertEqual("foofoo", out.a)
+ self.assertEqual("bar", out.b[0])
+ self.assertEqual("bazbazbaz", out.b[1])
+
+ nt = NestTest.ABTuple(a=("something", "something_else"),
+ b="yet another thing")
+ rev_nt = nest.map_structure(lambda x: x[::-1], nt)
+ # Check the output is the correct structure, and all strings are reversed.
+ nest.assert_same_structure(nt, rev_nt)
+ self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
+ self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
+ self.assertEqual(nt.b[::-1], rev_nt.b)
+
+ @test_util.run_deprecated_v1
+ def testMapStructureOverPlaceholders(self):
+ inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
+ array_ops.placeholder(dtypes.float32, shape=[3, 7]))
+ inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
+ array_ops.placeholder(dtypes.float32, shape=[3, 7]))
+
+ output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
+
+ nest.assert_same_structure(output, inp_a)
+ self.assertShapeEqual(np.zeros((3, 4)), output[0])
+ self.assertShapeEqual(np.zeros((3, 7)), output[1])
+
+ feed_dict = {
+ inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
+ inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
+ }
+
+ with self.cached_session() as sess:
+ output_np = sess.run(output, feed_dict=feed_dict)
+ self.assertAllClose(output_np[0],
+ feed_dict[inp_a][0] + feed_dict[inp_b][0])
+ self.assertAllClose(output_np[1],
+ feed_dict[inp_a][1] + feed_dict[inp_b][1])
+
+ def testAssertShallowStructure(self):
+ inp_ab = ["a", "b"]
+ inp_abc = ["a", "b", "c"]
+ expected_message = (
+ "The two structures don't have the same sequence length. Input "
+ "structure has length 2, while shallow structure has length 3.")
+ with self.assertRaisesRegexp(ValueError, expected_message):
+ nest.assert_shallow_structure(inp_abc, inp_ab)
+
+ inp_ab1 = [(1, 1), (2, 2)]
+ inp_ab2 = [[1, 1], [2, 2]]
+ expected_message = (
+ "The two structures don't have the same sequence type. Input structure "
+ "has type <(type|class) 'tuple'>, while shallow structure has type "
+ "<(type|class) 'list'>.")
+ with self.assertRaisesRegexp(TypeError, expected_message):
+ nest.assert_shallow_structure(inp_ab2, inp_ab1)
+ nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
+
+ inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
+ inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
+ expected_message = (
+ r"The two structures don't have the same keys. Input "
+ r"structure has keys \['c'\], while shallow structure has "
+ r"keys \['d'\].")
+
+ with self.assertRaisesRegexp(ValueError, expected_message):
+ nest.assert_shallow_structure(inp_ab2, inp_ab1)
+
+ inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
+ inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
+ nest.assert_shallow_structure(inp_ab, inp_ba)
+
+ # This assertion is expected to pass: two namedtuples with the same
+ # name and field names are considered to be identical.
+ inp_shallow = NestTest.SameNameab(1, 2)
+ inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
+ nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
+ nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
+
+ def testFlattenUpTo(self):
+ # Shallow tree ends at scalar.
+ input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
+ shallow_tree = [[True, True], [False, True]]
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
+ self.assertEqual(flattened_shallow_tree, [True, True, False, True])
+
+ # Shallow tree ends at string.
+ input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
+ shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
+ input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ input_tree)
+ input_tree_flattened = nest.flatten(input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
+ self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
+
+ # Make sure dicts are correctly flattened, yielding values, not keys.
+ input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
+ shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
+ input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [1, {"c": 2}, 3, (4, 5)])
+
+ # Namedtuples.
+ ab_tuple = NestTest.ABTuple
+ input_tree = ab_tuple(a=[0, 1], b=2)
+ shallow_tree = ab_tuple(a=0, b=1)
+ input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [[0, 1], 2])
+
+ # Nested dicts, OrderedDicts and namedtuples.
+ input_tree = collections.OrderedDict(
+ [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
+ ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
+ shallow_tree = input_tree
+ input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
+ shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
+ input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [ab_tuple(a=[0, {"b": 1}], b=2),
+ 3,
+ collections.OrderedDict([("f", 4)])])
+ shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
+ input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [ab_tuple(a=[0, {"b": 1}], b=2),
+ {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
+
+ ## Shallow non-list edge-case.
+ # Using iterable elements.
+ input_tree = ["input_tree"]
+ shallow_tree = "shallow_tree"
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ input_tree = ["input_tree_0", "input_tree_1"]
+ shallow_tree = "shallow_tree"
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ # Using non-iterable elements.
+ input_tree = [0]
+ shallow_tree = 9
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ input_tree = [0, 1]
+ shallow_tree = 9
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ ## Both non-list edge-case.
+ # Using iterable elements.
+ input_tree = "input_tree"
+ shallow_tree = "shallow_tree"
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ # Using non-iterable elements.
+ input_tree = 0
+ shallow_tree = 0
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ ## Input non-list edge-case.
+ # Using iterable elements.
+ input_tree = "input_tree"
+ shallow_tree = ["shallow_tree"]
+ expected_message = ("If shallow structure is a sequence, input must also "
+ "be a sequence. Input has type: <(type|class) 'str'>.")
+ with self.assertRaisesRegexp(TypeError, expected_message):
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ input_tree = "input_tree"
+ shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
+ with self.assertRaisesRegexp(TypeError, expected_message):
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ # Using non-iterable elements.
+ input_tree = 0
+ shallow_tree = [9]
+ expected_message = ("If shallow structure is a sequence, input must also "
+ "be a sequence. Input has type: <(type|class) 'int'>.")
+ with self.assertRaisesRegexp(TypeError, expected_message):
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ input_tree = 0
+ shallow_tree = [9, 8]
+ with self.assertRaisesRegexp(TypeError, expected_message):
+ flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
+ flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ def testMapStructureUpTo(self):
+ # Named tuples.
+ ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+ op_tuple = collections.namedtuple("op_tuple", "add, mul")
+ inp_val = ab_tuple(a=2, b=3)
+ inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
+ out = nest.map_structure_up_to(
+ inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
+ self.assertEqual(out.a, 6)
+ self.assertEqual(out.b, 15)
+
+ # Lists.
+ data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
+ name_list = ["evens", ["odds", "primes"]]
+ out = nest.map_structure_up_to(
+ name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
+ name_list, data_list)
+ self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
+
+ # Dicts.
+ inp_val = dict(a=2, b=3)
+ inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
+ out = nest.map_structure_up_to(
+ inp_val,
+ lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+ self.assertEqual(out["a"], 6)
+ self.assertEqual(out["b"], 15)
+
+ # Non-equal dicts.
+ inp_val = dict(a=2, b=3)
+ inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
+ with self.assertRaisesRegexp(ValueError, "same keys"):
+ nest.map_structure_up_to(
+ inp_val,
+ lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+
+ # Dict+custom mapping.
+ inp_val = dict(a=2, b=3)
+ inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
+ out = nest.map_structure_up_to(
+ inp_val,
+ lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+ self.assertEqual(out["a"], 6)
+ self.assertEqual(out["b"], 15)
+
+ # Non-equal dict/mapping.
+ inp_val = dict(a=2, b=3)
+ inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
+ with self.assertRaisesRegexp(ValueError, "same keys"):
+ nest.map_structure_up_to(
+ inp_val,
+ lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
+
+ def testGetTraverseShallowStructure(self):
+ scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
+ scalar_traverse_r = nest.get_traverse_shallow_structure(
+ lambda s: not isinstance(s, tuple),
+ scalar_traverse_input)
+ self.assertEqual(scalar_traverse_r,
+ [True, True, False, [True, True], {"a": False}, []])
+ nest.assert_shallow_structure(scalar_traverse_r,
+ scalar_traverse_input)
+
+ structure_traverse_input = [(1, [2]), ([1], 2)]
+ structure_traverse_r = nest.get_traverse_shallow_structure(
+ lambda s: (True, False) if isinstance(s, tuple) else True,
+ structure_traverse_input)
+ self.assertEqual(structure_traverse_r,
+ [(True, False), ([True], False)])
+ nest.assert_shallow_structure(structure_traverse_r,
+ structure_traverse_input)
+
+ with self.assertRaisesRegexp(TypeError, "returned structure"):
+ nest.get_traverse_shallow_structure(lambda _: [True], 0)
+
+ with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
+ nest.get_traverse_shallow_structure(lambda _: 1, [1])
+
+ with self.assertRaisesRegexp(
+ TypeError, "didn't return a depth=1 structure of bools"):
+ nest.get_traverse_shallow_structure(lambda _: [1], [1])
+
+ def testYieldFlatStringPaths(self):
+ for inputs_expected in ({"inputs": [], "expected": []},
+ {"inputs": 3, "expected": [()]},
+ {"inputs": [3], "expected": [(0,)]},
+ {"inputs": {"a": 3}, "expected": [("a",)]},
+ {"inputs": {"a": {"b": 4}},
+ "expected": [("a", "b")]},
+ {"inputs": [{"a": 2}], "expected": [(0, "a")]},
+ {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
+ {"inputs": [{"a": [(23, 42)]}],
+ "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
+ {"inputs": [{"a": ([23], 42)}],
+ "expected": [(0, "a", 0, 0), (0, "a", 1)]},
+ {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
+ "expected": [("a", "a"), ("c", 0, 0, 0)]},
+ {"inputs": {"0": [{"1": 23}]},
+ "expected": [("0", 0, "1")]}):
+ inputs = inputs_expected["inputs"]
+ expected = inputs_expected["expected"]
+ self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
+
+ def testFlattenWithStringPaths(self):
+ for inputs_expected in (
+ {"inputs": [], "expected": []},
+ {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
+ {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
+ inputs = inputs_expected["inputs"]
+ expected = inputs_expected["expected"]
+ self.assertEqual(
+ nest.flatten_with_joined_string_paths(inputs, separator="/"),
+ expected)
+
+ # Need a separate test for namedtuple as we can't declare tuple definitions
+ # in the @parameterized arguments.
+ def testFlattenNamedTuple(self):
+ # pylint: disable=invalid-name
+ Foo = collections.namedtuple("Foo", ["a", "b"])
+ Bar = collections.namedtuple("Bar", ["c", "d"])
+ # pylint: enable=invalid-name
+ test_cases = [
+ (Foo(a=3, b=Bar(c=23, d=42)),
+ [("a", 3), ("b/c", 23), ("b/d", 42)]),
+ (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
+ [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
+ (Bar(c=42, d=43),
+ [("c", 42), ("d", 43)]),
+ (Bar(c=[42], d=43),
+ [("c/0", 42), ("d", 43)]),
+ ]
+ for inputs, expected in test_cases:
+ self.assertEqual(
+ list(nest.flatten_with_joined_string_paths(inputs)), expected)
+
+ @parameterized.named_parameters(
+ ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
+ ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
+ {"a": ("a", 4), "b": ("b", 6)}),
+ ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
+ ("nested",
+ {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
+ {"a": [("a/0", 10), ("a/1", 12)],
+ "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
+ def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
+ def format_sum(path, *values):
+ return (path, sum(values))
+ result = nest.map_structure_with_paths(format_sum, s1, s2,
+ check_types=check_types)
+ self.assertEqual(expected, result)
+
+ @parameterized.named_parameters(
+ ("tuples", (1, 2), (3, 4, 5), ValueError),
+ ("dicts", {"a": 1}, {"b": 2}, ValueError),
+ ("mixed", (1, 2), [3, 4], TypeError),
+ ("nested",
+ {"a": [2, 3], "b": [1, 3]},
+ {"b": [5, 6, 7], "a": [8, 9]},
+ ValueError
+ ))
+ def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
+ with self.assertRaises(error_type):
+ nest.map_structure_with_paths(lambda path, *s: 0, s1, s2)
+
+
+class NestBenchmark(test.Benchmark):
+
+ def run_and_report(self, s1, s2, name):
+ burn_iter, test_iter = 100, 30000
+
+ for _ in xrange(burn_iter):
+ nest.assert_same_structure(s1, s2)
+
+ t0 = time.time()
+ for _ in xrange(test_iter):
+ nest.assert_same_structure(s1, s2)
+ t1 = time.time()
+
+ self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
+ name=name)
+
+ def benchmark_assert_structure(self):
+ s1 = (((1, 2), 3), 4, (5, 6))
+ s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
+ self.run_and_report(s1, s2, "assert_same_structure_6_elem")
+
+ s1 = (((1, 2), 3), 4, (5, 6)) * 10
+ s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
+ self.run_and_report(s1, s2, "assert_same_structure_60_elem")
+
+
+if __name__ == "__main__":
+ test.main()