Browse Source

started porting nest, flatten and pack_structure_as already work

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
3b5c6f762c
3 changed files with 2606 additions and 0 deletions
  1. +871
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs
  2. +852
    -0
      test/TensorFlowNET.UnitTest/nest_test/NestTest.cs
  3. +883
    -0
      test/TensorFlowNET.UnitTest/nest_test/nest_test.py

+ 871
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -0,0 +1,871 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using NumSharp;
namespace Tensorflow.Util
{
//Functions for working with arbitrarily nested sequences of elements.
//This module can perform operations on nested structures. A nested structure is a
//Python sequence, tuple (including `namedtuple`), or dict that can contain
//further sequences, tuples, and dicts.
//The utilities here assume (and do not check) that the nested structures form a
//'tree', i.e., no references in the structure of the input of these functions
//should be recursive.
//Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
// (np.array([3, 4]), tf.constant([3, 4])))`
//
public static class nest
{
//def _get_attrs_values(obj):
// """Returns the list of values from an attrs instance."""
// attrs = getattr(obj.__class__, "__attrs_attrs__")
// return [getattr(obj, a.name) for a in attrs]
/// <summary>
/// Returns a sorted list of the dict keys, with error if keys not sortable.
/// </summary>
private static IEnumerable<string> _sorted(IDictionary dict_)
{
return dict_.Keys.OfType<string>().OrderBy(x => x);
}
//def _is_namedtuple(instance, strict=False):
// """Returns True iff `instance` is a `namedtuple`.
// Args:
// instance: An instance of a Python object.
// strict: If True, `instance` is considered to be a `namedtuple` only if
// it is a "plain" namedtuple. For instance, a class inheriting
// from a `namedtuple` will be considered to be a `namedtuple`
// iff `strict=False`.
// Returns:
// True if `instance` is a `namedtuple`.
// """
// return _pywrap_tensorflow.IsNamedtuple(instance, strict)
//# See the swig file (util.i) for documentation.
//_is_mapping = _pywrap_tensorflow.IsMapping
//_is_attrs = _pywrap_tensorflow.IsAttrs
/// <summary>
/// Converts the sequence `args` to the same type as `instance`.
/// </summary>
/// <param name="instance">an instance of `tuple`, `list`, `namedtuple`, `dict`, or
/// `collections.OrderedDict`.</param>
/// <param name="args">elements to be converted to the `instance` type.</param>
/// <returns>`args` with the type of `instance`.</returns>
private static object _sequence_like(object instance, IEnumerable<object> args)
{
if (is_mapping(instance))
{
//# Pack dictionaries in a deterministic order by sorting the keys.
//# Notice this means that we ignore the original order of `OrderedDict`
//# instances. This is intentional, to avoid potential bugs caused by mixing
//# ordered and plain dicts (e.g., flattening a dict but using a
//# corresponding `OrderedDict` to pack it back).
// result = dict(zip(_sorted(instance), args))
// return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
}
//else if( _is_namedtuple(instance) || _is_attrs(instance))
// return type(instance)(*args)
else
{
// Not a namedtuple
switch (instance)
{
case object[] array:
var result_array = new object[args.Count()];
int i = 0;
foreach (var x in args)
{
result_array[i] = x;
i++;
}
return result_array;
case List<object> list:
return new List<object>(args);
default:
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
}
}
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
}
/// <summary>
/// Yields the next value from the given iterable.
/// </summary>
private static IEnumerable<object> _yield_value(object iterable)
{
if (is_mapping(iterable))
{
var dict = iterable as IDictionary;
//# Iterate through dictionaries in a deterministic order by sorting the
//# keys. Notice this means that we ignore the original order of `OrderedDict`
//# instances. This is intentional, to avoid potential bugs caused by mixing
//# ordered and plain dicts (e.g., flattening a dict but using a
//# corresponding `OrderedDict` to pack it back).
foreach (var key in _sorted(dict))
yield return dict[key];
}
//else if (_is_attrs(iterable))
//{
// // for value in _get_attrs_values(iterable):
// // yield value
//}
else if (iterable is IEnumerable)
{
var enumerable = iterable as IEnumerable;
foreach (var value in enumerable)
yield return value;
}
else
{
throw new TypeError("Unexpected iterable type: " + iterable.GetType());
//var jobj = JObject.FromObject(iterable);
//foreach (var key in _sorted())
// yield return jobj[key];
}
}
//# See the swig file (util.i) for documentation.
public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string);
public static bool is_mapping(object arg) => arg is IDictionary;
//# See the swig file (util.i) for documentation.
//flatten = _pywrap_tensorflow.Flatten
public static List<object> flatten(object structure)
{
var list = new List<object>();
_flatten_recursive(structure, list);
return list;
}
private static void _flatten_recursive(object obj, List<object> list)
{
if (obj is string)
{
list.Add(obj);
return;
}
if (obj is IDictionary)
{
var dict = obj as IDictionary;
foreach (var key in _sorted(dict))
_flatten_recursive(dict[key], list);
return;
}
if (obj is NDArray)
{
list.Add(obj);
return;
}
if (obj is IEnumerable)
{
var structure = obj as IEnumerable;
foreach (var child in structure)
_flatten_recursive(child, list);
return;
}
list.Add(obj);
}
//# See the swig file (util.i) for documentation.
//_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
//class _DotString(object):
// def __str__(self):
// return "."
// def __repr__(self):
// return "."
//_DOT = _DotString()
//def assert_same_structure(nest1, nest2, check_types=True):
// """Asserts that two structures are nested in the same way.
// Note that namedtuples with identical name and fields are always considered
// to have the same shallow structure (even with `check_types=True`).
// For intance, this code will print `True`:
// ```python
// def nt(a, b):
// return collections.namedtuple('foo', 'a b')(a, b)
// print(assert_same_structure(nt(0, 1), nt(2, 3)))
// ```
// Args:
// nest1: an arbitrarily nested structure.
// nest2: an arbitrarily nested structure.
// check_types: if `True` (default) types of sequences are checked as well,
// including the keys of dictionaries. If set to `False`, for example a
// list and a tuple of objects will look the same if they have the same
// size. Note that namedtuples with identical name and fields are always
// considered to have the same shallow structure. Two types will also be
// considered the same if they are both list subtypes (which allows "list"
// and "_ListWrapper" from checkpointable dependency tracking to compare
// equal).
// Raises:
// ValueError: If the two structures do not have the same number of elements or
// if the two structures are not nested in the same way.
// TypeError: If the two structures differ in the type of sequence in any of
// their substructures. Only possible if `check_types` is `True`.
// """
// try:
// _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
// except (ValueError, TypeError) as e:
// str1 = str(map_structure(lambda _: _DOT, nest1))
// str2 = str(map_structure(lambda _: _DOT, nest2))
// raise type(e)("%s\n"
// "Entire first structure:\n%s\n"
// "Entire second structure:\n%s"
// % (str(e), str1, str2))
//def flatten_dict_items(dictionary):
// """Returns a dictionary with flattened keys and values.
// This function flattens the keys and values of a dictionary, which can be
// arbitrarily nested structures, and returns the flattened version of such
// structures:
// ```python
// example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
// result = {4: "a", 5: "b", 6: "c", 8: "d"}
// flatten_dict_items(example_dictionary) == result
// ```
// The input dictionary must satisfy two properties:
// 1. Its keys and values should have the same exact nested structure.
// 2. The set of all flattened keys of the dictionary must not contain repeated
// keys.
// Args:
// dictionary: the dictionary to zip
// Returns:
// The zipped dictionary.
// Raises:
// TypeError: If the input is not a dictionary.
// ValueError: If any key and value have not the same structure, or if keys are
// not unique.
// """
// if not isinstance(dictionary, (dict, _collections.Mapping)):
// raise TypeError("input must be a dictionary")
// flat_dictionary = {}
// for i, v in _six.iteritems(dictionary):
// if not is_sequence(i):
// if i in flat_dictionary:
// raise ValueError(
// "Could not flatten dictionary: key %s is not unique." % i)
// flat_dictionary[i] = v
// else:
// flat_i = flatten(i)
// flat_v = flatten(v)
// if len(flat_i) != len(flat_v):
// raise ValueError(
// "Could not flatten dictionary. Key had %d elements, but value had "
// "%d elements. Key: %s, value: %s."
// % (len(flat_i), len(flat_v), flat_i, flat_v))
// for new_i, new_v in zip(flat_i, flat_v):
// if new_i in flat_dictionary:
// raise ValueError(
// "Could not flatten dictionary: key %s is not unique."
// % (new_i))
// flat_dictionary[new_i] = new_v
// return flat_dictionary
/// <summary>
/// Helper function for pack_sequence_as.
/// </summary>
/// <param name="structure">Substructure (list / tuple / dict) to mimic.</param>
/// <param name="flat">Flattened values to output substructure for.</param>
/// <param name="index">Index at which to start reading from flat.</param>
/// <returns>
/// The tuple(new_index, child), where:
/// * new_index - the updated index into `flat` having processed `structure`.
/// * packed - the subset of `flat` corresponding to `structure`,
/// having started at `index`, and packed into the same nested
/// format.</returns>
private static (int new_index, List<object> child) _packed_nest_with_indices(object structure, List<object> flat,
int index)
{
var packed = new List<object>();
foreach (var s in _yield_value(structure))
{
if (is_sequence(s))
{
var (new_index, child) = _packed_nest_with_indices(s, flat, index);
packed.Add(_sequence_like(s, child));
index = new_index;
}
else
{
packed.Add(flat[index]);
index += 1;
}
}
return (index, packed);
}
private static int len(IEnumerable<object> x) => x.Count();
/// <summary>
/// Returns a given flattened sequence packed into a given structure.
/// If `structure` is a scalar, `flat_sequence` must be a single-element list;
/// in this case the return value is `flat_sequence[0]`.
///
/// If `structure` is or contains a dict instance, the keys will be sorted to
/// pack the flat sequence in deterministic order. This is true also for
/// `OrderedDict` instances: their sequence order is ignored, the sorting order of
/// keys is used instead. The same convention is followed in `flatten`.
/// This correctly repacks dicts and `OrderedDict`s after they have been
/// flattened, and also allows flattening an `OrderedDict` and then repacking it
/// back using a corresponding plain dict, or vice-versa.
/// Dictionaries with non-sortable keys cannot be flattened.
/// </summary>
/// <param name="structure">
/// Nested structure, whose structure is given by nested lists,
/// tuples, and dicts. Note: numpy arrays and strings are considered
/// scalars.
/// </param>
/// <param name="flat_sequence"> flat sequence to pack.</param>
/// <returns> `flat_sequence` converted to have the same recursive structure as
/// `structure`.
/// </returns>
public static object pack_sequence_as(object structure, List<object> flat_sequence)
{
if (flat_sequence == null)
throw new ArgumentException("flat_sequence must not be null");
// if not is_sequence(flat_sequence):
// raise TypeError("flat_sequence must be a sequence")
if (!is_sequence(structure))
{
if (len(flat_sequence) != 1)
throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1");
return flat_sequence.FirstOrDefault();
}
int final_index = 0;
List<object> packed = null;
try
{
(final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0);
if (final_index < len(flat_sequence))
throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }");
}
catch (IndexOutOfRangeException)
{
var flat_structure = flatten(structure);
if (len(flat_structure) != len(flat_sequence))
{
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}");
}
return _sequence_like(structure, packed);
}
return packed;
}
/// <summary>
/// Applies `func` to each entry in `structure` and returns a new structure.
///
/// Applies `func(x[0], x[1], ...)` where x[i] is an entry in
/// `structure[i]`. All structures in `structure` must have the same arity,
/// and the return value will contain the results in the same structure.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <typeparam name="U"></typeparam>
/// <param name="func"> A callable that accepts as many arguments as there are structures.</param>
/// <param name="structure">scalar, or tuple or list of constructed scalars and/or other
/// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param>
/// <param name="check_types">If set to
/// `True` (default) the types of iterables within the structures have to be
/// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
/// exception). To allow this set this argument to `False`.
/// Note that namedtuples with identical name and fields are always
/// considered to have the same shallow structure.</param>
/// <returns>
/// A new structure with the same arity as `structure`, whose values correspond
/// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
/// location in `structure[i]`. If there are different sequence types and
/// `check_types` is `False` the sequence types of the first structure will be
/// used.
/// </returns>
public static IEnumerable<U> map_structure<T, U>(Func<T, U> func, IEnumerable<T> structure, bool check_types = false)
{
// for other in structure[1:]:
// assert_same_structure(structure[0], other, check_types=check_types)
// flat_structure = [flatten(s) for s in structure]
// entries = zip(*flat_structure)
// return pack_sequence_as(
// structure[0], [func(*x) for x in entries])
return null;
}
//def map_structure_with_paths(func, *structure, **kwargs):
// """Applies `func` to each entry in `structure` and returns a new structure.
// Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
// `structure[i]` and `path` is the common path to x[i] in the structures. All
// structures in `structure` must have the same arity, and the return value will
// contain the results in the same structure. Special kwarg `check_types`
// determines whether the types of iterables within the structure must be the
// same-- see **kwargs definition below.
// Args:
// func: A callable with the signature func(path, *values, **kwargs) that is
// evaluated on the leaves of the structure.
// *structure: A variable number of compatible structures to process.
// **kwargs: Optional kwargs to be passed through to func. Special kwarg
// `check_types` is not passed to func, but instead determines whether the
// types of iterables within the structures have to be same (e.g.,
// `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
// default, the types must match. To allow iteration over structures of
// different types (but common arity), set this kwarg to `False`.
// Returns:
// A structure of the same form as the input structures whose leaves are the
// result of evaluating func on corresponding leaves of the input structures.
// Raises:
// TypeError: If `func` is not callable or if the structures do not match
// each other by depth tree.
// TypeError: If `check_types` is not `False` and the two structures differ in
// the type of sequence in any of their substructures.
// ValueError: If no structures are provided.
// """
// if not callable(func):
// raise TypeError("func must be callable, got: %s" % func)
// if not structure:
// raise ValueError("Must provide at least one structure")
// check_types = kwargs.pop("check_types", True)
// for other in structure[1:]:
// assert_same_structure(structure[0], other, check_types=check_types)
//# First set paths_and_values to:
//# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]]
// paths_and_values = [flatten_with_joined_string_paths(s) for s in structure]
//# Now zip(*paths_and_values) would be:
//# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))]
//# so grouped_by_path is set to:
//# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]]
//# Note that p1i, ... pmi must all be equal since the structures are the same.
// grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)]
// return pack_sequence_as(structure[0], [
// func(paths[0], *values, **kwargs) for paths, values in grouped_by_path])
//def _yield_flat_up_to(shallow_tree, input_tree):
// """Yields elements `input_tree` partially flattened up to `shallow_tree`."""
// if is_sequence(shallow_tree):
// for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
// _yield_value(input_tree)):
// for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
// yield input_leaf
// else:
// yield input_tree
//def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
// """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
// That is, this function tests if the `input_tree` structure can be created from
// the `shallow_tree` structure by replacing its leaf nodes with deeper
// tree structures.
// Examples:
// The following code will raise an exception:
// ```python
// shallow_tree = ["a", "b"]
// input_tree = ["c", ["d", "e"], "f"]
// assert_shallow_structure(shallow_tree, input_tree)
// ```
// The following code will not raise an exception:
// ```python
// shallow_tree = ["a", "b"]
// input_tree = ["c", ["d", "e"]]
// assert_shallow_structure(shallow_tree, input_tree)
// ```
// Args:
// shallow_tree: an arbitrarily nested structure.
// input_tree: an arbitrarily nested structure.
// check_types: if `True` (default) the sequence types of `shallow_tree` and
// `input_tree` have to be the same. Note that even with check_types==True,
// this function will consider two different namedtuple classes with the same
// name and _fields attribute to be the same class.
// Raises:
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
// TypeError: If the sequence types of `shallow_tree` are different from
// `input_tree`. Only raised if `check_types` is `True`.
// ValueError: If the sequence lengths of `shallow_tree` are different from
// `input_tree`.
// """
// if is_sequence(shallow_tree):
// if not is_sequence(input_tree):
// raise TypeError(
// "If shallow structure is a sequence, input must also be a sequence. "
// "Input has type: %s." % type(input_tree))
// if check_types and not isinstance(input_tree, type(shallow_tree)):
//# Duck-typing means that nest should be fine with two different
//# namedtuples with identical name and fields.
// shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
// input_is_namedtuple = _is_namedtuple(input_tree, False)
// if shallow_is_namedtuple and input_is_namedtuple:
// if not _same_namedtuples(shallow_tree, input_tree):
// raise TypeError(
// "The two namedtuples don't have the same sequence type. Input "
// "structure has type %s, while shallow structure has type %s."
// % (type(input_tree), type(shallow_tree)))
// elif not (isinstance(shallow_tree, _collections.Mapping)
// and isinstance(input_tree, _collections.Mapping)):
// raise TypeError(
// "The two structures don't have the same sequence type. Input "
// "structure has type %s, while shallow structure has type %s."
// % (type(input_tree), type(shallow_tree)))
// if len(input_tree) != len(shallow_tree):
// raise ValueError(
// "The two structures don't have the same sequence length. Input "
// "structure has length %s, while shallow structure has length %s."
// % (len(input_tree), len(shallow_tree)))
// if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)):
// if set(input_tree) != set(shallow_tree):
// raise ValueError(
// "The two structures don't have the same keys. Input "
// "structure has keys %s, while shallow structure has keys %s." %
// (list(_six.iterkeys(input_tree)),
// list(_six.iterkeys(shallow_tree))))
// input_tree = list(sorted(_six.iteritems(input_tree)))
// shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
// for shallow_branch, input_branch in zip(shallow_tree, input_tree):
// assert_shallow_structure(shallow_branch, input_branch,
// check_types=check_types)
//def flatten_up_to(shallow_tree, input_tree):
// """Flattens `input_tree` up to `shallow_tree`.
// Any further depth in structure in `input_tree` is retained as elements in the
// partially flatten output.
// If `shallow_tree` and `input_tree` are not sequences, this returns a
// single-element list: `[input_tree]`.
// Use Case:
// Sometimes we may wish to partially flatten a nested sequence, retaining some
// of the nested structure. We achieve this by specifying a shallow structure,
// `shallow_tree`, we wish to flatten up to.
// The input, `input_tree`, can be thought of as having the same structure as
// `shallow_tree`, but with leaf nodes that are themselves tree structures.
// Examples:
// ```python
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
// shallow_tree = [[True, True], [False, True]]
// flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
//# Output is:
//# [[2, 2], [3, 3], [4, 9], [5, 5]]
//# [True, True, False, True]
// ```
// ```python
// input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
// shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
// input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
// input_tree_flattened = flatten(input_tree)
//# Output is:
//# [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
//# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
// ```
// Non-Sequence Edge Cases:
// ```python
// flatten_up_to(0, 0) # Output: [0]
// flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
// flatten_up_to([0, 1, 2], 0) # Output: TypeError
// flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
// ```
// Args:
// shallow_tree: a possibly pruned structure of input_tree.
// input_tree: an arbitrarily nested structure or a scalar object.
// Note, numpy arrays are considered scalars.
// Returns:
// A Python list, the partially flattened version of `input_tree` according to
// the structure of `shallow_tree`.
// Raises:
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
// TypeError: If the sequence types of `shallow_tree` are different from
// `input_tree`.
// ValueError: If the sequence lengths of `shallow_tree` are different from
// `input_tree`.
// """
// assert_shallow_structure(shallow_tree, input_tree)
// return list(_yield_flat_up_to(shallow_tree, input_tree))
//def map_structure_up_to(shallow_tree, func, *inputs):
// """Applies a function or op to a number of partially flattened inputs.
// The `inputs` are flattened up to `shallow_tree` before being mapped.
// Use Case:
// Sometimes we wish to apply a function to a partially flattened
// sequence (for example when the function itself takes sequence inputs). We
// achieve this by specifying a shallow structure, `shallow_tree` we wish to
// flatten up to.
// The `inputs`, can be thought of as having the same structure as
// `shallow_tree`, but with leaf nodes that are themselves tree structures.
// This function therefore will return something with the same base structure as
// `shallow_tree`.
// Examples:
// ```python
// ab_tuple = collections.namedtuple("ab_tuple", "a, b")
// op_tuple = collections.namedtuple("op_tuple", "add, mul")
// inp_val = ab_tuple(a=2, b=3)
// inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
// out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
// inp_val, inp_ops)
//# Output is: ab_tuple(a=6, b=15)
// ```
// ```python
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
// name_list = ['evens', ['odds', 'primes']]
// out = map_structure_up_to(
// name_list,
// lambda name, sec: "first_{}_{}".format(len(sec), name),
// name_list, data_list)
//# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
// ```
// Args:
// shallow_tree: a shallow tree, common to all the inputs.
// func: callable which will be applied to each input individually.
// *inputs: arbitrarily nested combination of objects that are compatible with
// shallow_tree. The function `func` is applied to corresponding
// partially flattened elements of each input, so the function must support
// arity of `len(inputs)`.
// Raises:
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
// TypeError: If the sequence types of `shallow_tree` are different from
// `input_tree`.
// ValueError: If the sequence lengths of `shallow_tree` are different from
// `input_tree`.
// Returns:
// result of repeatedly applying `func`, with same structure as
// `shallow_tree`.
// """
// if not inputs:
// raise ValueError("Cannot map over no sequences")
// for input_tree in inputs:
// assert_shallow_structure(shallow_tree, input_tree)
//# Flatten each input separately, apply the function to corresponding elements,
//# then repack based on the structure of the first input.
// all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
// for input_tree in inputs]
// results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
// return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
//def get_traverse_shallow_structure(traverse_fn, structure):
// """Generates a shallow structure from a `traverse_fn` and `structure`.
// `traverse_fn` must accept any possible subtree of `structure` and return
// a depth=1 structure containing `True` or `False` values, describing which
// of the top-level subtrees may be traversed. It may also
// return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
// Examples are available in the unit tests (nest_test.py).
// Args:
// traverse_fn: Function taking a substructure and returning either a scalar
// `bool` (whether to traverse that substructure or not) or a depth=1
// shallow structure of the same type, describing which parts of the
// substructure to traverse.
// structure: The structure to traverse.
// Returns:
// A shallow structure containing python bools, which can be passed to
// `map_structure_up_to` and `flatten_up_to`.
// Raises:
// TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
// or a structure with depth higher than 1 for a sequence input,
// or if any leaf values in the returned structure or scalar are not type
// `bool`.
// """
// to_traverse = traverse_fn(structure)
// if not is_sequence(structure):
// if not isinstance(to_traverse, bool):
// raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
// % (to_traverse, structure))
// return to_traverse
// level_traverse = []
// if isinstance(to_traverse, bool):
// if not to_traverse:
//# Do not traverse this substructure at all. Exit early.
// return False
// else:
//# Traverse the entire substructure.
// for branch in _yield_value(structure):
// level_traverse.append(
// get_traverse_shallow_structure(traverse_fn, branch))
// elif not is_sequence(to_traverse):
// raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
// % (to_traverse, structure))
// else:
//# Traverse some subset of this substructure.
// assert_shallow_structure(to_traverse, structure)
// for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
// if not isinstance(t, bool):
// raise TypeError(
// "traverse_fn didn't return a depth=1 structure of bools. saw: %s "
// " for structure: %s" % (to_traverse, structure))
// if t:
// level_traverse.append(
// get_traverse_shallow_structure(traverse_fn, branch))
// else:
// level_traverse.append(False)
// return _sequence_like(structure, level_traverse)
//def yield_flat_paths(nest):
// """Yields paths for some nested structure.
// Paths are lists of objects which can be str-converted, which may include
// integers or other types which are used as indices in a dict.
// The flat list will be in the corresponding order as if you called
// `snt.nest.flatten` on the structure. This is handy for naming Tensors such
// the TF scope structure matches the tuple structure.
// E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
// ```shell
// >>> nest.flatten(value)
// [3, 23, 42]
// >>> list(nest.yield_flat_paths(value))
// [('a',), ('b', 'c'), ('b', 'd')]
// ```
// ```shell
// >>> list(nest.yield_flat_paths({'a': [3]}))
// [('a', 0)]
// >>> list(nest.yield_flat_paths({'a': 3}))
// [('a',)]
// ```
// Args:
// nest: the value to produce a flattened paths list for.
// Yields:
// Tuples containing index or key values which form the path to a specific
// leaf value in the nested structure.
// """
//# The _maybe_add_final_path_element function is used below in order to avoid
//# adding trailing slashes when the sub-element recursed into is a leaf.
// if isinstance(nest, (dict, _collections.Mapping)):
// for key in _sorted(nest):
// value = nest[key]
// for sub_path in yield_flat_paths(value):
// yield (key,) + sub_path
// elif _is_namedtuple(nest):
// for key in nest._fields:
// value = getattr(nest, key)
// for sub_path in yield_flat_paths(value):
// yield (key,) + sub_path
// elif isinstance(nest, _six.string_types):
// yield ()
// elif isinstance(nest, _collections.Sequence):
// for idx, value in enumerate(nest):
// for sub_path in yield_flat_paths(value):
// yield (idx,) + sub_path
// else:
// yield ()
//def flatten_with_joined_string_paths(structure, separator="/"):
// """Returns a list of (string path, data element) tuples.
// The order of tuples produced matches that of `nest.flatten`. This allows you
// to flatten a nested structure while keeping information about where in the
// structure each data element was located. See `nest.yield_flat_paths`
// for more information.
// Args:
// structure: the nested structure to flatten.
// separator: string to separate levels of hierarchy in the results, defaults
// to '/'.
// Returns:
// A list of (string, data element) tuples.
// """
// flat_paths = yield_flat_paths(structure)
// def stringify_and_join(path_elements):
// return separator.join(str(path_element) for path_element in path_elements)
// flat_string_paths = [stringify_and_join(path) for path in flat_paths]
// return list(zip(flat_string_paths, flatten(structure)))
}
}

+ 852
- 0
test/TensorFlowNET.UnitTest/nest_test/NestTest.cs View File

@@ -0,0 +1,852 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using Tensorflow;
using Tensorflow.Util;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/util/nest_test.py
/// </summary>
[TestClass]
public class NestTest : PythonTest
{
public class PointXY
{
public double x;
public double y;
}
// if attr:
// class BadAttr(object):
// """Class that has a non-iterable __attrs_attrs__."""
// __attrs_attrs__ = None
// @attr.s
// class SampleAttr(object):
// field1 = attr.ib()
// field2 = attr.ib()
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testAttrsFlattenAndPack(self) :
// if attr is None:
// self.skipTest("attr module is unavailable.")
// field_values = [1, 2]
// sample_attr = NestTest.SampleAttr(* field_values)
// self.assertFalse(nest._is_attrs(field_values))
// self.assertTrue(nest._is_attrs(sample_attr))
// flat = nest.flatten(sample_attr)
// self.assertEqual(field_values, flat)
// restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
// self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
// self.assertEqual(restructured_from_flat, sample_attr)
//# Check that flatten fails if attributes are not iterable
// with self.assertRaisesRegexp(TypeError, "object is not iterable"):
// flat = nest.flatten(NestTest.BadAttr())
[TestMethod]
public void testFlattenAndPack()
{
object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}};
var flat = new List<object> {"a", "b", "c", "d", "e", "f", "g", "h"};
self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8});
self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(),
JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString());
structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }};
flat = new List<object> { 4, 2, 1, 0};
self.assertEqual(nest.flatten(structure), flat);
// restructured_from_flat = nest.pack_sequence_as(structure, flat)
// self.assertEqual(restructured_from_flat, structure)
// self.assertEqual(restructured_from_flat[0].x, 4)
// self.assertEqual(restructured_from_flat[0].y, 2)
// self.assertEqual(restructured_from_flat[1][0][0].x, 1)
// self.assertEqual(restructured_from_flat[1][0][0].y, 0)
// self.assertEqual([5], nest.flatten(5))
// self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
// self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
// self.assertEqual(
// np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
// with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
// nest.pack_sequence_as("scalar", [4, 5])
// with self.assertRaisesRegexp(TypeError, "flat_sequence"):
// nest.pack_sequence_as([4, 5], "bad_sequence")
// with self.assertRaises(ValueError):
// nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
}
// @parameterized.parameters({"mapping_type": collections.OrderedDict
// },
// {"mapping_type": _CustomMapping
//})
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testFlattenDictOrder(self, mapping_type) :
// """`flatten` orders dicts by key, including OrderedDicts."""
// ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
// plain = {"d": 3, "b": 1, "a": 0, "c": 2}
// ordered_flat = nest.flatten(ordered)
// plain_flat = nest.flatten(plain)
// self.assertEqual([0, 1, 2, 3], ordered_flat)
// self.assertEqual([0, 1, 2, 3], plain_flat)
// @parameterized.parameters({"mapping_type": collections.OrderedDict},
// {"mapping_type": _CustomMapping})
// def testPackDictOrder(self, mapping_type):
// """Packing orders dicts by key, including OrderedDicts."""
// custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
// plain = {"d": 0, "b": 0, "a": 0, "c": 0}
// seq = [0, 1, 2, 3]
//custom_reconstruction = nest.pack_sequence_as(custom, seq)
//plain_reconstruction = nest.pack_sequence_as(plain, seq)
// self.assertIsInstance(custom_reconstruction, mapping_type)
// self.assertIsInstance(plain_reconstruction, dict)
// self.assertEqual(
// mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
// custom_reconstruction)
// self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
// Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testFlattenAndPack_withDicts(self) :
// # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
// mess = [
// "z",
// NestTest.Abc(3, 4), {
// "d": _CustomMapping({
// 41: 4
// }),
// "c": [
// 1,
// collections.OrderedDict([
// ("b", 3),
// ("a", 2),
// ]),
// ],
// "b": 5
// }, 17
// ]
// flattened = nest.flatten(mess)
// self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
// structure_of_mess = [
// 14,
// NestTest.Abc("a", True),
// {
// "d": _CustomMapping({
// 41: 42
// }),
// "c": [
// 0,
// collections.OrderedDict([
// ("b", 9),
// ("a", 8),
// ]),
// ],
// "b": 3
// },
// "hi everybody",
// ]
// unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
// self.assertEqual(unflattened, mess)
// # Check also that the OrderedDict was created, with the correct key order.
//unflattened_ordered_dict = unflattened[2]["c"][1]
// self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
// self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
// unflattened_custom_mapping = unflattened[2]["d"]
// self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
// self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
// def testFlatten_numpyIsNotFlattened(self):
// structure = np.array([1, 2, 3])
// flattened = nest.flatten(structure)
// self.assertEqual(len(flattened), 1)
// def testFlatten_stringIsNotFlattened(self):
// structure = "lots of letters"
// flattened = nest.flatten(structure)
// self.assertEqual(len(flattened), 1)
// unflattened = nest.pack_sequence_as("goodbye", flattened)
// self.assertEqual(structure, unflattened)
// def testPackSequenceAs_notIterableError(self) :
// with self.assertRaisesRegexp(TypeError,
// "flat_sequence must be a sequence"):
// nest.pack_sequence_as("hi", "bye")
// def testPackSequenceAs_wrongLengthsError(self):
// with self.assertRaisesRegexp(
// ValueError,
// "Structure had 2 elements, but flat_sequence had 3 elements."):
// nest.pack_sequence_as(["hello", "world"],
// ["and", "goodbye", "again"])
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testIsSequence(self):
// self.assertFalse(nest.is_sequence("1234"))
// self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
// self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
// self.assertTrue(nest.is_sequence([]))
// self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
// self.assertFalse(nest.is_sequence(set([1, 2])))
// ones = array_ops.ones([2, 3])
// self.assertFalse(nest.is_sequence(ones))
// self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
// self.assertFalse(nest.is_sequence(np.ones((4, 5))))
// @parameterized.parameters({"mapping_type": _CustomMapping},
// {"mapping_type": dict})
// def testFlattenDictItems(self, mapping_type):
// dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))})
// flat = {4: "a", 5: "b", 6: "c", 8: "d"}
// self.assertEqual(nest.flatten_dict_items(dictionary), flat)
// with self.assertRaises(TypeError):
// nest.flatten_dict_items(4)
// bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))})
// with self.assertRaisesRegexp(ValueError, "not unique"):
// nest.flatten_dict_items(bad_dictionary)
// another_bad_dictionary = mapping_type({
// (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
// })
// with self.assertRaisesRegexp(
// ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
// nest.flatten_dict_items(another_bad_dictionary)
//# pylint does not correctly recognize these as class names and
//# suggests to use variable style under_score naming.
//# pylint: disable=invalid-name
// Named0ab = collections.namedtuple("named_0", ("a", "b"))
// Named1ab = collections.namedtuple("named_1", ("a", "b"))
// SameNameab = collections.namedtuple("same_name", ("a", "b"))
// SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
// SameNamexy = collections.namedtuple("same_name", ("x", "y"))
// SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
// SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
// NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
// # pylint: enable=invalid-name
// class SameNamedType1(SameNameab):
// pass
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testAssertSameStructure(self):
// structure1 = (((1, 2), 3), 4, (5, 6))
// structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
// structure_different_num_elements = ("spam", "eggs")
// structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
// nest.assert_same_structure(structure1, structure2)
// nest.assert_same_structure("abc", 1.0)
// nest.assert_same_structure("abc", np.array([0, 1]))
// nest.assert_same_structure("abc", constant_op.constant([0, 1]))
// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// "More specifically: Substructure "
// r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
// 'substructure "type=str str=spam" is not\n'
// "Entire first structure:\n"
// r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
// "Entire second structure:\n"
// r"\(\., \.\)")):
// nest.assert_same_structure(structure1, structure_different_num_elements)
// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// r'More specifically: Substructure "type=list str=\[0, 1\]" '
// r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
// "is not")):
// nest.assert_same_structure([0, 1], np.array([0, 1]))
// with self.assertRaisesRegexp(
// ValueError,
// ("The two structures don't have the same nested structure\\.\n\n"
// "First structure:.*?\n\n"
// "Second structure:.*\n\n"
// r'More specifically: Substructure "type=list str=\[0, 1\]" '
// 'is a sequence, while substructure "type=int str=0" '
// "is not")):
// nest.assert_same_structure(0, [0, 1])
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure(structure1, structure_different_nesting)
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
// NestTest.Named0ab("a", "b"))
// nest.assert_same_structure(NestTest.Named0ab(3, 4),
// NestTest.Named0ab("a", "b"))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure(NestTest.Named0ab(3, 4),
// NestTest.Named0ab([3], 4))
// with self.assertRaisesRegexp(
// ValueError,
// ("don't have the same nested structure\\.\n\n"
// "First structure: .*?\n\nSecond structure: ")):
// nest.assert_same_structure([[3], 4], [3, [4]])
// structure1_list = [[[1, 2], 3], 4, [5, 6]]
// with self.assertRaisesRegexp(TypeError,
// "don't have the same sequence type"):
// nest.assert_same_structure(structure1, structure1_list)
// nest.assert_same_structure(structure1, structure2, check_types= False)
// nest.assert_same_structure(structure1, structure1_list, check_types=False)
// with self.assertRaisesRegexp(ValueError,
// "don't have the same set of keys"):
// nest.assert_same_structure({"a": 1}, {"b": 1})
// nest.assert_same_structure(NestTest.SameNameab(0, 1),
// NestTest.SameNameab2(2, 3))
// # This assertion is expected to pass: two namedtuples with the same
// # name and field names are considered to be identical.
// nest.assert_same_structure(
// NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
// NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
// expected_message = "The two structures don't have the same.*"
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_same_structure(
// NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
// NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
// self.assertRaises(TypeError, nest.assert_same_structure,
// NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
// EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
// def testHeterogeneousComparison(self):
// nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3))
// nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testMapStructure(self) :
// structure1 = (((1, 2), 3), 4, (5, 6))
// structure2 = (((7, 8), 9), 10, (11, 12))
// structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
// nest.assert_same_structure(structure1, structure1_plus1)
// self.assertAllEqual(
// [2, 3, 4, 5, 6, 7],
// nest.flatten(structure1_plus1))
// structure1_plus_structure2 = nest.map_structure(
// lambda x, y: x + y, structure1, structure2)
// self.assertEqual(
// (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
// structure1_plus_structure2)
// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
// self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
// # Empty structures
// self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
// self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
// self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
// self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
// NestTest.EmptyNT()))
// # This is checking actual equality of types, empty list != empty tuple
// self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
// with self.assertRaisesRegexp(TypeError, "callable"):
// nest.map_structure("bad", structure1_plus1)
// with self.assertRaisesRegexp(ValueError, "at least one structure"):
// nest.map_structure(lambda x: x)
// with self.assertRaisesRegexp(ValueError, "same number of elements"):
// nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, 3, (3,))
// with self.assertRaisesRegexp(TypeError, "same sequence type"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
// structure1_list = [[[1, 2], 3], 4, [5, 6]]
// with self.assertRaisesRegexp(TypeError, "same sequence type"):
// nest.map_structure(lambda x, y: None, structure1, structure1_list)
// nest.map_structure(lambda x, y: None, structure1, structure1_list,
// check_types=False)
// with self.assertRaisesRegexp(ValueError, "same nested structure"):
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
// check_types=False)
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
// nest.map_structure(lambda x: None, structure1, foo="a")
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
// nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
// ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
// @test_util.assert_no_new_pyobjects_executing_eagerly
// def testMapStructureWithStrings(self) :
// inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
// inp_b = NestTest.ABTuple(a=2, b=(1, 3))
// out = nest.map_structure(lambda string, repeats: string* repeats,
// inp_a,
// inp_b)
// self.assertEqual("foofoo", out.a)
// self.assertEqual("bar", out.b[0])
// self.assertEqual("bazbazbaz", out.b[1])
// nt = NestTest.ABTuple(a=("something", "something_else"),
// b="yet another thing")
// rev_nt = nest.map_structure(lambda x: x[::- 1], nt)
// # Check the output is the correct structure, and all strings are reversed.
// nest.assert_same_structure(nt, rev_nt)
// self.assertEqual(nt.a[0][::- 1], rev_nt.a[0])
// self.assertEqual(nt.a[1][::- 1], rev_nt.a[1])
// self.assertEqual(nt.b[::- 1], rev_nt.b)
// @test_util.run_deprecated_v1
// def testMapStructureOverPlaceholders(self) :
// inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
// array_ops.placeholder(dtypes.float32, shape=[3, 7]))
// inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
// array_ops.placeholder(dtypes.float32, shape=[3, 7]))
// output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
// nest.assert_same_structure(output, inp_a)
// self.assertShapeEqual(np.zeros((3, 4)), output[0])
// self.assertShapeEqual(np.zeros((3, 7)), output[1])
// feed_dict = {
// inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
// inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
// }
// with self.cached_session() as sess:
// output_np = sess.run(output, feed_dict=feed_dict)
// self.assertAllClose(output_np[0],
// feed_dict[inp_a][0] + feed_dict[inp_b][0])
// self.assertAllClose(output_np[1],
// feed_dict[inp_a][1] + feed_dict[inp_b][1])
// def testAssertShallowStructure(self):
// inp_ab = ["a", "b"]
//inp_abc = ["a", "b", "c"]
//expected_message = (
// "The two structures don't have the same sequence length. Input "
// "structure has length 2, while shallow structure has length 3.")
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_shallow_structure(inp_abc, inp_ab)
// inp_ab1 = [(1, 1), (2, 2)]
// inp_ab2 = [[1, 1], [2, 2]]
// expected_message = (
// "The two structures don't have the same sequence type. Input structure "
// "has type <(type|class) 'tuple'>, while shallow structure has type "
// "<(type|class) 'list'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// nest.assert_shallow_structure(inp_ab2, inp_ab1)
// nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False)
// inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
// inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
// expected_message = (
// r"The two structures don't have the same keys. Input "
// r"structure has keys \['c'\], while shallow structure has "
// r"keys \['d'\].")
// with self.assertRaisesRegexp(ValueError, expected_message):
// nest.assert_shallow_structure(inp_ab2, inp_ab1)
// inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
// inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
// nest.assert_shallow_structure(inp_ab, inp_ba)
// # This assertion is expected to pass: two namedtuples with the same
//# name and field names are considered to be identical.
//inp_shallow = NestTest.SameNameab(1, 2)
// inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
// def testFlattenUpTo(self):
// # Shallow tree ends at scalar.
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
// shallow_tree = [[True, True], [False, True]]
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
// self.assertEqual(flattened_shallow_tree, [True, True, False, True])
//# Shallow tree ends at string.
// input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
// shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// input_tree_flattened = nest.flatten(input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
// self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
// # Make sure dicts are correctly flattened, yielding values, not keys.
//input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
// shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [1, { "c": 2}, 3, (4, 5)])
// # Namedtuples.
// ab_tuple = NestTest.ABTuple
// input_tree = ab_tuple(a =[0, 1], b = 2)
// shallow_tree = ab_tuple(a= 0, b= 1)
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [[0, 1], 2])
// # Nested dicts, OrderedDicts and namedtuples.
// input_tree = collections.OrderedDict(
// [("a", ab_tuple(a =[0, {"b": 1}], b=2)),
// ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
// shallow_tree = input_tree
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
// shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [ab_tuple(a =[0, { "b": 1}], b=2),
// 3,
// collections.OrderedDict([("f", 4)])])
// shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
// input_tree)
// self.assertEqual(input_tree_flattened_as_shallow_tree,
// [ab_tuple(a =[0, {"b": 1}], b=2),
// {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
// ## Shallow non-list edge-case.
// # Using iterable elements.
// input_tree = ["input_tree"]
//shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// input_tree = ["input_tree_0", "input_tree_1"]
//shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// # Using non-iterable elements.
//input_tree = [0]
//shallow_tree = 9
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// input_tree = [0, 1]
//shallow_tree = 9
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// ## Both non-list edge-case.
//# Using iterable elements.
//input_tree = "input_tree"
// shallow_tree = "shallow_tree"
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// # Using non-iterable elements.
//input_tree = 0
// shallow_tree = 0
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_input_tree, [input_tree])
// self.assertEqual(flattened_shallow_tree, [shallow_tree])
// ## Input non-list edge-case.
//# Using iterable elements.
//input_tree = "input_tree"
// shallow_tree = ["shallow_tree"]
//expected_message = ("If shallow structure is a sequence, input must also "
// "be a sequence. Input has type: <(type|class) 'str'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
// input_tree = "input_tree"
// shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
//with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
//# Using non-iterable elements.
// input_tree = 0
// shallow_tree = [9]
//expected_message = ("If shallow structure is a sequence, input must also "
// "be a sequence. Input has type: <(type|class) 'int'>.")
// with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
// input_tree = 0
// shallow_tree = [9, 8]
//with self.assertRaisesRegexp(TypeError, expected_message):
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
// self.assertEqual(flattened_shallow_tree, shallow_tree)
// def testMapStructureUpTo(self) :
// # Named tuples.
// ab_tuple = collections.namedtuple("ab_tuple", "a, b")
// op_tuple = collections.namedtuple("op_tuple", "add, mul")
// inp_val = ab_tuple(a= 2, b= 3)
// inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
// self.assertEqual(out.a, 6)
// self.assertEqual(out.b, 15)
// # Lists.
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
// name_list = ["evens", ["odds", "primes"]]
// out = nest.map_structure_up_to(
// name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
// name_list, data_list)
// self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
// # Dicts.
// inp_val = dict(a= 2, b= 3)
// inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// self.assertEqual(out["a"], 6)
// self.assertEqual(out["b"], 15)
// # Non-equal dicts.
// inp_val = dict(a= 2, b= 3)
// inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
// with self.assertRaisesRegexp(ValueError, "same keys"):
// nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// # Dict+custom mapping.
// inp_val = dict(a= 2, b= 3)
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3))
// out = nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// self.assertEqual(out["a"], 6)
// self.assertEqual(out["b"], 15)
// # Non-equal dict/mapping.
// inp_val = dict(a= 2, b= 3)
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3))
// with self.assertRaisesRegexp(ValueError, "same keys"):
// nest.map_structure_up_to(
// inp_val,
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
// def testGetTraverseShallowStructure(self):
// scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
// scalar_traverse_r = nest.get_traverse_shallow_structure(
// lambda s: not isinstance(s, tuple),
// scalar_traverse_input)
// self.assertEqual(scalar_traverse_r,
// [True, True, False, [True, True], {"a": False}, []])
// nest.assert_shallow_structure(scalar_traverse_r,
// scalar_traverse_input)
// structure_traverse_input = [(1, [2]), ([1], 2)]
// structure_traverse_r = nest.get_traverse_shallow_structure(
// lambda s: (True, False) if isinstance(s, tuple) else True,
// structure_traverse_input)
// self.assertEqual(structure_traverse_r,
// [(True, False), ([True], False)])
// nest.assert_shallow_structure(structure_traverse_r,
// structure_traverse_input)
// with self.assertRaisesRegexp(TypeError, "returned structure"):
// nest.get_traverse_shallow_structure(lambda _: [True], 0)
// with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
// nest.get_traverse_shallow_structure(lambda _: 1, [1])
// with self.assertRaisesRegexp(
// TypeError, "didn't return a depth=1 structure of bools"):
// nest.get_traverse_shallow_structure(lambda _: [1], [1])
// def testYieldFlatStringPaths(self):
// for inputs_expected in ({"inputs": [], "expected": []},
// {"inputs": 3, "expected": [()]},
// {"inputs": [3], "expected": [(0,)]},
// {"inputs": {"a": 3}, "expected": [("a",)]},
// {"inputs": {"a": {"b": 4}},
// "expected": [("a", "b")]},
// {"inputs": [{"a": 2}], "expected": [(0, "a")]},
// {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
// {"inputs": [{"a": [(23, 42)]}],
// "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
// {"inputs": [{"a": ([23], 42)}],
// "expected": [(0, "a", 0, 0), (0, "a", 1)]},
// {"inputs": {"a": {"a": 2}, "c": [[[4]]]},
// "expected": [("a", "a"), ("c", 0, 0, 0)]},
// {"inputs": {"0": [{"1": 23}]},
// "expected": [("0", 0, "1")]}):
// inputs = inputs_expected["inputs"]
// expected = inputs_expected["expected"]
// self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
// def testFlattenWithStringPaths(self):
// for inputs_expected in (
// {"inputs": [], "expected": []},
// {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
// {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
// inputs = inputs_expected["inputs"]
// expected = inputs_expected["expected"]
// self.assertEqual(
// nest.flatten_with_joined_string_paths(inputs, separator="/"),
// expected)
// # Need a separate test for namedtuple as we can't declare tuple definitions
// # in the @parameterized arguments.
// def testFlattenNamedTuple(self):
// # pylint: disable=invalid-name
// Foo = collections.namedtuple("Foo", ["a", "b"])
// Bar = collections.namedtuple("Bar", ["c", "d"])
// # pylint: enable=invalid-name
// test_cases = [
// (Foo(a = 3, b = Bar(c = 23, d = 42)),
// [("a", 3), ("b/c", 23), ("b/d", 42)]),
// (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")),
// [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
// (Bar(c = 42, d = 43),
// [("c", 42), ("d", 43)]),
// (Bar(c =[42], d = 43),
// [("c/0", 42), ("d", 43)]),
// ]
// for inputs, expected in test_cases:
// self.assertEqual(
// list(nest.flatten_with_joined_string_paths(inputs)), expected)
// @parameterized.named_parameters(
// ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
// ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
// {"a": ("a", 4), "b": ("b", 6)}),
// ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
// ("nested",
// {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
// {"a": [("a/0", 10), ("a/1", 12)],
// "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
// def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
// def format_sum(path, * values):
// return (path, sum(values))
// result = nest.map_structure_with_paths(format_sum, s1, s2,
// check_types=check_types)
// self.assertEqual(expected, result)
// @parameterized.named_parameters(
// ("tuples", (1, 2), (3, 4, 5), ValueError),
// ("dicts", {"a": 1}, {"b": 2}, ValueError),
// ("mixed", (1, 2), [3, 4], TypeError),
// ("nested",
// {"a": [2, 3], "b": [1, 3]},
// {"b": [5, 6, 7], "a": [8, 9]},
// ValueError
// ))
// def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
// with self.assertRaises(error_type):
// nest.map_structure_with_paths(lambda path, * s: 0, s1, s2)
//class NestBenchmark(test.Benchmark):
// def run_and_report(self, s1, s2, name):
// burn_iter, test_iter = 100, 30000
// for _ in xrange(burn_iter) :
// nest.assert_same_structure(s1, s2)
// t0 = time.time()
// for _ in xrange(test_iter) :
// nest.assert_same_structure(s1, s2)
// t1 = time.time()
// self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
// name=name)
// def benchmark_assert_structure(self):
// s1 = (((1, 2), 3), 4, (5, 6))
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
// self.run_and_report(s1, s2, "assert_same_structure_6_elem")
// s1 = (((1, 2), 3), 4, (5, 6)) * 10
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
// self.run_and_report(s1, s2, "assert_same_structure_60_elem")
//if __name__ == "__main__":
// test.main()
}
}

+ 883
- 0
test/TensorFlowNET.UnitTest/nest_test/nest_test.py View File

@@ -0,0 +1,883 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with arbitrarily nested structures."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import time

from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest

try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None


class _CustomMapping(collections.Mapping):

def __init__(self, *args, **kwargs):
self._wrapped = dict(*args, **kwargs)

def __getitem__(self, key):
return self._wrapped[key]

def __iter__(self):
return iter(self._wrapped)

def __len__(self):
return len(self._wrapped)


class NestTest(parameterized.TestCase, test.TestCase):

PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name

if attr:
class BadAttr(object):
"""Class that has a non-iterable __attrs_attrs__."""
__attrs_attrs__ = None

@attr.s
class SampleAttr(object):
field1 = attr.ib()
field2 = attr.ib()

@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsFlattenAndPack(self):
if attr is None:
self.skipTest("attr module is unavailable.")

field_values = [1, 2]
sample_attr = NestTest.SampleAttr(*field_values)
self.assertFalse(nest._is_attrs(field_values))
self.assertTrue(nest._is_attrs(sample_attr))
flat = nest.flatten(sample_attr)
self.assertEqual(field_values, flat)
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
self.assertEqual(restructured_from_flat, sample_attr)

# Check that flatten fails if attributes are not iterable
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
flat = nest.flatten(NestTest.BadAttr())

@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
structure = (NestTest.PointXY(x=4, y=2),
((NestTest.PointXY(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)

self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])

with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")

with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])

@parameterized.parameters({"mapping_type": collections.OrderedDict},
{"mapping_type": _CustomMapping})
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenDictOrder(self, mapping_type):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)

@parameterized.parameters({"mapping_type": collections.OrderedDict},
{"mapping_type": _CustomMapping})
def testPackDictOrder(self, mapping_type):
"""Packing orders dicts by key, including OrderedDicts."""
custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
custom_reconstruction = nest.pack_sequence_as(custom, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
self.assertIsInstance(custom_reconstruction, mapping_type)
self.assertIsInstance(plain_reconstruction, dict)
self.assertEqual(
mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
custom_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)

Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name

@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
mess = [
"z",
NestTest.Abc(3, 4), {
"d": _CustomMapping({
41: 4
}),
"c": [
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
],
"b": 5
}, 17
]

flattened = nest.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])

structure_of_mess = [
14,
NestTest.Abc("a", True),
{
"d": _CustomMapping({
41: 42
}),
"c": [
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
],
"b": 3
},
"hi everybody",
]

unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
self.assertEqual(unflattened, mess)

# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = unflattened[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])

unflattened_custom_mapping = unflattened[2]["d"]
self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
self.assertEqual(list(unflattened_custom_mapping.keys()), [41])

def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = nest.flatten(structure)
self.assertEqual(len(flattened), 1)

def testFlatten_stringIsNotFlattened(self):
structure = "lots of letters"
flattened = nest.flatten(structure)
self.assertEqual(len(flattened), 1)
unflattened = nest.pack_sequence_as("goodbye", flattened)
self.assertEqual(structure, unflattened)

def testPackSequenceAs_notIterableError(self):
with self.assertRaisesRegexp(TypeError,
"flat_sequence must be a sequence"):
nest.pack_sequence_as("hi", "bye")

def testPackSequenceAs_wrongLengthsError(self):
with self.assertRaisesRegexp(
ValueError,
"Structure had 2 elements, but flat_sequence had 3 elements."):
nest.pack_sequence_as(["hello", "world"],
["and", "goodbye", "again"])

@test_util.assert_no_new_pyobjects_executing_eagerly
def testIsSequence(self):
self.assertFalse(nest.is_sequence("1234"))
self.assertTrue(nest.is_sequence([1, 3, [4, 5]]))
self.assertTrue(nest.is_sequence(((7, 8), (5, 6))))
self.assertTrue(nest.is_sequence([]))
self.assertTrue(nest.is_sequence({"a": 1, "b": 2}))
self.assertFalse(nest.is_sequence(set([1, 2])))
ones = array_ops.ones([2, 3])
self.assertFalse(nest.is_sequence(ones))
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))

@parameterized.parameters({"mapping_type": _CustomMapping},
{"mapping_type": dict})
def testFlattenDictItems(self, mapping_type):
dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
flat = {4: "a", 5: "b", 6: "c", 8: "d"}
self.assertEqual(nest.flatten_dict_items(dictionary), flat)

with self.assertRaises(TypeError):
nest.flatten_dict_items(4)

bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
with self.assertRaisesRegexp(ValueError, "not unique"):
nest.flatten_dict_items(bad_dictionary)

another_bad_dictionary = mapping_type({
(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
})
with self.assertRaisesRegexp(
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)

# pylint does not correctly recognize these as class names and
# suggests to use variable style under_score naming.
# pylint: disable=invalid-name
Named0ab = collections.namedtuple("named_0", ("a", "b"))
Named1ab = collections.namedtuple("named_1", ("a", "b"))
SameNameab = collections.namedtuple("same_name", ("a", "b"))
SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
SameNamexy = collections.namedtuple("same_name", ("x", "y"))
SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
# pylint: enable=invalid-name

class SameNamedType1(SameNameab):
pass

@test_util.assert_no_new_pyobjects_executing_eagerly
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))

with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
'substructure "type=str str=spam" is not\n'
"Entire first structure:\n"
r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
"Entire second structure:\n"
r"\(\., \.\)")):
nest.assert_same_structure(structure1, structure_different_num_elements)

with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
"is not")):
nest.assert_same_structure([0, 1], np.array([0, 1]))

with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
'is a sequence, while substructure "type=int str=0" '
"is not")):
nest.assert_same_structure(0, [0, 1])

self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])

with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(structure1, structure_different_nesting)

self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
NestTest.Named0ab("a", "b"))

nest.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab("a", "b"))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))

with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab([3], 4))

with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure([[3], 4], [3, [4]])

structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)

with self.assertRaisesRegexp(ValueError,
"don't have the same set of keys"):
nest.assert_same_structure({"a": 1}, {"b": 1})

nest.assert_same_structure(NestTest.SameNameab(0, 1),
NestTest.SameNameab2(2, 3))

# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
nest.assert_same_structure(
NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))

expected_message = "The two structures don't have the same.*"
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_same_structure(
NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))

self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))

EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name

def testHeterogeneousComparison(self):
nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})

@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)

self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))

self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))

# Empty structures
self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
NestTest.EmptyNT()))

# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))

with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)

with self.assertRaisesRegexp(ValueError, "at least one structure"):
nest.map_structure(lambda x: x)

with self.assertRaisesRegexp(ValueError, "same number of elements"):
nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))

with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))

with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])

with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))

structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, structure1, structure1_list)

nest.map_structure(lambda x, y: None, structure1, structure1_list,
check_types=False)

with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)

with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, foo="a")

with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")

ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name

@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructureWithStrings(self):
inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
inp_b = NestTest.ABTuple(a=2, b=(1, 3))
out = nest.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
self.assertEqual("foofoo", out.a)
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])

nt = NestTest.ABTuple(a=("something", "something_else"),
b="yet another thing")
rev_nt = nest.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
nest.assert_same_structure(nt, rev_nt)
self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
self.assertEqual(nt.b[::-1], rev_nt.b)

@test_util.run_deprecated_v1
def testMapStructureOverPlaceholders(self):
inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
array_ops.placeholder(dtypes.float32, shape=[3, 7]))
inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
array_ops.placeholder(dtypes.float32, shape=[3, 7]))

output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)

nest.assert_same_structure(output, inp_a)
self.assertShapeEqual(np.zeros((3, 4)), output[0])
self.assertShapeEqual(np.zeros((3, 7)), output[1])

feed_dict = {
inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
}

with self.cached_session() as sess:
output_np = sess.run(output, feed_dict=feed_dict)
self.assertAllClose(output_np[0],
feed_dict[inp_a][0] + feed_dict[inp_b][0])
self.assertAllClose(output_np[1],
feed_dict[inp_a][1] + feed_dict[inp_b][1])

def testAssertShallowStructure(self):
inp_ab = ["a", "b"]
inp_abc = ["a", "b", "c"]
expected_message = (
"The two structures don't have the same sequence length. Input "
"structure has length 2, while shallow structure has length 3.")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_abc, inp_ab)

inp_ab1 = [(1, 1), (2, 2)]
inp_ab2 = [[1, 1], [2, 2]]
expected_message = (
"The two structures don't have the same sequence type. Input structure "
"has type <(type|class) 'tuple'>, while shallow structure has type "
"<(type|class) 'list'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)

inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
expected_message = (
r"The two structures don't have the same keys. Input "
r"structure has keys \['c'\], while shallow structure has "
r"keys \['d'\].")

with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)

inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
nest.assert_shallow_structure(inp_ab, inp_ba)

# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
inp_shallow = NestTest.SameNameab(1, 2)
inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)

def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])

# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
input_tree_flattened = nest.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])

# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])

# Namedtuples.
ab_tuple = NestTest.ABTuple
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])

# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])

## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

input_tree = ["input_tree_0", "input_tree_1"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

input_tree = [0, 1]
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])

## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ["shallow_tree"]
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'str'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

# Using non-iterable elements.
input_tree = 0
shallow_tree = [9]
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'int'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

input_tree = 0
shallow_tree = [9, 8]
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)

def testMapStructureUpTo(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)

# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = nest.map_structure_up_to(
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])

# Dicts.
inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
self.assertEqual(out["a"], 6)
self.assertEqual(out["b"], 15)

# Non-equal dicts.
inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesRegexp(ValueError, "same keys"):
nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

# Dict+custom mapping.
inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
self.assertEqual(out["a"], 6)
self.assertEqual(out["b"], 15)

# Non-equal dict/mapping.
inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesRegexp(ValueError, "same keys"):
nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)

def testGetTraverseShallowStructure(self):
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
scalar_traverse_r = nest.get_traverse_shallow_structure(
lambda s: not isinstance(s, tuple),
scalar_traverse_input)
self.assertEqual(scalar_traverse_r,
[True, True, False, [True, True], {"a": False}, []])
nest.assert_shallow_structure(scalar_traverse_r,
scalar_traverse_input)

structure_traverse_input = [(1, [2]), ([1], 2)]
structure_traverse_r = nest.get_traverse_shallow_structure(
lambda s: (True, False) if isinstance(s, tuple) else True,
structure_traverse_input)
self.assertEqual(structure_traverse_r,
[(True, False), ([True], False)])
nest.assert_shallow_structure(structure_traverse_r,
structure_traverse_input)

with self.assertRaisesRegexp(TypeError, "returned structure"):
nest.get_traverse_shallow_structure(lambda _: [True], 0)

with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
nest.get_traverse_shallow_structure(lambda _: 1, [1])

with self.assertRaisesRegexp(
TypeError, "didn't return a depth=1 structure of bools"):
nest.get_traverse_shallow_structure(lambda _: [1], [1])

def testYieldFlatStringPaths(self):
for inputs_expected in ({"inputs": [], "expected": []},
{"inputs": 3, "expected": [()]},
{"inputs": [3], "expected": [(0,)]},
{"inputs": {"a": 3}, "expected": [("a",)]},
{"inputs": {"a": {"b": 4}},
"expected": [("a", "b")]},
{"inputs": [{"a": 2}], "expected": [(0, "a")]},
{"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
{"inputs": [{"a": [(23, 42)]}],
"expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
{"inputs": [{"a": ([23], 42)}],
"expected": [(0, "a", 0, 0), (0, "a", 1)]},
{"inputs": {"a": {"a": 2}, "c": [[[4]]]},
"expected": [("a", "a"), ("c", 0, 0, 0)]},
{"inputs": {"0": [{"1": 23}]},
"expected": [("0", 0, "1")]}):
inputs = inputs_expected["inputs"]
expected = inputs_expected["expected"]
self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)

def testFlattenWithStringPaths(self):
for inputs_expected in (
{"inputs": [], "expected": []},
{"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]},
{"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}):
inputs = inputs_expected["inputs"]
expected = inputs_expected["expected"]
self.assertEqual(
nest.flatten_with_joined_string_paths(inputs, separator="/"),
expected)

# Need a separate test for namedtuple as we can't declare tuple definitions
# in the @parameterized arguments.
def testFlattenNamedTuple(self):
# pylint: disable=invalid-name
Foo = collections.namedtuple("Foo", ["a", "b"])
Bar = collections.namedtuple("Bar", ["c", "d"])
# pylint: enable=invalid-name
test_cases = [
(Foo(a=3, b=Bar(c=23, d=42)),
[("a", 3), ("b/c", 23), ("b/d", 42)]),
(Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")),
[("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]),
(Bar(c=42, d=43),
[("c", 42), ("d", 43)]),
(Bar(c=[42], d=43),
[("c/0", 42), ("d", 43)]),
]
for inputs, expected in test_cases:
self.assertEqual(
list(nest.flatten_with_joined_string_paths(inputs)), expected)

@parameterized.named_parameters(
("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
{"a": ("a", 4), "b": ("b", 6)}),
("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
("nested",
{"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
{"a": [("a/0", 10), ("a/1", 12)],
"b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
def format_sum(path, *values):
return (path, sum(values))
result = nest.map_structure_with_paths(format_sum, s1, s2,
check_types=check_types)
self.assertEqual(expected, result)

@parameterized.named_parameters(
("tuples", (1, 2), (3, 4, 5), ValueError),
("dicts", {"a": 1}, {"b": 2}, ValueError),
("mixed", (1, 2), [3, 4], TypeError),
("nested",
{"a": [2, 3], "b": [1, 3]},
{"b": [5, 6, 7], "a": [8, 9]},
ValueError
))
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
with self.assertRaises(error_type):
nest.map_structure_with_paths(lambda path, *s: 0, s1, s2)


class NestBenchmark(test.Benchmark):

def run_and_report(self, s1, s2, name):
burn_iter, test_iter = 100, 30000

for _ in xrange(burn_iter):
nest.assert_same_structure(s1, s2)

t0 = time.time()
for _ in xrange(test_iter):
nest.assert_same_structure(s1, s2)
t1 = time.time()

self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
name=name)

def benchmark_assert_structure(self):
s1 = (((1, 2), 3), 4, (5, 6))
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
self.run_and_report(s1, s2, "assert_same_structure_6_elem")

s1 = (((1, 2), 3), 4, (5, 6)) * 10
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
self.run_and_report(s1, s2, "assert_same_structure_60_elem")


if __name__ == "__main__":
test.main()

Loading…
Cancel
Save