@@ -0,0 +1,871 @@ | |||||
using System; | |||||
using System.Collections; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using NumSharp; | |||||
namespace Tensorflow.Util | |||||
{ | |||||
//Functions for working with arbitrarily nested sequences of elements. | |||||
//This module can perform operations on nested structures. A nested structure is a | |||||
//Python sequence, tuple (including `namedtuple`), or dict that can contain | |||||
//further sequences, tuples, and dicts. | |||||
//The utilities here assume (and do not check) that the nested structures form a | |||||
//'tree', i.e., no references in the structure of the input of these functions | |||||
//should be recursive. | |||||
//Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), | |||||
// (np.array([3, 4]), tf.constant([3, 4])))` | |||||
// | |||||
public static class nest | |||||
{ | |||||
//def _get_attrs_values(obj): | |||||
// """Returns the list of values from an attrs instance.""" | |||||
// attrs = getattr(obj.__class__, "__attrs_attrs__") | |||||
// return [getattr(obj, a.name) for a in attrs] | |||||
/// <summary> | |||||
/// Returns a sorted list of the dict keys, with error if keys not sortable. | |||||
/// </summary> | |||||
private static IEnumerable<string> _sorted(IDictionary dict_) | |||||
{ | |||||
return dict_.Keys.OfType<string>().OrderBy(x => x); | |||||
} | |||||
//def _is_namedtuple(instance, strict=False): | |||||
// """Returns True iff `instance` is a `namedtuple`. | |||||
// Args: | |||||
// instance: An instance of a Python object. | |||||
// strict: If True, `instance` is considered to be a `namedtuple` only if | |||||
// it is a "plain" namedtuple. For instance, a class inheriting | |||||
// from a `namedtuple` will be considered to be a `namedtuple` | |||||
// iff `strict=False`. | |||||
// Returns: | |||||
// True if `instance` is a `namedtuple`. | |||||
// """ | |||||
// return _pywrap_tensorflow.IsNamedtuple(instance, strict) | |||||
//# See the swig file (util.i) for documentation. | |||||
//_is_mapping = _pywrap_tensorflow.IsMapping | |||||
//_is_attrs = _pywrap_tensorflow.IsAttrs | |||||
/// <summary> | |||||
/// Converts the sequence `args` to the same type as `instance`. | |||||
/// </summary> | |||||
/// <param name="instance">an instance of `tuple`, `list`, `namedtuple`, `dict`, or | |||||
/// `collections.OrderedDict`.</param> | |||||
/// <param name="args">elements to be converted to the `instance` type.</param> | |||||
/// <returns>`args` with the type of `instance`.</returns> | |||||
private static object _sequence_like(object instance, IEnumerable<object> args) | |||||
{ | |||||
if (is_mapping(instance)) | |||||
{ | |||||
//# Pack dictionaries in a deterministic order by sorting the keys. | |||||
//# Notice this means that we ignore the original order of `OrderedDict` | |||||
//# instances. This is intentional, to avoid potential bugs caused by mixing | |||||
//# ordered and plain dicts (e.g., flattening a dict but using a | |||||
//# corresponding `OrderedDict` to pack it back). | |||||
// result = dict(zip(_sorted(instance), args)) | |||||
// return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) | |||||
} | |||||
//else if( _is_namedtuple(instance) || _is_attrs(instance)) | |||||
// return type(instance)(*args) | |||||
else | |||||
{ | |||||
// Not a namedtuple | |||||
switch (instance) | |||||
{ | |||||
case object[] array: | |||||
var result_array = new object[args.Count()]; | |||||
int i = 0; | |||||
foreach (var x in args) | |||||
{ | |||||
result_array[i] = x; | |||||
i++; | |||||
} | |||||
return result_array; | |||||
case List<object> list: | |||||
return new List<object>(args); | |||||
default: | |||||
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||||
} | |||||
} | |||||
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||||
} | |||||
/// <summary> | |||||
/// Yields the next value from the given iterable. | |||||
/// </summary> | |||||
private static IEnumerable<object> _yield_value(object iterable) | |||||
{ | |||||
if (is_mapping(iterable)) | |||||
{ | |||||
var dict = iterable as IDictionary; | |||||
//# Iterate through dictionaries in a deterministic order by sorting the | |||||
//# keys. Notice this means that we ignore the original order of `OrderedDict` | |||||
//# instances. This is intentional, to avoid potential bugs caused by mixing | |||||
//# ordered and plain dicts (e.g., flattening a dict but using a | |||||
//# corresponding `OrderedDict` to pack it back). | |||||
foreach (var key in _sorted(dict)) | |||||
yield return dict[key]; | |||||
} | |||||
//else if (_is_attrs(iterable)) | |||||
//{ | |||||
// // for value in _get_attrs_values(iterable): | |||||
// // yield value | |||||
//} | |||||
else if (iterable is IEnumerable) | |||||
{ | |||||
var enumerable = iterable as IEnumerable; | |||||
foreach (var value in enumerable) | |||||
yield return value; | |||||
} | |||||
else | |||||
{ | |||||
throw new TypeError("Unexpected iterable type: " + iterable.GetType()); | |||||
//var jobj = JObject.FromObject(iterable); | |||||
//foreach (var key in _sorted()) | |||||
// yield return jobj[key]; | |||||
} | |||||
} | |||||
//# See the swig file (util.i) for documentation. | |||||
public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string); | |||||
public static bool is_mapping(object arg) => arg is IDictionary; | |||||
//# See the swig file (util.i) for documentation. | |||||
//flatten = _pywrap_tensorflow.Flatten | |||||
public static List<object> flatten(object structure) | |||||
{ | |||||
var list = new List<object>(); | |||||
_flatten_recursive(structure, list); | |||||
return list; | |||||
} | |||||
private static void _flatten_recursive(object obj, List<object> list) | |||||
{ | |||||
if (obj is string) | |||||
{ | |||||
list.Add(obj); | |||||
return; | |||||
} | |||||
if (obj is IDictionary) | |||||
{ | |||||
var dict = obj as IDictionary; | |||||
foreach (var key in _sorted(dict)) | |||||
_flatten_recursive(dict[key], list); | |||||
return; | |||||
} | |||||
if (obj is NDArray) | |||||
{ | |||||
list.Add(obj); | |||||
return; | |||||
} | |||||
if (obj is IEnumerable) | |||||
{ | |||||
var structure = obj as IEnumerable; | |||||
foreach (var child in structure) | |||||
_flatten_recursive(child, list); | |||||
return; | |||||
} | |||||
list.Add(obj); | |||||
} | |||||
//# See the swig file (util.i) for documentation. | |||||
//_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | |||||
//class _DotString(object): | |||||
// def __str__(self): | |||||
// return "." | |||||
// def __repr__(self): | |||||
// return "." | |||||
//_DOT = _DotString() | |||||
//def assert_same_structure(nest1, nest2, check_types=True): | |||||
// """Asserts that two structures are nested in the same way. | |||||
// Note that namedtuples with identical name and fields are always considered | |||||
// to have the same shallow structure (even with `check_types=True`). | |||||
// For intance, this code will print `True`: | |||||
// ```python | |||||
// def nt(a, b): | |||||
// return collections.namedtuple('foo', 'a b')(a, b) | |||||
// print(assert_same_structure(nt(0, 1), nt(2, 3))) | |||||
// ``` | |||||
// Args: | |||||
// nest1: an arbitrarily nested structure. | |||||
// nest2: an arbitrarily nested structure. | |||||
// check_types: if `True` (default) types of sequences are checked as well, | |||||
// including the keys of dictionaries. If set to `False`, for example a | |||||
// list and a tuple of objects will look the same if they have the same | |||||
// size. Note that namedtuples with identical name and fields are always | |||||
// considered to have the same shallow structure. Two types will also be | |||||
// considered the same if they are both list subtypes (which allows "list" | |||||
// and "_ListWrapper" from checkpointable dependency tracking to compare | |||||
// equal). | |||||
// Raises: | |||||
// ValueError: If the two structures do not have the same number of elements or | |||||
// if the two structures are not nested in the same way. | |||||
// TypeError: If the two structures differ in the type of sequence in any of | |||||
// their substructures. Only possible if `check_types` is `True`. | |||||
// """ | |||||
// try: | |||||
// _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) | |||||
// except (ValueError, TypeError) as e: | |||||
// str1 = str(map_structure(lambda _: _DOT, nest1)) | |||||
// str2 = str(map_structure(lambda _: _DOT, nest2)) | |||||
// raise type(e)("%s\n" | |||||
// "Entire first structure:\n%s\n" | |||||
// "Entire second structure:\n%s" | |||||
// % (str(e), str1, str2)) | |||||
//def flatten_dict_items(dictionary): | |||||
// """Returns a dictionary with flattened keys and values. | |||||
// This function flattens the keys and values of a dictionary, which can be | |||||
// arbitrarily nested structures, and returns the flattened version of such | |||||
// structures: | |||||
// ```python | |||||
// example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} | |||||
// result = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
// flatten_dict_items(example_dictionary) == result | |||||
// ``` | |||||
// The input dictionary must satisfy two properties: | |||||
// 1. Its keys and values should have the same exact nested structure. | |||||
// 2. The set of all flattened keys of the dictionary must not contain repeated | |||||
// keys. | |||||
// Args: | |||||
// dictionary: the dictionary to zip | |||||
// Returns: | |||||
// The zipped dictionary. | |||||
// Raises: | |||||
// TypeError: If the input is not a dictionary. | |||||
// ValueError: If any key and value have not the same structure, or if keys are | |||||
// not unique. | |||||
// """ | |||||
// if not isinstance(dictionary, (dict, _collections.Mapping)): | |||||
// raise TypeError("input must be a dictionary") | |||||
// flat_dictionary = {} | |||||
// for i, v in _six.iteritems(dictionary): | |||||
// if not is_sequence(i): | |||||
// if i in flat_dictionary: | |||||
// raise ValueError( | |||||
// "Could not flatten dictionary: key %s is not unique." % i) | |||||
// flat_dictionary[i] = v | |||||
// else: | |||||
// flat_i = flatten(i) | |||||
// flat_v = flatten(v) | |||||
// if len(flat_i) != len(flat_v): | |||||
// raise ValueError( | |||||
// "Could not flatten dictionary. Key had %d elements, but value had " | |||||
// "%d elements. Key: %s, value: %s." | |||||
// % (len(flat_i), len(flat_v), flat_i, flat_v)) | |||||
// for new_i, new_v in zip(flat_i, flat_v): | |||||
// if new_i in flat_dictionary: | |||||
// raise ValueError( | |||||
// "Could not flatten dictionary: key %s is not unique." | |||||
// % (new_i)) | |||||
// flat_dictionary[new_i] = new_v | |||||
// return flat_dictionary | |||||
/// <summary> | |||||
/// Helper function for pack_sequence_as. | |||||
/// </summary> | |||||
/// <param name="structure">Substructure (list / tuple / dict) to mimic.</param> | |||||
/// <param name="flat">Flattened values to output substructure for.</param> | |||||
/// <param name="index">Index at which to start reading from flat.</param> | |||||
/// <returns> | |||||
/// The tuple(new_index, child), where: | |||||
/// * new_index - the updated index into `flat` having processed `structure`. | |||||
/// * packed - the subset of `flat` corresponding to `structure`, | |||||
/// having started at `index`, and packed into the same nested | |||||
/// format.</returns> | |||||
private static (int new_index, List<object> child) _packed_nest_with_indices(object structure, List<object> flat, | |||||
int index) | |||||
{ | |||||
var packed = new List<object>(); | |||||
foreach (var s in _yield_value(structure)) | |||||
{ | |||||
if (is_sequence(s)) | |||||
{ | |||||
var (new_index, child) = _packed_nest_with_indices(s, flat, index); | |||||
packed.Add(_sequence_like(s, child)); | |||||
index = new_index; | |||||
} | |||||
else | |||||
{ | |||||
packed.Add(flat[index]); | |||||
index += 1; | |||||
} | |||||
} | |||||
return (index, packed); | |||||
} | |||||
private static int len(IEnumerable<object> x) => x.Count(); | |||||
/// <summary> | |||||
/// Returns a given flattened sequence packed into a given structure. | |||||
/// If `structure` is a scalar, `flat_sequence` must be a single-element list; | |||||
/// in this case the return value is `flat_sequence[0]`. | |||||
/// | |||||
/// If `structure` is or contains a dict instance, the keys will be sorted to | |||||
/// pack the flat sequence in deterministic order. This is true also for | |||||
/// `OrderedDict` instances: their sequence order is ignored, the sorting order of | |||||
/// keys is used instead. The same convention is followed in `flatten`. | |||||
/// This correctly repacks dicts and `OrderedDict`s after they have been | |||||
/// flattened, and also allows flattening an `OrderedDict` and then repacking it | |||||
/// back using a corresponding plain dict, or vice-versa. | |||||
/// Dictionaries with non-sortable keys cannot be flattened. | |||||
/// </summary> | |||||
/// <param name="structure"> | |||||
/// Nested structure, whose structure is given by nested lists, | |||||
/// tuples, and dicts. Note: numpy arrays and strings are considered | |||||
/// scalars. | |||||
/// </param> | |||||
/// <param name="flat_sequence"> flat sequence to pack.</param> | |||||
/// <returns> `flat_sequence` converted to have the same recursive structure as | |||||
/// `structure`. | |||||
/// </returns> | |||||
public static object pack_sequence_as(object structure, List<object> flat_sequence) | |||||
{ | |||||
if (flat_sequence == null) | |||||
throw new ArgumentException("flat_sequence must not be null"); | |||||
// if not is_sequence(flat_sequence): | |||||
// raise TypeError("flat_sequence must be a sequence") | |||||
if (!is_sequence(structure)) | |||||
{ | |||||
if (len(flat_sequence) != 1) | |||||
throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1"); | |||||
return flat_sequence.FirstOrDefault(); | |||||
} | |||||
int final_index = 0; | |||||
List<object> packed = null; | |||||
try | |||||
{ | |||||
(final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0); | |||||
if (final_index < len(flat_sequence)) | |||||
throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }"); | |||||
} | |||||
catch (IndexOutOfRangeException) | |||||
{ | |||||
var flat_structure = flatten(structure); | |||||
if (len(flat_structure) != len(flat_sequence)) | |||||
{ | |||||
throw new ValueError("Could not pack sequence. Structure had %d elements, but " + | |||||
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}"); | |||||
} | |||||
return _sequence_like(structure, packed); | |||||
} | |||||
return packed; | |||||
} | |||||
/// <summary> | |||||
/// Applies `func` to each entry in `structure` and returns a new structure. | |||||
/// | |||||
/// Applies `func(x[0], x[1], ...)` where x[i] is an entry in | |||||
/// `structure[i]`. All structures in `structure` must have the same arity, | |||||
/// and the return value will contain the results in the same structure. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <typeparam name="U"></typeparam> | |||||
/// <param name="func"> A callable that accepts as many arguments as there are structures.</param> | |||||
/// <param name="structure">scalar, or tuple or list of constructed scalars and/or other | |||||
/// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param> | |||||
/// <param name="check_types">If set to | |||||
/// `True` (default) the types of iterables within the structures have to be | |||||
/// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` | |||||
/// exception). To allow this set this argument to `False`. | |||||
/// Note that namedtuples with identical name and fields are always | |||||
/// considered to have the same shallow structure.</param> | |||||
/// <returns> | |||||
/// A new structure with the same arity as `structure`, whose values correspond | |||||
/// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding | |||||
/// location in `structure[i]`. If there are different sequence types and | |||||
/// `check_types` is `False` the sequence types of the first structure will be | |||||
/// used. | |||||
/// </returns> | |||||
public static IEnumerable<U> map_structure<T, U>(Func<T, U> func, IEnumerable<T> structure, bool check_types = false) | |||||
{ | |||||
// for other in structure[1:]: | |||||
// assert_same_structure(structure[0], other, check_types=check_types) | |||||
// flat_structure = [flatten(s) for s in structure] | |||||
// entries = zip(*flat_structure) | |||||
// return pack_sequence_as( | |||||
// structure[0], [func(*x) for x in entries]) | |||||
return null; | |||||
} | |||||
//def map_structure_with_paths(func, *structure, **kwargs): | |||||
// """Applies `func` to each entry in `structure` and returns a new structure. | |||||
// Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in | |||||
// `structure[i]` and `path` is the common path to x[i] in the structures. All | |||||
// structures in `structure` must have the same arity, and the return value will | |||||
// contain the results in the same structure. Special kwarg `check_types` | |||||
// determines whether the types of iterables within the structure must be the | |||||
// same-- see **kwargs definition below. | |||||
// Args: | |||||
// func: A callable with the signature func(path, *values, **kwargs) that is | |||||
// evaluated on the leaves of the structure. | |||||
// *structure: A variable number of compatible structures to process. | |||||
// **kwargs: Optional kwargs to be passed through to func. Special kwarg | |||||
// `check_types` is not passed to func, but instead determines whether the | |||||
// types of iterables within the structures have to be same (e.g., | |||||
// `map_structure(func, [1], (1,))` raises a `TypeError` exception). By | |||||
// default, the types must match. To allow iteration over structures of | |||||
// different types (but common arity), set this kwarg to `False`. | |||||
// Returns: | |||||
// A structure of the same form as the input structures whose leaves are the | |||||
// result of evaluating func on corresponding leaves of the input structures. | |||||
// Raises: | |||||
// TypeError: If `func` is not callable or if the structures do not match | |||||
// each other by depth tree. | |||||
// TypeError: If `check_types` is not `False` and the two structures differ in | |||||
// the type of sequence in any of their substructures. | |||||
// ValueError: If no structures are provided. | |||||
// """ | |||||
// if not callable(func): | |||||
// raise TypeError("func must be callable, got: %s" % func) | |||||
// if not structure: | |||||
// raise ValueError("Must provide at least one structure") | |||||
// check_types = kwargs.pop("check_types", True) | |||||
// for other in structure[1:]: | |||||
// assert_same_structure(structure[0], other, check_types=check_types) | |||||
//# First set paths_and_values to: | |||||
//# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] | |||||
// paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] | |||||
//# Now zip(*paths_and_values) would be: | |||||
//# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] | |||||
//# so grouped_by_path is set to: | |||||
//# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] | |||||
//# Note that p1i, ... pmi must all be equal since the structures are the same. | |||||
// grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] | |||||
// return pack_sequence_as(structure[0], [ | |||||
// func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) | |||||
//def _yield_flat_up_to(shallow_tree, input_tree): | |||||
// """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" | |||||
// if is_sequence(shallow_tree): | |||||
// for shallow_branch, input_branch in zip(_yield_value(shallow_tree), | |||||
// _yield_value(input_tree)): | |||||
// for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): | |||||
// yield input_leaf | |||||
// else: | |||||
// yield input_tree | |||||
//def assert_shallow_structure(shallow_tree, input_tree, check_types=True): | |||||
// """Asserts that `shallow_tree` is a shallow structure of `input_tree`. | |||||
// That is, this function tests if the `input_tree` structure can be created from | |||||
// the `shallow_tree` structure by replacing its leaf nodes with deeper | |||||
// tree structures. | |||||
// Examples: | |||||
// The following code will raise an exception: | |||||
// ```python | |||||
// shallow_tree = ["a", "b"] | |||||
// input_tree = ["c", ["d", "e"], "f"] | |||||
// assert_shallow_structure(shallow_tree, input_tree) | |||||
// ``` | |||||
// The following code will not raise an exception: | |||||
// ```python | |||||
// shallow_tree = ["a", "b"] | |||||
// input_tree = ["c", ["d", "e"]] | |||||
// assert_shallow_structure(shallow_tree, input_tree) | |||||
// ``` | |||||
// Args: | |||||
// shallow_tree: an arbitrarily nested structure. | |||||
// input_tree: an arbitrarily nested structure. | |||||
// check_types: if `True` (default) the sequence types of `shallow_tree` and | |||||
// `input_tree` have to be the same. Note that even with check_types==True, | |||||
// this function will consider two different namedtuple classes with the same | |||||
// name and _fields attribute to be the same class. | |||||
// Raises: | |||||
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||||
// TypeError: If the sequence types of `shallow_tree` are different from | |||||
// `input_tree`. Only raised if `check_types` is `True`. | |||||
// ValueError: If the sequence lengths of `shallow_tree` are different from | |||||
// `input_tree`. | |||||
// """ | |||||
// if is_sequence(shallow_tree): | |||||
// if not is_sequence(input_tree): | |||||
// raise TypeError( | |||||
// "If shallow structure is a sequence, input must also be a sequence. " | |||||
// "Input has type: %s." % type(input_tree)) | |||||
// if check_types and not isinstance(input_tree, type(shallow_tree)): | |||||
//# Duck-typing means that nest should be fine with two different | |||||
//# namedtuples with identical name and fields. | |||||
// shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) | |||||
// input_is_namedtuple = _is_namedtuple(input_tree, False) | |||||
// if shallow_is_namedtuple and input_is_namedtuple: | |||||
// if not _same_namedtuples(shallow_tree, input_tree): | |||||
// raise TypeError( | |||||
// "The two namedtuples don't have the same sequence type. Input " | |||||
// "structure has type %s, while shallow structure has type %s." | |||||
// % (type(input_tree), type(shallow_tree))) | |||||
// elif not (isinstance(shallow_tree, _collections.Mapping) | |||||
// and isinstance(input_tree, _collections.Mapping)): | |||||
// raise TypeError( | |||||
// "The two structures don't have the same sequence type. Input " | |||||
// "structure has type %s, while shallow structure has type %s." | |||||
// % (type(input_tree), type(shallow_tree))) | |||||
// if len(input_tree) != len(shallow_tree): | |||||
// raise ValueError( | |||||
// "The two structures don't have the same sequence length. Input " | |||||
// "structure has length %s, while shallow structure has length %s." | |||||
// % (len(input_tree), len(shallow_tree))) | |||||
// if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): | |||||
// if set(input_tree) != set(shallow_tree): | |||||
// raise ValueError( | |||||
// "The two structures don't have the same keys. Input " | |||||
// "structure has keys %s, while shallow structure has keys %s." % | |||||
// (list(_six.iterkeys(input_tree)), | |||||
// list(_six.iterkeys(shallow_tree)))) | |||||
// input_tree = list(sorted(_six.iteritems(input_tree))) | |||||
// shallow_tree = list(sorted(_six.iteritems(shallow_tree))) | |||||
// for shallow_branch, input_branch in zip(shallow_tree, input_tree): | |||||
// assert_shallow_structure(shallow_branch, input_branch, | |||||
// check_types=check_types) | |||||
//def flatten_up_to(shallow_tree, input_tree): | |||||
// """Flattens `input_tree` up to `shallow_tree`. | |||||
// Any further depth in structure in `input_tree` is retained as elements in the | |||||
// partially flatten output. | |||||
// If `shallow_tree` and `input_tree` are not sequences, this returns a | |||||
// single-element list: `[input_tree]`. | |||||
// Use Case: | |||||
// Sometimes we may wish to partially flatten a nested sequence, retaining some | |||||
// of the nested structure. We achieve this by specifying a shallow structure, | |||||
// `shallow_tree`, we wish to flatten up to. | |||||
// The input, `input_tree`, can be thought of as having the same structure as | |||||
// `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||||
// Examples: | |||||
// ```python | |||||
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
// shallow_tree = [[True, True], [False, True]] | |||||
// flattened_input_tree = flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) | |||||
//# Output is: | |||||
//# [[2, 2], [3, 3], [4, 9], [5, 5]] | |||||
//# [True, True, False, True] | |||||
// ``` | |||||
// ```python | |||||
// input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] | |||||
// shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] | |||||
// input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) | |||||
// input_tree_flattened = flatten(input_tree) | |||||
//# Output is: | |||||
//# [('a', 1), ('b', 2), ('c', 3), ('d', 4)] | |||||
//# ['a', 1, 'b', 2, 'c', 3, 'd', 4] | |||||
// ``` | |||||
// Non-Sequence Edge Cases: | |||||
// ```python | |||||
// flatten_up_to(0, 0) # Output: [0] | |||||
// flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] | |||||
// flatten_up_to([0, 1, 2], 0) # Output: TypeError | |||||
// flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] | |||||
// ``` | |||||
// Args: | |||||
// shallow_tree: a possibly pruned structure of input_tree. | |||||
// input_tree: an arbitrarily nested structure or a scalar object. | |||||
// Note, numpy arrays are considered scalars. | |||||
// Returns: | |||||
// A Python list, the partially flattened version of `input_tree` according to | |||||
// the structure of `shallow_tree`. | |||||
// Raises: | |||||
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||||
// TypeError: If the sequence types of `shallow_tree` are different from | |||||
// `input_tree`. | |||||
// ValueError: If the sequence lengths of `shallow_tree` are different from | |||||
// `input_tree`. | |||||
// """ | |||||
// assert_shallow_structure(shallow_tree, input_tree) | |||||
// return list(_yield_flat_up_to(shallow_tree, input_tree)) | |||||
//def map_structure_up_to(shallow_tree, func, *inputs): | |||||
// """Applies a function or op to a number of partially flattened inputs. | |||||
// The `inputs` are flattened up to `shallow_tree` before being mapped. | |||||
// Use Case: | |||||
// Sometimes we wish to apply a function to a partially flattened | |||||
// sequence (for example when the function itself takes sequence inputs). We | |||||
// achieve this by specifying a shallow structure, `shallow_tree` we wish to | |||||
// flatten up to. | |||||
// The `inputs`, can be thought of as having the same structure as | |||||
// `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||||
// This function therefore will return something with the same base structure as | |||||
// `shallow_tree`. | |||||
// Examples: | |||||
// ```python | |||||
// ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
// op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
// inp_val = ab_tuple(a=2, b=3) | |||||
// inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||||
// out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, | |||||
// inp_val, inp_ops) | |||||
//# Output is: ab_tuple(a=6, b=15) | |||||
// ``` | |||||
// ```python | |||||
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
// name_list = ['evens', ['odds', 'primes']] | |||||
// out = map_structure_up_to( | |||||
// name_list, | |||||
// lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
// name_list, data_list) | |||||
//# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] | |||||
// ``` | |||||
// Args: | |||||
// shallow_tree: a shallow tree, common to all the inputs. | |||||
// func: callable which will be applied to each input individually. | |||||
// *inputs: arbitrarily nested combination of objects that are compatible with | |||||
// shallow_tree. The function `func` is applied to corresponding | |||||
// partially flattened elements of each input, so the function must support | |||||
// arity of `len(inputs)`. | |||||
// Raises: | |||||
// TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||||
// TypeError: If the sequence types of `shallow_tree` are different from | |||||
// `input_tree`. | |||||
// ValueError: If the sequence lengths of `shallow_tree` are different from | |||||
// `input_tree`. | |||||
// Returns: | |||||
// result of repeatedly applying `func`, with same structure as | |||||
// `shallow_tree`. | |||||
// """ | |||||
// if not inputs: | |||||
// raise ValueError("Cannot map over no sequences") | |||||
// for input_tree in inputs: | |||||
// assert_shallow_structure(shallow_tree, input_tree) | |||||
//# Flatten each input separately, apply the function to corresponding elements, | |||||
//# then repack based on the structure of the first input. | |||||
// all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) | |||||
// for input_tree in inputs] | |||||
// results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] | |||||
// return pack_sequence_as(structure=shallow_tree, flat_sequence=results) | |||||
//def get_traverse_shallow_structure(traverse_fn, structure): | |||||
// """Generates a shallow structure from a `traverse_fn` and `structure`. | |||||
// `traverse_fn` must accept any possible subtree of `structure` and return | |||||
// a depth=1 structure containing `True` or `False` values, describing which | |||||
// of the top-level subtrees may be traversed. It may also | |||||
// return scalar `True` or `False` "traversal is OK / not OK for all subtrees." | |||||
// Examples are available in the unit tests (nest_test.py). | |||||
// Args: | |||||
// traverse_fn: Function taking a substructure and returning either a scalar | |||||
// `bool` (whether to traverse that substructure or not) or a depth=1 | |||||
// shallow structure of the same type, describing which parts of the | |||||
// substructure to traverse. | |||||
// structure: The structure to traverse. | |||||
// Returns: | |||||
// A shallow structure containing python bools, which can be passed to | |||||
// `map_structure_up_to` and `flatten_up_to`. | |||||
// Raises: | |||||
// TypeError: if `traverse_fn` returns a sequence for a non-sequence input, | |||||
// or a structure with depth higher than 1 for a sequence input, | |||||
// or if any leaf values in the returned structure or scalar are not type | |||||
// `bool`. | |||||
// """ | |||||
// to_traverse = traverse_fn(structure) | |||||
// if not is_sequence(structure): | |||||
// if not isinstance(to_traverse, bool): | |||||
// raise TypeError("traverse_fn returned structure: %s for non-structure: %s" | |||||
// % (to_traverse, structure)) | |||||
// return to_traverse | |||||
// level_traverse = [] | |||||
// if isinstance(to_traverse, bool): | |||||
// if not to_traverse: | |||||
//# Do not traverse this substructure at all. Exit early. | |||||
// return False | |||||
// else: | |||||
//# Traverse the entire substructure. | |||||
// for branch in _yield_value(structure): | |||||
// level_traverse.append( | |||||
// get_traverse_shallow_structure(traverse_fn, branch)) | |||||
// elif not is_sequence(to_traverse): | |||||
// raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" | |||||
// % (to_traverse, structure)) | |||||
// else: | |||||
//# Traverse some subset of this substructure. | |||||
// assert_shallow_structure(to_traverse, structure) | |||||
// for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): | |||||
// if not isinstance(t, bool): | |||||
// raise TypeError( | |||||
// "traverse_fn didn't return a depth=1 structure of bools. saw: %s " | |||||
// " for structure: %s" % (to_traverse, structure)) | |||||
// if t: | |||||
// level_traverse.append( | |||||
// get_traverse_shallow_structure(traverse_fn, branch)) | |||||
// else: | |||||
// level_traverse.append(False) | |||||
// return _sequence_like(structure, level_traverse) | |||||
//def yield_flat_paths(nest): | |||||
// """Yields paths for some nested structure. | |||||
// Paths are lists of objects which can be str-converted, which may include | |||||
// integers or other types which are used as indices in a dict. | |||||
// The flat list will be in the corresponding order as if you called | |||||
// `snt.nest.flatten` on the structure. This is handy for naming Tensors such | |||||
// the TF scope structure matches the tuple structure. | |||||
// E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` | |||||
// ```shell | |||||
// >>> nest.flatten(value) | |||||
// [3, 23, 42] | |||||
// >>> list(nest.yield_flat_paths(value)) | |||||
// [('a',), ('b', 'c'), ('b', 'd')] | |||||
// ``` | |||||
// ```shell | |||||
// >>> list(nest.yield_flat_paths({'a': [3]})) | |||||
// [('a', 0)] | |||||
// >>> list(nest.yield_flat_paths({'a': 3})) | |||||
// [('a',)] | |||||
// ``` | |||||
// Args: | |||||
// nest: the value to produce a flattened paths list for. | |||||
// Yields: | |||||
// Tuples containing index or key values which form the path to a specific | |||||
// leaf value in the nested structure. | |||||
// """ | |||||
//# The _maybe_add_final_path_element function is used below in order to avoid | |||||
//# adding trailing slashes when the sub-element recursed into is a leaf. | |||||
// if isinstance(nest, (dict, _collections.Mapping)): | |||||
// for key in _sorted(nest): | |||||
// value = nest[key] | |||||
// for sub_path in yield_flat_paths(value): | |||||
// yield (key,) + sub_path | |||||
// elif _is_namedtuple(nest): | |||||
// for key in nest._fields: | |||||
// value = getattr(nest, key) | |||||
// for sub_path in yield_flat_paths(value): | |||||
// yield (key,) + sub_path | |||||
// elif isinstance(nest, _six.string_types): | |||||
// yield () | |||||
// elif isinstance(nest, _collections.Sequence): | |||||
// for idx, value in enumerate(nest): | |||||
// for sub_path in yield_flat_paths(value): | |||||
// yield (idx,) + sub_path | |||||
// else: | |||||
// yield () | |||||
//def flatten_with_joined_string_paths(structure, separator="/"): | |||||
// """Returns a list of (string path, data element) tuples. | |||||
// The order of tuples produced matches that of `nest.flatten`. This allows you | |||||
// to flatten a nested structure while keeping information about where in the | |||||
// structure each data element was located. See `nest.yield_flat_paths` | |||||
// for more information. | |||||
// Args: | |||||
// structure: the nested structure to flatten. | |||||
// separator: string to separate levels of hierarchy in the results, defaults | |||||
// to '/'. | |||||
// Returns: | |||||
// A list of (string, data element) tuples. | |||||
// """ | |||||
// flat_paths = yield_flat_paths(structure) | |||||
// def stringify_and_join(path_elements): | |||||
// return separator.join(str(path_element) for path_element in path_elements) | |||||
// flat_string_paths = [stringify_and_join(path) for path in flat_paths] | |||||
// return list(zip(flat_string_paths, flatten(structure))) | |||||
} | |||||
} |
@@ -0,0 +1,852 @@ | |||||
using System.Collections; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Newtonsoft.Json.Linq; | |||||
using Tensorflow; | |||||
using Tensorflow.Util; | |||||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
{ | |||||
/// <summary> | |||||
/// excerpt of tensorflow/python/framework/util/nest_test.py | |||||
/// </summary> | |||||
[TestClass] | |||||
public class NestTest : PythonTest | |||||
{ | |||||
public class PointXY | |||||
{ | |||||
public double x; | |||||
public double y; | |||||
} | |||||
// if attr: | |||||
// class BadAttr(object): | |||||
// """Class that has a non-iterable __attrs_attrs__.""" | |||||
// __attrs_attrs__ = None | |||||
// @attr.s | |||||
// class SampleAttr(object): | |||||
// field1 = attr.ib() | |||||
// field2 = attr.ib() | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testAttrsFlattenAndPack(self) : | |||||
// if attr is None: | |||||
// self.skipTest("attr module is unavailable.") | |||||
// field_values = [1, 2] | |||||
// sample_attr = NestTest.SampleAttr(* field_values) | |||||
// self.assertFalse(nest._is_attrs(field_values)) | |||||
// self.assertTrue(nest._is_attrs(sample_attr)) | |||||
// flat = nest.flatten(sample_attr) | |||||
// self.assertEqual(field_values, flat) | |||||
// restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||||
// self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||||
// self.assertEqual(restructured_from_flat, sample_attr) | |||||
//# Check that flatten fails if attributes are not iterable | |||||
// with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||||
// flat = nest.flatten(NestTest.BadAttr()) | |||||
[TestMethod] | |||||
public void testFlattenAndPack() | |||||
{ | |||||
object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}}; | |||||
var flat = new List<object> {"a", "b", "c", "d", "e", "f", "g", "h"}; | |||||
self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8}); | |||||
self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), | |||||
JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString()); | |||||
structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }}; | |||||
flat = new List<object> { 4, 2, 1, 0}; | |||||
self.assertEqual(nest.flatten(structure), flat); | |||||
// restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||||
// self.assertEqual(restructured_from_flat, structure) | |||||
// self.assertEqual(restructured_from_flat[0].x, 4) | |||||
// self.assertEqual(restructured_from_flat[0].y, 2) | |||||
// self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||||
// self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||||
// self.assertEqual([5], nest.flatten(5)) | |||||
// self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||||
// self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||||
// self.assertEqual( | |||||
// np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||||
// with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||||
// nest.pack_sequence_as("scalar", [4, 5]) | |||||
// with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||||
// nest.pack_sequence_as([4, 5], "bad_sequence") | |||||
// with self.assertRaises(ValueError): | |||||
// nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||||
} | |||||
// @parameterized.parameters({"mapping_type": collections.OrderedDict | |||||
// }, | |||||
// {"mapping_type": _CustomMapping | |||||
//}) | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testFlattenDictOrder(self, mapping_type) : | |||||
// """`flatten` orders dicts by key, including OrderedDicts.""" | |||||
// ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||||
// plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||||
// ordered_flat = nest.flatten(ordered) | |||||
// plain_flat = nest.flatten(plain) | |||||
// self.assertEqual([0, 1, 2, 3], ordered_flat) | |||||
// self.assertEqual([0, 1, 2, 3], plain_flat) | |||||
// @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
// {"mapping_type": _CustomMapping}) | |||||
// def testPackDictOrder(self, mapping_type): | |||||
// """Packing orders dicts by key, including OrderedDicts.""" | |||||
// custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||||
// plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||||
// seq = [0, 1, 2, 3] | |||||
//custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||||
//plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||||
// self.assertIsInstance(custom_reconstruction, mapping_type) | |||||
// self.assertIsInstance(plain_reconstruction, dict) | |||||
// self.assertEqual( | |||||
// mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||||
// custom_reconstruction) | |||||
// self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||||
// Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testFlattenAndPack_withDicts(self) : | |||||
// # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||||
// mess = [ | |||||
// "z", | |||||
// NestTest.Abc(3, 4), { | |||||
// "d": _CustomMapping({ | |||||
// 41: 4 | |||||
// }), | |||||
// "c": [ | |||||
// 1, | |||||
// collections.OrderedDict([ | |||||
// ("b", 3), | |||||
// ("a", 2), | |||||
// ]), | |||||
// ], | |||||
// "b": 5 | |||||
// }, 17 | |||||
// ] | |||||
// flattened = nest.flatten(mess) | |||||
// self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||||
// structure_of_mess = [ | |||||
// 14, | |||||
// NestTest.Abc("a", True), | |||||
// { | |||||
// "d": _CustomMapping({ | |||||
// 41: 42 | |||||
// }), | |||||
// "c": [ | |||||
// 0, | |||||
// collections.OrderedDict([ | |||||
// ("b", 9), | |||||
// ("a", 8), | |||||
// ]), | |||||
// ], | |||||
// "b": 3 | |||||
// }, | |||||
// "hi everybody", | |||||
// ] | |||||
// unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||||
// self.assertEqual(unflattened, mess) | |||||
// # Check also that the OrderedDict was created, with the correct key order. | |||||
//unflattened_ordered_dict = unflattened[2]["c"][1] | |||||
// self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||||
// self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||||
// unflattened_custom_mapping = unflattened[2]["d"] | |||||
// self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||||
// self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||||
// def testFlatten_numpyIsNotFlattened(self): | |||||
// structure = np.array([1, 2, 3]) | |||||
// flattened = nest.flatten(structure) | |||||
// self.assertEqual(len(flattened), 1) | |||||
// def testFlatten_stringIsNotFlattened(self): | |||||
// structure = "lots of letters" | |||||
// flattened = nest.flatten(structure) | |||||
// self.assertEqual(len(flattened), 1) | |||||
// unflattened = nest.pack_sequence_as("goodbye", flattened) | |||||
// self.assertEqual(structure, unflattened) | |||||
// def testPackSequenceAs_notIterableError(self) : | |||||
// with self.assertRaisesRegexp(TypeError, | |||||
// "flat_sequence must be a sequence"): | |||||
// nest.pack_sequence_as("hi", "bye") | |||||
// def testPackSequenceAs_wrongLengthsError(self): | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// "Structure had 2 elements, but flat_sequence had 3 elements."): | |||||
// nest.pack_sequence_as(["hello", "world"], | |||||
// ["and", "goodbye", "again"]) | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testIsSequence(self): | |||||
// self.assertFalse(nest.is_sequence("1234")) | |||||
// self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||||
// self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||||
// self.assertTrue(nest.is_sequence([])) | |||||
// self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||||
// self.assertFalse(nest.is_sequence(set([1, 2]))) | |||||
// ones = array_ops.ones([2, 3]) | |||||
// self.assertFalse(nest.is_sequence(ones)) | |||||
// self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||||
// self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||||
// @parameterized.parameters({"mapping_type": _CustomMapping}, | |||||
// {"mapping_type": dict}) | |||||
// def testFlattenDictItems(self, mapping_type): | |||||
// dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||||
// flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
// self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||||
// with self.assertRaises(TypeError): | |||||
// nest.flatten_dict_items(4) | |||||
// bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||||
// with self.assertRaisesRegexp(ValueError, "not unique"): | |||||
// nest.flatten_dict_items(bad_dictionary) | |||||
// another_bad_dictionary = mapping_type({ | |||||
// (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||||
// }) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||||
// nest.flatten_dict_items(another_bad_dictionary) | |||||
//# pylint does not correctly recognize these as class names and | |||||
//# suggests to use variable style under_score naming. | |||||
//# pylint: disable=invalid-name | |||||
// Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||||
// Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||||
// SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||||
// SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||||
// SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||||
// SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||||
// SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||||
// NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||||
// # pylint: enable=invalid-name | |||||
// class SameNamedType1(SameNameab): | |||||
// pass | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testAssertSameStructure(self): | |||||
// structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
// structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
// structure_different_num_elements = ("spam", "eggs") | |||||
// structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||||
// nest.assert_same_structure(structure1, structure2) | |||||
// nest.assert_same_structure("abc", 1.0) | |||||
// nest.assert_same_structure("abc", np.array([0, 1])) | |||||
// nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// ("The two structures don't have the same nested structure\\.\n\n" | |||||
// "First structure:.*?\n\n" | |||||
// "Second structure:.*\n\n" | |||||
// "More specifically: Substructure " | |||||
// r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||||
// 'substructure "type=str str=spam" is not\n' | |||||
// "Entire first structure:\n" | |||||
// r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||||
// "Entire second structure:\n" | |||||
// r"\(\., \.\)")): | |||||
// nest.assert_same_structure(structure1, structure_different_num_elements) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// ("The two structures don't have the same nested structure\\.\n\n" | |||||
// "First structure:.*?\n\n" | |||||
// "Second structure:.*\n\n" | |||||
// r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
// r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||||
// "is not")): | |||||
// nest.assert_same_structure([0, 1], np.array([0, 1])) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// ("The two structures don't have the same nested structure\\.\n\n" | |||||
// "First structure:.*?\n\n" | |||||
// "Second structure:.*\n\n" | |||||
// r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
// 'is a sequence, while substructure "type=int str=0" ' | |||||
// "is not")): | |||||
// nest.assert_same_structure(0, [0, 1]) | |||||
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// ("don't have the same nested structure\\.\n\n" | |||||
// "First structure: .*?\n\nSecond structure: ")): | |||||
// nest.assert_same_structure(structure1, structure_different_nesting) | |||||
// self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||||
// NestTest.Named0ab("a", "b")) | |||||
// nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
// NestTest.Named0ab("a", "b")) | |||||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||||
// NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// ("don't have the same nested structure\\.\n\n" | |||||
// "First structure: .*?\n\nSecond structure: ")): | |||||
// nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
// NestTest.Named0ab([3], 4)) | |||||
// with self.assertRaisesRegexp( | |||||
// ValueError, | |||||
// ("don't have the same nested structure\\.\n\n" | |||||
// "First structure: .*?\n\nSecond structure: ")): | |||||
// nest.assert_same_structure([[3], 4], [3, [4]]) | |||||
// structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
// with self.assertRaisesRegexp(TypeError, | |||||
// "don't have the same sequence type"): | |||||
// nest.assert_same_structure(structure1, structure1_list) | |||||
// nest.assert_same_structure(structure1, structure2, check_types= False) | |||||
// nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||||
// with self.assertRaisesRegexp(ValueError, | |||||
// "don't have the same set of keys"): | |||||
// nest.assert_same_structure({"a": 1}, {"b": 1}) | |||||
// nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||||
// NestTest.SameNameab2(2, 3)) | |||||
// # This assertion is expected to pass: two namedtuples with the same | |||||
// # name and field names are considered to be identical. | |||||
// nest.assert_same_structure( | |||||
// NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||||
// NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||||
// expected_message = "The two structures don't have the same.*" | |||||
// with self.assertRaisesRegexp(ValueError, expected_message): | |||||
// nest.assert_same_structure( | |||||
// NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||||
// NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||||
// NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||||
// NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||||
// self.assertRaises(TypeError, nest.assert_same_structure, | |||||
// NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||||
// EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||||
// def testHeterogeneousComparison(self): | |||||
// nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) | |||||
// nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testMapStructure(self) : | |||||
// structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
// structure2 = (((7, 8), 9), 10, (11, 12)) | |||||
// structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||||
// nest.assert_same_structure(structure1, structure1_plus1) | |||||
// self.assertAllEqual( | |||||
// [2, 3, 4, 5, 6, 7], | |||||
// nest.flatten(structure1_plus1)) | |||||
// structure1_plus_structure2 = nest.map_structure( | |||||
// lambda x, y: x + y, structure1, structure2) | |||||
// self.assertEqual( | |||||
// (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||||
// structure1_plus_structure2) | |||||
// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||||
// self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||||
// # Empty structures | |||||
// self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||||
// self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||||
// self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||||
// self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||||
// NestTest.EmptyNT())) | |||||
// # This is checking actual equality of types, empty list != empty tuple | |||||
// self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||||
// with self.assertRaisesRegexp(TypeError, "callable"): | |||||
// nest.map_structure("bad", structure1_plus1) | |||||
// with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||||
// nest.map_structure(lambda x: x) | |||||
// with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||||
// nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||||
// with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
// nest.map_structure(lambda x, y: None, 3, (3,)) | |||||
// with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
// nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||||
// with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||||
// structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
// with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
// nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||||
// nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||||
// check_types=False) | |||||
// with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
// nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||||
// check_types=False) | |||||
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
// nest.map_structure(lambda x: None, structure1, foo="a") | |||||
// with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
// nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||||
// ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||||
// @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
// def testMapStructureWithStrings(self) : | |||||
// inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||||
// inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||||
// out = nest.map_structure(lambda string, repeats: string* repeats, | |||||
// inp_a, | |||||
// inp_b) | |||||
// self.assertEqual("foofoo", out.a) | |||||
// self.assertEqual("bar", out.b[0]) | |||||
// self.assertEqual("bazbazbaz", out.b[1]) | |||||
// nt = NestTest.ABTuple(a=("something", "something_else"), | |||||
// b="yet another thing") | |||||
// rev_nt = nest.map_structure(lambda x: x[::- 1], nt) | |||||
// # Check the output is the correct structure, and all strings are reversed. | |||||
// nest.assert_same_structure(nt, rev_nt) | |||||
// self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) | |||||
// self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) | |||||
// self.assertEqual(nt.b[::- 1], rev_nt.b) | |||||
// @test_util.run_deprecated_v1 | |||||
// def testMapStructureOverPlaceholders(self) : | |||||
// inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
// array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
// inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
// array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
// output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||||
// nest.assert_same_structure(output, inp_a) | |||||
// self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||||
// self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||||
// feed_dict = { | |||||
// inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||||
// inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||||
// } | |||||
// with self.cached_session() as sess: | |||||
// output_np = sess.run(output, feed_dict=feed_dict) | |||||
// self.assertAllClose(output_np[0], | |||||
// feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||||
// self.assertAllClose(output_np[1], | |||||
// feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||||
// def testAssertShallowStructure(self): | |||||
// inp_ab = ["a", "b"] | |||||
//inp_abc = ["a", "b", "c"] | |||||
//expected_message = ( | |||||
// "The two structures don't have the same sequence length. Input " | |||||
// "structure has length 2, while shallow structure has length 3.") | |||||
// with self.assertRaisesRegexp(ValueError, expected_message): | |||||
// nest.assert_shallow_structure(inp_abc, inp_ab) | |||||
// inp_ab1 = [(1, 1), (2, 2)] | |||||
// inp_ab2 = [[1, 1], [2, 2]] | |||||
// expected_message = ( | |||||
// "The two structures don't have the same sequence type. Input structure " | |||||
// "has type <(type|class) 'tuple'>, while shallow structure has type " | |||||
// "<(type|class) 'list'>.") | |||||
// with self.assertRaisesRegexp(TypeError, expected_message): | |||||
// nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
// nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) | |||||
// inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||||
// inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||||
// expected_message = ( | |||||
// r"The two structures don't have the same keys. Input " | |||||
// r"structure has keys \['c'\], while shallow structure has " | |||||
// r"keys \['d'\].") | |||||
// with self.assertRaisesRegexp(ValueError, expected_message): | |||||
// nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
// inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||||
// inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||||
// nest.assert_shallow_structure(inp_ab, inp_ba) | |||||
// # This assertion is expected to pass: two namedtuples with the same | |||||
//# name and field names are considered to be identical. | |||||
//inp_shallow = NestTest.SameNameab(1, 2) | |||||
// inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||||
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||||
// nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||||
// def testFlattenUpTo(self): | |||||
// # Shallow tree ends at scalar. | |||||
// input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
// shallow_tree = [[True, True], [False, True]] | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||||
// self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||||
//# Shallow tree ends at string. | |||||
// input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||||
// shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
// input_tree) | |||||
// input_tree_flattened = nest.flatten(input_tree) | |||||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
// [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||||
// self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||||
// # Make sure dicts are correctly flattened, yielding values, not keys. | |||||
//input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||||
// shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
// input_tree) | |||||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
// [1, { "c": 2}, 3, (4, 5)]) | |||||
// # Namedtuples. | |||||
// ab_tuple = NestTest.ABTuple | |||||
// input_tree = ab_tuple(a =[0, 1], b = 2) | |||||
// shallow_tree = ab_tuple(a= 0, b= 1) | |||||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
// input_tree) | |||||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
// [[0, 1], 2]) | |||||
// # Nested dicts, OrderedDicts and namedtuples. | |||||
// input_tree = collections.OrderedDict( | |||||
// [("a", ab_tuple(a =[0, {"b": 1}], b=2)), | |||||
// ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||||
// shallow_tree = input_tree | |||||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
// input_tree) | |||||
// self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||||
// shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
// input_tree) | |||||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
// [ab_tuple(a =[0, { "b": 1}], b=2), | |||||
// 3, | |||||
// collections.OrderedDict([("f", 4)])]) | |||||
// shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||||
// input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
// input_tree) | |||||
// self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
// [ab_tuple(a =[0, {"b": 1}], b=2), | |||||
// {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||||
// ## Shallow non-list edge-case. | |||||
// # Using iterable elements. | |||||
// input_tree = ["input_tree"] | |||||
//shallow_tree = "shallow_tree" | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
// input_tree = ["input_tree_0", "input_tree_1"] | |||||
//shallow_tree = "shallow_tree" | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
// # Using non-iterable elements. | |||||
//input_tree = [0] | |||||
//shallow_tree = 9 | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
// input_tree = [0, 1] | |||||
//shallow_tree = 9 | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
// ## Both non-list edge-case. | |||||
//# Using iterable elements. | |||||
//input_tree = "input_tree" | |||||
// shallow_tree = "shallow_tree" | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
// # Using non-iterable elements. | |||||
//input_tree = 0 | |||||
// shallow_tree = 0 | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_input_tree, [input_tree]) | |||||
// self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
// ## Input non-list edge-case. | |||||
//# Using iterable elements. | |||||
//input_tree = "input_tree" | |||||
// shallow_tree = ["shallow_tree"] | |||||
//expected_message = ("If shallow structure is a sequence, input must also " | |||||
// "be a sequence. Input has type: <(type|class) 'str'>.") | |||||
// with self.assertRaisesRegexp(TypeError, expected_message): | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
// input_tree = "input_tree" | |||||
// shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||||
//with self.assertRaisesRegexp(TypeError, expected_message): | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
//# Using non-iterable elements. | |||||
// input_tree = 0 | |||||
// shallow_tree = [9] | |||||
//expected_message = ("If shallow structure is a sequence, input must also " | |||||
// "be a sequence. Input has type: <(type|class) 'int'>.") | |||||
// with self.assertRaisesRegexp(TypeError, expected_message): | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
// input_tree = 0 | |||||
// shallow_tree = [9, 8] | |||||
//with self.assertRaisesRegexp(TypeError, expected_message): | |||||
// flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
// flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
// self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
// def testMapStructureUpTo(self) : | |||||
// # Named tuples. | |||||
// ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
// op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
// inp_val = ab_tuple(a= 2, b= 3) | |||||
// inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) | |||||
// out = nest.map_structure_up_to( | |||||
// inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||||
// self.assertEqual(out.a, 6) | |||||
// self.assertEqual(out.b, 15) | |||||
// # Lists. | |||||
// data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
// name_list = ["evens", ["odds", "primes"]] | |||||
// out = nest.map_structure_up_to( | |||||
// name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
// name_list, data_list) | |||||
// self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||||
// # Dicts. | |||||
// inp_val = dict(a= 2, b= 3) | |||||
// inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||||
// out = nest.map_structure_up_to( | |||||
// inp_val, | |||||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
// self.assertEqual(out["a"], 6) | |||||
// self.assertEqual(out["b"], 15) | |||||
// # Non-equal dicts. | |||||
// inp_val = dict(a= 2, b= 3) | |||||
// inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||||
// with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
// nest.map_structure_up_to( | |||||
// inp_val, | |||||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
// # Dict+custom mapping. | |||||
// inp_val = dict(a= 2, b= 3) | |||||
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||||
// out = nest.map_structure_up_to( | |||||
// inp_val, | |||||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
// self.assertEqual(out["a"], 6) | |||||
// self.assertEqual(out["b"], 15) | |||||
// # Non-equal dict/mapping. | |||||
// inp_val = dict(a= 2, b= 3) | |||||
// inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||||
// with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
// nest.map_structure_up_to( | |||||
// inp_val, | |||||
// lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
// def testGetTraverseShallowStructure(self): | |||||
// scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||||
// scalar_traverse_r = nest.get_traverse_shallow_structure( | |||||
// lambda s: not isinstance(s, tuple), | |||||
// scalar_traverse_input) | |||||
// self.assertEqual(scalar_traverse_r, | |||||
// [True, True, False, [True, True], {"a": False}, []]) | |||||
// nest.assert_shallow_structure(scalar_traverse_r, | |||||
// scalar_traverse_input) | |||||
// structure_traverse_input = [(1, [2]), ([1], 2)] | |||||
// structure_traverse_r = nest.get_traverse_shallow_structure( | |||||
// lambda s: (True, False) if isinstance(s, tuple) else True, | |||||
// structure_traverse_input) | |||||
// self.assertEqual(structure_traverse_r, | |||||
// [(True, False), ([True], False)]) | |||||
// nest.assert_shallow_structure(structure_traverse_r, | |||||
// structure_traverse_input) | |||||
// with self.assertRaisesRegexp(TypeError, "returned structure"): | |||||
// nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||||
// with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||||
// nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||||
// with self.assertRaisesRegexp( | |||||
// TypeError, "didn't return a depth=1 structure of bools"): | |||||
// nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||||
// def testYieldFlatStringPaths(self): | |||||
// for inputs_expected in ({"inputs": [], "expected": []}, | |||||
// {"inputs": 3, "expected": [()]}, | |||||
// {"inputs": [3], "expected": [(0,)]}, | |||||
// {"inputs": {"a": 3}, "expected": [("a",)]}, | |||||
// {"inputs": {"a": {"b": 4}}, | |||||
// "expected": [("a", "b")]}, | |||||
// {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||||
// {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||||
// {"inputs": [{"a": [(23, 42)]}], | |||||
// "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||||
// {"inputs": [{"a": ([23], 42)}], | |||||
// "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||||
// {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||||
// "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||||
// {"inputs": {"0": [{"1": 23}]}, | |||||
// "expected": [("0", 0, "1")]}): | |||||
// inputs = inputs_expected["inputs"] | |||||
// expected = inputs_expected["expected"] | |||||
// self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||||
// def testFlattenWithStringPaths(self): | |||||
// for inputs_expected in ( | |||||
// {"inputs": [], "expected": []}, | |||||
// {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||||
// {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||||
// inputs = inputs_expected["inputs"] | |||||
// expected = inputs_expected["expected"] | |||||
// self.assertEqual( | |||||
// nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||||
// expected) | |||||
// # Need a separate test for namedtuple as we can't declare tuple definitions | |||||
// # in the @parameterized arguments. | |||||
// def testFlattenNamedTuple(self): | |||||
// # pylint: disable=invalid-name | |||||
// Foo = collections.namedtuple("Foo", ["a", "b"]) | |||||
// Bar = collections.namedtuple("Bar", ["c", "d"]) | |||||
// # pylint: enable=invalid-name | |||||
// test_cases = [ | |||||
// (Foo(a = 3, b = Bar(c = 23, d = 42)), | |||||
// [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||||
// (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), | |||||
// [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||||
// (Bar(c = 42, d = 43), | |||||
// [("c", 42), ("d", 43)]), | |||||
// (Bar(c =[42], d = 43), | |||||
// [("c/0", 42), ("d", 43)]), | |||||
// ] | |||||
// for inputs, expected in test_cases: | |||||
// self.assertEqual( | |||||
// list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||||
// @parameterized.named_parameters( | |||||
// ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||||
// ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||||
// {"a": ("a", 4), "b": ("b", 6)}), | |||||
// ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||||
// ("nested", | |||||
// {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||||
// {"a": [("a/0", 10), ("a/1", 12)], | |||||
// "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||||
// def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||||
// def format_sum(path, * values): | |||||
// return (path, sum(values)) | |||||
// result = nest.map_structure_with_paths(format_sum, s1, s2, | |||||
// check_types=check_types) | |||||
// self.assertEqual(expected, result) | |||||
// @parameterized.named_parameters( | |||||
// ("tuples", (1, 2), (3, 4, 5), ValueError), | |||||
// ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||||
// ("mixed", (1, 2), [3, 4], TypeError), | |||||
// ("nested", | |||||
// {"a": [2, 3], "b": [1, 3]}, | |||||
// {"b": [5, 6, 7], "a": [8, 9]}, | |||||
// ValueError | |||||
// )) | |||||
// def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||||
// with self.assertRaises(error_type): | |||||
// nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) | |||||
//class NestBenchmark(test.Benchmark): | |||||
// def run_and_report(self, s1, s2, name): | |||||
// burn_iter, test_iter = 100, 30000 | |||||
// for _ in xrange(burn_iter) : | |||||
// nest.assert_same_structure(s1, s2) | |||||
// t0 = time.time() | |||||
// for _ in xrange(test_iter) : | |||||
// nest.assert_same_structure(s1, s2) | |||||
// t1 = time.time() | |||||
// self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||||
// name=name) | |||||
// def benchmark_assert_structure(self): | |||||
// s1 = (((1, 2), 3), 4, (5, 6)) | |||||
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
// self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||||
// s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||||
// s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||||
// self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||||
//if __name__ == "__main__": | |||||
// test.main() | |||||
} | |||||
} |
@@ -0,0 +1,883 @@ | |||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
# ============================================================================== | |||||
"""Tests for utilities working with arbitrarily nested structures.""" | |||||
from __future__ import absolute_import | |||||
from __future__ import division | |||||
from __future__ import print_function | |||||
import collections | |||||
import time | |||||
from absl.testing import parameterized | |||||
import numpy as np | |||||
from six.moves import xrange # pylint: disable=redefined-builtin | |||||
from tensorflow.python.framework import constant_op | |||||
from tensorflow.python.framework import dtypes | |||||
from tensorflow.python.framework import test_util | |||||
from tensorflow.python.ops import array_ops | |||||
from tensorflow.python.ops import math_ops | |||||
from tensorflow.python.platform import test | |||||
from tensorflow.python.util import nest | |||||
try: | |||||
import attr # pylint:disable=g-import-not-at-top | |||||
except ImportError: | |||||
attr = None | |||||
class _CustomMapping(collections.Mapping): | |||||
def __init__(self, *args, **kwargs): | |||||
self._wrapped = dict(*args, **kwargs) | |||||
def __getitem__(self, key): | |||||
return self._wrapped[key] | |||||
def __iter__(self): | |||||
return iter(self._wrapped) | |||||
def __len__(self): | |||||
return len(self._wrapped) | |||||
class NestTest(parameterized.TestCase, test.TestCase): | |||||
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name | |||||
if attr: | |||||
class BadAttr(object): | |||||
"""Class that has a non-iterable __attrs_attrs__.""" | |||||
__attrs_attrs__ = None | |||||
@attr.s | |||||
class SampleAttr(object): | |||||
field1 = attr.ib() | |||||
field2 = attr.ib() | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testAttrsFlattenAndPack(self): | |||||
if attr is None: | |||||
self.skipTest("attr module is unavailable.") | |||||
field_values = [1, 2] | |||||
sample_attr = NestTest.SampleAttr(*field_values) | |||||
self.assertFalse(nest._is_attrs(field_values)) | |||||
self.assertTrue(nest._is_attrs(sample_attr)) | |||||
flat = nest.flatten(sample_attr) | |||||
self.assertEqual(field_values, flat) | |||||
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||||
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||||
self.assertEqual(restructured_from_flat, sample_attr) | |||||
# Check that flatten fails if attributes are not iterable | |||||
with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||||
flat = nest.flatten(NestTest.BadAttr()) | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testFlattenAndPack(self): | |||||
structure = ((3, 4), 5, (6, 7, (9, 10), 8)) | |||||
flat = ["a", "b", "c", "d", "e", "f", "g", "h"] | |||||
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) | |||||
self.assertEqual( | |||||
nest.pack_sequence_as(structure, flat), (("a", "b"), "c", | |||||
("d", "e", ("f", "g"), "h"))) | |||||
structure = (NestTest.PointXY(x=4, y=2), | |||||
((NestTest.PointXY(x=1, y=0),),)) | |||||
flat = [4, 2, 1, 0] | |||||
self.assertEqual(nest.flatten(structure), flat) | |||||
restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||||
self.assertEqual(restructured_from_flat, structure) | |||||
self.assertEqual(restructured_from_flat[0].x, 4) | |||||
self.assertEqual(restructured_from_flat[0].y, 2) | |||||
self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||||
self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||||
self.assertEqual([5], nest.flatten(5)) | |||||
self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||||
self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||||
self.assertEqual( | |||||
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||||
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||||
nest.pack_sequence_as("scalar", [4, 5]) | |||||
with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||||
nest.pack_sequence_as([4, 5], "bad_sequence") | |||||
with self.assertRaises(ValueError): | |||||
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||||
@parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
{"mapping_type": _CustomMapping}) | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testFlattenDictOrder(self, mapping_type): | |||||
"""`flatten` orders dicts by key, including OrderedDicts.""" | |||||
ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||||
plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||||
ordered_flat = nest.flatten(ordered) | |||||
plain_flat = nest.flatten(plain) | |||||
self.assertEqual([0, 1, 2, 3], ordered_flat) | |||||
self.assertEqual([0, 1, 2, 3], plain_flat) | |||||
@parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
{"mapping_type": _CustomMapping}) | |||||
def testPackDictOrder(self, mapping_type): | |||||
"""Packing orders dicts by key, including OrderedDicts.""" | |||||
custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||||
plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||||
seq = [0, 1, 2, 3] | |||||
custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||||
plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||||
self.assertIsInstance(custom_reconstruction, mapping_type) | |||||
self.assertIsInstance(plain_reconstruction, dict) | |||||
self.assertEqual( | |||||
mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||||
custom_reconstruction) | |||||
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||||
Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testFlattenAndPack_withDicts(self): | |||||
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||||
mess = [ | |||||
"z", | |||||
NestTest.Abc(3, 4), { | |||||
"d": _CustomMapping({ | |||||
41: 4 | |||||
}), | |||||
"c": [ | |||||
1, | |||||
collections.OrderedDict([ | |||||
("b", 3), | |||||
("a", 2), | |||||
]), | |||||
], | |||||
"b": 5 | |||||
}, 17 | |||||
] | |||||
flattened = nest.flatten(mess) | |||||
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||||
structure_of_mess = [ | |||||
14, | |||||
NestTest.Abc("a", True), | |||||
{ | |||||
"d": _CustomMapping({ | |||||
41: 42 | |||||
}), | |||||
"c": [ | |||||
0, | |||||
collections.OrderedDict([ | |||||
("b", 9), | |||||
("a", 8), | |||||
]), | |||||
], | |||||
"b": 3 | |||||
}, | |||||
"hi everybody", | |||||
] | |||||
unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||||
self.assertEqual(unflattened, mess) | |||||
# Check also that the OrderedDict was created, with the correct key order. | |||||
unflattened_ordered_dict = unflattened[2]["c"][1] | |||||
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||||
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||||
unflattened_custom_mapping = unflattened[2]["d"] | |||||
self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||||
self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||||
def testFlatten_numpyIsNotFlattened(self): | |||||
structure = np.array([1, 2, 3]) | |||||
flattened = nest.flatten(structure) | |||||
self.assertEqual(len(flattened), 1) | |||||
def testFlatten_stringIsNotFlattened(self): | |||||
structure = "lots of letters" | |||||
flattened = nest.flatten(structure) | |||||
self.assertEqual(len(flattened), 1) | |||||
unflattened = nest.pack_sequence_as("goodbye", flattened) | |||||
self.assertEqual(structure, unflattened) | |||||
def testPackSequenceAs_notIterableError(self): | |||||
with self.assertRaisesRegexp(TypeError, | |||||
"flat_sequence must be a sequence"): | |||||
nest.pack_sequence_as("hi", "bye") | |||||
def testPackSequenceAs_wrongLengthsError(self): | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
"Structure had 2 elements, but flat_sequence had 3 elements."): | |||||
nest.pack_sequence_as(["hello", "world"], | |||||
["and", "goodbye", "again"]) | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testIsSequence(self): | |||||
self.assertFalse(nest.is_sequence("1234")) | |||||
self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||||
self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||||
self.assertTrue(nest.is_sequence([])) | |||||
self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||||
self.assertFalse(nest.is_sequence(set([1, 2]))) | |||||
ones = array_ops.ones([2, 3]) | |||||
self.assertFalse(nest.is_sequence(ones)) | |||||
self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||||
self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||||
@parameterized.parameters({"mapping_type": _CustomMapping}, | |||||
{"mapping_type": dict}) | |||||
def testFlattenDictItems(self, mapping_type): | |||||
dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||||
flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||||
with self.assertRaises(TypeError): | |||||
nest.flatten_dict_items(4) | |||||
bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||||
with self.assertRaisesRegexp(ValueError, "not unique"): | |||||
nest.flatten_dict_items(bad_dictionary) | |||||
another_bad_dictionary = mapping_type({ | |||||
(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||||
}) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||||
nest.flatten_dict_items(another_bad_dictionary) | |||||
# pylint does not correctly recognize these as class names and | |||||
# suggests to use variable style under_score naming. | |||||
# pylint: disable=invalid-name | |||||
Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||||
Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||||
SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||||
SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||||
SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||||
SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||||
SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||||
NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||||
# pylint: enable=invalid-name | |||||
class SameNamedType1(SameNameab): | |||||
pass | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testAssertSameStructure(self): | |||||
structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
structure_different_num_elements = ("spam", "eggs") | |||||
structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||||
nest.assert_same_structure(structure1, structure2) | |||||
nest.assert_same_structure("abc", 1.0) | |||||
nest.assert_same_structure("abc", np.array([0, 1])) | |||||
nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
("The two structures don't have the same nested structure\\.\n\n" | |||||
"First structure:.*?\n\n" | |||||
"Second structure:.*\n\n" | |||||
"More specifically: Substructure " | |||||
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||||
'substructure "type=str str=spam" is not\n' | |||||
"Entire first structure:\n" | |||||
r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||||
"Entire second structure:\n" | |||||
r"\(\., \.\)")): | |||||
nest.assert_same_structure(structure1, structure_different_num_elements) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
("The two structures don't have the same nested structure\\.\n\n" | |||||
"First structure:.*?\n\n" | |||||
"Second structure:.*\n\n" | |||||
r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||||
"is not")): | |||||
nest.assert_same_structure([0, 1], np.array([0, 1])) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
("The two structures don't have the same nested structure\\.\n\n" | |||||
"First structure:.*?\n\n" | |||||
"Second structure:.*\n\n" | |||||
r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
'is a sequence, while substructure "type=int str=0" ' | |||||
"is not")): | |||||
nest.assert_same_structure(0, [0, 1]) | |||||
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
("don't have the same nested structure\\.\n\n" | |||||
"First structure: .*?\n\nSecond structure: ")): | |||||
nest.assert_same_structure(structure1, structure_different_nesting) | |||||
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||||
NestTest.Named0ab("a", "b")) | |||||
nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
NestTest.Named0ab("a", "b")) | |||||
self.assertRaises(TypeError, nest.assert_same_structure, | |||||
NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
("don't have the same nested structure\\.\n\n" | |||||
"First structure: .*?\n\nSecond structure: ")): | |||||
nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
NestTest.Named0ab([3], 4)) | |||||
with self.assertRaisesRegexp( | |||||
ValueError, | |||||
("don't have the same nested structure\\.\n\n" | |||||
"First structure: .*?\n\nSecond structure: ")): | |||||
nest.assert_same_structure([[3], 4], [3, [4]]) | |||||
structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
with self.assertRaisesRegexp(TypeError, | |||||
"don't have the same sequence type"): | |||||
nest.assert_same_structure(structure1, structure1_list) | |||||
nest.assert_same_structure(structure1, structure2, check_types=False) | |||||
nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||||
with self.assertRaisesRegexp(ValueError, | |||||
"don't have the same set of keys"): | |||||
nest.assert_same_structure({"a": 1}, {"b": 1}) | |||||
nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||||
NestTest.SameNameab2(2, 3)) | |||||
# This assertion is expected to pass: two namedtuples with the same | |||||
# name and field names are considered to be identical. | |||||
nest.assert_same_structure( | |||||
NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||||
NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||||
expected_message = "The two structures don't have the same.*" | |||||
with self.assertRaisesRegexp(ValueError, expected_message): | |||||
nest.assert_same_structure( | |||||
NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||||
NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||||
self.assertRaises(TypeError, nest.assert_same_structure, | |||||
NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||||
self.assertRaises(TypeError, nest.assert_same_structure, | |||||
NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||||
self.assertRaises(TypeError, nest.assert_same_structure, | |||||
NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||||
EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||||
def testHeterogeneousComparison(self): | |||||
nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) | |||||
nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testMapStructure(self): | |||||
structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
structure2 = (((7, 8), 9), 10, (11, 12)) | |||||
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||||
nest.assert_same_structure(structure1, structure1_plus1) | |||||
self.assertAllEqual( | |||||
[2, 3, 4, 5, 6, 7], | |||||
nest.flatten(structure1_plus1)) | |||||
structure1_plus_structure2 = nest.map_structure( | |||||
lambda x, y: x + y, structure1, structure2) | |||||
self.assertEqual( | |||||
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||||
structure1_plus_structure2) | |||||
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||||
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||||
# Empty structures | |||||
self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||||
self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||||
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||||
self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||||
NestTest.EmptyNT())) | |||||
# This is checking actual equality of types, empty list != empty tuple | |||||
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||||
with self.assertRaisesRegexp(TypeError, "callable"): | |||||
nest.map_structure("bad", structure1_plus1) | |||||
with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||||
nest.map_structure(lambda x: x) | |||||
with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||||
nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||||
with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
nest.map_structure(lambda x, y: None, 3, (3,)) | |||||
with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||||
with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||||
structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||||
nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||||
check_types=False) | |||||
with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||||
check_types=False) | |||||
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
nest.map_structure(lambda x: None, structure1, foo="a") | |||||
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||||
ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||||
@test_util.assert_no_new_pyobjects_executing_eagerly | |||||
def testMapStructureWithStrings(self): | |||||
inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||||
inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||||
out = nest.map_structure(lambda string, repeats: string * repeats, | |||||
inp_a, | |||||
inp_b) | |||||
self.assertEqual("foofoo", out.a) | |||||
self.assertEqual("bar", out.b[0]) | |||||
self.assertEqual("bazbazbaz", out.b[1]) | |||||
nt = NestTest.ABTuple(a=("something", "something_else"), | |||||
b="yet another thing") | |||||
rev_nt = nest.map_structure(lambda x: x[::-1], nt) | |||||
# Check the output is the correct structure, and all strings are reversed. | |||||
nest.assert_same_structure(nt, rev_nt) | |||||
self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) | |||||
self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) | |||||
self.assertEqual(nt.b[::-1], rev_nt.b) | |||||
@test_util.run_deprecated_v1 | |||||
def testMapStructureOverPlaceholders(self): | |||||
inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||||
nest.assert_same_structure(output, inp_a) | |||||
self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||||
self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||||
feed_dict = { | |||||
inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||||
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||||
} | |||||
with self.cached_session() as sess: | |||||
output_np = sess.run(output, feed_dict=feed_dict) | |||||
self.assertAllClose(output_np[0], | |||||
feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||||
self.assertAllClose(output_np[1], | |||||
feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||||
def testAssertShallowStructure(self): | |||||
inp_ab = ["a", "b"] | |||||
inp_abc = ["a", "b", "c"] | |||||
expected_message = ( | |||||
"The two structures don't have the same sequence length. Input " | |||||
"structure has length 2, while shallow structure has length 3.") | |||||
with self.assertRaisesRegexp(ValueError, expected_message): | |||||
nest.assert_shallow_structure(inp_abc, inp_ab) | |||||
inp_ab1 = [(1, 1), (2, 2)] | |||||
inp_ab2 = [[1, 1], [2, 2]] | |||||
expected_message = ( | |||||
"The two structures don't have the same sequence type. Input structure " | |||||
"has type <(type|class) 'tuple'>, while shallow structure has type " | |||||
"<(type|class) 'list'>.") | |||||
with self.assertRaisesRegexp(TypeError, expected_message): | |||||
nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) | |||||
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||||
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||||
expected_message = ( | |||||
r"The two structures don't have the same keys. Input " | |||||
r"structure has keys \['c'\], while shallow structure has " | |||||
r"keys \['d'\].") | |||||
with self.assertRaisesRegexp(ValueError, expected_message): | |||||
nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||||
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||||
nest.assert_shallow_structure(inp_ab, inp_ba) | |||||
# This assertion is expected to pass: two namedtuples with the same | |||||
# name and field names are considered to be identical. | |||||
inp_shallow = NestTest.SameNameab(1, 2) | |||||
inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||||
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||||
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||||
def testFlattenUpTo(self): | |||||
# Shallow tree ends at scalar. | |||||
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
shallow_tree = [[True, True], [False, True]] | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||||
self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||||
# Shallow tree ends at string. | |||||
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||||
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
input_tree) | |||||
input_tree_flattened = nest.flatten(input_tree) | |||||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
[("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||||
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||||
# Make sure dicts are correctly flattened, yielding values, not keys. | |||||
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||||
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
input_tree) | |||||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
[1, {"c": 2}, 3, (4, 5)]) | |||||
# Namedtuples. | |||||
ab_tuple = NestTest.ABTuple | |||||
input_tree = ab_tuple(a=[0, 1], b=2) | |||||
shallow_tree = ab_tuple(a=0, b=1) | |||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
input_tree) | |||||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
[[0, 1], 2]) | |||||
# Nested dicts, OrderedDicts and namedtuples. | |||||
input_tree = collections.OrderedDict( | |||||
[("a", ab_tuple(a=[0, {"b": 1}], b=2)), | |||||
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||||
shallow_tree = input_tree | |||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
input_tree) | |||||
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||||
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
input_tree) | |||||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
[ab_tuple(a=[0, {"b": 1}], b=2), | |||||
3, | |||||
collections.OrderedDict([("f", 4)])]) | |||||
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||||
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
input_tree) | |||||
self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
[ab_tuple(a=[0, {"b": 1}], b=2), | |||||
{"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||||
## Shallow non-list edge-case. | |||||
# Using iterable elements. | |||||
input_tree = ["input_tree"] | |||||
shallow_tree = "shallow_tree" | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [input_tree]) | |||||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
input_tree = ["input_tree_0", "input_tree_1"] | |||||
shallow_tree = "shallow_tree" | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [input_tree]) | |||||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
# Using non-iterable elements. | |||||
input_tree = [0] | |||||
shallow_tree = 9 | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [input_tree]) | |||||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
input_tree = [0, 1] | |||||
shallow_tree = 9 | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [input_tree]) | |||||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
## Both non-list edge-case. | |||||
# Using iterable elements. | |||||
input_tree = "input_tree" | |||||
shallow_tree = "shallow_tree" | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [input_tree]) | |||||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
# Using non-iterable elements. | |||||
input_tree = 0 | |||||
shallow_tree = 0 | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_input_tree, [input_tree]) | |||||
self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
## Input non-list edge-case. | |||||
# Using iterable elements. | |||||
input_tree = "input_tree" | |||||
shallow_tree = ["shallow_tree"] | |||||
expected_message = ("If shallow structure is a sequence, input must also " | |||||
"be a sequence. Input has type: <(type|class) 'str'>.") | |||||
with self.assertRaisesRegexp(TypeError, expected_message): | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
input_tree = "input_tree" | |||||
shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||||
with self.assertRaisesRegexp(TypeError, expected_message): | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
# Using non-iterable elements. | |||||
input_tree = 0 | |||||
shallow_tree = [9] | |||||
expected_message = ("If shallow structure is a sequence, input must also " | |||||
"be a sequence. Input has type: <(type|class) 'int'>.") | |||||
with self.assertRaisesRegexp(TypeError, expected_message): | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
input_tree = 0 | |||||
shallow_tree = [9, 8] | |||||
with self.assertRaisesRegexp(TypeError, expected_message): | |||||
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
def testMapStructureUpTo(self): | |||||
# Named tuples. | |||||
ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
inp_val = ab_tuple(a=2, b=3) | |||||
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||||
out = nest.map_structure_up_to( | |||||
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||||
self.assertEqual(out.a, 6) | |||||
self.assertEqual(out.b, 15) | |||||
# Lists. | |||||
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
name_list = ["evens", ["odds", "primes"]] | |||||
out = nest.map_structure_up_to( | |||||
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
name_list, data_list) | |||||
self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||||
# Dicts. | |||||
inp_val = dict(a=2, b=3) | |||||
inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||||
out = nest.map_structure_up_to( | |||||
inp_val, | |||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
self.assertEqual(out["a"], 6) | |||||
self.assertEqual(out["b"], 15) | |||||
# Non-equal dicts. | |||||
inp_val = dict(a=2, b=3) | |||||
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||||
with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
nest.map_structure_up_to( | |||||
inp_val, | |||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
# Dict+custom mapping. | |||||
inp_val = dict(a=2, b=3) | |||||
inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||||
out = nest.map_structure_up_to( | |||||
inp_val, | |||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
self.assertEqual(out["a"], 6) | |||||
self.assertEqual(out["b"], 15) | |||||
# Non-equal dict/mapping. | |||||
inp_val = dict(a=2, b=3) | |||||
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||||
with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
nest.map_structure_up_to( | |||||
inp_val, | |||||
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
def testGetTraverseShallowStructure(self): | |||||
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||||
scalar_traverse_r = nest.get_traverse_shallow_structure( | |||||
lambda s: not isinstance(s, tuple), | |||||
scalar_traverse_input) | |||||
self.assertEqual(scalar_traverse_r, | |||||
[True, True, False, [True, True], {"a": False}, []]) | |||||
nest.assert_shallow_structure(scalar_traverse_r, | |||||
scalar_traverse_input) | |||||
structure_traverse_input = [(1, [2]), ([1], 2)] | |||||
structure_traverse_r = nest.get_traverse_shallow_structure( | |||||
lambda s: (True, False) if isinstance(s, tuple) else True, | |||||
structure_traverse_input) | |||||
self.assertEqual(structure_traverse_r, | |||||
[(True, False), ([True], False)]) | |||||
nest.assert_shallow_structure(structure_traverse_r, | |||||
structure_traverse_input) | |||||
with self.assertRaisesRegexp(TypeError, "returned structure"): | |||||
nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||||
with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||||
nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||||
with self.assertRaisesRegexp( | |||||
TypeError, "didn't return a depth=1 structure of bools"): | |||||
nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||||
def testYieldFlatStringPaths(self): | |||||
for inputs_expected in ({"inputs": [], "expected": []}, | |||||
{"inputs": 3, "expected": [()]}, | |||||
{"inputs": [3], "expected": [(0,)]}, | |||||
{"inputs": {"a": 3}, "expected": [("a",)]}, | |||||
{"inputs": {"a": {"b": 4}}, | |||||
"expected": [("a", "b")]}, | |||||
{"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||||
{"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||||
{"inputs": [{"a": [(23, 42)]}], | |||||
"expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||||
{"inputs": [{"a": ([23], 42)}], | |||||
"expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||||
{"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||||
"expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||||
{"inputs": {"0": [{"1": 23}]}, | |||||
"expected": [("0", 0, "1")]}): | |||||
inputs = inputs_expected["inputs"] | |||||
expected = inputs_expected["expected"] | |||||
self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||||
def testFlattenWithStringPaths(self): | |||||
for inputs_expected in ( | |||||
{"inputs": [], "expected": []}, | |||||
{"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||||
{"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||||
inputs = inputs_expected["inputs"] | |||||
expected = inputs_expected["expected"] | |||||
self.assertEqual( | |||||
nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||||
expected) | |||||
# Need a separate test for namedtuple as we can't declare tuple definitions | |||||
# in the @parameterized arguments. | |||||
def testFlattenNamedTuple(self): | |||||
# pylint: disable=invalid-name | |||||
Foo = collections.namedtuple("Foo", ["a", "b"]) | |||||
Bar = collections.namedtuple("Bar", ["c", "d"]) | |||||
# pylint: enable=invalid-name | |||||
test_cases = [ | |||||
(Foo(a=3, b=Bar(c=23, d=42)), | |||||
[("a", 3), ("b/c", 23), ("b/d", 42)]), | |||||
(Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), | |||||
[("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||||
(Bar(c=42, d=43), | |||||
[("c", 42), ("d", 43)]), | |||||
(Bar(c=[42], d=43), | |||||
[("c/0", 42), ("d", 43)]), | |||||
] | |||||
for inputs, expected in test_cases: | |||||
self.assertEqual( | |||||
list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||||
@parameterized.named_parameters( | |||||
("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||||
("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||||
{"a": ("a", 4), "b": ("b", 6)}), | |||||
("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||||
("nested", | |||||
{"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||||
{"a": [("a/0", 10), ("a/1", 12)], | |||||
"b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||||
def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||||
def format_sum(path, *values): | |||||
return (path, sum(values)) | |||||
result = nest.map_structure_with_paths(format_sum, s1, s2, | |||||
check_types=check_types) | |||||
self.assertEqual(expected, result) | |||||
@parameterized.named_parameters( | |||||
("tuples", (1, 2), (3, 4, 5), ValueError), | |||||
("dicts", {"a": 1}, {"b": 2}, ValueError), | |||||
("mixed", (1, 2), [3, 4], TypeError), | |||||
("nested", | |||||
{"a": [2, 3], "b": [1, 3]}, | |||||
{"b": [5, 6, 7], "a": [8, 9]}, | |||||
ValueError | |||||
)) | |||||
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||||
with self.assertRaises(error_type): | |||||
nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) | |||||
class NestBenchmark(test.Benchmark): | |||||
def run_and_report(self, s1, s2, name): | |||||
burn_iter, test_iter = 100, 30000 | |||||
for _ in xrange(burn_iter): | |||||
nest.assert_same_structure(s1, s2) | |||||
t0 = time.time() | |||||
for _ in xrange(test_iter): | |||||
nest.assert_same_structure(s1, s2) | |||||
t1 = time.time() | |||||
self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||||
name=name) | |||||
def benchmark_assert_structure(self): | |||||
s1 = (((1, 2), 3), 4, (5, 6)) | |||||
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||||
s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||||
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||||
self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||||
if __name__ == "__main__": | |||||
test.main() |