diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs new file mode 100644 index 00000000..2d43b976 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -0,0 +1,871 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using NumSharp; + +namespace Tensorflow.Util +{ + //Functions for working with arbitrarily nested sequences of elements. + + //This module can perform operations on nested structures. A nested structure is a + //Python sequence, tuple (including `namedtuple`), or dict that can contain + //further sequences, tuples, and dicts. + + //The utilities here assume (and do not check) that the nested structures form a + //'tree', i.e., no references in the structure of the input of these functions + //should be recursive. + + //Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), + // (np.array([3, 4]), tf.constant([3, 4])))` + // + + public static class nest + { + + + //def _get_attrs_values(obj): + // """Returns the list of values from an attrs instance.""" + // attrs = getattr(obj.__class__, "__attrs_attrs__") + // return [getattr(obj, a.name) for a in attrs] + + /// + /// Returns a sorted list of the dict keys, with error if keys not sortable. + /// + private static IEnumerable _sorted(IDictionary dict_) + { + return dict_.Keys.OfType().OrderBy(x => x); + } + + + //def _is_namedtuple(instance, strict=False): + // """Returns True iff `instance` is a `namedtuple`. + + // Args: + // instance: An instance of a Python object. + // strict: If True, `instance` is considered to be a `namedtuple` only if + // it is a "plain" namedtuple. For instance, a class inheriting + // from a `namedtuple` will be considered to be a `namedtuple` + // iff `strict=False`. + + // Returns: + // True if `instance` is a `namedtuple`. + // """ + // return _pywrap_tensorflow.IsNamedtuple(instance, strict) + + + //# See the swig file (util.i) for documentation. + //_is_mapping = _pywrap_tensorflow.IsMapping + //_is_attrs = _pywrap_tensorflow.IsAttrs + + /// + /// Converts the sequence `args` to the same type as `instance`. + /// + /// an instance of `tuple`, `list`, `namedtuple`, `dict`, or + /// `collections.OrderedDict`. + /// elements to be converted to the `instance` type. + /// `args` with the type of `instance`. + private static object _sequence_like(object instance, IEnumerable args) + { + if (is_mapping(instance)) + { + //# Pack dictionaries in a deterministic order by sorting the keys. + //# Notice this means that we ignore the original order of `OrderedDict` + //# instances. This is intentional, to avoid potential bugs caused by mixing + //# ordered and plain dicts (e.g., flattening a dict but using a + //# corresponding `OrderedDict` to pack it back). + // result = dict(zip(_sorted(instance), args)) + // return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) + } + //else if( _is_namedtuple(instance) || _is_attrs(instance)) + // return type(instance)(*args) + else + { + // Not a namedtuple + switch (instance) + { + case object[] array: + var result_array = new object[args.Count()]; + int i = 0; + foreach (var x in args) + { + result_array[i] = x; + i++; + } + return result_array; + case List list: + return new List(args); + default: + throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); + } + } + throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); + } + + /// + /// Yields the next value from the given iterable. + /// + private static IEnumerable _yield_value(object iterable) + { + if (is_mapping(iterable)) + { + var dict = iterable as IDictionary; + //# Iterate through dictionaries in a deterministic order by sorting the + //# keys. Notice this means that we ignore the original order of `OrderedDict` + //# instances. This is intentional, to avoid potential bugs caused by mixing + //# ordered and plain dicts (e.g., flattening a dict but using a + //# corresponding `OrderedDict` to pack it back). + foreach (var key in _sorted(dict)) + yield return dict[key]; + } + //else if (_is_attrs(iterable)) + //{ + // // for value in _get_attrs_values(iterable): + // // yield value + //} + else if (iterable is IEnumerable) + { + var enumerable = iterable as IEnumerable; + foreach (var value in enumerable) + yield return value; + } + else + { + throw new TypeError("Unexpected iterable type: " + iterable.GetType()); + //var jobj = JObject.FromObject(iterable); + //foreach (var key in _sorted()) + // yield return jobj[key]; + } + } + + //# See the swig file (util.i) for documentation. + public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string); + + public static bool is_mapping(object arg) => arg is IDictionary; + + //# See the swig file (util.i) for documentation. + //flatten = _pywrap_tensorflow.Flatten + + public static List flatten(object structure) + { + var list = new List(); + _flatten_recursive(structure, list); + return list; + } + + private static void _flatten_recursive(object obj, List list) + { + if (obj is string) + { + list.Add(obj); + return; + } + if (obj is IDictionary) + { + var dict = obj as IDictionary; + foreach (var key in _sorted(dict)) + _flatten_recursive(dict[key], list); + return; + } + if (obj is NDArray) + { + list.Add(obj); + return; + } + if (obj is IEnumerable) + { + var structure = obj as IEnumerable; + foreach (var child in structure) + _flatten_recursive(child, list); + return; + } + list.Add(obj); + } + + + //# See the swig file (util.i) for documentation. + //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples + + + //class _DotString(object): + + // def __str__(self): + // return "." + + // def __repr__(self): + // return "." + + + //_DOT = _DotString() + + + //def assert_same_structure(nest1, nest2, check_types=True): + // """Asserts that two structures are nested in the same way. + + // Note that namedtuples with identical name and fields are always considered + // to have the same shallow structure (even with `check_types=True`). + // For intance, this code will print `True`: + + // ```python + // def nt(a, b): + // return collections.namedtuple('foo', 'a b')(a, b) + // print(assert_same_structure(nt(0, 1), nt(2, 3))) + // ``` + + // Args: + // nest1: an arbitrarily nested structure. + // nest2: an arbitrarily nested structure. + // check_types: if `True` (default) types of sequences are checked as well, + // including the keys of dictionaries. If set to `False`, for example a + // list and a tuple of objects will look the same if they have the same + // size. Note that namedtuples with identical name and fields are always + // considered to have the same shallow structure. Two types will also be + // considered the same if they are both list subtypes (which allows "list" + // and "_ListWrapper" from checkpointable dependency tracking to compare + // equal). + + // Raises: + // ValueError: If the two structures do not have the same number of elements or + // if the two structures are not nested in the same way. + // TypeError: If the two structures differ in the type of sequence in any of + // their substructures. Only possible if `check_types` is `True`. + // """ + // try: + // _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) + // except (ValueError, TypeError) as e: + // str1 = str(map_structure(lambda _: _DOT, nest1)) + // str2 = str(map_structure(lambda _: _DOT, nest2)) + // raise type(e)("%s\n" + // "Entire first structure:\n%s\n" + // "Entire second structure:\n%s" + // % (str(e), str1, str2)) + + + //def flatten_dict_items(dictionary): + // """Returns a dictionary with flattened keys and values. + + // This function flattens the keys and values of a dictionary, which can be + // arbitrarily nested structures, and returns the flattened version of such + // structures: + + // ```python + // example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} + // result = {4: "a", 5: "b", 6: "c", 8: "d"} + // flatten_dict_items(example_dictionary) == result + // ``` + + // The input dictionary must satisfy two properties: + + // 1. Its keys and values should have the same exact nested structure. + // 2. The set of all flattened keys of the dictionary must not contain repeated + // keys. + + // Args: + // dictionary: the dictionary to zip + + // Returns: + // The zipped dictionary. + + // Raises: + // TypeError: If the input is not a dictionary. + // ValueError: If any key and value have not the same structure, or if keys are + // not unique. + // """ + // if not isinstance(dictionary, (dict, _collections.Mapping)): + // raise TypeError("input must be a dictionary") + // flat_dictionary = {} + // for i, v in _six.iteritems(dictionary): + // if not is_sequence(i): + // if i in flat_dictionary: + // raise ValueError( + // "Could not flatten dictionary: key %s is not unique." % i) + // flat_dictionary[i] = v + // else: + // flat_i = flatten(i) + // flat_v = flatten(v) + // if len(flat_i) != len(flat_v): + // raise ValueError( + // "Could not flatten dictionary. Key had %d elements, but value had " + // "%d elements. Key: %s, value: %s." + // % (len(flat_i), len(flat_v), flat_i, flat_v)) + // for new_i, new_v in zip(flat_i, flat_v): + // if new_i in flat_dictionary: + // raise ValueError( + // "Could not flatten dictionary: key %s is not unique." + // % (new_i)) + // flat_dictionary[new_i] = new_v + // return flat_dictionary + + /// + /// Helper function for pack_sequence_as. + /// + /// Substructure (list / tuple / dict) to mimic. + /// Flattened values to output substructure for. + /// Index at which to start reading from flat. + /// + /// The tuple(new_index, child), where: + /// * new_index - the updated index into `flat` having processed `structure`. + /// * packed - the subset of `flat` corresponding to `structure`, + /// having started at `index`, and packed into the same nested + /// format. + private static (int new_index, List child) _packed_nest_with_indices(object structure, List flat, + int index) + { + var packed = new List(); + foreach (var s in _yield_value(structure)) + { + if (is_sequence(s)) + { + var (new_index, child) = _packed_nest_with_indices(s, flat, index); + packed.Add(_sequence_like(s, child)); + index = new_index; + } + else + { + packed.Add(flat[index]); + index += 1; + } + } + return (index, packed); + } + + private static int len(IEnumerable x) => x.Count(); + + /// + /// Returns a given flattened sequence packed into a given structure. + /// If `structure` is a scalar, `flat_sequence` must be a single-element list; + /// in this case the return value is `flat_sequence[0]`. + /// + /// If `structure` is or contains a dict instance, the keys will be sorted to + /// pack the flat sequence in deterministic order. This is true also for + /// `OrderedDict` instances: their sequence order is ignored, the sorting order of + /// keys is used instead. The same convention is followed in `flatten`. + /// This correctly repacks dicts and `OrderedDict`s after they have been + /// flattened, and also allows flattening an `OrderedDict` and then repacking it + /// back using a corresponding plain dict, or vice-versa. + /// Dictionaries with non-sortable keys cannot be flattened. + /// + /// + /// Nested structure, whose structure is given by nested lists, + /// tuples, and dicts. Note: numpy arrays and strings are considered + /// scalars. + /// + /// flat sequence to pack. + /// `flat_sequence` converted to have the same recursive structure as + /// `structure`. + /// + public static object pack_sequence_as(object structure, List flat_sequence) + { + if (flat_sequence == null) + throw new ArgumentException("flat_sequence must not be null"); + // if not is_sequence(flat_sequence): + // raise TypeError("flat_sequence must be a sequence") + + if (!is_sequence(structure)) + { + if (len(flat_sequence) != 1) + throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1"); + return flat_sequence.FirstOrDefault(); + } + int final_index = 0; + List packed = null; + try + { + (final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0); + if (final_index < len(flat_sequence)) + throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }"); + } + catch (IndexOutOfRangeException) + { + var flat_structure = flatten(structure); + if (len(flat_structure) != len(flat_sequence)) + { + throw new ValueError("Could not pack sequence. Structure had %d elements, but " + + $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}"); + } + return _sequence_like(structure, packed); + } + return packed; + } + + /// + /// Applies `func` to each entry in `structure` and returns a new structure. + /// + /// Applies `func(x[0], x[1], ...)` where x[i] is an entry in + /// `structure[i]`. All structures in `structure` must have the same arity, + /// and the return value will contain the results in the same structure. + /// + /// + /// + /// A callable that accepts as many arguments as there are structures. + /// scalar, or tuple or list of constructed scalars and/or other + /// tuples/lists, or scalars. Note: numpy arrays are considered as scalars. + /// If set to + /// `True` (default) the types of iterables within the structures have to be + /// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` + /// exception). To allow this set this argument to `False`. + /// Note that namedtuples with identical name and fields are always + /// considered to have the same shallow structure. + /// + /// A new structure with the same arity as `structure`, whose values correspond + /// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding + /// location in `structure[i]`. If there are different sequence types and + /// `check_types` is `False` the sequence types of the first structure will be + /// used. + /// + public static IEnumerable map_structure(Func func, IEnumerable structure, bool check_types = false) + { + + // for other in structure[1:]: + // assert_same_structure(structure[0], other, check_types=check_types) + + // flat_structure = [flatten(s) for s in structure] + // entries = zip(*flat_structure) + + // return pack_sequence_as( + // structure[0], [func(*x) for x in entries]) + return null; + } + + //def map_structure_with_paths(func, *structure, **kwargs): + // """Applies `func` to each entry in `structure` and returns a new structure. + + // Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in + // `structure[i]` and `path` is the common path to x[i] in the structures. All + // structures in `structure` must have the same arity, and the return value will + // contain the results in the same structure. Special kwarg `check_types` + // determines whether the types of iterables within the structure must be the + // same-- see **kwargs definition below. + + // Args: + // func: A callable with the signature func(path, *values, **kwargs) that is + // evaluated on the leaves of the structure. + // *structure: A variable number of compatible structures to process. + // **kwargs: Optional kwargs to be passed through to func. Special kwarg + // `check_types` is not passed to func, but instead determines whether the + // types of iterables within the structures have to be same (e.g., + // `map_structure(func, [1], (1,))` raises a `TypeError` exception). By + // default, the types must match. To allow iteration over structures of + // different types (but common arity), set this kwarg to `False`. + + // Returns: + // A structure of the same form as the input structures whose leaves are the + // result of evaluating func on corresponding leaves of the input structures. + + // Raises: + // TypeError: If `func` is not callable or if the structures do not match + // each other by depth tree. + // TypeError: If `check_types` is not `False` and the two structures differ in + // the type of sequence in any of their substructures. + // ValueError: If no structures are provided. + // """ + // if not callable(func): + // raise TypeError("func must be callable, got: %s" % func) + // if not structure: + // raise ValueError("Must provide at least one structure") + + // check_types = kwargs.pop("check_types", True) + // for other in structure[1:]: + // assert_same_structure(structure[0], other, check_types=check_types) + + //# First set paths_and_values to: + //# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] + // paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] + + //# Now zip(*paths_and_values) would be: + //# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] + //# so grouped_by_path is set to: + //# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] + //# Note that p1i, ... pmi must all be equal since the structures are the same. + // grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] + + // return pack_sequence_as(structure[0], [ + // func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) + + + //def _yield_flat_up_to(shallow_tree, input_tree): + // """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" + // if is_sequence(shallow_tree): + // for shallow_branch, input_branch in zip(_yield_value(shallow_tree), + // _yield_value(input_tree)): + // for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): + // yield input_leaf + // else: + // yield input_tree + + + //def assert_shallow_structure(shallow_tree, input_tree, check_types=True): + // """Asserts that `shallow_tree` is a shallow structure of `input_tree`. + + // That is, this function tests if the `input_tree` structure can be created from + // the `shallow_tree` structure by replacing its leaf nodes with deeper + // tree structures. + + // Examples: + + // The following code will raise an exception: + // ```python + // shallow_tree = ["a", "b"] + // input_tree = ["c", ["d", "e"], "f"] + // assert_shallow_structure(shallow_tree, input_tree) + // ``` + + // The following code will not raise an exception: + // ```python + // shallow_tree = ["a", "b"] + // input_tree = ["c", ["d", "e"]] + // assert_shallow_structure(shallow_tree, input_tree) + // ``` + + // Args: + // shallow_tree: an arbitrarily nested structure. + // input_tree: an arbitrarily nested structure. + // check_types: if `True` (default) the sequence types of `shallow_tree` and + // `input_tree` have to be the same. Note that even with check_types==True, + // this function will consider two different namedtuple classes with the same + // name and _fields attribute to be the same class. + + // Raises: + // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + // TypeError: If the sequence types of `shallow_tree` are different from + // `input_tree`. Only raised if `check_types` is `True`. + // ValueError: If the sequence lengths of `shallow_tree` are different from + // `input_tree`. + // """ + // if is_sequence(shallow_tree): + // if not is_sequence(input_tree): + // raise TypeError( + // "If shallow structure is a sequence, input must also be a sequence. " + // "Input has type: %s." % type(input_tree)) + + // if check_types and not isinstance(input_tree, type(shallow_tree)): + //# Duck-typing means that nest should be fine with two different + //# namedtuples with identical name and fields. + // shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) + // input_is_namedtuple = _is_namedtuple(input_tree, False) + // if shallow_is_namedtuple and input_is_namedtuple: + // if not _same_namedtuples(shallow_tree, input_tree): + // raise TypeError( + // "The two namedtuples don't have the same sequence type. Input " + // "structure has type %s, while shallow structure has type %s." + // % (type(input_tree), type(shallow_tree))) + // elif not (isinstance(shallow_tree, _collections.Mapping) + // and isinstance(input_tree, _collections.Mapping)): + // raise TypeError( + // "The two structures don't have the same sequence type. Input " + // "structure has type %s, while shallow structure has type %s." + // % (type(input_tree), type(shallow_tree))) + + // if len(input_tree) != len(shallow_tree): + // raise ValueError( + // "The two structures don't have the same sequence length. Input " + // "structure has length %s, while shallow structure has length %s." + // % (len(input_tree), len(shallow_tree))) + + // if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): + // if set(input_tree) != set(shallow_tree): + // raise ValueError( + // "The two structures don't have the same keys. Input " + // "structure has keys %s, while shallow structure has keys %s." % + // (list(_six.iterkeys(input_tree)), + // list(_six.iterkeys(shallow_tree)))) + + // input_tree = list(sorted(_six.iteritems(input_tree))) + // shallow_tree = list(sorted(_six.iteritems(shallow_tree))) + + // for shallow_branch, input_branch in zip(shallow_tree, input_tree): + // assert_shallow_structure(shallow_branch, input_branch, + // check_types=check_types) + + + //def flatten_up_to(shallow_tree, input_tree): + // """Flattens `input_tree` up to `shallow_tree`. + + // Any further depth in structure in `input_tree` is retained as elements in the + // partially flatten output. + + // If `shallow_tree` and `input_tree` are not sequences, this returns a + // single-element list: `[input_tree]`. + + // Use Case: + + // Sometimes we may wish to partially flatten a nested sequence, retaining some + // of the nested structure. We achieve this by specifying a shallow structure, + // `shallow_tree`, we wish to flatten up to. + + // The input, `input_tree`, can be thought of as having the same structure as + // `shallow_tree`, but with leaf nodes that are themselves tree structures. + + // Examples: + + // ```python + // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + // shallow_tree = [[True, True], [False, True]] + + // flattened_input_tree = flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) + + //# Output is: + //# [[2, 2], [3, 3], [4, 9], [5, 5]] + //# [True, True, False, True] + // ``` + + // ```python + // input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] + // shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] + + // input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) + // input_tree_flattened = flatten(input_tree) + + //# Output is: + //# [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + //# ['a', 1, 'b', 2, 'c', 3, 'd', 4] + // ``` + + // Non-Sequence Edge Cases: + + // ```python + // flatten_up_to(0, 0) # Output: [0] + // flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] + // flatten_up_to([0, 1, 2], 0) # Output: TypeError + // flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] + // ``` + + // Args: + // shallow_tree: a possibly pruned structure of input_tree. + // input_tree: an arbitrarily nested structure or a scalar object. + // Note, numpy arrays are considered scalars. + + // Returns: + // A Python list, the partially flattened version of `input_tree` according to + // the structure of `shallow_tree`. + + // Raises: + // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + // TypeError: If the sequence types of `shallow_tree` are different from + // `input_tree`. + // ValueError: If the sequence lengths of `shallow_tree` are different from + // `input_tree`. + // """ + // assert_shallow_structure(shallow_tree, input_tree) + // return list(_yield_flat_up_to(shallow_tree, input_tree)) + + + //def map_structure_up_to(shallow_tree, func, *inputs): + // """Applies a function or op to a number of partially flattened inputs. + + // The `inputs` are flattened up to `shallow_tree` before being mapped. + + // Use Case: + + // Sometimes we wish to apply a function to a partially flattened + // sequence (for example when the function itself takes sequence inputs). We + // achieve this by specifying a shallow structure, `shallow_tree` we wish to + // flatten up to. + + // The `inputs`, can be thought of as having the same structure as + // `shallow_tree`, but with leaf nodes that are themselves tree structures. + + // This function therefore will return something with the same base structure as + // `shallow_tree`. + + // Examples: + + // ```python + // ab_tuple = collections.namedtuple("ab_tuple", "a, b") + // op_tuple = collections.namedtuple("op_tuple", "add, mul") + // inp_val = ab_tuple(a=2, b=3) + // inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) + // out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, + // inp_val, inp_ops) + + //# Output is: ab_tuple(a=6, b=15) + // ``` + + // ```python + // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + // name_list = ['evens', ['odds', 'primes']] + // out = map_structure_up_to( + // name_list, + // lambda name, sec: "first_{}_{}".format(len(sec), name), + // name_list, data_list) + + //# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] + // ``` + + // Args: + // shallow_tree: a shallow tree, common to all the inputs. + // func: callable which will be applied to each input individually. + // *inputs: arbitrarily nested combination of objects that are compatible with + // shallow_tree. The function `func` is applied to corresponding + // partially flattened elements of each input, so the function must support + // arity of `len(inputs)`. + + // Raises: + // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. + // TypeError: If the sequence types of `shallow_tree` are different from + // `input_tree`. + // ValueError: If the sequence lengths of `shallow_tree` are different from + // `input_tree`. + + // Returns: + // result of repeatedly applying `func`, with same structure as + // `shallow_tree`. + // """ + // if not inputs: + // raise ValueError("Cannot map over no sequences") + // for input_tree in inputs: + // assert_shallow_structure(shallow_tree, input_tree) + + //# Flatten each input separately, apply the function to corresponding elements, + //# then repack based on the structure of the first input. + // all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) + // for input_tree in inputs] + // results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] + // return pack_sequence_as(structure=shallow_tree, flat_sequence=results) + + + //def get_traverse_shallow_structure(traverse_fn, structure): + // """Generates a shallow structure from a `traverse_fn` and `structure`. + + // `traverse_fn` must accept any possible subtree of `structure` and return + // a depth=1 structure containing `True` or `False` values, describing which + // of the top-level subtrees may be traversed. It may also + // return scalar `True` or `False` "traversal is OK / not OK for all subtrees." + + // Examples are available in the unit tests (nest_test.py). + + // Args: + // traverse_fn: Function taking a substructure and returning either a scalar + // `bool` (whether to traverse that substructure or not) or a depth=1 + // shallow structure of the same type, describing which parts of the + // substructure to traverse. + // structure: The structure to traverse. + + // Returns: + // A shallow structure containing python bools, which can be passed to + // `map_structure_up_to` and `flatten_up_to`. + + // Raises: + // TypeError: if `traverse_fn` returns a sequence for a non-sequence input, + // or a structure with depth higher than 1 for a sequence input, + // or if any leaf values in the returned structure or scalar are not type + // `bool`. + // """ + // to_traverse = traverse_fn(structure) + // if not is_sequence(structure): + // if not isinstance(to_traverse, bool): + // raise TypeError("traverse_fn returned structure: %s for non-structure: %s" + // % (to_traverse, structure)) + // return to_traverse + // level_traverse = [] + // if isinstance(to_traverse, bool): + // if not to_traverse: + //# Do not traverse this substructure at all. Exit early. + // return False + // else: + //# Traverse the entire substructure. + // for branch in _yield_value(structure): + // level_traverse.append( + // get_traverse_shallow_structure(traverse_fn, branch)) + // elif not is_sequence(to_traverse): + // raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" + // % (to_traverse, structure)) + // else: + //# Traverse some subset of this substructure. + // assert_shallow_structure(to_traverse, structure) + // for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): + // if not isinstance(t, bool): + // raise TypeError( + // "traverse_fn didn't return a depth=1 structure of bools. saw: %s " + // " for structure: %s" % (to_traverse, structure)) + // if t: + // level_traverse.append( + // get_traverse_shallow_structure(traverse_fn, branch)) + // else: + // level_traverse.append(False) + // return _sequence_like(structure, level_traverse) + + + //def yield_flat_paths(nest): + // """Yields paths for some nested structure. + + // Paths are lists of objects which can be str-converted, which may include + // integers or other types which are used as indices in a dict. + + // The flat list will be in the corresponding order as if you called + // `snt.nest.flatten` on the structure. This is handy for naming Tensors such + // the TF scope structure matches the tuple structure. + + // E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` + + // ```shell + // >>> nest.flatten(value) + // [3, 23, 42] + // >>> list(nest.yield_flat_paths(value)) + // [('a',), ('b', 'c'), ('b', 'd')] + // ``` + + // ```shell + // >>> list(nest.yield_flat_paths({'a': [3]})) + // [('a', 0)] + // >>> list(nest.yield_flat_paths({'a': 3})) + // [('a',)] + // ``` + + // Args: + // nest: the value to produce a flattened paths list for. + + // Yields: + // Tuples containing index or key values which form the path to a specific + // leaf value in the nested structure. + // """ + + //# The _maybe_add_final_path_element function is used below in order to avoid + //# adding trailing slashes when the sub-element recursed into is a leaf. + // if isinstance(nest, (dict, _collections.Mapping)): + // for key in _sorted(nest): + // value = nest[key] + // for sub_path in yield_flat_paths(value): + // yield (key,) + sub_path + // elif _is_namedtuple(nest): + // for key in nest._fields: + // value = getattr(nest, key) + // for sub_path in yield_flat_paths(value): + // yield (key,) + sub_path + // elif isinstance(nest, _six.string_types): + // yield () + // elif isinstance(nest, _collections.Sequence): + // for idx, value in enumerate(nest): + // for sub_path in yield_flat_paths(value): + // yield (idx,) + sub_path + // else: + // yield () + + + //def flatten_with_joined_string_paths(structure, separator="/"): + // """Returns a list of (string path, data element) tuples. + + // The order of tuples produced matches that of `nest.flatten`. This allows you + // to flatten a nested structure while keeping information about where in the + // structure each data element was located. See `nest.yield_flat_paths` + // for more information. + + // Args: + // structure: the nested structure to flatten. + // separator: string to separate levels of hierarchy in the results, defaults + // to '/'. + + // Returns: + // A list of (string, data element) tuples. + // """ + // flat_paths = yield_flat_paths(structure) + // def stringify_and_join(path_elements): + // return separator.join(str(path_element) for path_element in path_elements) + // flat_string_paths = [stringify_and_join(path) for path in flat_paths] + // return list(zip(flat_string_paths, flatten(structure))) + + + } +} diff --git a/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs new file mode 100644 index 00000000..7b0d61ba --- /dev/null +++ b/test/TensorFlowNET.UnitTest/nest_test/NestTest.cs @@ -0,0 +1,852 @@ +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json.Linq; +using Tensorflow; +using Tensorflow.Util; + +namespace TensorFlowNET.UnitTest.control_flow_ops_test +{ + /// + /// excerpt of tensorflow/python/framework/util/nest_test.py + /// + [TestClass] + public class NestTest : PythonTest + { + + public class PointXY + { + public double x; + public double y; + } + + // if attr: + // class BadAttr(object): + // """Class that has a non-iterable __attrs_attrs__.""" + // __attrs_attrs__ = None + + // @attr.s + // class SampleAttr(object): + // field1 = attr.ib() + // field2 = attr.ib() + + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testAttrsFlattenAndPack(self) : + // if attr is None: + // self.skipTest("attr module is unavailable.") + + // field_values = [1, 2] + // sample_attr = NestTest.SampleAttr(* field_values) + // self.assertFalse(nest._is_attrs(field_values)) + // self.assertTrue(nest._is_attrs(sample_attr)) + // flat = nest.flatten(sample_attr) + // self.assertEqual(field_values, flat) + // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) + // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) + // self.assertEqual(restructured_from_flat, sample_attr) + + //# Check that flatten fails if attributes are not iterable + // with self.assertRaisesRegexp(TypeError, "object is not iterable"): + // flat = nest.flatten(NestTest.BadAttr()) + + [TestMethod] + public void testFlattenAndPack() + { + object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}}; + var flat = new List {"a", "b", "c", "d", "e", "f", "g", "h"}; + + self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8}); + self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), + JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString()); + structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }}; + flat = new List { 4, 2, 1, 0}; + self.assertEqual(nest.flatten(structure), flat); + // restructured_from_flat = nest.pack_sequence_as(structure, flat) + // self.assertEqual(restructured_from_flat, structure) + // self.assertEqual(restructured_from_flat[0].x, 4) + // self.assertEqual(restructured_from_flat[0].y, 2) + // self.assertEqual(restructured_from_flat[1][0][0].x, 1) + // self.assertEqual(restructured_from_flat[1][0][0].y, 0) + + // self.assertEqual([5], nest.flatten(5)) + // self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) + + // self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) + // self.assertEqual( + // np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) + + // with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): + // nest.pack_sequence_as("scalar", [4, 5]) + + // with self.assertRaisesRegexp(TypeError, "flat_sequence"): + // nest.pack_sequence_as([4, 5], "bad_sequence") + + // with self.assertRaises(ValueError): + // nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + + } + + // @parameterized.parameters({"mapping_type": collections.OrderedDict + // }, + // {"mapping_type": _CustomMapping + //}) + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testFlattenDictOrder(self, mapping_type) : + // """`flatten` orders dicts by key, including OrderedDicts.""" + // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) + // plain = {"d": 3, "b": 1, "a": 0, "c": 2} + // ordered_flat = nest.flatten(ordered) + // plain_flat = nest.flatten(plain) + // self.assertEqual([0, 1, 2, 3], ordered_flat) + // self.assertEqual([0, 1, 2, 3], plain_flat) + + // @parameterized.parameters({"mapping_type": collections.OrderedDict}, + // {"mapping_type": _CustomMapping}) + // def testPackDictOrder(self, mapping_type): + // """Packing orders dicts by key, including OrderedDicts.""" + // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) + // plain = {"d": 0, "b": 0, "a": 0, "c": 0} + // seq = [0, 1, 2, 3] + //custom_reconstruction = nest.pack_sequence_as(custom, seq) + //plain_reconstruction = nest.pack_sequence_as(plain, seq) + // self.assertIsInstance(custom_reconstruction, mapping_type) + // self.assertIsInstance(plain_reconstruction, dict) + // self.assertEqual( + // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), + // custom_reconstruction) + // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) + + // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name + + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testFlattenAndPack_withDicts(self) : + // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. + // mess = [ + // "z", + // NestTest.Abc(3, 4), { + // "d": _CustomMapping({ + // 41: 4 + // }), + // "c": [ + // 1, + // collections.OrderedDict([ + // ("b", 3), + // ("a", 2), + // ]), + // ], + // "b": 5 + // }, 17 + // ] + + // flattened = nest.flatten(mess) + // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) + + // structure_of_mess = [ + // 14, + // NestTest.Abc("a", True), + // { + // "d": _CustomMapping({ + // 41: 42 + // }), + // "c": [ + // 0, + // collections.OrderedDict([ + // ("b", 9), + // ("a", 8), + // ]), + // ], + // "b": 3 + // }, + // "hi everybody", + // ] + + // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) + // self.assertEqual(unflattened, mess) + + // # Check also that the OrderedDict was created, with the correct key order. + //unflattened_ordered_dict = unflattened[2]["c"][1] + // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) + // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) + + // unflattened_custom_mapping = unflattened[2]["d"] + // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) + // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) + + // def testFlatten_numpyIsNotFlattened(self): + // structure = np.array([1, 2, 3]) + // flattened = nest.flatten(structure) + // self.assertEqual(len(flattened), 1) + + // def testFlatten_stringIsNotFlattened(self): + // structure = "lots of letters" + // flattened = nest.flatten(structure) + // self.assertEqual(len(flattened), 1) + // unflattened = nest.pack_sequence_as("goodbye", flattened) + // self.assertEqual(structure, unflattened) + + // def testPackSequenceAs_notIterableError(self) : + // with self.assertRaisesRegexp(TypeError, + // "flat_sequence must be a sequence"): + // nest.pack_sequence_as("hi", "bye") + + // def testPackSequenceAs_wrongLengthsError(self): + // with self.assertRaisesRegexp( + // ValueError, + // "Structure had 2 elements, but flat_sequence had 3 elements."): + // nest.pack_sequence_as(["hello", "world"], + // ["and", "goodbye", "again"]) + + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testIsSequence(self): + // self.assertFalse(nest.is_sequence("1234")) + // self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) + // self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) + // self.assertTrue(nest.is_sequence([])) + // self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) + // self.assertFalse(nest.is_sequence(set([1, 2]))) + // ones = array_ops.ones([2, 3]) + // self.assertFalse(nest.is_sequence(ones)) + // self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) + // self.assertFalse(nest.is_sequence(np.ones((4, 5)))) + + // @parameterized.parameters({"mapping_type": _CustomMapping}, + // {"mapping_type": dict}) + // def testFlattenDictItems(self, mapping_type): + // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) + // flat = {4: "a", 5: "b", 6: "c", 8: "d"} + // self.assertEqual(nest.flatten_dict_items(dictionary), flat) + + // with self.assertRaises(TypeError): + // nest.flatten_dict_items(4) + + // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) + // with self.assertRaisesRegexp(ValueError, "not unique"): + // nest.flatten_dict_items(bad_dictionary) + + // another_bad_dictionary = mapping_type({ + // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) + // }) + // with self.assertRaisesRegexp( + // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): + // nest.flatten_dict_items(another_bad_dictionary) + + //# pylint does not correctly recognize these as class names and + //# suggests to use variable style under_score naming. + //# pylint: disable=invalid-name + // Named0ab = collections.namedtuple("named_0", ("a", "b")) + // Named1ab = collections.namedtuple("named_1", ("a", "b")) + // SameNameab = collections.namedtuple("same_name", ("a", "b")) + // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) + // SameNamexy = collections.namedtuple("same_name", ("x", "y")) + // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) + // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) + // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) + // # pylint: enable=invalid-name + + // class SameNamedType1(SameNameab): + // pass + + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testAssertSameStructure(self): + // structure1 = (((1, 2), 3), 4, (5, 6)) + // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + // structure_different_num_elements = ("spam", "eggs") + // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) + // nest.assert_same_structure(structure1, structure2) + // nest.assert_same_structure("abc", 1.0) + // nest.assert_same_structure("abc", np.array([0, 1])) + // nest.assert_same_structure("abc", constant_op.constant([0, 1])) + + // with self.assertRaisesRegexp( + // ValueError, + // ("The two structures don't have the same nested structure\\.\n\n" + // "First structure:.*?\n\n" + // "Second structure:.*\n\n" + // "More specifically: Substructure " + // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' + // 'substructure "type=str str=spam" is not\n' + // "Entire first structure:\n" + // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" + // "Entire second structure:\n" + // r"\(\., \.\)")): + // nest.assert_same_structure(structure1, structure_different_num_elements) + + // with self.assertRaisesRegexp( + // ValueError, + // ("The two structures don't have the same nested structure\\.\n\n" + // "First structure:.*?\n\n" + // "Second structure:.*\n\n" + // r'More specifically: Substructure "type=list str=\[0, 1\]" ' + // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' + // "is not")): + // nest.assert_same_structure([0, 1], np.array([0, 1])) + + // with self.assertRaisesRegexp( + // ValueError, + // ("The two structures don't have the same nested structure\\.\n\n" + // "First structure:.*?\n\n" + // "Second structure:.*\n\n" + // r'More specifically: Substructure "type=list str=\[0, 1\]" ' + // 'is a sequence, while substructure "type=int str=0" ' + // "is not")): + // nest.assert_same_structure(0, [0, 1]) + + // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) + + // with self.assertRaisesRegexp( + // ValueError, + // ("don't have the same nested structure\\.\n\n" + // "First structure: .*?\n\nSecond structure: ")): + // nest.assert_same_structure(structure1, structure_different_nesting) + + // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), + // NestTest.Named0ab("a", "b")) + + // nest.assert_same_structure(NestTest.Named0ab(3, 4), + // NestTest.Named0ab("a", "b")) + + // self.assertRaises(TypeError, nest.assert_same_structure, + // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) + + // with self.assertRaisesRegexp( + // ValueError, + // ("don't have the same nested structure\\.\n\n" + // "First structure: .*?\n\nSecond structure: ")): + // nest.assert_same_structure(NestTest.Named0ab(3, 4), + // NestTest.Named0ab([3], 4)) + + // with self.assertRaisesRegexp( + // ValueError, + // ("don't have the same nested structure\\.\n\n" + // "First structure: .*?\n\nSecond structure: ")): + // nest.assert_same_structure([[3], 4], [3, [4]]) + + // structure1_list = [[[1, 2], 3], 4, [5, 6]] + // with self.assertRaisesRegexp(TypeError, + // "don't have the same sequence type"): + // nest.assert_same_structure(structure1, structure1_list) + // nest.assert_same_structure(structure1, structure2, check_types= False) + // nest.assert_same_structure(structure1, structure1_list, check_types=False) + + // with self.assertRaisesRegexp(ValueError, + // "don't have the same set of keys"): + // nest.assert_same_structure({"a": 1}, {"b": 1}) + + // nest.assert_same_structure(NestTest.SameNameab(0, 1), + // NestTest.SameNameab2(2, 3)) + + // # This assertion is expected to pass: two namedtuples with the same + // # name and field names are considered to be identical. + // nest.assert_same_structure( + // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), + // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) + + // expected_message = "The two structures don't have the same.*" + // with self.assertRaisesRegexp(ValueError, expected_message): + // nest.assert_same_structure( + // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), + // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) + + // self.assertRaises(TypeError, nest.assert_same_structure, + // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) + + // self.assertRaises(TypeError, nest.assert_same_structure, + // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) + + // self.assertRaises(TypeError, nest.assert_same_structure, + // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) + + // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name + + // def testHeterogeneousComparison(self): + // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) + // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) + + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testMapStructure(self) : + // structure1 = (((1, 2), 3), 4, (5, 6)) + // structure2 = (((7, 8), 9), 10, (11, 12)) + // structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) + // nest.assert_same_structure(structure1, structure1_plus1) + // self.assertAllEqual( + // [2, 3, 4, 5, 6, 7], + // nest.flatten(structure1_plus1)) + // structure1_plus_structure2 = nest.map_structure( + // lambda x, y: x + y, structure1, structure2) + // self.assertEqual( + // (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), + // structure1_plus_structure2) + + // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) + + // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) + + // # Empty structures + // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) + // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) + // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) + // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, + // NestTest.EmptyNT())) + + // # This is checking actual equality of types, empty list != empty tuple + // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) + + // with self.assertRaisesRegexp(TypeError, "callable"): + // nest.map_structure("bad", structure1_plus1) + + // with self.assertRaisesRegexp(ValueError, "at least one structure"): + // nest.map_structure(lambda x: x) + + // with self.assertRaisesRegexp(ValueError, "same number of elements"): + // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) + + // with self.assertRaisesRegexp(ValueError, "same nested structure"): + // nest.map_structure(lambda x, y: None, 3, (3,)) + + // with self.assertRaisesRegexp(TypeError, "same sequence type"): + // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) + + // with self.assertRaisesRegexp(ValueError, "same nested structure"): + // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) + + // structure1_list = [[[1, 2], 3], 4, [5, 6]] + // with self.assertRaisesRegexp(TypeError, "same sequence type"): + // nest.map_structure(lambda x, y: None, structure1, structure1_list) + + // nest.map_structure(lambda x, y: None, structure1, structure1_list, + // check_types=False) + + // with self.assertRaisesRegexp(ValueError, "same nested structure"): + // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), + // check_types=False) + + // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + // nest.map_structure(lambda x: None, structure1, foo="a") + + // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") + + // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name + + // @test_util.assert_no_new_pyobjects_executing_eagerly + // def testMapStructureWithStrings(self) : + // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) + // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) + // out = nest.map_structure(lambda string, repeats: string* repeats, + // inp_a, + // inp_b) + // self.assertEqual("foofoo", out.a) + // self.assertEqual("bar", out.b[0]) + // self.assertEqual("bazbazbaz", out.b[1]) + + // nt = NestTest.ABTuple(a=("something", "something_else"), + // b="yet another thing") + // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) + // # Check the output is the correct structure, and all strings are reversed. + // nest.assert_same_structure(nt, rev_nt) + // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) + // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) + // self.assertEqual(nt.b[::- 1], rev_nt.b) + + // @test_util.run_deprecated_v1 + // def testMapStructureOverPlaceholders(self) : + // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), + // array_ops.placeholder(dtypes.float32, shape=[3, 7])) + // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), + // array_ops.placeholder(dtypes.float32, shape=[3, 7])) + + // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) + + // nest.assert_same_structure(output, inp_a) + // self.assertShapeEqual(np.zeros((3, 4)), output[0]) + // self.assertShapeEqual(np.zeros((3, 7)), output[1]) + + // feed_dict = { + // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), + // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) + // } + + // with self.cached_session() as sess: + // output_np = sess.run(output, feed_dict=feed_dict) + // self.assertAllClose(output_np[0], + // feed_dict[inp_a][0] + feed_dict[inp_b][0]) + // self.assertAllClose(output_np[1], + // feed_dict[inp_a][1] + feed_dict[inp_b][1]) + + // def testAssertShallowStructure(self): + // inp_ab = ["a", "b"] + //inp_abc = ["a", "b", "c"] + //expected_message = ( + // "The two structures don't have the same sequence length. Input " + // "structure has length 2, while shallow structure has length 3.") + // with self.assertRaisesRegexp(ValueError, expected_message): + // nest.assert_shallow_structure(inp_abc, inp_ab) + + // inp_ab1 = [(1, 1), (2, 2)] + // inp_ab2 = [[1, 1], [2, 2]] + // expected_message = ( + // "The two structures don't have the same sequence type. Input structure " + // "has type <(type|class) 'tuple'>, while shallow structure has type " + // "<(type|class) 'list'>.") + // with self.assertRaisesRegexp(TypeError, expected_message): + // nest.assert_shallow_structure(inp_ab2, inp_ab1) + // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) + + // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} + // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} + // expected_message = ( + // r"The two structures don't have the same keys. Input " + // r"structure has keys \['c'\], while shallow structure has " + // r"keys \['d'\].") + + // with self.assertRaisesRegexp(ValueError, expected_message): + // nest.assert_shallow_structure(inp_ab2, inp_ab1) + + // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) + // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) + // nest.assert_shallow_structure(inp_ab, inp_ba) + + // # This assertion is expected to pass: two namedtuples with the same + //# name and field names are considered to be identical. + //inp_shallow = NestTest.SameNameab(1, 2) + // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) + // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) + // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) + + // def testFlattenUpTo(self): + // # Shallow tree ends at scalar. + // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + // shallow_tree = [[True, True], [False, True]] + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) + // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) + + //# Shallow tree ends at string. + // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] + // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] + // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + // input_tree) + // input_tree_flattened = nest.flatten(input_tree) + // self.assertEqual(input_tree_flattened_as_shallow_tree, + // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) + // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) + + // # Make sure dicts are correctly flattened, yielding values, not keys. + //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} + // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} + // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + // input_tree) + // self.assertEqual(input_tree_flattened_as_shallow_tree, + // [1, { "c": 2}, 3, (4, 5)]) + + // # Namedtuples. + // ab_tuple = NestTest.ABTuple + // input_tree = ab_tuple(a =[0, 1], b = 2) + // shallow_tree = ab_tuple(a= 0, b= 1) + // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + // input_tree) + // self.assertEqual(input_tree_flattened_as_shallow_tree, + // [[0, 1], 2]) + + // # Nested dicts, OrderedDicts and namedtuples. + // input_tree = collections.OrderedDict( + // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), + // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) + // shallow_tree = input_tree + // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + // input_tree) + // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) + // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) + // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + // input_tree) + // self.assertEqual(input_tree_flattened_as_shallow_tree, + // [ab_tuple(a =[0, { "b": 1}], b=2), + // 3, + // collections.OrderedDict([("f", 4)])]) + // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) + // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + // input_tree) + // self.assertEqual(input_tree_flattened_as_shallow_tree, + // [ab_tuple(a =[0, {"b": 1}], b=2), + // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) + + // ## Shallow non-list edge-case. + // # Using iterable elements. + // input_tree = ["input_tree"] + //shallow_tree = "shallow_tree" + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [input_tree]) + // self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + // input_tree = ["input_tree_0", "input_tree_1"] + //shallow_tree = "shallow_tree" + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [input_tree]) + // self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + // # Using non-iterable elements. + //input_tree = [0] + //shallow_tree = 9 + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [input_tree]) + // self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + // input_tree = [0, 1] + //shallow_tree = 9 + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [input_tree]) + // self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + // ## Both non-list edge-case. + //# Using iterable elements. + //input_tree = "input_tree" + // shallow_tree = "shallow_tree" + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [input_tree]) + // self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + // # Using non-iterable elements. + //input_tree = 0 + // shallow_tree = 0 + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_input_tree, [input_tree]) + // self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + // ## Input non-list edge-case. + //# Using iterable elements. + //input_tree = "input_tree" + // shallow_tree = ["shallow_tree"] + //expected_message = ("If shallow structure is a sequence, input must also " + // "be a sequence. Input has type: <(type|class) 'str'>.") + // with self.assertRaisesRegexp(TypeError, expected_message): + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_shallow_tree, shallow_tree) + + // input_tree = "input_tree" + // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] + //with self.assertRaisesRegexp(TypeError, expected_message): + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_shallow_tree, shallow_tree) + + //# Using non-iterable elements. + // input_tree = 0 + // shallow_tree = [9] + //expected_message = ("If shallow structure is a sequence, input must also " + // "be a sequence. Input has type: <(type|class) 'int'>.") + // with self.assertRaisesRegexp(TypeError, expected_message): + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_shallow_tree, shallow_tree) + + // input_tree = 0 + // shallow_tree = [9, 8] + //with self.assertRaisesRegexp(TypeError, expected_message): + // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + // self.assertEqual(flattened_shallow_tree, shallow_tree) + + // def testMapStructureUpTo(self) : + // # Named tuples. + // ab_tuple = collections.namedtuple("ab_tuple", "a, b") + // op_tuple = collections.namedtuple("op_tuple", "add, mul") + // inp_val = ab_tuple(a= 2, b= 3) + // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) + // out = nest.map_structure_up_to( + // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) + // self.assertEqual(out.a, 6) + // self.assertEqual(out.b, 15) + + // # Lists. + // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + // name_list = ["evens", ["odds", "primes"]] + // out = nest.map_structure_up_to( + // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), + // name_list, data_list) + // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) + + // # Dicts. + // inp_val = dict(a= 2, b= 3) + // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) + // out = nest.map_structure_up_to( + // inp_val, + // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + // self.assertEqual(out["a"], 6) + // self.assertEqual(out["b"], 15) + + // # Non-equal dicts. + // inp_val = dict(a= 2, b= 3) + // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) + // with self.assertRaisesRegexp(ValueError, "same keys"): + // nest.map_structure_up_to( + // inp_val, + // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + + // # Dict+custom mapping. + // inp_val = dict(a= 2, b= 3) + // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) + // out = nest.map_structure_up_to( + // inp_val, + // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + // self.assertEqual(out["a"], 6) + // self.assertEqual(out["b"], 15) + + // # Non-equal dict/mapping. + // inp_val = dict(a= 2, b= 3) + // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) + // with self.assertRaisesRegexp(ValueError, "same keys"): + // nest.map_structure_up_to( + // inp_val, + // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + + // def testGetTraverseShallowStructure(self): + // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] + // scalar_traverse_r = nest.get_traverse_shallow_structure( + // lambda s: not isinstance(s, tuple), + // scalar_traverse_input) + // self.assertEqual(scalar_traverse_r, + // [True, True, False, [True, True], {"a": False}, []]) + // nest.assert_shallow_structure(scalar_traverse_r, + // scalar_traverse_input) + + // structure_traverse_input = [(1, [2]), ([1], 2)] + // structure_traverse_r = nest.get_traverse_shallow_structure( + // lambda s: (True, False) if isinstance(s, tuple) else True, + // structure_traverse_input) + // self.assertEqual(structure_traverse_r, + // [(True, False), ([True], False)]) + // nest.assert_shallow_structure(structure_traverse_r, + // structure_traverse_input) + + // with self.assertRaisesRegexp(TypeError, "returned structure"): + // nest.get_traverse_shallow_structure(lambda _: [True], 0) + + // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): + // nest.get_traverse_shallow_structure(lambda _: 1, [1]) + + // with self.assertRaisesRegexp( + // TypeError, "didn't return a depth=1 structure of bools"): + // nest.get_traverse_shallow_structure(lambda _: [1], [1]) + + // def testYieldFlatStringPaths(self): + // for inputs_expected in ({"inputs": [], "expected": []}, + // {"inputs": 3, "expected": [()]}, + // {"inputs": [3], "expected": [(0,)]}, + // {"inputs": {"a": 3}, "expected": [("a",)]}, + // {"inputs": {"a": {"b": 4}}, + // "expected": [("a", "b")]}, + // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, + // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, + // {"inputs": [{"a": [(23, 42)]}], + // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, + // {"inputs": [{"a": ([23], 42)}], + // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, + // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, + // "expected": [("a", "a"), ("c", 0, 0, 0)]}, + // {"inputs": {"0": [{"1": 23}]}, + // "expected": [("0", 0, "1")]}): + // inputs = inputs_expected["inputs"] + // expected = inputs_expected["expected"] + // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) + + // def testFlattenWithStringPaths(self): + // for inputs_expected in ( + // {"inputs": [], "expected": []}, + // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, + // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): + // inputs = inputs_expected["inputs"] + // expected = inputs_expected["expected"] + // self.assertEqual( + // nest.flatten_with_joined_string_paths(inputs, separator="/"), + // expected) + + // # Need a separate test for namedtuple as we can't declare tuple definitions + // # in the @parameterized arguments. + // def testFlattenNamedTuple(self): + // # pylint: disable=invalid-name + // Foo = collections.namedtuple("Foo", ["a", "b"]) + // Bar = collections.namedtuple("Bar", ["c", "d"]) + // # pylint: enable=invalid-name + // test_cases = [ + // (Foo(a = 3, b = Bar(c = 23, d = 42)), + // [("a", 3), ("b/c", 23), ("b/d", 42)]), + // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), + // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), + // (Bar(c = 42, d = 43), + // [("c", 42), ("d", 43)]), + // (Bar(c =[42], d = 43), + // [("c/0", 42), ("d", 43)]), + // ] + // for inputs, expected in test_cases: + // self.assertEqual( + // list(nest.flatten_with_joined_string_paths(inputs)), expected) + + // @parameterized.named_parameters( + // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), + // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, + // {"a": ("a", 4), "b": ("b", 6)}), + // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), + // ("nested", + // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, + // {"a": [("a/0", 10), ("a/1", 12)], + // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) + // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): + // def format_sum(path, * values): + // return (path, sum(values)) + // result = nest.map_structure_with_paths(format_sum, s1, s2, + // check_types=check_types) + // self.assertEqual(expected, result) + + // @parameterized.named_parameters( + // ("tuples", (1, 2), (3, 4, 5), ValueError), + // ("dicts", {"a": 1}, {"b": 2}, ValueError), + // ("mixed", (1, 2), [3, 4], TypeError), + // ("nested", + // {"a": [2, 3], "b": [1, 3]}, + // {"b": [5, 6, 7], "a": [8, 9]}, + // ValueError + // )) + // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): + // with self.assertRaises(error_type): + // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) + + + //class NestBenchmark(test.Benchmark): + + // def run_and_report(self, s1, s2, name): + // burn_iter, test_iter = 100, 30000 + + // for _ in xrange(burn_iter) : + // nest.assert_same_structure(s1, s2) + + // t0 = time.time() + // for _ in xrange(test_iter) : + // nest.assert_same_structure(s1, s2) + // t1 = time.time() + + // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, + // name=name) + + // def benchmark_assert_structure(self): + // s1 = (((1, 2), 3), 4, (5, 6)) + // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + // self.run_and_report(s1, s2, "assert_same_structure_6_elem") + + // s1 = (((1, 2), 3), 4, (5, 6)) * 10 + // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 + // self.run_and_report(s1, s2, "assert_same_structure_60_elem") + + + //if __name__ == "__main__": + // test.main() + } +} diff --git a/test/TensorFlowNET.UnitTest/nest_test/nest_test.py b/test/TensorFlowNET.UnitTest/nest_test/nest_test.py new file mode 100644 index 00000000..d0d0c5f7 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/nest_test/nest_test.py @@ -0,0 +1,883 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities working with arbitrarily nested structures.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import time + +from absl.testing import parameterized +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.util import nest + +try: + import attr # pylint:disable=g-import-not-at-top +except ImportError: + attr = None + + +class _CustomMapping(collections.Mapping): + + def __init__(self, *args, **kwargs): + self._wrapped = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._wrapped[key] + + def __iter__(self): + return iter(self._wrapped) + + def __len__(self): + return len(self._wrapped) + + +class NestTest(parameterized.TestCase, test.TestCase): + + PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name + + if attr: + class BadAttr(object): + """Class that has a non-iterable __attrs_attrs__.""" + __attrs_attrs__ = None + + @attr.s + class SampleAttr(object): + field1 = attr.ib() + field2 = attr.ib() + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testAttrsFlattenAndPack(self): + if attr is None: + self.skipTest("attr module is unavailable.") + + field_values = [1, 2] + sample_attr = NestTest.SampleAttr(*field_values) + self.assertFalse(nest._is_attrs(field_values)) + self.assertTrue(nest._is_attrs(sample_attr)) + flat = nest.flatten(sample_attr) + self.assertEqual(field_values, flat) + restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) + self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) + self.assertEqual(restructured_from_flat, sample_attr) + + # Check that flatten fails if attributes are not iterable + with self.assertRaisesRegexp(TypeError, "object is not iterable"): + flat = nest.flatten(NestTest.BadAttr()) + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testFlattenAndPack(self): + structure = ((3, 4), 5, (6, 7, (9, 10), 8)) + flat = ["a", "b", "c", "d", "e", "f", "g", "h"] + self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) + self.assertEqual( + nest.pack_sequence_as(structure, flat), (("a", "b"), "c", + ("d", "e", ("f", "g"), "h"))) + structure = (NestTest.PointXY(x=4, y=2), + ((NestTest.PointXY(x=1, y=0),),)) + flat = [4, 2, 1, 0] + self.assertEqual(nest.flatten(structure), flat) + restructured_from_flat = nest.pack_sequence_as(structure, flat) + self.assertEqual(restructured_from_flat, structure) + self.assertEqual(restructured_from_flat[0].x, 4) + self.assertEqual(restructured_from_flat[0].y, 2) + self.assertEqual(restructured_from_flat[1][0][0].x, 1) + self.assertEqual(restructured_from_flat[1][0][0].y, 0) + + self.assertEqual([5], nest.flatten(5)) + self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) + + self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) + self.assertEqual( + np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) + + with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): + nest.pack_sequence_as("scalar", [4, 5]) + + with self.assertRaisesRegexp(TypeError, "flat_sequence"): + nest.pack_sequence_as([4, 5], "bad_sequence") + + with self.assertRaises(ValueError): + nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + + @parameterized.parameters({"mapping_type": collections.OrderedDict}, + {"mapping_type": _CustomMapping}) + @test_util.assert_no_new_pyobjects_executing_eagerly + def testFlattenDictOrder(self, mapping_type): + """`flatten` orders dicts by key, including OrderedDicts.""" + ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) + plain = {"d": 3, "b": 1, "a": 0, "c": 2} + ordered_flat = nest.flatten(ordered) + plain_flat = nest.flatten(plain) + self.assertEqual([0, 1, 2, 3], ordered_flat) + self.assertEqual([0, 1, 2, 3], plain_flat) + + @parameterized.parameters({"mapping_type": collections.OrderedDict}, + {"mapping_type": _CustomMapping}) + def testPackDictOrder(self, mapping_type): + """Packing orders dicts by key, including OrderedDicts.""" + custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) + plain = {"d": 0, "b": 0, "a": 0, "c": 0} + seq = [0, 1, 2, 3] + custom_reconstruction = nest.pack_sequence_as(custom, seq) + plain_reconstruction = nest.pack_sequence_as(plain, seq) + self.assertIsInstance(custom_reconstruction, mapping_type) + self.assertIsInstance(plain_reconstruction, dict) + self.assertEqual( + mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), + custom_reconstruction) + self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) + + Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testFlattenAndPack_withDicts(self): + # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. + mess = [ + "z", + NestTest.Abc(3, 4), { + "d": _CustomMapping({ + 41: 4 + }), + "c": [ + 1, + collections.OrderedDict([ + ("b", 3), + ("a", 2), + ]), + ], + "b": 5 + }, 17 + ] + + flattened = nest.flatten(mess) + self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) + + structure_of_mess = [ + 14, + NestTest.Abc("a", True), + { + "d": _CustomMapping({ + 41: 42 + }), + "c": [ + 0, + collections.OrderedDict([ + ("b", 9), + ("a", 8), + ]), + ], + "b": 3 + }, + "hi everybody", + ] + + unflattened = nest.pack_sequence_as(structure_of_mess, flattened) + self.assertEqual(unflattened, mess) + + # Check also that the OrderedDict was created, with the correct key order. + unflattened_ordered_dict = unflattened[2]["c"][1] + self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) + self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) + + unflattened_custom_mapping = unflattened[2]["d"] + self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) + self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) + + def testFlatten_numpyIsNotFlattened(self): + structure = np.array([1, 2, 3]) + flattened = nest.flatten(structure) + self.assertEqual(len(flattened), 1) + + def testFlatten_stringIsNotFlattened(self): + structure = "lots of letters" + flattened = nest.flatten(structure) + self.assertEqual(len(flattened), 1) + unflattened = nest.pack_sequence_as("goodbye", flattened) + self.assertEqual(structure, unflattened) + + def testPackSequenceAs_notIterableError(self): + with self.assertRaisesRegexp(TypeError, + "flat_sequence must be a sequence"): + nest.pack_sequence_as("hi", "bye") + + def testPackSequenceAs_wrongLengthsError(self): + with self.assertRaisesRegexp( + ValueError, + "Structure had 2 elements, but flat_sequence had 3 elements."): + nest.pack_sequence_as(["hello", "world"], + ["and", "goodbye", "again"]) + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testIsSequence(self): + self.assertFalse(nest.is_sequence("1234")) + self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) + self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) + self.assertTrue(nest.is_sequence([])) + self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) + self.assertFalse(nest.is_sequence(set([1, 2]))) + ones = array_ops.ones([2, 3]) + self.assertFalse(nest.is_sequence(ones)) + self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) + self.assertFalse(nest.is_sequence(np.ones((4, 5)))) + + @parameterized.parameters({"mapping_type": _CustomMapping}, + {"mapping_type": dict}) + def testFlattenDictItems(self, mapping_type): + dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) + flat = {4: "a", 5: "b", 6: "c", 8: "d"} + self.assertEqual(nest.flatten_dict_items(dictionary), flat) + + with self.assertRaises(TypeError): + nest.flatten_dict_items(4) + + bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) + with self.assertRaisesRegexp(ValueError, "not unique"): + nest.flatten_dict_items(bad_dictionary) + + another_bad_dictionary = mapping_type({ + (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) + }) + with self.assertRaisesRegexp( + ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): + nest.flatten_dict_items(another_bad_dictionary) + + # pylint does not correctly recognize these as class names and + # suggests to use variable style under_score naming. + # pylint: disable=invalid-name + Named0ab = collections.namedtuple("named_0", ("a", "b")) + Named1ab = collections.namedtuple("named_1", ("a", "b")) + SameNameab = collections.namedtuple("same_name", ("a", "b")) + SameNameab2 = collections.namedtuple("same_name", ("a", "b")) + SameNamexy = collections.namedtuple("same_name", ("x", "y")) + SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) + SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) + NotSameName = collections.namedtuple("not_same_name", ("a", "b")) + # pylint: enable=invalid-name + + class SameNamedType1(SameNameab): + pass + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testAssertSameStructure(self): + structure1 = (((1, 2), 3), 4, (5, 6)) + structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + structure_different_num_elements = ("spam", "eggs") + structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) + nest.assert_same_structure(structure1, structure2) + nest.assert_same_structure("abc", 1.0) + nest.assert_same_structure("abc", np.array([0, 1])) + nest.assert_same_structure("abc", constant_op.constant([0, 1])) + + with self.assertRaisesRegexp( + ValueError, + ("The two structures don't have the same nested structure\\.\n\n" + "First structure:.*?\n\n" + "Second structure:.*\n\n" + "More specifically: Substructure " + r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' + 'substructure "type=str str=spam" is not\n' + "Entire first structure:\n" + r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" + "Entire second structure:\n" + r"\(\., \.\)")): + nest.assert_same_structure(structure1, structure_different_num_elements) + + with self.assertRaisesRegexp( + ValueError, + ("The two structures don't have the same nested structure\\.\n\n" + "First structure:.*?\n\n" + "Second structure:.*\n\n" + r'More specifically: Substructure "type=list str=\[0, 1\]" ' + r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' + "is not")): + nest.assert_same_structure([0, 1], np.array([0, 1])) + + with self.assertRaisesRegexp( + ValueError, + ("The two structures don't have the same nested structure\\.\n\n" + "First structure:.*?\n\n" + "Second structure:.*\n\n" + r'More specifically: Substructure "type=list str=\[0, 1\]" ' + 'is a sequence, while substructure "type=int str=0" ' + "is not")): + nest.assert_same_structure(0, [0, 1]) + + self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) + + with self.assertRaisesRegexp( + ValueError, + ("don't have the same nested structure\\.\n\n" + "First structure: .*?\n\nSecond structure: ")): + nest.assert_same_structure(structure1, structure_different_nesting) + + self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), + NestTest.Named0ab("a", "b")) + + nest.assert_same_structure(NestTest.Named0ab(3, 4), + NestTest.Named0ab("a", "b")) + + self.assertRaises(TypeError, nest.assert_same_structure, + NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) + + with self.assertRaisesRegexp( + ValueError, + ("don't have the same nested structure\\.\n\n" + "First structure: .*?\n\nSecond structure: ")): + nest.assert_same_structure(NestTest.Named0ab(3, 4), + NestTest.Named0ab([3], 4)) + + with self.assertRaisesRegexp( + ValueError, + ("don't have the same nested structure\\.\n\n" + "First structure: .*?\n\nSecond structure: ")): + nest.assert_same_structure([[3], 4], [3, [4]]) + + structure1_list = [[[1, 2], 3], 4, [5, 6]] + with self.assertRaisesRegexp(TypeError, + "don't have the same sequence type"): + nest.assert_same_structure(structure1, structure1_list) + nest.assert_same_structure(structure1, structure2, check_types=False) + nest.assert_same_structure(structure1, structure1_list, check_types=False) + + with self.assertRaisesRegexp(ValueError, + "don't have the same set of keys"): + nest.assert_same_structure({"a": 1}, {"b": 1}) + + nest.assert_same_structure(NestTest.SameNameab(0, 1), + NestTest.SameNameab2(2, 3)) + + # This assertion is expected to pass: two namedtuples with the same + # name and field names are considered to be identical. + nest.assert_same_structure( + NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), + NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) + + expected_message = "The two structures don't have the same.*" + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_same_structure( + NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), + NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) + + self.assertRaises(TypeError, nest.assert_same_structure, + NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) + + self.assertRaises(TypeError, nest.assert_same_structure, + NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) + + self.assertRaises(TypeError, nest.assert_same_structure, + NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) + + EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name + + def testHeterogeneousComparison(self): + nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) + nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testMapStructure(self): + structure1 = (((1, 2), 3), 4, (5, 6)) + structure2 = (((7, 8), 9), 10, (11, 12)) + structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) + nest.assert_same_structure(structure1, structure1_plus1) + self.assertAllEqual( + [2, 3, 4, 5, 6, 7], + nest.flatten(structure1_plus1)) + structure1_plus_structure2 = nest.map_structure( + lambda x, y: x + y, structure1, structure2) + self.assertEqual( + (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), + structure1_plus_structure2) + + self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) + + self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) + + # Empty structures + self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) + self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) + self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) + self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, + NestTest.EmptyNT())) + + # This is checking actual equality of types, empty list != empty tuple + self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) + + with self.assertRaisesRegexp(TypeError, "callable"): + nest.map_structure("bad", structure1_plus1) + + with self.assertRaisesRegexp(ValueError, "at least one structure"): + nest.map_structure(lambda x: x) + + with self.assertRaisesRegexp(ValueError, "same number of elements"): + nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, 3, (3,)) + + with self.assertRaisesRegexp(TypeError, "same sequence type"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) + + structure1_list = [[[1, 2], 3], 4, [5, 6]] + with self.assertRaisesRegexp(TypeError, "same sequence type"): + nest.map_structure(lambda x, y: None, structure1, structure1_list) + + nest.map_structure(lambda x, y: None, structure1, structure1_list, + check_types=False) + + with self.assertRaisesRegexp(ValueError, "same nested structure"): + nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), + check_types=False) + + with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + nest.map_structure(lambda x: None, structure1, foo="a") + + with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): + nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") + + ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name + + @test_util.assert_no_new_pyobjects_executing_eagerly + def testMapStructureWithStrings(self): + inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) + inp_b = NestTest.ABTuple(a=2, b=(1, 3)) + out = nest.map_structure(lambda string, repeats: string * repeats, + inp_a, + inp_b) + self.assertEqual("foofoo", out.a) + self.assertEqual("bar", out.b[0]) + self.assertEqual("bazbazbaz", out.b[1]) + + nt = NestTest.ABTuple(a=("something", "something_else"), + b="yet another thing") + rev_nt = nest.map_structure(lambda x: x[::-1], nt) + # Check the output is the correct structure, and all strings are reversed. + nest.assert_same_structure(nt, rev_nt) + self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) + self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) + self.assertEqual(nt.b[::-1], rev_nt.b) + + @test_util.run_deprecated_v1 + def testMapStructureOverPlaceholders(self): + inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), + array_ops.placeholder(dtypes.float32, shape=[3, 7])) + inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), + array_ops.placeholder(dtypes.float32, shape=[3, 7])) + + output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) + + nest.assert_same_structure(output, inp_a) + self.assertShapeEqual(np.zeros((3, 4)), output[0]) + self.assertShapeEqual(np.zeros((3, 7)), output[1]) + + feed_dict = { + inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), + inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) + } + + with self.cached_session() as sess: + output_np = sess.run(output, feed_dict=feed_dict) + self.assertAllClose(output_np[0], + feed_dict[inp_a][0] + feed_dict[inp_b][0]) + self.assertAllClose(output_np[1], + feed_dict[inp_a][1] + feed_dict[inp_b][1]) + + def testAssertShallowStructure(self): + inp_ab = ["a", "b"] + inp_abc = ["a", "b", "c"] + expected_message = ( + "The two structures don't have the same sequence length. Input " + "structure has length 2, while shallow structure has length 3.") + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_shallow_structure(inp_abc, inp_ab) + + inp_ab1 = [(1, 1), (2, 2)] + inp_ab2 = [[1, 1], [2, 2]] + expected_message = ( + "The two structures don't have the same sequence type. Input structure " + "has type <(type|class) 'tuple'>, while shallow structure has type " + "<(type|class) 'list'>.") + with self.assertRaisesRegexp(TypeError, expected_message): + nest.assert_shallow_structure(inp_ab2, inp_ab1) + nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) + + inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} + inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} + expected_message = ( + r"The two structures don't have the same keys. Input " + r"structure has keys \['c'\], while shallow structure has " + r"keys \['d'\].") + + with self.assertRaisesRegexp(ValueError, expected_message): + nest.assert_shallow_structure(inp_ab2, inp_ab1) + + inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) + inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) + nest.assert_shallow_structure(inp_ab, inp_ba) + + # This assertion is expected to pass: two namedtuples with the same + # name and field names are considered to be identical. + inp_shallow = NestTest.SameNameab(1, 2) + inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) + nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) + nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) + + def testFlattenUpTo(self): + # Shallow tree ends at scalar. + input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] + shallow_tree = [[True, True], [False, True]] + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) + self.assertEqual(flattened_shallow_tree, [True, True, False, True]) + + # Shallow tree ends at string. + input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] + shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + input_tree_flattened = nest.flatten(input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) + self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) + + # Make sure dicts are correctly flattened, yielding values, not keys. + input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} + shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [1, {"c": 2}, 3, (4, 5)]) + + # Namedtuples. + ab_tuple = NestTest.ABTuple + input_tree = ab_tuple(a=[0, 1], b=2) + shallow_tree = ab_tuple(a=0, b=1) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [[0, 1], 2]) + + # Nested dicts, OrderedDicts and namedtuples. + input_tree = collections.OrderedDict( + [("a", ab_tuple(a=[0, {"b": 1}], b=2)), + ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) + shallow_tree = input_tree + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) + shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [ab_tuple(a=[0, {"b": 1}], b=2), + 3, + collections.OrderedDict([("f", 4)])]) + shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) + input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, + input_tree) + self.assertEqual(input_tree_flattened_as_shallow_tree, + [ab_tuple(a=[0, {"b": 1}], b=2), + {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) + + ## Shallow non-list edge-case. + # Using iterable elements. + input_tree = ["input_tree"] + shallow_tree = "shallow_tree" + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + input_tree = ["input_tree_0", "input_tree_1"] + shallow_tree = "shallow_tree" + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + # Using non-iterable elements. + input_tree = [0] + shallow_tree = 9 + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + input_tree = [0, 1] + shallow_tree = 9 + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + ## Both non-list edge-case. + # Using iterable elements. + input_tree = "input_tree" + shallow_tree = "shallow_tree" + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + # Using non-iterable elements. + input_tree = 0 + shallow_tree = 0 + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_input_tree, [input_tree]) + self.assertEqual(flattened_shallow_tree, [shallow_tree]) + + ## Input non-list edge-case. + # Using iterable elements. + input_tree = "input_tree" + shallow_tree = ["shallow_tree"] + expected_message = ("If shallow structure is a sequence, input must also " + "be a sequence. Input has type: <(type|class) 'str'>.") + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + input_tree = "input_tree" + shallow_tree = ["shallow_tree_9", "shallow_tree_8"] + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + # Using non-iterable elements. + input_tree = 0 + shallow_tree = [9] + expected_message = ("If shallow structure is a sequence, input must also " + "be a sequence. Input has type: <(type|class) 'int'>.") + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + input_tree = 0 + shallow_tree = [9, 8] + with self.assertRaisesRegexp(TypeError, expected_message): + flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) + flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) + self.assertEqual(flattened_shallow_tree, shallow_tree) + + def testMapStructureUpTo(self): + # Named tuples. + ab_tuple = collections.namedtuple("ab_tuple", "a, b") + op_tuple = collections.namedtuple("op_tuple", "add, mul") + inp_val = ab_tuple(a=2, b=3) + inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) + out = nest.map_structure_up_to( + inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) + self.assertEqual(out.a, 6) + self.assertEqual(out.b, 15) + + # Lists. + data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] + name_list = ["evens", ["odds", "primes"]] + out = nest.map_structure_up_to( + name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), + name_list, data_list) + self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) + + # Dicts. + inp_val = dict(a=2, b=3) + inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) + out = nest.map_structure_up_to( + inp_val, + lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + self.assertEqual(out["a"], 6) + self.assertEqual(out["b"], 15) + + # Non-equal dicts. + inp_val = dict(a=2, b=3) + inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) + with self.assertRaisesRegexp(ValueError, "same keys"): + nest.map_structure_up_to( + inp_val, + lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + + # Dict+custom mapping. + inp_val = dict(a=2, b=3) + inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) + out = nest.map_structure_up_to( + inp_val, + lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + self.assertEqual(out["a"], 6) + self.assertEqual(out["b"], 15) + + # Non-equal dict/mapping. + inp_val = dict(a=2, b=3) + inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) + with self.assertRaisesRegexp(ValueError, "same keys"): + nest.map_structure_up_to( + inp_val, + lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) + + def testGetTraverseShallowStructure(self): + scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] + scalar_traverse_r = nest.get_traverse_shallow_structure( + lambda s: not isinstance(s, tuple), + scalar_traverse_input) + self.assertEqual(scalar_traverse_r, + [True, True, False, [True, True], {"a": False}, []]) + nest.assert_shallow_structure(scalar_traverse_r, + scalar_traverse_input) + + structure_traverse_input = [(1, [2]), ([1], 2)] + structure_traverse_r = nest.get_traverse_shallow_structure( + lambda s: (True, False) if isinstance(s, tuple) else True, + structure_traverse_input) + self.assertEqual(structure_traverse_r, + [(True, False), ([True], False)]) + nest.assert_shallow_structure(structure_traverse_r, + structure_traverse_input) + + with self.assertRaisesRegexp(TypeError, "returned structure"): + nest.get_traverse_shallow_structure(lambda _: [True], 0) + + with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): + nest.get_traverse_shallow_structure(lambda _: 1, [1]) + + with self.assertRaisesRegexp( + TypeError, "didn't return a depth=1 structure of bools"): + nest.get_traverse_shallow_structure(lambda _: [1], [1]) + + def testYieldFlatStringPaths(self): + for inputs_expected in ({"inputs": [], "expected": []}, + {"inputs": 3, "expected": [()]}, + {"inputs": [3], "expected": [(0,)]}, + {"inputs": {"a": 3}, "expected": [("a",)]}, + {"inputs": {"a": {"b": 4}}, + "expected": [("a", "b")]}, + {"inputs": [{"a": 2}], "expected": [(0, "a")]}, + {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, + {"inputs": [{"a": [(23, 42)]}], + "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, + {"inputs": [{"a": ([23], 42)}], + "expected": [(0, "a", 0, 0), (0, "a", 1)]}, + {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, + "expected": [("a", "a"), ("c", 0, 0, 0)]}, + {"inputs": {"0": [{"1": 23}]}, + "expected": [("0", 0, "1")]}): + inputs = inputs_expected["inputs"] + expected = inputs_expected["expected"] + self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) + + def testFlattenWithStringPaths(self): + for inputs_expected in ( + {"inputs": [], "expected": []}, + {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, + {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): + inputs = inputs_expected["inputs"] + expected = inputs_expected["expected"] + self.assertEqual( + nest.flatten_with_joined_string_paths(inputs, separator="/"), + expected) + + # Need a separate test for namedtuple as we can't declare tuple definitions + # in the @parameterized arguments. + def testFlattenNamedTuple(self): + # pylint: disable=invalid-name + Foo = collections.namedtuple("Foo", ["a", "b"]) + Bar = collections.namedtuple("Bar", ["c", "d"]) + # pylint: enable=invalid-name + test_cases = [ + (Foo(a=3, b=Bar(c=23, d=42)), + [("a", 3), ("b/c", 23), ("b/d", 42)]), + (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), + [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), + (Bar(c=42, d=43), + [("c", 42), ("d", 43)]), + (Bar(c=[42], d=43), + [("c/0", 42), ("d", 43)]), + ] + for inputs, expected in test_cases: + self.assertEqual( + list(nest.flatten_with_joined_string_paths(inputs)), expected) + + @parameterized.named_parameters( + ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), + ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, + {"a": ("a", 4), "b": ("b", 6)}), + ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), + ("nested", + {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, + {"a": [("a/0", 10), ("a/1", 12)], + "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) + def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): + def format_sum(path, *values): + return (path, sum(values)) + result = nest.map_structure_with_paths(format_sum, s1, s2, + check_types=check_types) + self.assertEqual(expected, result) + + @parameterized.named_parameters( + ("tuples", (1, 2), (3, 4, 5), ValueError), + ("dicts", {"a": 1}, {"b": 2}, ValueError), + ("mixed", (1, 2), [3, 4], TypeError), + ("nested", + {"a": [2, 3], "b": [1, 3]}, + {"b": [5, 6, 7], "a": [8, 9]}, + ValueError + )) + def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): + with self.assertRaises(error_type): + nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) + + +class NestBenchmark(test.Benchmark): + + def run_and_report(self, s1, s2, name): + burn_iter, test_iter = 100, 30000 + + for _ in xrange(burn_iter): + nest.assert_same_structure(s1, s2) + + t0 = time.time() + for _ in xrange(test_iter): + nest.assert_same_structure(s1, s2) + t1 = time.time() + + self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, + name=name) + + def benchmark_assert_structure(self): + s1 = (((1, 2), 3), 4, (5, 6)) + s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) + self.run_and_report(s1, s2, "assert_same_structure_6_elem") + + s1 = (((1, 2), 3), 4, (5, 6)) * 10 + s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 + self.run_and_report(s1, s2, "assert_same_structure_60_elem") + + +if __name__ == "__main__": + test.main()