Browse Source

feat: add Nestable and Nest<T>.

pull/1106/head
Yaohui Liu Wanglongzhi2001 2 years ago
parent
commit
1e9708014a
12 changed files with 946 additions and 124 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
  2. +33
    -0
      src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
  3. +52
    -1
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  4. +27
    -0
      src/TensorFlowNET.Core/Common/Types/INest.cs
  5. +11
    -0
      src/TensorFlowNET.Core/Common/Types/INestable.cs
  6. +62
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
  7. +458
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  8. +99
    -0
      src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
  9. +43
    -0
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  10. +32
    -0
      src/TensorFlowNET.Core/Common/Types/NestNode.cs
  11. +121
    -90
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  12. +1
    -33
      src/TensorFlowNET.Core/Util/nest.py.cs

+ 7
- 0
src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs View File

@@ -22,5 +22,12 @@ namespace Tensorflow.Common.Extensions
{
return new Tensors(tensors);
}

public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
{
first = values.Item1;
second = values.Item2;
third = values.Item3;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Common.Extensions
{
public static class NestExtensions
{
public static Tensors ToTensors(this INestable<Tensor> tensors)
{
return new Tensors(tensors.AsNest());
}

public static Tensors? ToTensors(this Nest<Tensor> tensors)
{
return Tensors.FromNest(tensors);
}

/// <summary>
/// If the nested object is already a nested type, this function could reduce it.
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
{
return Nest<TOut>.ReduceFrom(input);
}
}
}

+ 52
- 1
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -5,7 +5,7 @@ using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: IEnumerable<long?[]>
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
{
public TensorShapeConfig[] Shapes { get; set; }
/// <summary>
@@ -63,6 +63,57 @@ namespace Tensorflow.Common.Types
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public IEnumerable<long?> Flatten()
{
List<long?> result = new List<long?>();
foreach(var shapeConfig in Shapes)
{
result.AddRange(shapeConfig.Items);
}
return result;
}
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
{
List<Nest<TOut>> lists = new();
foreach(var shapeConfig in Shapes)
{
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
}
return new Nest<TOut>(lists);
}

public Nest<long?> AsNest()
{
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
{
if (config.Items.Length == 0)
{
return Nest<long?>.Empty;
}
else if (config.Items.Length == 1)
{
return new Nest<long?>(config.Items[0]);
}
else
{
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
}
}

if(Shapes.Length == 0)
{
return Nest<long?>.Empty;
}
else if(Shapes.Length == 1)
{
return DealWithSingleShape(Shapes[0]);
}
else
{
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}

public IEnumerator<long?[]> GetEnumerator()
{
foreach (var shape in Shapes)


+ 27
- 0
src/TensorFlowNET.Core/Common/Types/INest.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface indicates that a class may have a nested structure and provide
/// methods to manipulate with the structure.
/// </summary>
public interface INestStructure<T>: INestable<T>
{
/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
/// </summary>
/// <returns></returns>
IEnumerable<T> Flatten();
/// <summary>
/// Construct a new object with the same nested structure.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <returns></returns>
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
}
}

+ 11
- 0
src/TensorFlowNET.Core/Common/Types/INestable.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public interface INestable<T>
{
Nest<T> AsNest();
}
}

+ 62
- 0
src/TensorFlowNET.Core/Common/Types/Nest.Static.cs View File

@@ -0,0 +1,62 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public static class Nest
{
/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems)
{
return template.AsNest().PackSequence(flatItems.ToArray());
}

/// <summary>
/// Flatten the nested object.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject)
{
return nestedObject.AsNest().Flatten();
}

/// <summary>
/// Map the structure with specified function.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject)
{
return nestedObject.AsNest().MapStructure(func);
}

public static bool IsNested<T>(INestable<T> obj)
{
return obj.AsNest().IsNested();
}
}
}

+ 458
- 0
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -0,0 +1,458 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Common.Types
{
public enum NestType
{
Empty,
Node,
List,
Dictionary
}

/// <summary>
/// A nested structure which may inclulde value, list and dictionary.
/// Note that dictionary does not ensure the data order. When using it as IEnumerable,
/// its order is depth-first.
/// </summary>
/// <typeparam name="T"></typeparam>
public class Nest<T> : INestStructure<T>, IEnumerable<T>
{
private static readonly Nest<T> _empty = new Nest<T>()
{
NestType = NestType.Empty,
};
public static Nest<T> Empty => _empty;
public NestType NestType { get; protected set; }
public string? Name { get; set; }
public T? Value { get; protected set; }
public List<Nest<T>>? ListValue { get; protected set; }
public Dictionary<string, Nest<T>>? DictValue { get; protected set; }

protected Nest() { }

public Nest(T value, string? name = null)
{
Value = value;
Name = name;
NestType = NestType.Node;
}

public Nest(IEnumerable<Nest<T>> values, string? name = null)
{
ListValue = values.ToList();
Name = name;
NestType = NestType.List;
}

public Nest(Dictionary<string, Nest<T>> value, string? name = null)
{
DictValue = value;
Name = name;
NestType = NestType.Dictionary;
}

public Nest(Nest<T> other)
{
NestType = other.NestType;
Value = other.Value;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public virtual IEnumerable<T> Flatten()
{
return FlattenInternal(this);
}
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return MapStructureInternal(func);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <param name="flatItems"></param>
/// <returns></returns>
public virtual Nest<T> PackSequence(T[] flatItems)
{
if(flatItems.Length == 0)
{
return Nest<T>.Empty;
}
int index = 0;
return PackSequenceInternal(this, flatItems, ref index);
}

private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index)
{
if(template.NestType == NestType.Node)
{
if(index >= flatItems.Length)
{
throw new InvalidArgumentError("The template and flat items are not matched.");
}
return new Nest<T>(flatItems[index++]);
}
else if(template.NestType == NestType.List)
{
List<Nest<T>> nestedObjects = new List<Nest<T>>();
for (int i = 0; i < template.ListValue!.Count; i++)
{
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index));
}
return new Nest<T>(nestedObjects);
}
else if(template.NestType == NestType.Node)
{
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>();
foreach(var (key, value) in template.DictValue!)
{
dict[key] = PackSequenceInternal(value, flatItems, ref index);
}
return new Nest<T>(dict);
}
// Consider Empty as invalid type.
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
}

public virtual Nest<T> AsNest()
{
return this;
}

public virtual Nest<T> MergeWith(Nest<T>? other)
{
if(other is null || other == Nest<T>.Empty)
{
return this;
}
if(this == Nest<T>.Empty)
{
return other;
}
if(NestType == NestType.Node && other.NestType == NestType.Node)
{
return new Nest<T>(new Nest<T>[] { this, other });
}
else if(NestType == NestType.List && other.NestType == NestType.List)
{
return new Nest<T>(this.ListValue!.Concat(other.ListValue!));
}
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary)
{
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value));
}
else
{
return new Nest<T>(new Nest<T>[] { this, other });
}
}

/// <summary>
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested.
/// </summary>
/// <returns></returns>
public bool IsNested()
{
if(NestType is NestType.Empty or NestType.Node)
{
return false;
}
else if(NestType is NestType.List)
{
foreach(var item in ListValue!)
{
if(item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
}
else
{
foreach (var item in DictValue!.Values)
{
if (item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
}
}

[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")]
public T this[int index]
{
get
{
bool success = FindInternal(this, index, out var result);
if (success)
{
return result;
}
else
{
throw new IndexOutOfRangeException();
}
}
set
{
bool success = SetInternal(this, index, value);
if (!success)
{
throw new IndexOutOfRangeException();
}
}
}

/// <summary>
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it
/// to `Nest[T]`.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T>
{
var nested = input.AsNest();
return ReduceInternal(nested);
}

private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
{
if(node.NestType == NestType.Empty)
{
return Nest<T>.Empty;
}
else if(node.NestType == NestType.Node)
{
return node.Value!.AsNest();
}
else if(node.NestType == NestType.List)
{
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x)));
}
else // Dictionary type
{
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value)));
}
}

private static bool FindInternal(Nest<T> node, int index, out T? result)
{
if (node.NestType == NestType.Node)
{
if(index == 0)
{
result = node.Value!;
return true;
}
result = default(T);
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if(index == 0)
{
return FindInternal(item, index, out result);
}
index--;
}
result = default(T);
return false;
}
else if(node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return FindInternal(item, index, out result);
}
index--;
}
result = default(T);
return false;
}
else
{
result = default(T);
return false;
}
}

private static bool SetInternal(Nest<T> node, int index, T newValue)
{
if (node.NestType == NestType.Node)
{
if (index == 0)
{
node.Value = newValue;
return true;
}
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if (index == 0)
{
return SetInternal(item, index, newValue);
}
index--;
}
return false;
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return SetInternal(item, index, newValue);
}
index--;
}
return false;
}
else
{
return false;
}
}

private static IEnumerable<T> FlattenInternal(Nest<T> node)
{
if (node.NestType == NestType.Node)
{
yield return node.Value!;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
foreach(var val in FlattenInternal(item))
{
yield return val;
}
}
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
foreach (var val in FlattenInternal(item))
{
yield return val;
}
}
}
}

private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func)
{
if (NestType == NestType.Node)
{
return new Nest<TOut>(func(Value!));
}
else if (NestType == NestType.List)
{
List<Nest<TOut>> outs = new List<Nest<TOut>>();
foreach (var item in ListValue!)
{
outs.Add(item.MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else if (NestType == NestType.Dictionary)
{
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>();
foreach (var (key, value) in DictValue!)
{
outs.Add(key, value.MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else
{
return Nest<TOut>.Empty;
}
}

public IEnumerator<T> GetEnumerator()
{
return Flatten().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public override string ToString()
{
StringBuilder sb = new StringBuilder();
sb.Append("(");
WriteString(this, sb);
sb.Append(")");
return sb.ToString();
}

private static void WriteString(Nest<T> node, StringBuilder sb)
{
if (!string.IsNullOrEmpty(node.Name))
{
sb.Append($"{node.Name}: ");
}
if (node.NestType == NestType.Node)
{
sb.Append(node.Value!.ToString());
}
else if (node.NestType == NestType.List)
{
sb.Append("[");
for(int i = 0; i < node.ListValue!.Count; i++)
{
WriteString(node.ListValue![i], sb);
if(i != node.ListValue!.Count - 1)
{
sb.Append(", ");
}
}
sb.Append("]");
}
else if (node.NestType == NestType.Dictionary)
{
sb.Append("{");
int count = node.DictValue!.Count;
int i = 0;
foreach (var (key, value) in node.DictValue!)
{
sb.Append($"{key}: ");
WriteString(value, sb);
if (i != count - 1)
{
sb.Append(", ");
}
i++;
}
sb.Append("}");
}
else
{
sb.Append("<empty>");
}
}
}
}

+ 99
- 0
src/TensorFlowNET.Core/Common/Types/NestDictionary.cs View File

@@ -0,0 +1,99 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull
{
public IDictionary<TKey, TValue> Value { get; set; }
public NestDictionary(IDictionary<TKey, TValue> dict)
{
Value = dict;
}
public IEnumerable<TValue> Flatten()
{
return Value.Select(x => x.Value);
}
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x.Value)));
}

public Nest<TValue> AsNest()
{
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x)));
}

// Required IDictionary<TKey, TValue> members
public int Count => Value.Count;

public bool IsReadOnly => Value.IsReadOnly;

public ICollection<TKey> Keys => Value.Keys;

public ICollection<TValue> Values => Value.Values;

public void Add(TKey key, TValue value)
{
Value.Add(key, value);
}

public void Add(KeyValuePair<TKey, TValue> item)
{
Value.Add(item);
}

public void Clear()
{
Value.Clear();
}

public bool Contains(KeyValuePair<TKey, TValue> item)
{
return Value.Contains(item);
}

public bool ContainsKey(TKey key)
{
return Value.ContainsKey(key);
}

public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
Value.CopyTo(array, arrayIndex);
}

public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
return Value.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public bool Remove(TKey key)
{
return Value.Remove(key);
}

public bool Remove(KeyValuePair<TKey, TValue> item)
{
return Value.Remove(item);
}

public bool TryGetValue(TKey key, out TValue value)
{
return Value.TryGetValue(key, out value);
}

// Optional IDictionary<TKey, TValue> members
public TValue this[TKey key]
{
get => Value[key];
set => Value[key] = value;
}
}
}

+ 43
- 0
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -0,0 +1,43 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// The implementation of a list that support nest structure, in which the depth is 1.
/// </summary>
/// <typeparam name="T"></typeparam>
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
{
public List<T> Value { get; set; }
public NestList(IEnumerable<T> values)
{
Value = new List<T>(values);
}
public IEnumerable<T> Flatten()
{
return Value;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x)));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value.Select(x => new Nest<T>(x)));
}

// Enumerator implementation
public IEnumerator<T> GetEnumerator()
{
return Value.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

+ 32
- 0
src/TensorFlowNET.Core/Common/Types/NestNode.cs View File

@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// A nested structure with only one element.
/// </summary>
/// <typeparam name="T"></typeparam>
public class NestNode<T> : INestStructure<T>
{
public T Value { get; set; }
public NestNode(T value)
{
Value = value;
}
public IEnumerable<T> Flatten()
{
yield return Value;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestNode<TOut>(func(Value));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value);
}
}
}

+ 121
- 90
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;

namespace Tensorflow
{
@@ -13,16 +14,14 @@ namespace Tensorflow
/// and Tensor[] from Tensors implicitily.
/// It works for tuple and scalar as well.
/// </summary>
public class Tensors : IEnumerable<Tensor>, IDisposable
public sealed class Tensors : Nest<Tensor>, IDisposable
{
List<Tensor> items = new List<Tensor>();

public TF_DataType dtype => items.First().dtype;
public Shape shape => items.First().shape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public TF_DataType dtype => this.First().dtype;
public Shape shape => this.First().shape;
public int rank => this.First().rank;
public Graph graph => this.First().graph;
public bool IsList { get; set; }
public int Length => items.Count();
public int Length => this.Count();
/// <summary>
/// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception.
/// </summary>
@@ -35,7 +34,7 @@ namespace Tensorflow
throw new ValueError("Tensors with more than one tensor cannot be " +
"implicitly converted to Tensor.");
}
return items.First();
return this.First();
}
}

@@ -52,150 +51,194 @@ namespace Tensorflow
throw new ValueError($"Tensors with {Length} tensor cannot be " +
"implicitly converted to Tensor.");
}
return items.FirstOrDefault();
return this.FirstOrDefault();
}
}

public Tensor this[int index]
public Tensor this[params string[] slices]
=> this.First()[slices];

public Tensors(Tensor tensor) : base(tensor)
{
get => items[index];
set => items[index] = value;

}

public Tensor this[params string[] slices]
=> items.First()[slices];
public Tensors(params Tensor[] tensors)
private Tensors(Nest<Tensor> nested) : base(nested)
{
items.AddRange(tensors);

}

public Tensors(IEnumerable<Tensor> tensors)
public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
{
items.AddRange(tensors);
}

public Tensors(NDArray nd)
public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
{
items.Add(ops.convert_to_tensor(nd));
}

public IEnumerator<Tensor> GetEnumerator()
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
{
foreach (var tensor in items)
yield return tensor;
}

public bool IsSingle()
{
return Length == 1;
}

public new Tensors MergeWith(Nest<Tensor>? other)
{
return FromNest(base.MergeWith(other));
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " +
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")]
public void Add(Tensor tensor)
=> items.Add(tensor);
{
if(NestType == NestType.Dictionary)
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
}
else
{
ListValue.Add(new Nest<Tensor>(tensor));
}
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " +
"some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")]
public void AddRange(IEnumerable<Tensor> tensors)
=> items.AddRange(tensors);
{
if (NestType == NestType.Dictionary)
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
else if (NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
}
else
{
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
}
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " +
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")]
public void Insert(int index, Tensor tensor)
=> items.Insert(index, tensor);

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
{
if (NestType == NestType.List)
{
ListValue.Insert(index, new Nest<Tensor>(tensor));
}
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue.Insert(index, new Nest<Tensor>(tensor));
Value = null;
}
else
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
}

public string[] StringData()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData();
return Single.StringData();
}

public string StringData(int index)
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData(index);
return Single.StringData(index);
}

public NDArray numpy()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].numpy();
return Single.numpy();
}

[Obsolete]
public T[] ToArray<T>() where T: unmanaged
{
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
return this[0].ToArray<T>();
return Single.ToArray<T>();
}

#region Explicit Conversions
public unsafe static explicit operator bool(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to bool");
return (bool)tensor[0];
return (bool)tensor.Single;
}

public unsafe static explicit operator sbyte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
return (sbyte)tensor[0];
return (sbyte)tensor.Single;
}

public unsafe static explicit operator byte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
return (byte)tensor.Single;
}

public unsafe static explicit operator ushort(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ushort");
return (ushort)tensor[0];
return (ushort)tensor.Single;
}

public unsafe static explicit operator short(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to short");
return (short)tensor[0];
return (short)tensor.Single;
}

public unsafe static explicit operator int(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to int");
return (int)tensor[0];
return (int)tensor.Single;
}

public unsafe static explicit operator uint(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to uint");
return (uint)tensor[0];
return (uint)tensor.Single;
}

public unsafe static explicit operator long(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to long");
return (long)tensor[0];
return (long)tensor.Single;
}

public unsafe static explicit operator ulong(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ulong");
return (ulong)tensor[0];
return (ulong)tensor.Single;
}

public unsafe static explicit operator float(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
return (byte)tensor.Single;
}

public unsafe static explicit operator double(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to double");
return (double)tensor[0];
return (double)tensor.Single;
}

public unsafe static explicit operator string(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to string");
return (string)tensor[0];
return (string)tensor.Single;
}

public static explicit operator object[](Tensors tensors)
=> tensors.items.ToArray();
=> tensors.Flatten().ToArray();
#endregion

#region Implicit Conversions
@@ -219,52 +262,40 @@ namespace Tensorflow
=> tensors?.SingleOrNull;

public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();

=> tensors.Flatten().ToArray();
#endregion

public void Deconstruct(out Tensor a, out Tensors? b)
public static Tensors? FromNest(Nest<Tensor> nested)
{
a = items[0];
b = Length == 1? null : new Tensors(items.Skip(1));
if(nested == Nest<Tensor>.Empty)
{
return null;
}
return new Tensors(nested);
}

private static void EnsureSingleTensor(Tensors tensors, string methodnName)
public void Deconstruct(out Tensor a, out Tensors? b)
{
if(tensors.Length == 0)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
}
else if(tensors.Length > 1)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
}
a = this.First();
b = Length == 1? null : new Tensors(this.Skip(1));
}

public override string ToString()
{
if(items.Count == 1)
if(Length == 1)
{
return items[0].ToString();
return this.First().ToString();
}
else
{
StringBuilder sb = new StringBuilder();
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n");
for(int i = 0; i < items.Count; i++)
{
var tensor = items[i];
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n");
}
sb.Append("]\n");
return sb.ToString();
return $"Totally {Length} tensors: {base.ToString()}";
}
}

public void Dispose()
{
foreach (var item in items)
item.Dispose();
foreach (var tensor in this)
tensor.Dispose();
}
}
}

+ 1
- 33
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -36,6 +36,7 @@ namespace Tensorflow.Util
// (np.array([3, 4]), tf.constant([3, 4])))`
//

[Obsolete]
public static class nest
{

@@ -170,39 +171,6 @@ namespace Tensorflow.Util
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
}

public static bool is_nested(object obj)
{
// Refer to https://www.tensorflow.org/api_docs/python/tf/nest
//if (obj is IList || obj is IDictionary || obj is ITuple)
// return true;
if (obj is IList || obj is IDictionary)
return true;

if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType
|| obj is ISet<int> || obj is ISet<float> || obj is ISet<double>)
return false;

if (obj.GetType().IsNested) return true;
// Check if the object is an IEnumerable
if (obj is IEnumerable)
{
// If it is, check if it is a nested structure
foreach (object item in (IEnumerable)obj)
{
if (is_nested(item))
{
return true;
}
}
return true;
}
else
{
// If it is not, return false
return false;
}
}

/// <summary>
/// Yields the next value from the given iterable.
/// </summary>


Loading…
Cancel
Save