Browse Source

nest.map_structure: fixed bug when merging multiple structures

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
2bf72e9d62
3 changed files with 58 additions and 64 deletions
  1. +0
    -38
      src/TensorFlowNET.Core/Python.cs
  2. +54
    -21
      src/TensorFlowNET.Core/Util/nest.py.cs
  3. +4
    -5
      test/TensorFlowNET.UnitTest/nest_test/NestTest.cs

+ 0
- 38
src/TensorFlowNET.Core/Python.cs View File

@@ -131,44 +131,6 @@ namespace Tensorflow
} }
} }


/// <summary>
/// Untyped implementation of zip for arbitrary data
///
/// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays
/// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]]
/// </summary>
/// <param name="lists">one or multiple sequences to be zipped</param>
/// <returns></returns>
public static IEnumerable<object[]> zip(params object[] lists)
{
if (lists.Length == 0)
yield break;
var first = lists[0];
if (first == null)
yield break;
var arity = (first as IEnumerable).OfType<object>().Count();
for (int i = 0; i < arity; i++)
{
var array= new object[lists.Length];
for (int j = 0; j < lists.Length; j++)
array[j] = GetSequenceElementAt(lists[j], i);
yield return array;
}
}

private static object GetSequenceElementAt(object sequence, int i)
{
switch (sequence)
{
case Array array:
return array.GetValue(i);
case IList list:
return list[i];
default:
return (sequence as IEnumerable).OfType<object>().Skip(Math.Max(0, i)).FirstOrDefault();
}
}

public static IEnumerable<(int, T)> enumerate<T>(IList<T> values) public static IEnumerable<(int, T)> enumerate<T>(IList<T> values)
{ {
for (int i = 0; i < values.Count; i++) for (int i = 0; i < values.Count; i++)


+ 54
- 21
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -23,8 +23,44 @@ namespace Tensorflow.Util
public static class nest public static class nest
{ {
public static IEnumerable<object[]> zip(params object[] structures)
=> Python.zip(structures);
/// <summary>
/// Untyped implementation of zip for arbitrary data
///
/// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays
/// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]]
/// </summary>
/// <param name="lists">one or multiple sequences to be zipped</param>
/// <returns></returns>
public static IEnumerable<object[]> zip_many(params IEnumerable<object>[] lists)
{
if (lists.Length == 0)
yield break;
var first = lists[0];
if (first == null)
yield break;
var arity = first.Count();
for (int i = 0; i < arity; i++)
{
var array = new object[lists.Length];
for (int j = 0; j < lists.Length; j++)
array[j] = GetSequenceElementAt(lists[j], i);
yield return array;
}
}
private static object GetSequenceElementAt(object sequence, int i)
{
switch (sequence)
{
case Array array:
return array.GetValue(i);
case IList list:
return list[i];
default:
return _yield_value(sequence).Skip(Math.Max(0, i)).FirstOrDefault();
}
}
public static IEnumerable<(T1, T2)> zip<T1, T2>(IEnumerable<T1> e1, IEnumerable<T2> e2) public static IEnumerable<(T1, T2)> zip<T1, T2>(IEnumerable<T1> e1, IEnumerable<T2> e2)
=> Python.zip(e1, e2); => Python.zip(e1, e2);
@@ -40,9 +76,9 @@ namespace Tensorflow.Util
/// <summary> /// <summary>
/// Returns a sorted list of the dict keys, with error if keys not sortable. /// Returns a sorted list of the dict keys, with error if keys not sortable.
/// </summary> /// </summary>
private static IEnumerable<string> _sorted(IDictionary dict_)
private static IEnumerable<object> _sorted(IDictionary dict_)
{ {
return dict_.Keys.OfType<string>().OrderBy(x => x);
return dict_.Keys.OfType<object>().OrderBy(x => x);
} }
@@ -86,7 +122,7 @@ namespace Tensorflow.Util
{ {
case Hashtable hash: case Hashtable hash:
var result = new Hashtable(); var result = new Hashtable();
foreach ((object key, object value) in zip(_sorted(hash).OfType<object>(), args))
foreach ((object key, object value) in zip<object, object>(_sorted(hash), args))
result[key] = value; result[key] = value;
return result; return result;
} }
@@ -370,13 +406,13 @@ namespace Tensorflow.Util
/// <returns> `flat_sequence` converted to have the same recursive structure as /// <returns> `flat_sequence` converted to have the same recursive structure as
/// `structure`. /// `structure`.
/// </returns> /// </returns>
public static object pack_sequence_as<T>(object structure, IEnumerable<T> flat_sequence)
public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence)
{ {
List<object> flat = null; List<object> flat = null;
if (flat_sequence is List<object>) if (flat_sequence is List<object>)
flat = flat_sequence as List<object>; flat = flat_sequence as List<object>;
else else
flat=new List<object>(flat_sequence.OfType<object>());
flat=new List<object>(flat_sequence);
if (flat_sequence==null) if (flat_sequence==null)
throw new ArgumentException("flat_sequence must not be null"); throw new ArgumentException("flat_sequence must not be null");
// if not is_sequence(flat_sequence): // if not is_sequence(flat_sequence):
@@ -403,7 +439,7 @@ namespace Tensorflow.Util
var flat_structure = flatten(structure); var flat_structure = flatten(structure);
if (len(flat_structure) != len(flat)) if (len(flat_structure) != len(flat))
{ {
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " +
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}"); $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
} }
return _sequence_like(structure, packed); return _sequence_like(structure, packed);
@@ -413,7 +449,7 @@ namespace Tensorflow.Util
var flat_structure = flatten(structure); var flat_structure = flatten(structure);
if (len(flat_structure) != len(flat)) if (len(flat_structure) != len(flat))
{ {
throw new ValueError("Could not pack sequence. Structure had %d elements, but " +
throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " +
$"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}"); $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
} }
return _sequence_like(structure, packed); return _sequence_like(structure, packed);
@@ -427,10 +463,8 @@ namespace Tensorflow.Util
/// `structure[i]`. All structures in `structure` must have the same arity, /// `structure[i]`. All structures in `structure` must have the same arity,
/// and the return value will contain the results in the same structure. /// and the return value will contain the results in the same structure.
/// </summary> /// </summary>
/// <typeparam name="T">the type of the elements of the output structure (object if diverse)</typeparam>
/// <param name="func"> A callable that accepts as many arguments as there are structures.</param> /// <param name="func"> A callable that accepts as many arguments as there are structures.</param>
/// <param name="structures">scalar, or tuple or list of constructed scalars and/or other
/// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param>
/// <param name="structures">one or many IEnumerable of object</param>
/// <param name="check_types">If set to /// <param name="check_types">If set to
/// `True` (default) the types of iterables within the structures have to be /// `True` (default) the types of iterables within the structures have to be
/// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` /// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
@@ -444,23 +478,22 @@ namespace Tensorflow.Util
/// `check_types` is `False` the sequence types of the first structure will be /// `check_types` is `False` the sequence types of the first structure will be
/// used. /// used.
/// </returns> /// </returns>
public static IEnumerable<object> map_structure(Func<object[], object> func, object structure, params object[] more_structures)
public static IEnumerable<object> map_structure(Func<object[], object> func, params IEnumerable<object>[] structure)
{ {
// TODO: check structure and types // TODO: check structure and types
// for other in structure[1:]: // for other in structure[1:]:
// assert_same_structure(structure[0], other, check_types=check_types) // assert_same_structure(structure[0], other, check_types=check_types)
if (more_structures.Length==0)
if (structure.Length==1)
{ {
// we don't need to zip if we have only one structure // we don't need to zip if we have only one structure
return map_structure(a => func(new object[]{a}), structure);
return map_structure(a => func(new object[]{a}), structure[0]);
} }
var flat_structures = new List<object>() { flatten(structure) };
flat_structures.AddRange(more_structures.Select(flatten));
var entries = zip(flat_structures);
var flat_structures = structure.Select(flatten).ToArray(); // ToArray is important here!
var entries = zip_many(flat_structures);
var mapped_flat_structure = entries.Select(func); var mapped_flat_structure = entries.Select(func);
return (pack_sequence_as(structure, mapped_flat_structure) as IEnumerable).OfType<object>();
return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList();
} }
/// <summary> /// <summary>
@@ -469,7 +502,7 @@ namespace Tensorflow.Util
/// <param name="func"></param> /// <param name="func"></param>
/// <param name="structure"></param> /// <param name="structure"></param>
/// <returns></returns> /// <returns></returns>
public static IEnumerable<object> map_structure(Func<object, object> func, object structure)
public static IEnumerable<object> map_structure(Func<object, object> func, IEnumerable<object> structure)
{ {
// TODO: check structure and types // TODO: check structure and types
// for other in structure[1:]: // for other in structure[1:]:
@@ -478,7 +511,7 @@ namespace Tensorflow.Util
var flat_structure = flatten(structure); var flat_structure = flatten(structure);
var mapped_flat_structure = flat_structure.Select(func).ToList(); var mapped_flat_structure = flat_structure.Select(func).ToList();
return (pack_sequence_as(structure, mapped_flat_structure) as IEnumerable).OfType<object>();
return _yield_value(pack_sequence_as(structure, mapped_flat_structure)).ToList();
} }
//def map_structure_with_paths(func, *structure, **kwargs): //def map_structure_with_paths(func, *structure, **kwargs):


+ 4
- 5
test/TensorFlowNET.UnitTest/nest_test/NestTest.cs View File

@@ -387,11 +387,10 @@ namespace TensorFlowNET.UnitTest.nest_test
// nest.assert_same_structure(structure1, structure1_plus1) // nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 }); self.assertAllEqual( nest.flatten(structure1_plus1), new object[] { 2, 3, 4, 5, 6, 7 });
self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" }); self.assertAllEqual(nest.flatten(structure1_strings), new object[] { "1", "2", "3", "4", "5", "6" });
// structure1_plus_structure2 = nest.map_structure(
// lambda x, y: x + y, structure1, structure2)
// self.assertEqual(
// (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
// structure1_plus_structure2)
var structure1_plus_structure2 = nest.map_structure(x => (int)(x[0]) + (int)(x[1]), structure1, structure2);
self.assertEqual(
new object[] { new object[] { new object[] { 1 + 7, 2 + 8}, 3 + 9}, 4 + 10, new object[] { 5 + 11, 6 + 12}},
structure1_plus_structure2);
// self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))


Loading…
Cancel
Save