From 4939105b8f2de49d1f943c7edafd1b35690366ff Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Wed, 7 Jun 2023 07:56:19 +0800 Subject: [PATCH] feat: add Nestable and Nest. --- .../Common/Extensions/LinqExtensions.cs | 7 + .../Common/Extensions/NestExtensions.cs | 33 ++ .../Common/Types/GeneralizedTensorShape.cs | 53 +- src/TensorFlowNET.Core/Common/Types/INest.cs | 27 ++ .../Common/Types/INestable.cs | 11 + .../Common/Types/Nest.Static.cs | 62 +++ src/TensorFlowNET.Core/Common/Types/Nest.cs | 458 ++++++++++++++++++ .../Common/Types/NestDictionary.cs | 99 ++++ .../Common/Types/NestList.cs | 43 ++ .../Common/Types/NestNode.cs | 32 ++ src/TensorFlowNET.Core/Tensors/Tensors.cs | 211 ++++---- src/TensorFlowNET.Core/Util/nest.py.cs | 34 +- 12 files changed, 946 insertions(+), 124 deletions(-) create mode 100644 src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/INest.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/INestable.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/Nest.Static.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/Nest.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/NestDictionary.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/NestList.cs create mode 100644 src/TensorFlowNET.Core/Common/Types/NestNode.cs diff --git a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs index 0402fca0..6cf62e7b 100644 --- a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs +++ b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs @@ -22,5 +22,12 @@ namespace Tensorflow.Common.Extensions { return new Tensors(tensors); } + + public static void Deconstruct(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) + { + first = values.Item1; + second = values.Item2; + third = values.Item3; + } } } diff --git a/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs new file mode 100644 index 00000000..76bdd613 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Types; + +namespace Tensorflow.Common.Extensions +{ + public static class NestExtensions + { + public static Tensors ToTensors(this INestable tensors) + { + return new Tensors(tensors.AsNest()); + } + + public static Tensors? ToTensors(this Nest tensors) + { + return Tensors.FromNest(tensors); + } + + /// + /// If the nested object is already a nested type, this function could reduce it. + /// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. + /// + /// + /// + /// + /// + public static Nest ReduceTo(this INestStructure input) where TIn: INestStructure + { + return Nest.ReduceFrom(input); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs index edb9a802..e05d3deb 100644 --- a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs +++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow.Common.Types { - public class GeneralizedTensorShape: IEnumerable + public class GeneralizedTensorShape: IEnumerable, INestStructure, INestable { public TensorShapeConfig[] Shapes { get; set; } /// @@ -63,6 +63,57 @@ namespace Tensorflow.Common.Types return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); } + public IEnumerable Flatten() + { + List result = new List(); + foreach(var shapeConfig in Shapes) + { + result.AddRange(shapeConfig.Items); + } + return result; + } + public INestStructure MapStructure(Func func) + { + List> lists = new(); + foreach(var shapeConfig in Shapes) + { + lists.Add(new Nest(shapeConfig.Items.Select(x => new Nest(func(x))))); + } + return new Nest(lists); + } + + public Nest AsNest() + { + Nest DealWithSingleShape(TensorShapeConfig config) + { + if (config.Items.Length == 0) + { + return Nest.Empty; + } + else if (config.Items.Length == 1) + { + return new Nest(config.Items[0]); + } + else + { + return new Nest(config.Items.Select(x => new Nest(x))); + } + } + + if(Shapes.Length == 0) + { + return Nest.Empty; + } + else if(Shapes.Length == 1) + { + return DealWithSingleShape(Shapes[0]); + } + else + { + return new Nest(Shapes.Select(s => DealWithSingleShape(s))); + } + } + public IEnumerator GetEnumerator() { foreach (var shape in Shapes) diff --git a/src/TensorFlowNET.Core/Common/Types/INest.cs b/src/TensorFlowNET.Core/Common/Types/INest.cs new file mode 100644 index 00000000..001141dd --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/INest.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This interface indicates that a class may have a nested structure and provide + /// methods to manipulate with the structure. + /// + public interface INestStructure: INestable + { + /// + /// Flatten the Nestable object. Node that if the object contains only one value, + /// it will be flattened to an enumerable with one element. + /// + /// + IEnumerable Flatten(); + /// + /// Construct a new object with the same nested structure. + /// + /// + /// + /// + INestStructure MapStructure(Func func); + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/INestable.cs b/src/TensorFlowNET.Core/Common/Types/INestable.cs new file mode 100644 index 00000000..7ce49f85 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/INestable.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public interface INestable + { + Nest AsNest(); + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs new file mode 100644 index 00000000..b67d11f4 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public static class Nest + { + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + /// + /// + public static Nest PackSequenceAs(INestable template, T[] flatItems) + { + return template.AsNest().PackSequence(flatItems); + } + + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + /// + /// + public static Nest PackSequenceAs(INestable template, List flatItems) + { + return template.AsNest().PackSequence(flatItems.ToArray()); + } + + /// + /// Flatten the nested object. + /// + /// + /// + /// + public static IEnumerable Flatten(INestable nestedObject) + { + return nestedObject.AsNest().Flatten(); + } + + /// + /// Map the structure with specified function. + /// + /// + /// + /// + /// + /// + public static INestStructure MapStructure(Func func, INestable nestedObject) + { + return nestedObject.AsNest().MapStructure(func); + } + + public static bool IsNested(INestable obj) + { + return obj.AsNest().IsNested(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.cs b/src/TensorFlowNET.Core/Common/Types/Nest.cs new file mode 100644 index 00000000..84a60402 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/Nest.cs @@ -0,0 +1,458 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Common.Extensions; + +namespace Tensorflow.Common.Types +{ + public enum NestType + { + Empty, + Node, + List, + Dictionary + } + + /// + /// A nested structure which may inclulde value, list and dictionary. + /// Note that dictionary does not ensure the data order. When using it as IEnumerable, + /// its order is depth-first. + /// + /// + public class Nest : INestStructure, IEnumerable + { + private static readonly Nest _empty = new Nest() + { + NestType = NestType.Empty, + }; + public static Nest Empty => _empty; + public NestType NestType { get; protected set; } + public string? Name { get; set; } + public T? Value { get; protected set; } + public List>? ListValue { get; protected set; } + public Dictionary>? DictValue { get; protected set; } + + protected Nest() { } + + public Nest(T value, string? name = null) + { + Value = value; + Name = name; + NestType = NestType.Node; + } + + public Nest(IEnumerable> values, string? name = null) + { + ListValue = values.ToList(); + Name = name; + NestType = NestType.List; + } + + public Nest(Dictionary> value, string? name = null) + { + DictValue = value; + Name = name; + NestType = NestType.Dictionary; + } + + public Nest(Nest other) + { + NestType = other.NestType; + Value = other.Value; + DictValue = other.DictValue; + ListValue = other.ListValue; + Name = other.Name; + } + + public virtual IEnumerable Flatten() + { + return FlattenInternal(this); + } + public virtual INestStructure MapStructure(Func func) + { + return MapStructureInternal(func); + } + + /// + /// Pack the flat items to a nested sequence by the template. + /// + /// + /// + public virtual Nest PackSequence(T[] flatItems) + { + if(flatItems.Length == 0) + { + return Nest.Empty; + } + int index = 0; + return PackSequenceInternal(this, flatItems, ref index); + } + + private static Nest PackSequenceInternal(Nest template, T[] flatItems, ref int index) + { + if(template.NestType == NestType.Node) + { + if(index >= flatItems.Length) + { + throw new InvalidArgumentError("The template and flat items are not matched."); + } + return new Nest(flatItems[index++]); + } + else if(template.NestType == NestType.List) + { + List> nestedObjects = new List>(); + for (int i = 0; i < template.ListValue!.Count; i++) + { + nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); + } + return new Nest(nestedObjects); + } + else if(template.NestType == NestType.Node) + { + Dictionary> dict = new Dictionary>(); + foreach(var (key, value) in template.DictValue!) + { + dict[key] = PackSequenceInternal(value, flatItems, ref index); + } + return new Nest(dict); + } + // Consider Empty as invalid type. + throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); + } + + public virtual Nest AsNest() + { + return this; + } + + public virtual Nest MergeWith(Nest? other) + { + if(other is null || other == Nest.Empty) + { + return this; + } + if(this == Nest.Empty) + { + return other; + } + if(NestType == NestType.Node && other.NestType == NestType.Node) + { + return new Nest(new Nest[] { this, other }); + } + else if(NestType == NestType.List && other.NestType == NestType.List) + { + return new Nest(this.ListValue!.Concat(other.ListValue!)); + } + else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) + { + return new Nest(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); + } + else + { + return new Nest(new Nest[] { this, other }); + } + } + + /// + /// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not + /// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. + /// + /// + public bool IsNested() + { + if(NestType is NestType.Empty or NestType.Node) + { + return false; + } + else if(NestType is NestType.List) + { + foreach(var item in ListValue!) + { + if(item.NestType is NestType.List or NestType.Dictionary) + { + return true; + } + } + return false; + } + else + { + foreach (var item in DictValue!.Values) + { + if (item.NestType is NestType.List or NestType.Dictionary) + { + return true; + } + } + return false; + } + } + + [Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] + public T this[int index] + { + get + { + bool success = FindInternal(this, index, out var result); + if (success) + { + return result; + } + else + { + throw new IndexOutOfRangeException(); + } + } + set + { + bool success = SetInternal(this, index, value); + if (!success) + { + throw new IndexOutOfRangeException(); + } + } + } + + /// + /// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it + /// to `Nest[T]`. + /// + /// + /// + /// + public static Nest ReduceFrom(INestStructure input) where TOut: INestStructure + { + var nested = input.AsNest(); + return ReduceInternal(nested); + } + + private static Nest ReduceInternal(Nest node) where TOut : INestStructure + { + if(node.NestType == NestType.Empty) + { + return Nest.Empty; + } + else if(node.NestType == NestType.Node) + { + return node.Value!.AsNest(); + } + else if(node.NestType == NestType.List) + { + return new Nest(node.ListValue!.Select(x => ReduceInternal(x))); + } + else // Dictionary type + { + return new Nest(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); + } + } + + private static bool FindInternal(Nest node, int index, out T? result) + { + if (node.NestType == NestType.Node) + { + if(index == 0) + { + result = node.Value!; + return true; + } + result = default(T); + return false; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + if(index == 0) + { + return FindInternal(item, index, out result); + } + index--; + } + result = default(T); + return false; + } + else if(node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + if (index == 0) + { + return FindInternal(item, index, out result); + } + index--; + } + result = default(T); + return false; + } + else + { + result = default(T); + return false; + } + } + + private static bool SetInternal(Nest node, int index, T newValue) + { + if (node.NestType == NestType.Node) + { + if (index == 0) + { + node.Value = newValue; + return true; + } + return false; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + if (index == 0) + { + return SetInternal(item, index, newValue); + } + index--; + } + return false; + } + else if (node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + if (index == 0) + { + return SetInternal(item, index, newValue); + } + index--; + } + return false; + } + else + { + return false; + } + } + + private static IEnumerable FlattenInternal(Nest node) + { + if (node.NestType == NestType.Node) + { + yield return node.Value!; + } + else if (node.NestType == NestType.List) + { + foreach (var item in node.ListValue!) + { + foreach(var val in FlattenInternal(item)) + { + yield return val; + } + } + } + else if (node.NestType == NestType.Dictionary) + { + foreach (var item in node.DictValue!.Values) + { + foreach (var val in FlattenInternal(item)) + { + yield return val; + } + } + } + } + + private Nest MapStructureInternal(Func func) + { + if (NestType == NestType.Node) + { + return new Nest(func(Value!)); + } + else if (NestType == NestType.List) + { + List> outs = new List>(); + foreach (var item in ListValue!) + { + outs.Add(item.MapStructureInternal(func)); + } + return new Nest(outs); + } + else if (NestType == NestType.Dictionary) + { + Dictionary> outs = new Dictionary>(); + foreach (var (key, value) in DictValue!) + { + outs.Add(key, value.MapStructureInternal(func)); + } + return new Nest(outs); + } + else + { + return Nest.Empty; + } + } + + public IEnumerator GetEnumerator() + { + return Flatten().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public override string ToString() + { + StringBuilder sb = new StringBuilder(); + sb.Append("("); + WriteString(this, sb); + sb.Append(")"); + return sb.ToString(); + } + + private static void WriteString(Nest node, StringBuilder sb) + { + if (!string.IsNullOrEmpty(node.Name)) + { + sb.Append($"{node.Name}: "); + } + if (node.NestType == NestType.Node) + { + sb.Append(node.Value!.ToString()); + } + else if (node.NestType == NestType.List) + { + sb.Append("["); + for(int i = 0; i < node.ListValue!.Count; i++) + { + WriteString(node.ListValue![i], sb); + if(i != node.ListValue!.Count - 1) + { + sb.Append(", "); + } + } + sb.Append("]"); + } + else if (node.NestType == NestType.Dictionary) + { + sb.Append("{"); + int count = node.DictValue!.Count; + int i = 0; + foreach (var (key, value) in node.DictValue!) + { + sb.Append($"{key}: "); + WriteString(value, sb); + if (i != count - 1) + { + sb.Append(", "); + } + i++; + } + sb.Append("}"); + } + else + { + sb.Append(""); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs new file mode 100644 index 00000000..554ca526 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs @@ -0,0 +1,99 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + public class NestDictionary : INestStructure, IDictionary where TKey : notnull + { + public IDictionary Value { get; set; } + public NestDictionary(IDictionary dict) + { + Value = dict; + } + public IEnumerable Flatten() + { + return Value.Select(x => x.Value); + } + public INestStructure MapStructure(Func func) + { + return new NestList(Value.Select(x => func(x.Value))); + } + + public Nest AsNest() + { + return new Nest(Value.Values.Select(x => new Nest(x))); + } + + // Required IDictionary members + public int Count => Value.Count; + + public bool IsReadOnly => Value.IsReadOnly; + + public ICollection Keys => Value.Keys; + + public ICollection Values => Value.Values; + + public void Add(TKey key, TValue value) + { + Value.Add(key, value); + } + + public void Add(KeyValuePair item) + { + Value.Add(item); + } + + public void Clear() + { + Value.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return Value.Contains(item); + } + + public bool ContainsKey(TKey key) + { + return Value.ContainsKey(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + Value.CopyTo(array, arrayIndex); + } + + public IEnumerator> GetEnumerator() + { + return Value.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public bool Remove(TKey key) + { + return Value.Remove(key); + } + + public bool Remove(KeyValuePair item) + { + return Value.Remove(item); + } + + public bool TryGetValue(TKey key, out TValue value) + { + return Value.TryGetValue(key, out value); + } + + // Optional IDictionary members + public TValue this[TKey key] + { + get => Value[key]; + set => Value[key] = value; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestList.cs b/src/TensorFlowNET.Core/Common/Types/NestList.cs new file mode 100644 index 00000000..08218718 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestList.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// The implementation of a list that support nest structure, in which the depth is 1. + /// + /// + public sealed class NestList : INestStructure, IEnumerable + { + public List Value { get; set; } + public NestList(IEnumerable values) + { + Value = new List(values); + } + public IEnumerable Flatten() + { + return Value; + } + public INestStructure MapStructure(Func func) + { + return new NestList(Value.Select(x => func(x))); + } + + public Nest AsNest() + { + return new Nest(Value.Select(x => new Nest(x))); + } + + // Enumerator implementation + public IEnumerator GetEnumerator() + { + return Value.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/NestNode.cs b/src/TensorFlowNET.Core/Common/Types/NestNode.cs new file mode 100644 index 00000000..1dad421d --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/NestNode.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// A nested structure with only one element. + /// + /// + public class NestNode : INestStructure + { + public T Value { get; set; } + public NestNode(T value) + { + Value = value; + } + public IEnumerable Flatten() + { + yield return Value; + } + public INestStructure MapStructure(Func func) + { + return new NestNode(func(Value)); + } + + public Nest AsNest() + { + return new Nest(Value); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index caa36b76..cba8f954 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -3,6 +3,7 @@ using System; using System.Collections; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; namespace Tensorflow { @@ -13,16 +14,14 @@ namespace Tensorflow /// and Tensor[] from Tensors implicitily. /// It works for tuple and scalar as well. /// - public class Tensors : IEnumerable, IDisposable + public sealed class Tensors : Nest, IDisposable { - List items = new List(); - - public TF_DataType dtype => items.First().dtype; - public Shape shape => items.First().shape; - public int rank => items.First().rank; - public Graph graph => items.First().graph; + public TF_DataType dtype => this.First().dtype; + public Shape shape => this.First().shape; + public int rank => this.First().rank; + public Graph graph => this.First().graph; public bool IsList { get; set; } - public int Length => items.Count(); + public int Length => this.Count(); /// /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. /// @@ -35,7 +34,7 @@ namespace Tensorflow throw new ValueError("Tensors with more than one tensor cannot be " + "implicitly converted to Tensor."); } - return items.First(); + return this.First(); } } @@ -52,150 +51,194 @@ namespace Tensorflow throw new ValueError($"Tensors with {Length} tensor cannot be " + "implicitly converted to Tensor."); } - return items.FirstOrDefault(); + return this.FirstOrDefault(); } } - public Tensor this[int index] + public Tensor this[params string[] slices] + => this.First()[slices]; + + public Tensors(Tensor tensor) : base(tensor) { - get => items[index]; - set => items[index] = value; + } - public Tensor this[params string[] slices] - => items.First()[slices]; - public Tensors(params Tensor[] tensors) + private Tensors(Nest nested) : base(nested) { - items.AddRange(tensors); + } - public Tensors(IEnumerable tensors) + public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest(x))) { - items.AddRange(tensors); + } - public Tensors(NDArray nd) + public Tensors(IEnumerable tensors): base(tensors.Select(x => new Nest(x))) { - items.Add(ops.convert_to_tensor(nd)); + } - public IEnumerator GetEnumerator() + public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) { - foreach (var tensor in items) - yield return tensor; + } + public bool IsSingle() + { + return Length == 1; + } + + public new Tensors MergeWith(Nest? other) + { + return FromNest(base.MergeWith(other)); + } + + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + + "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] public void Add(Tensor tensor) - => items.Add(tensor); + { + if(NestType == NestType.Dictionary) + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + else if(NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(Value), new Nest(tensor) }; + Value = null; + } + else + { + ListValue.Add(new Nest(tensor)); + } + } + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + + "some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] public void AddRange(IEnumerable tensors) - => items.AddRange(tensors); + { + if (NestType == NestType.Dictionary) + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + else if (NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(Value) }; + ListValue.AddRange(tensors.Select(x => new Nest(x))); + Value = null; + } + else + { + ListValue.AddRange(tensors.Select(x => new Nest(x))); + } + } + [Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + + "a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] public void Insert(int index, Tensor tensor) - => items.Insert(index, tensor); - - IEnumerator IEnumerable.GetEnumerator() - => GetEnumerator(); + { + if (NestType == NestType.List) + { + ListValue.Insert(index, new Nest(tensor)); + } + else if(NestType == NestType.Node) + { + NestType = NestType.List; + ListValue = new() { new Nest(Value) }; + ListValue.Insert(index, new Nest(tensor)); + Value = null; + } + else + { + throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); + } + } public string[] StringData() { - EnsureSingleTensor(this, "nnumpy"); - return this[0].StringData(); + return Single.StringData(); } public string StringData(int index) { - EnsureSingleTensor(this, "nnumpy"); - return this[0].StringData(index); + return Single.StringData(index); } public NDArray numpy() { - EnsureSingleTensor(this, "nnumpy"); - return this[0].numpy(); + return Single.numpy(); } + [Obsolete] public T[] ToArray() where T: unmanaged { - EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); - return this[0].ToArray(); + return Single.ToArray(); } #region Explicit Conversions public unsafe static explicit operator bool(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to bool"); - return (bool)tensor[0]; + return (bool)tensor.Single; } public unsafe static explicit operator sbyte(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to sbyte"); - return (sbyte)tensor[0]; + return (sbyte)tensor.Single; } public unsafe static explicit operator byte(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to byte"); - return (byte)tensor[0]; + return (byte)tensor.Single; } public unsafe static explicit operator ushort(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to ushort"); - return (ushort)tensor[0]; + return (ushort)tensor.Single; } public unsafe static explicit operator short(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to short"); - return (short)tensor[0]; + return (short)tensor.Single; } public unsafe static explicit operator int(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to int"); - return (int)tensor[0]; + return (int)tensor.Single; } public unsafe static explicit operator uint(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to uint"); - return (uint)tensor[0]; + return (uint)tensor.Single; } public unsafe static explicit operator long(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to long"); - return (long)tensor[0]; + return (long)tensor.Single; } public unsafe static explicit operator ulong(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to ulong"); - return (ulong)tensor[0]; + return (ulong)tensor.Single; } public unsafe static explicit operator float(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to byte"); - return (byte)tensor[0]; + return (byte)tensor.Single; } public unsafe static explicit operator double(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to double"); - return (double)tensor[0]; + return (double)tensor.Single; } public unsafe static explicit operator string(Tensors tensor) { - EnsureSingleTensor(tensor, "explicit conversion to string"); - return (string)tensor[0]; + return (string)tensor.Single; } public static explicit operator object[](Tensors tensors) - => tensors.items.ToArray(); + => tensors.Flatten().ToArray(); #endregion #region Implicit Conversions @@ -219,52 +262,40 @@ namespace Tensorflow => tensors?.SingleOrNull; public static implicit operator Tensor[](Tensors tensors) - => tensors.items.ToArray(); - + => tensors.Flatten().ToArray(); #endregion - public void Deconstruct(out Tensor a, out Tensors? b) + public static Tensors? FromNest(Nest nested) { - a = items[0]; - b = Length == 1? null : new Tensors(items.Skip(1)); + if(nested == Nest.Empty) + { + return null; + } + return new Tensors(nested); } - private static void EnsureSingleTensor(Tensors tensors, string methodnName) + public void Deconstruct(out Tensor a, out Tensors? b) { - if(tensors.Length == 0) - { - throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); - } - else if(tensors.Length > 1) - { - throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); - } + a = this.First(); + b = Length == 1? null : new Tensors(this.Skip(1)); } public override string ToString() { - if(items.Count == 1) + if(Length == 1) { - return items[0].ToString(); + return this.First().ToString(); } else { - StringBuilder sb = new StringBuilder(); - sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); - for(int i = 0; i < items.Count; i++) - { - var tensor = items[i]; - sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); - } - sb.Append("]\n"); - return sb.ToString(); + return $"Totally {Length} tensors: {base.ToString()}"; } } public void Dispose() { - foreach (var item in items) - item.Dispose(); + foreach (var tensor in this) + tensor.Dispose(); } } } diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs index ab6f56b3..3ba3ce78 100644 --- a/src/TensorFlowNET.Core/Util/nest.py.cs +++ b/src/TensorFlowNET.Core/Util/nest.py.cs @@ -36,6 +36,7 @@ namespace Tensorflow.Util // (np.array([3, 4]), tf.constant([3, 4])))` // + [Obsolete] public static class nest { @@ -170,39 +171,6 @@ namespace Tensorflow.Util throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); } - public static bool is_nested(object obj) - { - // Refer to https://www.tensorflow.org/api_docs/python/tf/nest - //if (obj is IList || obj is IDictionary || obj is ITuple) - // return true; - if (obj is IList || obj is IDictionary) - return true; - - if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType - || obj is ISet || obj is ISet || obj is ISet) - return false; - - if (obj.GetType().IsNested) return true; - // Check if the object is an IEnumerable - if (obj is IEnumerable) - { - // If it is, check if it is a nested structure - foreach (object item in (IEnumerable)obj) - { - if (is_nested(item)) - { - return true; - } - } - return true; - } - else - { - // If it is not, return false - return false; - } - } - /// /// Yields the next value from the given iterable. ///