@@ -0,0 +1,871 @@ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using NumSharp; | |||
namespace Tensorflow.Util | |||
{ | |||
//Functions for working with arbitrarily nested sequences of elements. | |||
//This module can perform operations on nested structures. A nested structure is a | |||
//Python sequence, tuple (including `namedtuple`), or dict that can contain | |||
//further sequences, tuples, and dicts. | |||
//The utilities here assume (and do not check) that the nested structures form a | |||
//'tree', i.e., no references in the structure of the input of these functions | |||
//should be recursive. | |||
//Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), | |||
// (np.array([3, 4]), tf.constant([3, 4])))` | |||
// | |||
public static class nest | |||
{ | |||
//def _get_attrs_values(obj): | |||
// """Returns the list of values from an attrs instance.""" | |||
// attrs = getattr(obj.__class__, "__attrs_attrs__") | |||
// return [getattr(obj, a.name) for a in attrs] | |||
/// <summary> | |||
/// Returns a sorted list of the dict keys, with error if keys not sortable. | |||
/// </summary> | |||
private static IEnumerable<string> _sorted(IDictionary dict_) | |||
{ | |||
return dict_.Keys.OfType<string>().OrderBy(x => x); | |||
} | |||
//def _is_namedtuple(instance, strict=False): | |||
// """Returns True iff `instance` is a `namedtuple`. | |||
// Args: | |||
// instance: An instance of a Python object. | |||
// strict: If True, `instance` is considered to be a `namedtuple` only if | |||
// it is a "plain" namedtuple. For instance, a class inheriting | |||
// from a `namedtuple` will be considered to be a `namedtuple` | |||
// iff `strict=False`. | |||
// Returns: | |||
// True if `instance` is a `namedtuple`. | |||
// """ | |||
// return _pywrap_tensorflow.IsNamedtuple(instance, strict) | |||
//# See the swig file (util.i) for documentation. | |||
//_is_mapping = _pywrap_tensorflow.IsMapping | |||
//_is_attrs = _pywrap_tensorflow.IsAttrs | |||
/// <summary> | |||
/// Converts the sequence `args` to the same type as `instance`. | |||
/// </summary> | |||
/// <param name="instance">an instance of `tuple`, `list`, `namedtuple`, `dict`, or | |||
/// `collections.OrderedDict`.</param> | |||
/// <param name="args">elements to be converted to the `instance` type.</param> | |||
/// <returns>`args` with the type of `instance`.</returns> | |||
private static object _sequence_like(object instance, IEnumerable<object> args) | |||
{ | |||
if (is_mapping(instance)) | |||
{ | |||
//# Pack dictionaries in a deterministic order by sorting the keys. | |||
//# Notice this means that we ignore the original order of `OrderedDict` | |||
//# instances. This is intentional, to avoid potential bugs caused by mixing | |||
//# ordered and plain dicts (e.g., flattening a dict but using a | |||
//# corresponding `OrderedDict` to pack it back). | |||
// result = dict(zip(_sorted(instance), args)) | |||
// return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) | |||
} | |||
//else if( _is_namedtuple(instance) || _is_attrs(instance)) | |||
// return type(instance)(*args) | |||
else | |||
{ | |||
// Not a namedtuple | |||
switch (instance) | |||
{ | |||
case object[] array: | |||
var result_array = new object[args.Count()]; | |||
int i = 0; | |||
foreach (var x in args) | |||
{ | |||
result_array[i] = x; | |||
i++; | |||
} | |||
return result_array; | |||
case List<object> list: | |||
return new List<object>(args); | |||
default: | |||
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||
} | |||
} | |||
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||
} | |||
/// <summary> | |||
/// Yields the next value from the given iterable. | |||
/// </summary> | |||
private static IEnumerable<object> _yield_value(object iterable) | |||
{ | |||
if (is_mapping(iterable)) | |||
{ | |||
var dict = iterable as IDictionary; | |||
//# Iterate through dictionaries in a deterministic order by sorting the | |||
//# keys. Notice this means that we ignore the original order of `OrderedDict` | |||
//# instances. This is intentional, to avoid potential bugs caused by mixing | |||
//# ordered and plain dicts (e.g., flattening a dict but using a | |||
//# corresponding `OrderedDict` to pack it back). | |||
foreach (var key in _sorted(dict)) | |||
yield return dict[key]; | |||
} | |||
//else if (_is_attrs(iterable)) | |||
//{ | |||
// // for value in _get_attrs_values(iterable): | |||
// // yield value | |||
//} | |||
else if (iterable is IEnumerable) | |||
{ | |||
var enumerable = iterable as IEnumerable; | |||
foreach (var value in enumerable) | |||
yield return value; | |||
} | |||
else | |||
{ | |||
throw new TypeError("Unexpected iterable type: " + iterable.GetType()); | |||
//var jobj = JObject.FromObject(iterable); | |||
//foreach (var key in _sorted()) | |||
// yield return jobj[key]; | |||
} | |||
} | |||
//# See the swig file (util.i) for documentation. | |||
public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string); | |||
public static bool is_mapping(object arg) => arg is IDictionary; | |||
//# See the swig file (util.i) for documentation. | |||
//flatten = _pywrap_tensorflow.Flatten | |||
public static List<object> flatten(object structure) | |||
{ | |||
var list = new List<object>(); | |||
_flatten_recursive(structure, list); | |||
return list; | |||
} | |||
private static void _flatten_recursive(object obj, List<object> list) | |||
{ | |||
if (obj is string) | |||
{ | |||
list.Add(obj); | |||
return; | |||
} | |||
if (obj is IDictionary) | |||
{ | |||
var dict = obj as IDictionary; | |||
foreach (var key in _sorted(dict)) | |||
_flatten_recursive(dict[key], list); | |||
return; | |||
} | |||
if (obj is NDArray) | |||
{ | |||
list.Add(obj); | |||
return; | |||
} | |||
if (obj is IEnumerable) | |||
{ | |||
var structure = obj as IEnumerable; | |||
foreach (var child in structure) | |||
_flatten_recursive(child, list); | |||
return; | |||
} | |||
list.Add(obj); | |||
} | |||
//# See the swig file (util.i) for documentation. | |||
//_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | |||
//class _DotString(object): | |||
// def __str__(self): | |||
// return "." | |||
// def __repr__(self): | |||
// return "." | |||
//_DOT = _DotString() | |||
//def assert_same_structure(nest1, nest2, check_types=True): | |||
// """Asserts that two structures are nested in the same way. | |||
// Note that namedtuples with identical name and fields are always considered | |||
// to have the same shallow structure (even with `check_types=True`). | |||
// For intance, this code will print `True`: | |||
// ```python | |||
// def nt(a, b): | |||
// return collections.namedtuple('foo', 'a b')(a, b) | |||
// print(assert_same_structure(nt(0, 1), nt(2, 3))) | |||
// ``` | |||
// Args: | |||
// nest1: an arbitrarily nested structure. | |||
// nest2: an arbitrarily nested structure. | |||
// check_types: if `True` (default) types of sequences are checked as well, | |||
// including the keys of dictionaries. If set to `False`, for example a | |||
// list and a tuple of objects will look the same if they have the same | |||
// size. Note that namedtuples with identical name and fields are always | |||
// considered to have the same shallow structure. Two types will also be | |||
// considered the same if they are both list subtypes (which allows "list" | |||
// and "_ListWrapper" from checkpointable dependency tracking to compare | |||
// equal). | |||
// Raises: | |||
// ValueError: If the two structures do not have the same number of elements or | |||
// if the two structures are not nested in the same way. | |||
// TypeError: If the two structures differ in the type of sequence in any of | |||
// their substructures. Only possible if `check_types` is `True`. | |||
// """ | |||
// try: | |||
// _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) | |||
// except (ValueError, TypeError) as e: | |||
// str1 = str(map_structure(lambda _: _DOT, nest1)) | |||
// str2 = str(map_structure(lambda _: _DOT, nest2)) | |||
// raise type(e)("%s\n" | |||
// "Entire first structure:\n%s\n" | |||
// "Entire second structure:\n%s" | |||
// % (str(e), str1, str2)) | |||
//def flatten_dict_items(dictionary): | |||
// """Returns a dictionary with flattened keys and values. | |||
// This function flattens the keys and values of a dictionary, which can be | |||
// arbitrarily nested structures, and returns the flattened version of such | |||
// structures: | |||
// ```python | |||
// example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} | |||
// result = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
// flatten_dict_items(example_dictionary) == result | |||
// ``` | |||
// The input dictionary must satisfy two properties: | |||
// 1. Its keys and values should have the same exact nested structure. | |||
// 2. The set of all flattened keys of the dictionary must not contain repeated | |||
// keys. | |||
// Args: | |||
// dictionary: the dictionary to zip | |||
// Returns: | |||
// The zipped dictionary. | |||
// Raises: | |||
// TypeError: If the input is not a dictionary. | |||
// ValueError: If any key and value have not the same structure, or if keys are | |||
// not unique. | |||
// """ | |||
// if not isinstance(dictionary, (dict, _collections.Mapping)): | |||
// raise TypeError("input must be a dictionary") | |||
// flat_dictionary = {} | |||
// for i, v in _six.iteritems(dictionary): | |||
// if not is_sequence(i): | |||
// if i in flat_dictionary: | |||
// raise ValueError( | |||
// "Could not flatten dictionary: key %s is not unique." % i) | |||
// flat_dictionary[i] = v | |||
// else: | |||
// flat_i = flatten(i) | |||
// flat_v = flatten(v) | |||
// if len(flat_i) != len(flat_v): | |||
// raise ValueError( | |||
// "Could not flatten dictionary. Key had %d elements, but value had " | |||
// "%d elements. Key: %s, value: %s." | |||
// % (len(flat_i), len(flat_v), flat_i, flat_v)) | |||
// for new_i, new_v in zip(flat_i, flat_v): | |||
// if new_i in flat_dictionary: | |||
// raise ValueError( | |||
// "Could not flatten dictionary: key %s is not unique." | |||
// % (new_i)) | |||
// flat_dictionary[new_i] = new_v | |||
// return flat_dictionary | |||
/// <summary> | |||
/// Helper function for pack_sequence_as. | |||
/// </summary> | |||
/// <param name="structure">Substructure (list / tuple / dict) to mimic.</param> | |||
/// <param name="flat">Flattened values to output substructure for.</param> | |||
/// <param name="index">Index at which to start reading from flat.</param> | |||
/// <returns> | |||
/// The tuple(new_index, child), where: | |||
/// * new_index - the updated index into `flat` having processed `structure`. | |||
/// * packed - the subset of `flat` corresponding to `structure`, | |||
/// having started at `index`, and packed into the same nested | |||
/// format.</returns> | |||
private static (int new_index, List<object> child) _packed_nest_with_indices(object structure, List<object> flat, | |||
int index) | |||
{ | |||
var packed = new List<object>(); | |||
foreach (var s in _yield_value(structure)) | |||
{ | |||
if (is_sequence(s)) | |||
{ | |||
var (new_index, child) = _packed_nest_with_indices(s, flat, index); | |||
packed.Add(_sequence_like(s, child)); | |||
index = new_index; | |||
} | |||
else | |||
{ | |||
packed.Add(flat[index]); | |||
index += 1; | |||
} | |||
} | |||
return (index, packed); | |||
} | |||
private static int len(IEnumerable<object> x) => x.Count(); | |||
/// <summary> | |||
/// Returns a given flattened sequence packed into a given structure. | |||
/// If `structure` is a scalar, `flat_sequence` must be a single-element list; | |||
/// in this case the return value is `flat_sequence[0]`. | |||
/// | |||
/// If `structure` is or contains a dict instance, the keys will be sorted to | |||
/// pack the flat sequence in deterministic order. This is true also for | |||
/// `OrderedDict` instances: their sequence order is ignored, the sorting order of | |||
/// keys is used instead. The same convention is followed in `flatten`. | |||
/// This correctly repacks dicts and `OrderedDict`s after they have been | |||
/// flattened, and also allows flattening an `OrderedDict` and then repacking it | |||
/// back using a corresponding plain dict, or vice-versa. | |||
/// Dictionaries with non-sortable keys cannot be flattened. | |||
/// </summary> | |||
/// <param name="structure"> | |||
/// Nested structure, whose structure is given by nested lists, | |||
/// tuples, and dicts. Note: numpy arrays and strings are considered | |||
/// scalars. | |||
/// </param> | |||
/// <param name="flat_sequence"> flat sequence to pack.</param> | |||
/// <returns> `flat_sequence` converted to have the same recursive structure as | |||
/// `structure`. | |||
/// </returns> | |||
public static object pack_sequence_as(object structure, List<object> flat_sequence) | |||
{ | |||
if (flat_sequence == null) | |||
throw new ArgumentException("flat_sequence must not be null"); | |||
// if not is_sequence(flat_sequence): | |||
// raise TypeError("flat_sequence must be a sequence") | |||
if (!is_sequence(structure)) | |||
{ | |||
if (len(flat_sequence) != 1) | |||
throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1"); | |||
return flat_sequence.FirstOrDefault(); | |||
} | |||
int final_index = 0; | |||
List<object> packed = null; | |||
try | |||
{ | |||
(final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0); | |||
if (final_index < len(flat_sequence)) | |||
throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }"); | |||
} | |||
catch (IndexOutOfRangeException) | |||
{ | |||
var flat_structure = flatten(structure); | |||
if (len(flat_structure) != len(flat_sequence)) | |||
{ | |||
throw new ValueError("Could not pack sequence. Structure had %d elements, but " + | |||
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}"); | |||
} | |||
return _sequence_like(structure, packed); | |||
} | |||
return packed; | |||
} | |||
/// <summary> | |||
/// Applies `func` to each entry in `structure` and returns a new structure. | |||
/// | |||
/// Applies `func(x[0], x[1], ...)` where x[i] is an entry in | |||
/// `structure[i]`. All structures in `structure` must have the same arity, | |||
/// and the return value will contain the results in the same structure. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <typeparam name="U"></typeparam> | |||
/// <param name="func"> A callable that accepts as many arguments as there are structures.</param> | |||
/// <param name="structure">scalar, or tuple or list of constructed scalars and/or other | |||
/// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param> | |||
/// <param name="check_types">If set to | |||
/// `True` (default) the types of iterables within the structures have to be | |||
/// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` | |||
/// exception). To allow this set this argument to `False`. | |||
/// Note that namedtuples with identical name and fields are always | |||
/// considered to have the same shallow structure.</param> | |||
/// <returns> | |||
/// A new structure with the same arity as `structure`, whose values correspond | |||
/// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding | |||
/// location in `structure[i]`. If there are different sequence types and | |||
/// `check_types` is `False` the sequence types of the first structure will be | |||
/// used. | |||
/// </returns> | |||
public static IEnumerable<U> map_structure<T, U>(Func<T, U> func, IEnumerable<T> structure, bool check_types = false) | |||
{ | |||
// for other in structure[1:]: | |||
// assert_same_structure(structure[0], other, check_types=check_types) | |||
// flat_structure = [flatten(s) for s in structure] | |||
// entries = zip(*flat_structure) | |||
// return pack_sequence_as( | |||
// structure[0], [func(*x) for x in entries]) | |||
return null; | |||
} | |||
//def map_structure_with_paths(func, *structure, **kwargs): | |||
// """Applies `func` to each entry in `structure` and returns a new structure. | |||
// Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in | |||
// `structure[i]` and `path` is the common path to x[i] in the structures. All | |||
// structures in `structure` must have the same arity, and the return value will | |||
// contain the results in the same structure. Special kwarg `check_types` | |||
// determines whether the types of iterables within the structure must be the | |||
// same-- see **kwargs definition below. | |||
// Args: | |||
// func: A callable with the signature func(path, *values, **kwargs) that is | |||
// evaluated on the leaves of the structure. | |||
// *structure: A variable number of compatible structures to process. | |||
// **kwargs: Optional kwargs to be passed through to func. Special kwarg | |||
// `check_types` is not passed to func, but instead determines whether the | |||
// types of iterables within the structures have to be same (e.g., | |||
// `map_structure(func, [1], (1,))` raises a `TypeError` exception). By | |||
// default, the types must match. To allow iteration over structures of | |||
// different types (but common arity), set this kwarg to `False`. | |||
// Returns: | |||
// A structure of the same form as the input structures whose leaves are the | |||
// result of evaluating func on corresponding leaves of the input structures. | |||
// Raises: | |||
// TypeError: If `func` is not callable or if the structures do not match | |||
// each other by depth tree. | |||
// TypeError: If `check_types` is not `False` and the two structures differ in | |||
// the type of sequence in any of their substructures. | |||
// ValueError: If no structures are provided. | |||
// """ | |||
// if not callable(func): | |||
// raise TypeError("func must be callable, got: %s" % func) | |||
// if not structure: | |||
// raise ValueError("Must provide at least one structure") | |||
// check_types = kwargs.pop("check_types", True) | |||
// for other in structure[1:]: | |||
// assert_same_structure(structure[0], other, check_types=check_types) | |||
//# First set paths_and_values to: | |||
//# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] | |||
// paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] | |||
//# Now zip(*paths_and_values) would be: | |||
//# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] | |||
//# so grouped_by_path is set to: | |||
//# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] | |||
//# Note that p1i, ... pmi must all be equal since the structures are the same. | |||
// grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] | |||
// return pack_sequence_as(structure[0], [ | |||
// func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) | |||
//def _yield_flat_up_to(shallow_tree, input_tree): | |||
// """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" | |||
// if is_sequence(shallow_tree): | |||
// for shallow_branch, input_branch in zip(_yield_value(shallow_tree), | |||
// _yield_value(input_tree)): | |||
// for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): | |||
// yield input_leaf | |||
// else: | |||
// yield input_tree | |||
//def assert_shallow_structure(shallow_tree, input_tree, check_types=True): | |||
// """Asserts that `shallow_tree` is a shallow structure of `input_tree`. | |||
// That is, this function tests if the `input_tree` structure can be created from | |||
// the `shallow_tree` structure by replacing its leaf nodes with deeper | |||
// tree structures. | |||
// Examples: | |||
// The following code will raise an exception: | |||
// ```python | |||
// shallow_tree = ["a", "b"] | |||
// input_tree = ["c", ["d", "e"], "f"] | |||
// assert_shallow_structure(shallow_tree, input_tree) | |||
// ``` | |||
// The following code will not raise an exception: | |||
// ```python | |||
// shallow_tree = ["a", "b"] | |||
// input_tree = ["c", ["d", "e"]] | |||
// assert_shallow_structure(shallow_tree, input_tree) | |||
// ``` | |||
// Args: | |||
// shallow_tree: an arbitrarily nested structure. | |||
// input_tree: an arbitrarily nested structure. | |||
// check_types: if `True` (default) the sequence types of `shallow_tree` and | |||
// `input_tree` have to be the same. Note that even with check_types==True, | |||
// this function will consider two different namedtuple classes with the same | |||
// name and _fields attribute to be the same class. | |||
// Raises: | |||
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||
// TypeError: If the sequence types of `shallow_tree` are different from | |||
// `input_tree`. Only raised if `check_types` is `True`. | |||
// ValueError: If the sequence lengths of `shallow_tree` are different from | |||
// `input_tree`. | |||
// """ | |||
// if is_sequence(shallow_tree): | |||
// if not is_sequence(input_tree): | |||
// raise TypeError( | |||
// "If shallow structure is a sequence, input must also be a sequence. " | |||
// "Input has type: %s." % type(input_tree)) | |||
// if check_types and not isinstance(input_tree, type(shallow_tree)): | |||
//# Duck-typing means that nest should be fine with two different | |||
//# namedtuples with identical name and fields. | |||
// shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) | |||
// input_is_namedtuple = _is_namedtuple(input_tree, False) | |||
// if shallow_is_namedtuple and input_is_namedtuple: | |||
// if not _same_namedtuples(shallow_tree, input_tree): | |||
// raise TypeError( | |||
// "The two namedtuples don't have the same sequence type. Input " | |||
// "structure has type %s, while shallow structure has type %s." | |||
// % (type(input_tree), type(shallow_tree))) | |||
// elif not (isinstance(shallow_tree, _collections.Mapping) | |||
// and isinstance(input_tree, _collections.Mapping)): | |||
// raise TypeError( | |||
// "The two structures don't have the same sequence type. Input " | |||
// "structure has type %s, while shallow structure has type %s." | |||
// % (type(input_tree), type(shallow_tree))) | |||
// if len(input_tree) != len(shallow_tree): | |||
// raise ValueError( | |||
// "The two structures don't have the same sequence length. Input " | |||
// "structure has length %s, while shallow structure has length %s." | |||
// % (len(input_tree), len(shallow_tree))) | |||
// if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): | |||
// if set(input_tree) != set(shallow_tree): | |||
// raise ValueError( | |||
// "The two structures don't have the same keys. Input " | |||
// "structure has keys %s, while shallow structure has keys %s." % | |||
// (list(_six.iterkeys(input_tree)), | |||
// list(_six.iterkeys(shallow_tree)))) | |||
// input_tree = list(sorted(_six.iteritems(input_tree))) | |||
// shallow_tree = list(sorted(_six.iteritems(shallow_tree))) | |||
// for shallow_branch, input_branch in zip(shallow_tree, input_tree): | |||
// assert_shallow_structure(shallow_branch, input_branch, | |||
// check_types=check_types) | |||
//def flatten_up_to(shallow_tree, input_tree): | |||
// """Flattens `input_tree` up to `shallow_tree`. | |||
// Any further depth in structure in `input_tree` is retained as elements in the | |||
// partially flatten output. | |||
// If `shallow_tree` and `input_tree` are not sequences, this returns a | |||
// single-element list: `[input_tree]`. | |||
// Use Case: | |||
// Sometimes we may wish to partially flatten a nested sequence, retaining some | |||
// of the nested structure. We achieve this by specifying a shallow structure, | |||
// `shallow_tree`, we wish to flatten up to. | |||
// The input, `input_tree`, can be thought of as having the same structure as | |||
// `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||
// Examples: | |||
// ```python | |||
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
// shallow_tree = [[True, True], [False, True]] | |||
// flattened_input_tree = flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) | |||
//# Output is: | |||
//# [[2, 2], [3, 3], [4, 9], [5, 5]] | |||
//# [True, True, False, True] | |||
// ``` | |||
// ```python | |||
// input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] | |||
// shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] | |||
// input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) | |||
// input_tree_flattened = flatten(input_tree) | |||
//# Output is: | |||
//# [('a', 1), ('b', 2), ('c', 3), ('d', 4)] | |||
//# ['a', 1, 'b', 2, 'c', 3, 'd', 4] | |||
// ``` | |||
// Non-Sequence Edge Cases: | |||
// ```python | |||
// flatten_up_to(0, 0) # Output: [0] | |||
// flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] | |||
// flatten_up_to([0, 1, 2], 0) # Output: TypeError | |||
// flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] | |||
// ``` | |||
// Args: | |||
// shallow_tree: a possibly pruned structure of input_tree. | |||
// input_tree: an arbitrarily nested structure or a scalar object. | |||
// Note, numpy arrays are considered scalars. | |||
// Returns: | |||
// A Python list, the partially flattened version of `input_tree` according to | |||
// the structure of `shallow_tree`. | |||
// Raises: | |||
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||
// TypeError: If the sequence types of `shallow_tree` are different from | |||
// `input_tree`. | |||
// ValueError: If the sequence lengths of `shallow_tree` are different from | |||
// `input_tree`. | |||
// """ | |||
// assert_shallow_structure(shallow_tree, input_tree) | |||
// return list(_yield_flat_up_to(shallow_tree, input_tree)) | |||
//def map_structure_up_to(shallow_tree, func, *inputs): | |||
// """Applies a function or op to a number of partially flattened inputs. | |||
// The `inputs` are flattened up to `shallow_tree` before being mapped. | |||
// Use Case: | |||
// Sometimes we wish to apply a function to a partially flattened | |||
// sequence (for example when the function itself takes sequence inputs). We | |||
// achieve this by specifying a shallow structure, `shallow_tree` we wish to | |||
// flatten up to. | |||
// The `inputs`, can be thought of as having the same structure as | |||
// `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||
// This function therefore will return something with the same base structure as | |||
// `shallow_tree`. | |||
// Examples: | |||
// ```python | |||
// ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
// op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
// inp_val = ab_tuple(a=2, b=3) | |||
// inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||
// out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, | |||
// inp_val, inp_ops) | |||
//# Output is: ab_tuple(a=6, b=15) | |||
// ``` | |||
// ```python | |||
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
// name_list = ['evens', ['odds', 'primes']] | |||
// out = map_structure_up_to( | |||
// name_list, | |||
// lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
// name_list, data_list) | |||
//# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] | |||
// ``` | |||
// Args: | |||
// shallow_tree: a shallow tree, common to all the inputs. | |||
// func: callable which will be applied to each input individually. | |||
// *inputs: arbitrarily nested combination of objects that are compatible with | |||
// shallow_tree. The function `func` is applied to corresponding | |||
// partially flattened elements of each input, so the function must support | |||
// arity of `len(inputs)`. | |||
// Raises: | |||
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||
// TypeError: If the sequence types of `shallow_tree` are different from | |||
// `input_tree`. | |||
// ValueError: If the sequence lengths of `shallow_tree` are different from | |||
// `input_tree`. | |||
// Returns: | |||
// result of repeatedly applying `func`, with same structure as | |||
// `shallow_tree`. | |||
// """ | |||
// if not inputs: | |||
// raise ValueError("Cannot map over no sequences") | |||
// for input_tree in inputs: | |||
// assert_shallow_structure(shallow_tree, input_tree) | |||
//# Flatten each input separately, apply the function to corresponding elements, | |||
//# then repack based on the structure of the first input. | |||
// all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) | |||
// for input_tree in inputs] | |||
// results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] | |||
// return pack_sequence_as(structure=shallow_tree, flat_sequence=results) | |||
//def get_traverse_shallow_structure(traverse_fn, structure): | |||
// """Generates a shallow structure from a `traverse_fn` and `structure`. | |||
// `traverse_fn` must accept any possible subtree of `structure` and return | |||
// a depth=1 structure containing `True` or `False` values, describing which | |||
// of the top-level subtrees may be traversed. It may also | |||
// return scalar `True` or `False` "traversal is OK / not OK for all subtrees." | |||
// Examples are available in the unit tests (nest_test.py). | |||
// Args: | |||
// traverse_fn: Function taking a substructure and returning either a scalar | |||
// `bool` (whether to traverse that substructure or not) or a depth=1 | |||
// shallow structure of the same type, describing which parts of the | |||
// substructure to traverse. | |||
// structure: The structure to traverse. | |||
// Returns: | |||
// A shallow structure containing python bools, which can be passed to | |||
// `map_structure_up_to` and `flatten_up_to`. | |||
// Raises: | |||
// TypeError: if `traverse_fn` returns a sequence for a non-sequence input, | |||
// or a structure with depth higher than 1 for a sequence input, | |||
// or if any leaf values in the returned structure or scalar are not type | |||
// `bool`. | |||
// """ | |||
// to_traverse = traverse_fn(structure) | |||
// if not is_sequence(structure): | |||
// if not isinstance(to_traverse, bool): | |||
// raise TypeError("traverse_fn returned structure: %s for non-structure: %s" | |||
// % (to_traverse, structure)) | |||
// return to_traverse | |||
// level_traverse = [] | |||
// if isinstance(to_traverse, bool): | |||
// if not to_traverse: | |||
//# Do not traverse this substructure at all. Exit early. | |||
// return False | |||
// else: | |||
//# Traverse the entire substructure. | |||
// for branch in _yield_value(structure): | |||
// level_traverse.append( | |||
// get_traverse_shallow_structure(traverse_fn, branch)) | |||
// elif not is_sequence(to_traverse): | |||
// raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" | |||
// % (to_traverse, structure)) | |||
// else: | |||
//# Traverse some subset of this substructure. | |||
// assert_shallow_structure(to_traverse, structure) | |||
// for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): | |||
// if not isinstance(t, bool): | |||
// raise TypeError( | |||
// "traverse_fn didn't return a depth=1 structure of bools. saw: %s " | |||
// " for structure: %s" % (to_traverse, structure)) | |||
// if t: | |||
// level_traverse.append( | |||
// get_traverse_shallow_structure(traverse_fn, branch)) | |||
// else: | |||
// level_traverse.append(False) | |||
// return _sequence_like(structure, level_traverse) | |||
//def yield_flat_paths(nest): | |||
// """Yields paths for some nested structure. | |||
// Paths are lists of objects which can be str-converted, which may include | |||
// integers or other types which are used as indices in a dict. | |||
// The flat list will be in the corresponding order as if you called | |||
// `snt.nest.flatten` on the structure. This is handy for naming Tensors such | |||
// the TF scope structure matches the tuple structure. | |||
// E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` | |||
// ```shell | |||
// >>> nest.flatten(value) | |||
// [3, 23, 42] | |||
// >>> list(nest.yield_flat_paths(value)) | |||
// [('a',), ('b', 'c'), ('b', 'd')] | |||
// ``` | |||
// ```shell | |||
// >>> list(nest.yield_flat_paths({'a': [3]})) | |||
// [('a', 0)] | |||
// >>> list(nest.yield_flat_paths({'a': 3})) | |||
// [('a',)] | |||
// ``` | |||
// Args: | |||
// nest: the value to produce a flattened paths list for. | |||
// Yields: | |||
// Tuples containing index or key values which form the path to a specific | |||
// leaf value in the nested structure. | |||
// """ | |||
//# The _maybe_add_final_path_element function is used below in order to avoid | |||
//# adding trailing slashes when the sub-element recursed into is a leaf. | |||
// if isinstance(nest, (dict, _collections.Mapping)): | |||
// for key in _sorted(nest): | |||
// value = nest[key] | |||
// for sub_path in yield_flat_paths(value): | |||
// yield (key,) + sub_path | |||
// elif _is_namedtuple(nest): | |||
// for key in nest._fields: | |||
// value = getattr(nest, key) | |||
// for sub_path in yield_flat_paths(value): | |||
// yield (key,) + sub_path | |||
// elif isinstance(nest, _six.string_types): | |||
// yield () | |||
// elif isinstance(nest, _collections.Sequence): | |||
// for idx, value in enumerate(nest): | |||
// for sub_path in yield_flat_paths(value): | |||
// yield (idx,) + sub_path | |||
// else: | |||
// yield () | |||
//def flatten_with_joined_string_paths(structure, separator="/"): | |||
// """Returns a list of (string path, data element) tuples. | |||
// The order of tuples produced matches that of `nest.flatten`. This allows you | |||
// to flatten a nested structure while keeping information about where in the | |||
// structure each data element was located. See `nest.yield_flat_paths` | |||
// for more information. | |||
// Args: | |||
// structure: the nested structure to flatten. | |||
// separator: string to separate levels of hierarchy in the results, defaults | |||
// to '/'. | |||
// Returns: | |||
// A list of (string, data element) tuples. | |||
// """ | |||
// flat_paths = yield_flat_paths(structure) | |||
// def stringify_and_join(path_elements): | |||
// return separator.join(str(path_element) for path_element in path_elements) | |||
// flat_string_paths = [stringify_and_join(path) for path in flat_paths] | |||
// return list(zip(flat_string_paths, flatten(structure))) | |||
} | |||
} |
@@ -0,0 +1,852 @@ | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Newtonsoft.Json.Linq; | |||
using Tensorflow; | |||
using Tensorflow.Util; | |||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/util/nest_test.py | |||
/// </summary> | |||
[TestClass] | |||
public class NestTest : PythonTest | |||
{ | |||
public class PointXY | |||
{ | |||
public double x; | |||
public double y; | |||
} | |||
// if attr: | |||
// class BadAttr(object): | |||
// """Class that has a non-iterable __attrs_attrs__.""" | |||
// __attrs_attrs__ = None | |||
// @attr.s | |||
// class SampleAttr(object): | |||
// field1 = attr.ib() | |||
// field2 = attr.ib() | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testAttrsFlattenAndPack(self) : | |||
// if attr is None: | |||
// self.skipTest("attr module is unavailable.") | |||
// field_values = [1, 2] | |||
// sample_attr = NestTest.SampleAttr(* field_values) | |||
// self.assertFalse(nest._is_attrs(field_values)) | |||
// self.assertTrue(nest._is_attrs(sample_attr)) | |||
// flat = nest.flatten(sample_attr) | |||
// self.assertEqual(field_values, flat) | |||
// restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||
// self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||
// self.assertEqual(restructured_from_flat, sample_attr) | |||
//# Check that flatten fails if attributes are not iterable | |||
// with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||
// flat = nest.flatten(NestTest.BadAttr()) | |||
[TestMethod] | |||
public void testFlattenAndPack() | |||
{ | |||
object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}}; | |||
var flat = new List<object> {"a", "b", "c", "d", "e", "f", "g", "h"}; | |||
self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8}); | |||
self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), | |||
JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString()); | |||
structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }}; | |||
flat = new List<object> { 4, 2, 1, 0}; | |||
self.assertEqual(nest.flatten(structure), flat); | |||
// restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||
// self.assertEqual(restructured_from_flat, structure) | |||
// self.assertEqual(restructured_from_flat[0].x, 4) | |||
// self.assertEqual(restructured_from_flat[0].y, 2) | |||
// self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||
// self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||
// self.assertEqual([5], nest.flatten(5)) | |||
// self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||
// self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||
// self.assertEqual( | |||
// np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||
// with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||
// nest.pack_sequence_as("scalar", [4, 5]) | |||
// with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||
// nest.pack_sequence_as([4, 5], "bad_sequence") | |||
// with self.assertRaises(ValueError): | |||
// nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||
} | |||
// @parameterized.parameters({"mapping_type": collections.OrderedDict | |||
// }, | |||
// {"mapping_type": _CustomMapping | |||
//}) | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testFlattenDictOrder(self, mapping_type) : | |||
// """`flatten` orders dicts by key, including OrderedDicts.""" | |||
// ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||
// plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||
// ordered_flat = nest.flatten(ordered) | |||
// plain_flat = nest.flatten(plain) | |||
// self.assertEqual([0, 1, 2, 3], ordered_flat) | |||
// self.assertEqual([0, 1, 2, 3], plain_flat) | |||
// @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
// {"mapping_type": _CustomMapping}) | |||
// def testPackDictOrder(self, mapping_type): | |||
// """Packing orders dicts by key, including OrderedDicts.""" | |||
// custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||
// plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||
// seq = [0, 1, 2, 3] | |||
//custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||
//plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||
// self.assertIsInstance(custom_reconstruction, mapping_type) | |||
// self.assertIsInstance(plain_reconstruction, dict) | |||
// self.assertEqual( | |||
// mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||
// custom_reconstruction) | |||
// self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||
// Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testFlattenAndPack_withDicts(self) : | |||
// # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||
// mess = [ | |||
// "z", | |||
// NestTest.Abc(3, 4), { | |||
// "d": _CustomMapping({ | |||
// 41: 4 | |||
// }), | |||
// "c": [ | |||
// 1, | |||
// collections.OrderedDict([ | |||
// ("b", 3), | |||
// ("a", 2), | |||
// ]), | |||
// ], | |||
// "b": 5 | |||
// }, 17 | |||
// ] | |||
// flattened = nest.flatten(mess) | |||
// self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||
// structure_of_mess = [ | |||
// 14, | |||
// NestTest.Abc("a", True), | |||
// { | |||
// "d": _CustomMapping({ | |||
// 41: 42 | |||
// }), | |||
// "c": [ | |||
// 0, | |||
// collections.OrderedDict([ | |||
// ("b", 9), | |||
// ("a", 8), | |||
// ]), | |||
// ], | |||
// "b": 3 | |||
// }, | |||
// "hi everybody", | |||
// ] | |||
// unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||
// self.assertEqual(unflattened, mess) | |||
// # Check also that the OrderedDict was created, with the correct key order. | |||
//unflattened_ordered_dict = unflattened[2]["c"][1] | |||
// self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||
// self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||
// unflattened_custom_mapping = unflattened[2]["d"] | |||
// self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||
// self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||
// def testFlatten_numpyIsNotFlattened(self): | |||
// structure = np.array([1, 2, 3]) | |||
// flattened = nest.flatten(structure) | |||
// self.assertEqual(len(flattened), 1) | |||
// def testFlatten_stringIsNotFlattened(self): | |||
// structure = "lots of letters" | |||
// flattened = nest.flatten(structure) | |||
// self.assertEqual(len(flattened), 1) | |||
// unflattened = nest.pack_sequence_as("goodbye", flattened) | |||
// self.assertEqual(structure, unflattened) | |||
// def testPackSequenceAs_notIterableError(self) : | |||
// with self.assertRaisesRegexp(TypeError, | |||
// "flat_sequence must be a sequence"): | |||
// nest.pack_sequence_as("hi", "bye") | |||
// def testPackSequenceAs_wrongLengthsError(self): | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// "Structure had 2 elements, but flat_sequence had 3 elements."): | |||
// nest.pack_sequence_as(["hello", "world"], | |||
// ["and", "goodbye", "again"]) | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testIsSequence(self): | |||
// self.assertFalse(nest.is_sequence("1234")) | |||
// self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||
// self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||
// self.assertTrue(nest.is_sequence([])) | |||
// self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||
// self.assertFalse(nest.is_sequence(set([1, 2]))) | |||
// ones = array_ops.ones([2, 3]) | |||
// self.assertFalse(nest.is_sequence(ones)) | |||
// self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||
// self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||
// @parameterized.parameters({"mapping_type": _CustomMapping}, | |||
// {"mapping_type": dict}) | |||
// def testFlattenDictItems(self, mapping_type): | |||
// dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||
// flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
// self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||
// with self.assertRaises(TypeError): | |||
// nest.flatten_dict_items(4) | |||
// bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||
// with self.assertRaisesRegexp(ValueError, "not unique"): | |||
// nest.flatten_dict_items(bad_dictionary) | |||
// another_bad_dictionary = mapping_type({ | |||
// (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||
// }) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||
// nest.flatten_dict_items(another_bad_dictionary) | |||
//# pylint does not correctly recognize these as class names and | |||
//# suggests to use variable style under_score naming. | |||
//# pylint: disable=invalid-name | |||
// Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||
// Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||
// SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||
// SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||
// SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||
// SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||
// SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||
// NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||
// # pylint: enable=invalid-name | |||
// class SameNamedType1(SameNameab): | |||
// pass | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testAssertSameStructure(self): | |||
// structure1 = (((1, 2), 3), 4, (5, 6)) | |||
// structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
// structure_different_num_elements = ("spam", "eggs") | |||
// structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||
// nest.assert_same_structure(structure1, structure2) | |||
// nest.assert_same_structure("abc", 1.0) | |||
// nest.assert_same_structure("abc", np.array([0, 1])) | |||
// nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// ("The two structures don't have the same nested structure\\.\n\n" | |||
// "First structure:.*?\n\n" | |||
// "Second structure:.*\n\n" | |||
// "More specifically: Substructure " | |||
// r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||
// 'substructure "type=str str=spam" is not\n' | |||
// "Entire first structure:\n" | |||
// r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||
// "Entire second structure:\n" | |||
// r"\(\., \.\)")): | |||
// nest.assert_same_structure(structure1, structure_different_num_elements) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// ("The two structures don't have the same nested structure\\.\n\n" | |||
// "First structure:.*?\n\n" | |||
// "Second structure:.*\n\n" | |||
// r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
// r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||
// "is not")): | |||
// nest.assert_same_structure([0, 1], np.array([0, 1])) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// ("The two structures don't have the same nested structure\\.\n\n" | |||
// "First structure:.*?\n\n" | |||
// "Second structure:.*\n\n" | |||
// r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
// 'is a sequence, while substructure "type=int str=0" ' | |||
// "is not")): | |||
// nest.assert_same_structure(0, [0, 1]) | |||
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// ("don't have the same nested structure\\.\n\n" | |||
// "First structure: .*?\n\nSecond structure: ")): | |||
// nest.assert_same_structure(structure1, structure_different_nesting) | |||
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||
// NestTest.Named0ab("a", "b")) | |||
// nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
// NestTest.Named0ab("a", "b")) | |||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||
// NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// ("don't have the same nested structure\\.\n\n" | |||
// "First structure: .*?\n\nSecond structure: ")): | |||
// nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
// NestTest.Named0ab([3], 4)) | |||
// with self.assertRaisesRegexp( | |||
// ValueError, | |||
// ("don't have the same nested structure\\.\n\n" | |||
// "First structure: .*?\n\nSecond structure: ")): | |||
// nest.assert_same_structure([[3], 4], [3, [4]]) | |||
// structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
// with self.assertRaisesRegexp(TypeError, | |||
// "don't have the same sequence type"): | |||
// nest.assert_same_structure(structure1, structure1_list) | |||
// nest.assert_same_structure(structure1, structure2, check_types= False) | |||
// nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||
// with self.assertRaisesRegexp(ValueError, | |||
// "don't have the same set of keys"): | |||
// nest.assert_same_structure({"a": 1}, {"b": 1}) | |||
// nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||
// NestTest.SameNameab2(2, 3)) | |||
// # This assertion is expected to pass: two namedtuples with the same | |||
// # name and field names are considered to be identical. | |||
// nest.assert_same_structure( | |||
// NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||
// NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||
// expected_message = "The two structures don't have the same.*" | |||
// with self.assertRaisesRegexp(ValueError, expected_message): | |||
// nest.assert_same_structure( | |||
// NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||
// NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||
// NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||
// NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||
// NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||
// EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||
// def testHeterogeneousComparison(self): | |||
// nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) | |||
// nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testMapStructure(self) : | |||
// structure1 = (((1, 2), 3), 4, (5, 6)) | |||
// structure2 = (((7, 8), 9), 10, (11, 12)) | |||
// structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||
// nest.assert_same_structure(structure1, structure1_plus1) | |||
// self.assertAllEqual( | |||
// [2, 3, 4, 5, 6, 7], | |||
// nest.flatten(structure1_plus1)) | |||
// structure1_plus_structure2 = nest.map_structure( | |||
// lambda x, y: x + y, structure1, structure2) | |||
// self.assertEqual( | |||
// (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||
// structure1_plus_structure2) | |||
// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||
// self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||
// # Empty structures | |||
// self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||
// self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||
// self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||
// self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||
// NestTest.EmptyNT())) | |||
// # This is checking actual equality of types, empty list != empty tuple | |||
// self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||
// with self.assertRaisesRegexp(TypeError, "callable"): | |||
// nest.map_structure("bad", structure1_plus1) | |||
// with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||
// nest.map_structure(lambda x: x) | |||
// with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||
// nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||
// with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
// nest.map_structure(lambda x, y: None, 3, (3,)) | |||
// with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
// nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||
// with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||
// structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
// with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
// nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||
// nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||
// check_types=False) | |||
// with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||
// check_types=False) | |||
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
// nest.map_structure(lambda x: None, structure1, foo="a") | |||
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
// nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||
// ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||
// def testMapStructureWithStrings(self) : | |||
// inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||
// inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||
// out = nest.map_structure(lambda string, repeats: string* repeats, | |||
// inp_a, | |||
// inp_b) | |||
// self.assertEqual("foofoo", out.a) | |||
// self.assertEqual("bar", out.b[0]) | |||
// self.assertEqual("bazbazbaz", out.b[1]) | |||
// nt = NestTest.ABTuple(a=("something", "something_else"), | |||
// b="yet another thing") | |||
// rev_nt = nest.map_structure(lambda x: x[::- 1], nt) | |||
// # Check the output is the correct structure, and all strings are reversed. | |||
// nest.assert_same_structure(nt, rev_nt) | |||
// self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) | |||
// self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) | |||
// self.assertEqual(nt.b[::- 1], rev_nt.b) | |||
// @test_util.run_deprecated_v1 | |||
// def testMapStructureOverPlaceholders(self) : | |||
// inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
// array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
// inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
// array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
// output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||
// nest.assert_same_structure(output, inp_a) | |||
// self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||
// self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||
// feed_dict = { | |||
// inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||
// inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||
// } | |||
// with self.cached_session() as sess: | |||
// output_np = sess.run(output, feed_dict=feed_dict) | |||
// self.assertAllClose(output_np[0], | |||
// feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||
// self.assertAllClose(output_np[1], | |||
// feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||
// def testAssertShallowStructure(self): | |||
// inp_ab = ["a", "b"] | |||
//inp_abc = ["a", "b", "c"] | |||
//expected_message = ( | |||
// "The two structures don't have the same sequence length. Input " | |||
// "structure has length 2, while shallow structure has length 3.") | |||
// with self.assertRaisesRegexp(ValueError, expected_message): | |||
// nest.assert_shallow_structure(inp_abc, inp_ab) | |||
// inp_ab1 = [(1, 1), (2, 2)] | |||
// inp_ab2 = [[1, 1], [2, 2]] | |||
// expected_message = ( | |||
// "The two structures don't have the same sequence type. Input structure " | |||
// "has type <(type|class) 'tuple'>, while shallow structure has type " | |||
// "<(type|class) 'list'>.") | |||
// with self.assertRaisesRegexp(TypeError, expected_message): | |||
// nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
// nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) | |||
// inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||
// inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||
// expected_message = ( | |||
// r"The two structures don't have the same keys. Input " | |||
// r"structure has keys \['c'\], while shallow structure has " | |||
// r"keys \['d'\].") | |||
// with self.assertRaisesRegexp(ValueError, expected_message): | |||
// nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
// inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||
// inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||
// nest.assert_shallow_structure(inp_ab, inp_ba) | |||
// # This assertion is expected to pass: two namedtuples with the same | |||
//# name and field names are considered to be identical. | |||
//inp_shallow = NestTest.SameNameab(1, 2) | |||
// inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||
// def testFlattenUpTo(self): | |||
// # Shallow tree ends at scalar. | |||
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
// shallow_tree = [[True, True], [False, True]] | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||
// self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||
//# Shallow tree ends at string. | |||
// input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||
// shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
// input_tree) | |||
// input_tree_flattened = nest.flatten(input_tree) | |||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
// [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||
// self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||
// # Make sure dicts are correctly flattened, yielding values, not keys. | |||
//input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||
// shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
// input_tree) | |||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
// [1, { "c": 2}, 3, (4, 5)]) | |||
// # Namedtuples. | |||
// ab_tuple = NestTest.ABTuple | |||
// input_tree = ab_tuple(a =[0, 1], b = 2) | |||
// shallow_tree = ab_tuple(a= 0, b= 1) | |||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
// input_tree) | |||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
// [[0, 1], 2]) | |||
// # Nested dicts, OrderedDicts and namedtuples. | |||
// input_tree = collections.OrderedDict( | |||
// [("a", ab_tuple(a =[0, {"b": 1}], b=2)), | |||
// ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||
// shallow_tree = input_tree | |||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
// input_tree) | |||
// self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||
// shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
// input_tree) | |||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
// [ab_tuple(a =[0, { "b": 1}], b=2), | |||
// 3, | |||
// collections.OrderedDict([("f", 4)])]) | |||
// shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
// input_tree) | |||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
// [ab_tuple(a =[0, {"b": 1}], b=2), | |||
// {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||
// ## Shallow non-list edge-case. | |||
// # Using iterable elements. | |||
// input_tree = ["input_tree"] | |||
//shallow_tree = "shallow_tree" | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
// input_tree = ["input_tree_0", "input_tree_1"] | |||
//shallow_tree = "shallow_tree" | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
// # Using non-iterable elements. | |||
//input_tree = [0] | |||
//shallow_tree = 9 | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
// input_tree = [0, 1] | |||
//shallow_tree = 9 | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
// ## Both non-list edge-case. | |||
//# Using iterable elements. | |||
//input_tree = "input_tree" | |||
// shallow_tree = "shallow_tree" | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
// # Using non-iterable elements. | |||
//input_tree = 0 | |||
// shallow_tree = 0 | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
// ## Input non-list edge-case. | |||
//# Using iterable elements. | |||
//input_tree = "input_tree" | |||
// shallow_tree = ["shallow_tree"] | |||
//expected_message = ("If shallow structure is a sequence, input must also " | |||
// "be a sequence. Input has type: <(type|class) 'str'>.") | |||
// with self.assertRaisesRegexp(TypeError, expected_message): | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
// input_tree = "input_tree" | |||
// shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||
//with self.assertRaisesRegexp(TypeError, expected_message): | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
//# Using non-iterable elements. | |||
// input_tree = 0 | |||
// shallow_tree = [9] | |||
//expected_message = ("If shallow structure is a sequence, input must also " | |||
// "be a sequence. Input has type: <(type|class) 'int'>.") | |||
// with self.assertRaisesRegexp(TypeError, expected_message): | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
// input_tree = 0 | |||
// shallow_tree = [9, 8] | |||
//with self.assertRaisesRegexp(TypeError, expected_message): | |||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
// def testMapStructureUpTo(self) : | |||
// # Named tuples. | |||
// ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
// op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
// inp_val = ab_tuple(a= 2, b= 3) | |||
// inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) | |||
// out = nest.map_structure_up_to( | |||
// inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||
// self.assertEqual(out.a, 6) | |||
// self.assertEqual(out.b, 15) | |||
// # Lists. | |||
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
// name_list = ["evens", ["odds", "primes"]] | |||
// out = nest.map_structure_up_to( | |||
// name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
// name_list, data_list) | |||
// self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||
// # Dicts. | |||
// inp_val = dict(a= 2, b= 3) | |||
// inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||
// out = nest.map_structure_up_to( | |||
// inp_val, | |||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
// self.assertEqual(out["a"], 6) | |||
// self.assertEqual(out["b"], 15) | |||
// # Non-equal dicts. | |||
// inp_val = dict(a= 2, b= 3) | |||
// inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||
// with self.assertRaisesRegexp(ValueError, "same keys"): | |||
// nest.map_structure_up_to( | |||
// inp_val, | |||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
// # Dict+custom mapping. | |||
// inp_val = dict(a= 2, b= 3) | |||
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||
// out = nest.map_structure_up_to( | |||
// inp_val, | |||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
// self.assertEqual(out["a"], 6) | |||
// self.assertEqual(out["b"], 15) | |||
// # Non-equal dict/mapping. | |||
// inp_val = dict(a= 2, b= 3) | |||
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||
// with self.assertRaisesRegexp(ValueError, "same keys"): | |||
// nest.map_structure_up_to( | |||
// inp_val, | |||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
// def testGetTraverseShallowStructure(self): | |||
// scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||
// scalar_traverse_r = nest.get_traverse_shallow_structure( | |||
// lambda s: not isinstance(s, tuple), | |||
// scalar_traverse_input) | |||
// self.assertEqual(scalar_traverse_r, | |||
// [True, True, False, [True, True], {"a": False}, []]) | |||
// nest.assert_shallow_structure(scalar_traverse_r, | |||
// scalar_traverse_input) | |||
// structure_traverse_input = [(1, [2]), ([1], 2)] | |||
// structure_traverse_r = nest.get_traverse_shallow_structure( | |||
// lambda s: (True, False) if isinstance(s, tuple) else True, | |||
// structure_traverse_input) | |||
// self.assertEqual(structure_traverse_r, | |||
// [(True, False), ([True], False)]) | |||
// nest.assert_shallow_structure(structure_traverse_r, | |||
// structure_traverse_input) | |||
// with self.assertRaisesRegexp(TypeError, "returned structure"): | |||
// nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||
// with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||
// nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||
// with self.assertRaisesRegexp( | |||
// TypeError, "didn't return a depth=1 structure of bools"): | |||
// nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||
// def testYieldFlatStringPaths(self): | |||
// for inputs_expected in ({"inputs": [], "expected": []}, | |||
// {"inputs": 3, "expected": [()]}, | |||
// {"inputs": [3], "expected": [(0,)]}, | |||
// {"inputs": {"a": 3}, "expected": [("a",)]}, | |||
// {"inputs": {"a": {"b": 4}}, | |||
// "expected": [("a", "b")]}, | |||
// {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||
// {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||
// {"inputs": [{"a": [(23, 42)]}], | |||
// "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||
// {"inputs": [{"a": ([23], 42)}], | |||
// "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||
// {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||
// "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||
// {"inputs": {"0": [{"1": 23}]}, | |||
// "expected": [("0", 0, "1")]}): | |||
// inputs = inputs_expected["inputs"] | |||
// expected = inputs_expected["expected"] | |||
// self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||
// def testFlattenWithStringPaths(self): | |||
// for inputs_expected in ( | |||
// {"inputs": [], "expected": []}, | |||
// {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||
// {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||
// inputs = inputs_expected["inputs"] | |||
// expected = inputs_expected["expected"] | |||
// self.assertEqual( | |||
// nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||
// expected) | |||
// # Need a separate test for namedtuple as we can't declare tuple definitions | |||
// # in the @parameterized arguments. | |||
// def testFlattenNamedTuple(self): | |||
// # pylint: disable=invalid-name | |||
// Foo = collections.namedtuple("Foo", ["a", "b"]) | |||
// Bar = collections.namedtuple("Bar", ["c", "d"]) | |||
// # pylint: enable=invalid-name | |||
// test_cases = [ | |||
// (Foo(a = 3, b = Bar(c = 23, d = 42)), | |||
// [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||
// (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), | |||
// [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||
// (Bar(c = 42, d = 43), | |||
// [("c", 42), ("d", 43)]), | |||
// (Bar(c =[42], d = 43), | |||
// [("c/0", 42), ("d", 43)]), | |||
// ] | |||
// for inputs, expected in test_cases: | |||
// self.assertEqual( | |||
// list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||
// @parameterized.named_parameters( | |||
// ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||
// ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||
// {"a": ("a", 4), "b": ("b", 6)}), | |||
// ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||
// ("nested", | |||
// {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||
// {"a": [("a/0", 10), ("a/1", 12)], | |||
// "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||
// def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||
// def format_sum(path, * values): | |||
// return (path, sum(values)) | |||
// result = nest.map_structure_with_paths(format_sum, s1, s2, | |||
// check_types=check_types) | |||
// self.assertEqual(expected, result) | |||
// @parameterized.named_parameters( | |||
// ("tuples", (1, 2), (3, 4, 5), ValueError), | |||
// ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||
// ("mixed", (1, 2), [3, 4], TypeError), | |||
// ("nested", | |||
// {"a": [2, 3], "b": [1, 3]}, | |||
// {"b": [5, 6, 7], "a": [8, 9]}, | |||
// ValueError | |||
// )) | |||
// def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||
// with self.assertRaises(error_type): | |||
// nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) | |||
//class NestBenchmark(test.Benchmark): | |||
// def run_and_report(self, s1, s2, name): | |||
// burn_iter, test_iter = 100, 30000 | |||
// for _ in xrange(burn_iter) : | |||
// nest.assert_same_structure(s1, s2) | |||
// t0 = time.time() | |||
// for _ in xrange(test_iter) : | |||
// nest.assert_same_structure(s1, s2) | |||
// t1 = time.time() | |||
// self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||
// name=name) | |||
// def benchmark_assert_structure(self): | |||
// s1 = (((1, 2), 3), 4, (5, 6)) | |||
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
// self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||
// s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||
// self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||
//if __name__ == "__main__": | |||
// test.main() | |||
} | |||
} |
@@ -0,0 +1,883 @@ | |||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
# ============================================================================== | |||
"""Tests for utilities working with arbitrarily nested structures.""" | |||
from __future__ import absolute_import | |||
from __future__ import division | |||
from __future__ import print_function | |||
import collections | |||
import time | |||
from absl.testing import parameterized | |||
import numpy as np | |||
from six.moves import xrange # pylint: disable=redefined-builtin | |||
from tensorflow.python.framework import constant_op | |||
from tensorflow.python.framework import dtypes | |||
from tensorflow.python.framework import test_util | |||
from tensorflow.python.ops import array_ops | |||
from tensorflow.python.ops import math_ops | |||
from tensorflow.python.platform import test | |||
from tensorflow.python.util import nest | |||
try: | |||
import attr # pylint:disable=g-import-not-at-top | |||
except ImportError: | |||
attr = None | |||
class _CustomMapping(collections.Mapping): | |||
def __init__(self, *args, **kwargs): | |||
self._wrapped = dict(*args, **kwargs) | |||
def __getitem__(self, key): | |||
return self._wrapped[key] | |||
def __iter__(self): | |||
return iter(self._wrapped) | |||
def __len__(self): | |||
return len(self._wrapped) | |||
class NestTest(parameterized.TestCase, test.TestCase): | |||
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name | |||
if attr: | |||
class BadAttr(object): | |||
"""Class that has a non-iterable __attrs_attrs__.""" | |||
__attrs_attrs__ = None | |||
@attr.s | |||
class SampleAttr(object): | |||
field1 = attr.ib() | |||
field2 = attr.ib() | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testAttrsFlattenAndPack(self): | |||
if attr is None: | |||
self.skipTest("attr module is unavailable.") | |||
field_values = [1, 2] | |||
sample_attr = NestTest.SampleAttr(*field_values) | |||
self.assertFalse(nest._is_attrs(field_values)) | |||
self.assertTrue(nest._is_attrs(sample_attr)) | |||
flat = nest.flatten(sample_attr) | |||
self.assertEqual(field_values, flat) | |||
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||
self.assertEqual(restructured_from_flat, sample_attr) | |||
# Check that flatten fails if attributes are not iterable | |||
with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||
flat = nest.flatten(NestTest.BadAttr()) | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testFlattenAndPack(self): | |||
structure = ((3, 4), 5, (6, 7, (9, 10), 8)) | |||
flat = ["a", "b", "c", "d", "e", "f", "g", "h"] | |||
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) | |||
self.assertEqual( | |||
nest.pack_sequence_as(structure, flat), (("a", "b"), "c", | |||
("d", "e", ("f", "g"), "h"))) | |||
structure = (NestTest.PointXY(x=4, y=2), | |||
((NestTest.PointXY(x=1, y=0),),)) | |||
flat = [4, 2, 1, 0] | |||
self.assertEqual(nest.flatten(structure), flat) | |||
restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||
self.assertEqual(restructured_from_flat, structure) | |||
self.assertEqual(restructured_from_flat[0].x, 4) | |||
self.assertEqual(restructured_from_flat[0].y, 2) | |||
self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||
self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||
self.assertEqual([5], nest.flatten(5)) | |||
self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||
self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||
self.assertEqual( | |||
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||
nest.pack_sequence_as("scalar", [4, 5]) | |||
with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||
nest.pack_sequence_as([4, 5], "bad_sequence") | |||
with self.assertRaises(ValueError): | |||
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||
@parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
{"mapping_type": _CustomMapping}) | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testFlattenDictOrder(self, mapping_type): | |||
"""`flatten` orders dicts by key, including OrderedDicts.""" | |||
ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||
plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||
ordered_flat = nest.flatten(ordered) | |||
plain_flat = nest.flatten(plain) | |||
self.assertEqual([0, 1, 2, 3], ordered_flat) | |||
self.assertEqual([0, 1, 2, 3], plain_flat) | |||
@parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
{"mapping_type": _CustomMapping}) | |||
def testPackDictOrder(self, mapping_type): | |||
"""Packing orders dicts by key, including OrderedDicts.""" | |||
custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||
plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||
seq = [0, 1, 2, 3] | |||
custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||
plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||
self.assertIsInstance(custom_reconstruction, mapping_type) | |||
self.assertIsInstance(plain_reconstruction, dict) | |||
self.assertEqual( | |||
mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||
custom_reconstruction) | |||
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||
Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testFlattenAndPack_withDicts(self): | |||
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||
mess = [ | |||
"z", | |||
NestTest.Abc(3, 4), { | |||
"d": _CustomMapping({ | |||
41: 4 | |||
}), | |||
"c": [ | |||
1, | |||
collections.OrderedDict([ | |||
("b", 3), | |||
("a", 2), | |||
]), | |||
], | |||
"b": 5 | |||
}, 17 | |||
] | |||
flattened = nest.flatten(mess) | |||
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||
structure_of_mess = [ | |||
14, | |||
NestTest.Abc("a", True), | |||
{ | |||
"d": _CustomMapping({ | |||
41: 42 | |||
}), | |||
"c": [ | |||
0, | |||
collections.OrderedDict([ | |||
("b", 9), | |||
("a", 8), | |||
]), | |||
], | |||
"b": 3 | |||
}, | |||
"hi everybody", | |||
] | |||
unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||
self.assertEqual(unflattened, mess) | |||
# Check also that the OrderedDict was created, with the correct key order. | |||
unflattened_ordered_dict = unflattened[2]["c"][1] | |||
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||
unflattened_custom_mapping = unflattened[2]["d"] | |||
self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||
self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||
def testFlatten_numpyIsNotFlattened(self): | |||
structure = np.array([1, 2, 3]) | |||
flattened = nest.flatten(structure) | |||
self.assertEqual(len(flattened), 1) | |||
def testFlatten_stringIsNotFlattened(self): | |||
structure = "lots of letters" | |||
flattened = nest.flatten(structure) | |||
self.assertEqual(len(flattened), 1) | |||
unflattened = nest.pack_sequence_as("goodbye", flattened) | |||
self.assertEqual(structure, unflattened) | |||
def testPackSequenceAs_notIterableError(self): | |||
with self.assertRaisesRegexp(TypeError, | |||
"flat_sequence must be a sequence"): | |||
nest.pack_sequence_as("hi", "bye") | |||
def testPackSequenceAs_wrongLengthsError(self): | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
"Structure had 2 elements, but flat_sequence had 3 elements."): | |||
nest.pack_sequence_as(["hello", "world"], | |||
["and", "goodbye", "again"]) | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testIsSequence(self): | |||
self.assertFalse(nest.is_sequence("1234")) | |||
self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||
self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||
self.assertTrue(nest.is_sequence([])) | |||
self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||
self.assertFalse(nest.is_sequence(set([1, 2]))) | |||
ones = array_ops.ones([2, 3]) | |||
self.assertFalse(nest.is_sequence(ones)) | |||
self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||
self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||
@parameterized.parameters({"mapping_type": _CustomMapping}, | |||
{"mapping_type": dict}) | |||
def testFlattenDictItems(self, mapping_type): | |||
dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||
flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||
with self.assertRaises(TypeError): | |||
nest.flatten_dict_items(4) | |||
bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||
with self.assertRaisesRegexp(ValueError, "not unique"): | |||
nest.flatten_dict_items(bad_dictionary) | |||
another_bad_dictionary = mapping_type({ | |||
(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||
}) | |||
with self.assertRaisesRegexp( | |||
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||
nest.flatten_dict_items(another_bad_dictionary) | |||
# pylint does not correctly recognize these as class names and | |||
# suggests to use variable style under_score naming. | |||
# pylint: disable=invalid-name | |||
Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||
Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||
SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||
SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||
SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||
SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||
SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||
NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||
# pylint: enable=invalid-name | |||
class SameNamedType1(SameNameab): | |||
pass | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testAssertSameStructure(self): | |||
structure1 = (((1, 2), 3), 4, (5, 6)) | |||
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
structure_different_num_elements = ("spam", "eggs") | |||
structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||
nest.assert_same_structure(structure1, structure2) | |||
nest.assert_same_structure("abc", 1.0) | |||
nest.assert_same_structure("abc", np.array([0, 1])) | |||
nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
("The two structures don't have the same nested structure\\.\n\n" | |||
"First structure:.*?\n\n" | |||
"Second structure:.*\n\n" | |||
"More specifically: Substructure " | |||
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||
'substructure "type=str str=spam" is not\n' | |||
"Entire first structure:\n" | |||
r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||
"Entire second structure:\n" | |||
r"\(\., \.\)")): | |||
nest.assert_same_structure(structure1, structure_different_num_elements) | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
("The two structures don't have the same nested structure\\.\n\n" | |||
"First structure:.*?\n\n" | |||
"Second structure:.*\n\n" | |||
r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||
"is not")): | |||
nest.assert_same_structure([0, 1], np.array([0, 1])) | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
("The two structures don't have the same nested structure\\.\n\n" | |||
"First structure:.*?\n\n" | |||
"Second structure:.*\n\n" | |||
r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
'is a sequence, while substructure "type=int str=0" ' | |||
"is not")): | |||
nest.assert_same_structure(0, [0, 1]) | |||
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
("don't have the same nested structure\\.\n\n" | |||
"First structure: .*?\n\nSecond structure: ")): | |||
nest.assert_same_structure(structure1, structure_different_nesting) | |||
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||
NestTest.Named0ab("a", "b")) | |||
nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
NestTest.Named0ab("a", "b")) | |||
self.assertRaises(TypeError, nest.assert_same_structure, | |||
NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
("don't have the same nested structure\\.\n\n" | |||
"First structure: .*?\n\nSecond structure: ")): | |||
nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
NestTest.Named0ab([3], 4)) | |||
with self.assertRaisesRegexp( | |||
ValueError, | |||
("don't have the same nested structure\\.\n\n" | |||
"First structure: .*?\n\nSecond structure: ")): | |||
nest.assert_same_structure([[3], 4], [3, [4]]) | |||
structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
with self.assertRaisesRegexp(TypeError, | |||
"don't have the same sequence type"): | |||
nest.assert_same_structure(structure1, structure1_list) | |||
nest.assert_same_structure(structure1, structure2, check_types=False) | |||
nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||
with self.assertRaisesRegexp(ValueError, | |||
"don't have the same set of keys"): | |||
nest.assert_same_structure({"a": 1}, {"b": 1}) | |||
nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||
NestTest.SameNameab2(2, 3)) | |||
# This assertion is expected to pass: two namedtuples with the same | |||
# name and field names are considered to be identical. | |||
nest.assert_same_structure( | |||
NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||
NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||
expected_message = "The two structures don't have the same.*" | |||
with self.assertRaisesRegexp(ValueError, expected_message): | |||
nest.assert_same_structure( | |||
NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||
NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||
self.assertRaises(TypeError, nest.assert_same_structure, | |||
NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||
self.assertRaises(TypeError, nest.assert_same_structure, | |||
NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||
self.assertRaises(TypeError, nest.assert_same_structure, | |||
NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||
EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||
def testHeterogeneousComparison(self): | |||
nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) | |||
nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testMapStructure(self): | |||
structure1 = (((1, 2), 3), 4, (5, 6)) | |||
structure2 = (((7, 8), 9), 10, (11, 12)) | |||
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||
nest.assert_same_structure(structure1, structure1_plus1) | |||
self.assertAllEqual( | |||
[2, 3, 4, 5, 6, 7], | |||
nest.flatten(structure1_plus1)) | |||
structure1_plus_structure2 = nest.map_structure( | |||
lambda x, y: x + y, structure1, structure2) | |||
self.assertEqual( | |||
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||
structure1_plus_structure2) | |||
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||
# Empty structures | |||
self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||
self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||
self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||
NestTest.EmptyNT())) | |||
# This is checking actual equality of types, empty list != empty tuple | |||
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||
with self.assertRaisesRegexp(TypeError, "callable"): | |||
nest.map_structure("bad", structure1_plus1) | |||
with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||
nest.map_structure(lambda x: x) | |||
with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||
nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||
with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
nest.map_structure(lambda x, y: None, 3, (3,)) | |||
with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||
with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||
structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||
nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||
check_types=False) | |||
with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||
check_types=False) | |||
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
nest.map_structure(lambda x: None, structure1, foo="a") | |||
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||
ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||
def testMapStructureWithStrings(self): | |||
inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||
inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||
out = nest.map_structure(lambda string, repeats: string * repeats, | |||
inp_a, | |||
inp_b) | |||
self.assertEqual("foofoo", out.a) | |||
self.assertEqual("bar", out.b[0]) | |||
self.assertEqual("bazbazbaz", out.b[1]) | |||
nt = NestTest.ABTuple(a=("something", "something_else"), | |||
b="yet another thing") | |||
rev_nt = nest.map_structure(lambda x: x[::-1], nt) | |||
# Check the output is the correct structure, and all strings are reversed. | |||
nest.assert_same_structure(nt, rev_nt) | |||
self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) | |||
self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) | |||
self.assertEqual(nt.b[::-1], rev_nt.b) | |||
@test_util.run_deprecated_v1 | |||
def testMapStructureOverPlaceholders(self): | |||
inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||
nest.assert_same_structure(output, inp_a) | |||
self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||
self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||
feed_dict = { | |||
inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||
} | |||
with self.cached_session() as sess: | |||
output_np = sess.run(output, feed_dict=feed_dict) | |||
self.assertAllClose(output_np[0], | |||
feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||
self.assertAllClose(output_np[1], | |||
feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||
def testAssertShallowStructure(self): | |||
inp_ab = ["a", "b"] | |||
inp_abc = ["a", "b", "c"] | |||
expected_message = ( | |||
"The two structures don't have the same sequence length. Input " | |||
"structure has length 2, while shallow structure has length 3.") | |||
with self.assertRaisesRegexp(ValueError, expected_message): | |||
nest.assert_shallow_structure(inp_abc, inp_ab) | |||
inp_ab1 = [(1, 1), (2, 2)] | |||
inp_ab2 = [[1, 1], [2, 2]] | |||
expected_message = ( | |||
"The two structures don't have the same sequence type. Input structure " | |||
"has type <(type|class) 'tuple'>, while shallow structure has type " | |||
"<(type|class) 'list'>.") | |||
with self.assertRaisesRegexp(TypeError, expected_message): | |||
nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) | |||
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||
expected_message = ( | |||
r"The two structures don't have the same keys. Input " | |||
r"structure has keys \['c'\], while shallow structure has " | |||
r"keys \['d'\].") | |||
with self.assertRaisesRegexp(ValueError, expected_message): | |||
nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||
nest.assert_shallow_structure(inp_ab, inp_ba) | |||
# This assertion is expected to pass: two namedtuples with the same | |||
# name and field names are considered to be identical. | |||
inp_shallow = NestTest.SameNameab(1, 2) | |||
inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||
def testFlattenUpTo(self): | |||
# Shallow tree ends at scalar. | |||
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
shallow_tree = [[True, True], [False, True]] | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||
self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||
# Shallow tree ends at string. | |||
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
input_tree) | |||
input_tree_flattened = nest.flatten(input_tree) | |||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
[("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||
# Make sure dicts are correctly flattened, yielding values, not keys. | |||
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
input_tree) | |||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
[1, {"c": 2}, 3, (4, 5)]) | |||
# Namedtuples. | |||
ab_tuple = NestTest.ABTuple | |||
input_tree = ab_tuple(a=[0, 1], b=2) | |||
shallow_tree = ab_tuple(a=0, b=1) | |||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
input_tree) | |||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
[[0, 1], 2]) | |||
# Nested dicts, OrderedDicts and namedtuples. | |||
input_tree = collections.OrderedDict( | |||
[("a", ab_tuple(a=[0, {"b": 1}], b=2)), | |||
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||
shallow_tree = input_tree | |||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
input_tree) | |||
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
input_tree) | |||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
[ab_tuple(a=[0, {"b": 1}], b=2), | |||
3, | |||
collections.OrderedDict([("f", 4)])]) | |||
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
input_tree) | |||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
[ab_tuple(a=[0, {"b": 1}], b=2), | |||
{"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||
## Shallow non-list edge-case. | |||
# Using iterable elements. | |||
input_tree = ["input_tree"] | |||
shallow_tree = "shallow_tree" | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [input_tree]) | |||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
input_tree = ["input_tree_0", "input_tree_1"] | |||
shallow_tree = "shallow_tree" | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [input_tree]) | |||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
# Using non-iterable elements. | |||
input_tree = [0] | |||
shallow_tree = 9 | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [input_tree]) | |||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
input_tree = [0, 1] | |||
shallow_tree = 9 | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [input_tree]) | |||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
## Both non-list edge-case. | |||
# Using iterable elements. | |||
input_tree = "input_tree" | |||
shallow_tree = "shallow_tree" | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [input_tree]) | |||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
# Using non-iterable elements. | |||
input_tree = 0 | |||
shallow_tree = 0 | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_input_tree, [input_tree]) | |||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
## Input non-list edge-case. | |||
# Using iterable elements. | |||
input_tree = "input_tree" | |||
shallow_tree = ["shallow_tree"] | |||
expected_message = ("If shallow structure is a sequence, input must also " | |||
"be a sequence. Input has type: <(type|class) 'str'>.") | |||
with self.assertRaisesRegexp(TypeError, expected_message): | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
input_tree = "input_tree" | |||
shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||
with self.assertRaisesRegexp(TypeError, expected_message): | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
# Using non-iterable elements. | |||
input_tree = 0 | |||
shallow_tree = [9] | |||
expected_message = ("If shallow structure is a sequence, input must also " | |||
"be a sequence. Input has type: <(type|class) 'int'>.") | |||
with self.assertRaisesRegexp(TypeError, expected_message): | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
input_tree = 0 | |||
shallow_tree = [9, 8] | |||
with self.assertRaisesRegexp(TypeError, expected_message): | |||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
def testMapStructureUpTo(self): | |||
# Named tuples. | |||
ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
inp_val = ab_tuple(a=2, b=3) | |||
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||
out = nest.map_structure_up_to( | |||
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||
self.assertEqual(out.a, 6) | |||
self.assertEqual(out.b, 15) | |||
# Lists. | |||
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
name_list = ["evens", ["odds", "primes"]] | |||
out = nest.map_structure_up_to( | |||
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
name_list, data_list) | |||
self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||
# Dicts. | |||
inp_val = dict(a=2, b=3) | |||
inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||
out = nest.map_structure_up_to( | |||
inp_val, | |||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
self.assertEqual(out["a"], 6) | |||
self.assertEqual(out["b"], 15) | |||
# Non-equal dicts. | |||
inp_val = dict(a=2, b=3) | |||
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||
with self.assertRaisesRegexp(ValueError, "same keys"): | |||
nest.map_structure_up_to( | |||
inp_val, | |||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
# Dict+custom mapping. | |||
inp_val = dict(a=2, b=3) | |||
inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||
out = nest.map_structure_up_to( | |||
inp_val, | |||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
self.assertEqual(out["a"], 6) | |||
self.assertEqual(out["b"], 15) | |||
# Non-equal dict/mapping. | |||
inp_val = dict(a=2, b=3) | |||
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||
with self.assertRaisesRegexp(ValueError, "same keys"): | |||
nest.map_structure_up_to( | |||
inp_val, | |||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
def testGetTraverseShallowStructure(self): | |||
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||
scalar_traverse_r = nest.get_traverse_shallow_structure( | |||
lambda s: not isinstance(s, tuple), | |||
scalar_traverse_input) | |||
self.assertEqual(scalar_traverse_r, | |||
[True, True, False, [True, True], {"a": False}, []]) | |||
nest.assert_shallow_structure(scalar_traverse_r, | |||
scalar_traverse_input) | |||
structure_traverse_input = [(1, [2]), ([1], 2)] | |||
structure_traverse_r = nest.get_traverse_shallow_structure( | |||
lambda s: (True, False) if isinstance(s, tuple) else True, | |||
structure_traverse_input) | |||
self.assertEqual(structure_traverse_r, | |||
[(True, False), ([True], False)]) | |||
nest.assert_shallow_structure(structure_traverse_r, | |||
structure_traverse_input) | |||
with self.assertRaisesRegexp(TypeError, "returned structure"): | |||
nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||
with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||
nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||
with self.assertRaisesRegexp( | |||
TypeError, "didn't return a depth=1 structure of bools"): | |||
nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||
def testYieldFlatStringPaths(self): | |||
for inputs_expected in ({"inputs": [], "expected": []}, | |||
{"inputs": 3, "expected": [()]}, | |||
{"inputs": [3], "expected": [(0,)]}, | |||
{"inputs": {"a": 3}, "expected": [("a",)]}, | |||
{"inputs": {"a": {"b": 4}}, | |||
"expected": [("a", "b")]}, | |||
{"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||
{"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||
{"inputs": [{"a": [(23, 42)]}], | |||
"expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||
{"inputs": [{"a": ([23], 42)}], | |||
"expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||
{"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||
"expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||
{"inputs": {"0": [{"1": 23}]}, | |||
"expected": [("0", 0, "1")]}): | |||
inputs = inputs_expected["inputs"] | |||
expected = inputs_expected["expected"] | |||
self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||
def testFlattenWithStringPaths(self): | |||
for inputs_expected in ( | |||
{"inputs": [], "expected": []}, | |||
{"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||
{"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||
inputs = inputs_expected["inputs"] | |||
expected = inputs_expected["expected"] | |||
self.assertEqual( | |||
nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||
expected) | |||
# Need a separate test for namedtuple as we can't declare tuple definitions | |||
# in the @parameterized arguments. | |||
def testFlattenNamedTuple(self): | |||
# pylint: disable=invalid-name | |||
Foo = collections.namedtuple("Foo", ["a", "b"]) | |||
Bar = collections.namedtuple("Bar", ["c", "d"]) | |||
# pylint: enable=invalid-name | |||
test_cases = [ | |||
(Foo(a=3, b=Bar(c=23, d=42)), | |||
[("a", 3), ("b/c", 23), ("b/d", 42)]), | |||
(Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), | |||
[("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||
(Bar(c=42, d=43), | |||
[("c", 42), ("d", 43)]), | |||
(Bar(c=[42], d=43), | |||
[("c/0", 42), ("d", 43)]), | |||
] | |||
for inputs, expected in test_cases: | |||
self.assertEqual( | |||
list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||
@parameterized.named_parameters( | |||
("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||
("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||
{"a": ("a", 4), "b": ("b", 6)}), | |||
("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||
("nested", | |||
{"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||
{"a": [("a/0", 10), ("a/1", 12)], | |||
"b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||
def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||
def format_sum(path, *values): | |||
return (path, sum(values)) | |||
result = nest.map_structure_with_paths(format_sum, s1, s2, | |||
check_types=check_types) | |||
self.assertEqual(expected, result) | |||
@parameterized.named_parameters( | |||
("tuples", (1, 2), (3, 4, 5), ValueError), | |||
("dicts", {"a": 1}, {"b": 2}, ValueError), | |||
("mixed", (1, 2), [3, 4], TypeError), | |||
("nested", | |||
{"a": [2, 3], "b": [1, 3]}, | |||
{"b": [5, 6, 7], "a": [8, 9]}, | |||
ValueError | |||
)) | |||
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||
with self.assertRaises(error_type): | |||
nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) | |||
class NestBenchmark(test.Benchmark): | |||
def run_and_report(self, s1, s2, name): | |||
burn_iter, test_iter = 100, 30000 | |||
for _ in xrange(burn_iter): | |||
nest.assert_same_structure(s1, s2) | |||
t0 = time.time() | |||
for _ in xrange(test_iter): | |||
nest.assert_same_structure(s1, s2) | |||
t1 = time.time() | |||
self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||
name=name) | |||
def benchmark_assert_structure(self): | |||
s1 = (((1, 2), 3), 4, (5, 6)) | |||
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||
s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||
self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||
if __name__ == "__main__": | |||
test.main() |