@@ -22,5 +22,12 @@ namespace Tensorflow.Common.Extensions | |||
{ | |||
return new Tensors(tensors); | |||
} | |||
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||
{ | |||
first = values.Item1; | |||
second = values.Item2; | |||
third = values.Item3; | |||
} | |||
} | |||
} |
@@ -0,0 +1,33 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow.Common.Extensions | |||
{ | |||
public static class NestExtensions | |||
{ | |||
public static Tensors ToTensors(this INestable<Tensor> tensors) | |||
{ | |||
return new Tensors(tensors.AsNest()); | |||
} | |||
public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||
{ | |||
return Tensors.FromNest(tensors); | |||
} | |||
/// <summary> | |||
/// If the nested object is already a nested type, this function could reduce it. | |||
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||
/// </summary> | |||
/// <typeparam name="TIn"></typeparam> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="input"></param> | |||
/// <returns></returns> | |||
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||
{ | |||
return Nest<TOut>.ReduceFrom(input); | |||
} | |||
} | |||
} |
@@ -5,7 +5,7 @@ using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class GeneralizedTensorShape: IEnumerable<long?[]> | |||
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||
{ | |||
public TensorShapeConfig[] Shapes { get; set; } | |||
/// <summary> | |||
@@ -63,6 +63,57 @@ namespace Tensorflow.Common.Types | |||
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||
} | |||
public IEnumerable<long?> Flatten() | |||
{ | |||
List<long?> result = new List<long?>(); | |||
foreach(var shapeConfig in Shapes) | |||
{ | |||
result.AddRange(shapeConfig.Items); | |||
} | |||
return result; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||
{ | |||
List<Nest<TOut>> lists = new(); | |||
foreach(var shapeConfig in Shapes) | |||
{ | |||
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||
} | |||
return new Nest<TOut>(lists); | |||
} | |||
public Nest<long?> AsNest() | |||
{ | |||
Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||
{ | |||
if (config.Items.Length == 0) | |||
{ | |||
return Nest<long?>.Empty; | |||
} | |||
else if (config.Items.Length == 1) | |||
{ | |||
return new Nest<long?>(config.Items[0]); | |||
} | |||
else | |||
{ | |||
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||
} | |||
} | |||
if(Shapes.Length == 0) | |||
{ | |||
return Nest<long?>.Empty; | |||
} | |||
else if(Shapes.Length == 1) | |||
{ | |||
return DealWithSingleShape(Shapes[0]); | |||
} | |||
else | |||
{ | |||
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||
} | |||
} | |||
public IEnumerator<long?[]> GetEnumerator() | |||
{ | |||
foreach (var shape in Shapes) | |||
@@ -0,0 +1,27 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// This interface indicates that a class may have a nested structure and provide | |||
/// methods to manipulate with the structure. | |||
/// </summary> | |||
public interface INestStructure<T>: INestable<T> | |||
{ | |||
/// <summary> | |||
/// Flatten the Nestable object. Node that if the object contains only one value, | |||
/// it will be flattened to an enumerable with one element. | |||
/// </summary> | |||
/// <returns></returns> | |||
IEnumerable<T> Flatten(); | |||
/// <summary> | |||
/// Construct a new object with the same nested structure. | |||
/// </summary> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="func"></param> | |||
/// <returns></returns> | |||
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||
} | |||
} |
@@ -0,0 +1,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public interface INestable<T> | |||
{ | |||
Nest<T> AsNest(); | |||
} | |||
} |
@@ -0,0 +1,62 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public static class Nest | |||
{ | |||
/// <summary> | |||
/// Pack the flat items to a nested sequence by the template. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="template"></param> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||
{ | |||
return template.AsNest().PackSequence(flatItems); | |||
} | |||
/// <summary> | |||
/// Pack the flat items to a nested sequence by the template. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="template"></param> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||
{ | |||
return template.AsNest().PackSequence(flatItems.ToArray()); | |||
} | |||
/// <summary> | |||
/// Flatten the nested object. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <param name="nestedObject"></param> | |||
/// <returns></returns> | |||
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||
{ | |||
return nestedObject.AsNest().Flatten(); | |||
} | |||
/// <summary> | |||
/// Map the structure with specified function. | |||
/// </summary> | |||
/// <typeparam name="TIn"></typeparam> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="func"></param> | |||
/// <param name="nestedObject"></param> | |||
/// <returns></returns> | |||
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||
{ | |||
return nestedObject.AsNest().MapStructure(func); | |||
} | |||
public static bool IsNested<T>(INestable<T> obj) | |||
{ | |||
return obj.AsNest().IsNested(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,458 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Extensions; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public enum NestType | |||
{ | |||
Empty, | |||
Node, | |||
List, | |||
Dictionary | |||
} | |||
/// <summary> | |||
/// A nested structure which may inclulde value, list and dictionary. | |||
/// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||
/// its order is depth-first. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||
{ | |||
private static readonly Nest<T> _empty = new Nest<T>() | |||
{ | |||
NestType = NestType.Empty, | |||
}; | |||
public static Nest<T> Empty => _empty; | |||
public NestType NestType { get; protected set; } | |||
public string? Name { get; set; } | |||
public T? Value { get; protected set; } | |||
public List<Nest<T>>? ListValue { get; protected set; } | |||
public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||
protected Nest() { } | |||
public Nest(T value, string? name = null) | |||
{ | |||
Value = value; | |||
Name = name; | |||
NestType = NestType.Node; | |||
} | |||
public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||
{ | |||
ListValue = values.ToList(); | |||
Name = name; | |||
NestType = NestType.List; | |||
} | |||
public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||
{ | |||
DictValue = value; | |||
Name = name; | |||
NestType = NestType.Dictionary; | |||
} | |||
public Nest(Nest<T> other) | |||
{ | |||
NestType = other.NestType; | |||
Value = other.Value; | |||
DictValue = other.DictValue; | |||
ListValue = other.ListValue; | |||
Name = other.Name; | |||
} | |||
public virtual IEnumerable<T> Flatten() | |||
{ | |||
return FlattenInternal(this); | |||
} | |||
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return MapStructureInternal(func); | |||
} | |||
/// <summary> | |||
/// Pack the flat items to a nested sequence by the template. | |||
/// </summary> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public virtual Nest<T> PackSequence(T[] flatItems) | |||
{ | |||
if(flatItems.Length == 0) | |||
{ | |||
return Nest<T>.Empty; | |||
} | |||
int index = 0; | |||
return PackSequenceInternal(this, flatItems, ref index); | |||
} | |||
private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||
{ | |||
if(template.NestType == NestType.Node) | |||
{ | |||
if(index >= flatItems.Length) | |||
{ | |||
throw new InvalidArgumentError("The template and flat items are not matched."); | |||
} | |||
return new Nest<T>(flatItems[index++]); | |||
} | |||
else if(template.NestType == NestType.List) | |||
{ | |||
List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||
for (int i = 0; i < template.ListValue!.Count; i++) | |||
{ | |||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||
} | |||
return new Nest<T>(nestedObjects); | |||
} | |||
else if(template.NestType == NestType.Node) | |||
{ | |||
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||
foreach(var (key, value) in template.DictValue!) | |||
{ | |||
dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||
} | |||
return new Nest<T>(dict); | |||
} | |||
// Consider Empty as invalid type. | |||
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
} | |||
public virtual Nest<T> AsNest() | |||
{ | |||
return this; | |||
} | |||
public virtual Nest<T> MergeWith(Nest<T>? other) | |||
{ | |||
if(other is null || other == Nest<T>.Empty) | |||
{ | |||
return this; | |||
} | |||
if(this == Nest<T>.Empty) | |||
{ | |||
return other; | |||
} | |||
if(NestType == NestType.Node && other.NestType == NestType.Node) | |||
{ | |||
return new Nest<T>(new Nest<T>[] { this, other }); | |||
} | |||
else if(NestType == NestType.List && other.NestType == NestType.List) | |||
{ | |||
return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||
} | |||
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||
{ | |||
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||
} | |||
else | |||
{ | |||
return new Nest<T>(new Nest<T>[] { this, other }); | |||
} | |||
} | |||
/// <summary> | |||
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||
/// </summary> | |||
/// <returns></returns> | |||
public bool IsNested() | |||
{ | |||
if(NestType is NestType.Empty or NestType.Node) | |||
{ | |||
return false; | |||
} | |||
else if(NestType is NestType.List) | |||
{ | |||
foreach(var item in ListValue!) | |||
{ | |||
if(item.NestType is NestType.List or NestType.Dictionary) | |||
{ | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
else | |||
{ | |||
foreach (var item in DictValue!.Values) | |||
{ | |||
if (item.NestType is NestType.List or NestType.Dictionary) | |||
{ | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
} | |||
[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||
public T this[int index] | |||
{ | |||
get | |||
{ | |||
bool success = FindInternal(this, index, out var result); | |||
if (success) | |||
{ | |||
return result; | |||
} | |||
else | |||
{ | |||
throw new IndexOutOfRangeException(); | |||
} | |||
} | |||
set | |||
{ | |||
bool success = SetInternal(this, index, value); | |||
if (!success) | |||
{ | |||
throw new IndexOutOfRangeException(); | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||
/// to `Nest[T]`. | |||
/// </summary> | |||
/// <typeparam name="TOut"></typeparam> | |||
/// <param name="input"></param> | |||
/// <returns></returns> | |||
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
{ | |||
var nested = input.AsNest(); | |||
return ReduceInternal(nested); | |||
} | |||
private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
{ | |||
if(node.NestType == NestType.Empty) | |||
{ | |||
return Nest<T>.Empty; | |||
} | |||
else if(node.NestType == NestType.Node) | |||
{ | |||
return node.Value!.AsNest(); | |||
} | |||
else if(node.NestType == NestType.List) | |||
{ | |||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||
} | |||
else // Dictionary type | |||
{ | |||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||
} | |||
} | |||
private static bool FindInternal(Nest<T> node, int index, out T? result) | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
if(index == 0) | |||
{ | |||
result = node.Value!; | |||
return true; | |||
} | |||
result = default(T); | |||
return false; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
if(index == 0) | |||
{ | |||
return FindInternal(item, index, out result); | |||
} | |||
index--; | |||
} | |||
result = default(T); | |||
return false; | |||
} | |||
else if(node.NestType == NestType.Dictionary) | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
if (index == 0) | |||
{ | |||
return FindInternal(item, index, out result); | |||
} | |||
index--; | |||
} | |||
result = default(T); | |||
return false; | |||
} | |||
else | |||
{ | |||
result = default(T); | |||
return false; | |||
} | |||
} | |||
private static bool SetInternal(Nest<T> node, int index, T newValue) | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
if (index == 0) | |||
{ | |||
node.Value = newValue; | |||
return true; | |||
} | |||
return false; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
if (index == 0) | |||
{ | |||
return SetInternal(item, index, newValue); | |||
} | |||
index--; | |||
} | |||
return false; | |||
} | |||
else if (node.NestType == NestType.Dictionary) | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
if (index == 0) | |||
{ | |||
return SetInternal(item, index, newValue); | |||
} | |||
index--; | |||
} | |||
return false; | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
yield return node.Value!; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
foreach(var val in FlattenInternal(item)) | |||
{ | |||
yield return val; | |||
} | |||
} | |||
} | |||
else if (node.NestType == NestType.Dictionary) | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
foreach (var val in FlattenInternal(item)) | |||
{ | |||
yield return val; | |||
} | |||
} | |||
} | |||
} | |||
private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||
{ | |||
if (NestType == NestType.Node) | |||
{ | |||
return new Nest<TOut>(func(Value!)); | |||
} | |||
else if (NestType == NestType.List) | |||
{ | |||
List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
foreach (var item in ListValue!) | |||
{ | |||
outs.Add(item.MapStructureInternal(func)); | |||
} | |||
return new Nest<TOut>(outs); | |||
} | |||
else if (NestType == NestType.Dictionary) | |||
{ | |||
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||
foreach (var (key, value) in DictValue!) | |||
{ | |||
outs.Add(key, value.MapStructureInternal(func)); | |||
} | |||
return new Nest<TOut>(outs); | |||
} | |||
else | |||
{ | |||
return Nest<TOut>.Empty; | |||
} | |||
} | |||
public IEnumerator<T> GetEnumerator() | |||
{ | |||
return Flatten().GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
public override string ToString() | |||
{ | |||
StringBuilder sb = new StringBuilder(); | |||
sb.Append("("); | |||
WriteString(this, sb); | |||
sb.Append(")"); | |||
return sb.ToString(); | |||
} | |||
private static void WriteString(Nest<T> node, StringBuilder sb) | |||
{ | |||
if (!string.IsNullOrEmpty(node.Name)) | |||
{ | |||
sb.Append($"{node.Name}: "); | |||
} | |||
if (node.NestType == NestType.Node) | |||
{ | |||
sb.Append(node.Value!.ToString()); | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
sb.Append("["); | |||
for(int i = 0; i < node.ListValue!.Count; i++) | |||
{ | |||
WriteString(node.ListValue![i], sb); | |||
if(i != node.ListValue!.Count - 1) | |||
{ | |||
sb.Append(", "); | |||
} | |||
} | |||
sb.Append("]"); | |||
} | |||
else if (node.NestType == NestType.Dictionary) | |||
{ | |||
sb.Append("{"); | |||
int count = node.DictValue!.Count; | |||
int i = 0; | |||
foreach (var (key, value) in node.DictValue!) | |||
{ | |||
sb.Append($"{key}: "); | |||
WriteString(value, sb); | |||
if (i != count - 1) | |||
{ | |||
sb.Append(", "); | |||
} | |||
i++; | |||
} | |||
sb.Append("}"); | |||
} | |||
else | |||
{ | |||
sb.Append("<empty>"); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,99 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
{ | |||
public IDictionary<TKey, TValue> Value { get; set; } | |||
public NestDictionary(IDictionary<TKey, TValue> dict) | |||
{ | |||
Value = dict; | |||
} | |||
public IEnumerable<TValue> Flatten() | |||
{ | |||
return Value.Select(x => x.Value); | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||
{ | |||
return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||
} | |||
public Nest<TValue> AsNest() | |||
{ | |||
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||
} | |||
// Required IDictionary<TKey, TValue> members | |||
public int Count => Value.Count; | |||
public bool IsReadOnly => Value.IsReadOnly; | |||
public ICollection<TKey> Keys => Value.Keys; | |||
public ICollection<TValue> Values => Value.Values; | |||
public void Add(TKey key, TValue value) | |||
{ | |||
Value.Add(key, value); | |||
} | |||
public void Add(KeyValuePair<TKey, TValue> item) | |||
{ | |||
Value.Add(item); | |||
} | |||
public void Clear() | |||
{ | |||
Value.Clear(); | |||
} | |||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||
{ | |||
return Value.Contains(item); | |||
} | |||
public bool ContainsKey(TKey key) | |||
{ | |||
return Value.ContainsKey(key); | |||
} | |||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||
{ | |||
Value.CopyTo(array, arrayIndex); | |||
} | |||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||
{ | |||
return Value.GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
public bool Remove(TKey key) | |||
{ | |||
return Value.Remove(key); | |||
} | |||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||
{ | |||
return Value.Remove(item); | |||
} | |||
public bool TryGetValue(TKey key, out TValue value) | |||
{ | |||
return Value.TryGetValue(key, out value); | |||
} | |||
// Optional IDictionary<TKey, TValue> members | |||
public TValue this[TKey key] | |||
{ | |||
get => Value[key]; | |||
set => Value[key] = value; | |||
} | |||
} | |||
} |
@@ -0,0 +1,43 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// The implementation of a list that support nest structure, in which the depth is 1. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
{ | |||
public List<T> Value { get; set; } | |||
public NestList(IEnumerable<T> values) | |||
{ | |||
Value = new List<T>(values); | |||
} | |||
public IEnumerable<T> Flatten() | |||
{ | |||
return Value; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return new NestList<TOut>(Value.Select(x => func(x))); | |||
} | |||
public Nest<T> AsNest() | |||
{ | |||
return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||
} | |||
// Enumerator implementation | |||
public IEnumerator<T> GetEnumerator() | |||
{ | |||
return Value.GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
{ | |||
return GetEnumerator(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,32 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// A nested structure with only one element. | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
public class NestNode<T> : INestStructure<T> | |||
{ | |||
public T Value { get; set; } | |||
public NestNode(T value) | |||
{ | |||
Value = value; | |||
} | |||
public IEnumerable<T> Flatten() | |||
{ | |||
yield return Value; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return new NestNode<TOut>(func(Value)); | |||
} | |||
public Nest<T> AsNest() | |||
{ | |||
return new Nest<T>(Value); | |||
} | |||
} | |||
} |
@@ -3,6 +3,7 @@ using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Common.Types; | |||
namespace Tensorflow | |||
{ | |||
@@ -13,16 +14,14 @@ namespace Tensorflow | |||
/// and Tensor[] from Tensors implicitily. | |||
/// It works for tuple and scalar as well. | |||
/// </summary> | |||
public class Tensors : IEnumerable<Tensor>, IDisposable | |||
public sealed class Tensors : Nest<Tensor>, IDisposable | |||
{ | |||
List<Tensor> items = new List<Tensor>(); | |||
public TF_DataType dtype => items.First().dtype; | |||
public Shape shape => items.First().shape; | |||
public int rank => items.First().rank; | |||
public Graph graph => items.First().graph; | |||
public TF_DataType dtype => this.First().dtype; | |||
public Shape shape => this.First().shape; | |||
public int rank => this.First().rank; | |||
public Graph graph => this.First().graph; | |||
public bool IsList { get; set; } | |||
public int Length => items.Count(); | |||
public int Length => this.Count(); | |||
/// <summary> | |||
/// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | |||
/// </summary> | |||
@@ -35,7 +34,7 @@ namespace Tensorflow | |||
throw new ValueError("Tensors with more than one tensor cannot be " + | |||
"implicitly converted to Tensor."); | |||
} | |||
return items.First(); | |||
return this.First(); | |||
} | |||
} | |||
@@ -52,150 +51,194 @@ namespace Tensorflow | |||
throw new ValueError($"Tensors with {Length} tensor cannot be " + | |||
"implicitly converted to Tensor."); | |||
} | |||
return items.FirstOrDefault(); | |||
return this.FirstOrDefault(); | |||
} | |||
} | |||
public Tensor this[int index] | |||
public Tensor this[params string[] slices] | |||
=> this.First()[slices]; | |||
public Tensors(Tensor tensor) : base(tensor) | |||
{ | |||
get => items[index]; | |||
set => items[index] = value; | |||
} | |||
public Tensor this[params string[] slices] | |||
=> items.First()[slices]; | |||
public Tensors(params Tensor[] tensors) | |||
private Tensors(Nest<Tensor> nested) : base(nested) | |||
{ | |||
items.AddRange(tensors); | |||
} | |||
public Tensors(IEnumerable<Tensor> tensors) | |||
public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
{ | |||
items.AddRange(tensors); | |||
} | |||
public Tensors(NDArray nd) | |||
public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
{ | |||
items.Add(ops.convert_to_tensor(nd)); | |||
} | |||
public IEnumerator<Tensor> GetEnumerator() | |||
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||
{ | |||
foreach (var tensor in items) | |||
yield return tensor; | |||
} | |||
public bool IsSingle() | |||
{ | |||
return Length == 1; | |||
} | |||
public new Tensors MergeWith(Nest<Tensor>? other) | |||
{ | |||
return FromNest(base.MergeWith(other)); | |||
} | |||
[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
public void Add(Tensor tensor) | |||
=> items.Add(tensor); | |||
{ | |||
if(NestType == NestType.Dictionary) | |||
{ | |||
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
} | |||
else if(NestType == NestType.Node) | |||
{ | |||
NestType = NestType.List; | |||
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||
Value = null; | |||
} | |||
else | |||
{ | |||
ListValue.Add(new Nest<Tensor>(tensor)); | |||
} | |||
} | |||
[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||
"some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] | |||
public void AddRange(IEnumerable<Tensor> tensors) | |||
=> items.AddRange(tensors); | |||
{ | |||
if (NestType == NestType.Dictionary) | |||
{ | |||
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
} | |||
else if (NestType == NestType.Node) | |||
{ | |||
NestType = NestType.List; | |||
ListValue = new() { new Nest<Tensor>(Value) }; | |||
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
Value = null; | |||
} | |||
else | |||
{ | |||
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
} | |||
} | |||
[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + | |||
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||
public void Insert(int index, Tensor tensor) | |||
=> items.Insert(index, tensor); | |||
IEnumerator IEnumerable.GetEnumerator() | |||
=> GetEnumerator(); | |||
{ | |||
if (NestType == NestType.List) | |||
{ | |||
ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
} | |||
else if(NestType == NestType.Node) | |||
{ | |||
NestType = NestType.List; | |||
ListValue = new() { new Nest<Tensor>(Value) }; | |||
ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
Value = null; | |||
} | |||
else | |||
{ | |||
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||
} | |||
} | |||
public string[] StringData() | |||
{ | |||
EnsureSingleTensor(this, "nnumpy"); | |||
return this[0].StringData(); | |||
return Single.StringData(); | |||
} | |||
public string StringData(int index) | |||
{ | |||
EnsureSingleTensor(this, "nnumpy"); | |||
return this[0].StringData(index); | |||
return Single.StringData(index); | |||
} | |||
public NDArray numpy() | |||
{ | |||
EnsureSingleTensor(this, "nnumpy"); | |||
return this[0].numpy(); | |||
return Single.numpy(); | |||
} | |||
[Obsolete] | |||
public T[] ToArray<T>() where T: unmanaged | |||
{ | |||
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||
return this[0].ToArray<T>(); | |||
return Single.ToArray<T>(); | |||
} | |||
#region Explicit Conversions | |||
public unsafe static explicit operator bool(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||
return (bool)tensor[0]; | |||
return (bool)tensor.Single; | |||
} | |||
public unsafe static explicit operator sbyte(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||
return (sbyte)tensor[0]; | |||
return (sbyte)tensor.Single; | |||
} | |||
public unsafe static explicit operator byte(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
return (byte)tensor[0]; | |||
return (byte)tensor.Single; | |||
} | |||
public unsafe static explicit operator ushort(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||
return (ushort)tensor[0]; | |||
return (ushort)tensor.Single; | |||
} | |||
public unsafe static explicit operator short(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to short"); | |||
return (short)tensor[0]; | |||
return (short)tensor.Single; | |||
} | |||
public unsafe static explicit operator int(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to int"); | |||
return (int)tensor[0]; | |||
return (int)tensor.Single; | |||
} | |||
public unsafe static explicit operator uint(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||
return (uint)tensor[0]; | |||
return (uint)tensor.Single; | |||
} | |||
public unsafe static explicit operator long(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to long"); | |||
return (long)tensor[0]; | |||
return (long)tensor.Single; | |||
} | |||
public unsafe static explicit operator ulong(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||
return (ulong)tensor[0]; | |||
return (ulong)tensor.Single; | |||
} | |||
public unsafe static explicit operator float(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||
return (byte)tensor[0]; | |||
return (byte)tensor.Single; | |||
} | |||
public unsafe static explicit operator double(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to double"); | |||
return (double)tensor[0]; | |||
return (double)tensor.Single; | |||
} | |||
public unsafe static explicit operator string(Tensors tensor) | |||
{ | |||
EnsureSingleTensor(tensor, "explicit conversion to string"); | |||
return (string)tensor[0]; | |||
return (string)tensor.Single; | |||
} | |||
public static explicit operator object[](Tensors tensors) | |||
=> tensors.items.ToArray(); | |||
=> tensors.Flatten().ToArray(); | |||
#endregion | |||
#region Implicit Conversions | |||
@@ -219,52 +262,40 @@ namespace Tensorflow | |||
=> tensors?.SingleOrNull; | |||
public static implicit operator Tensor[](Tensors tensors) | |||
=> tensors.items.ToArray(); | |||
=> tensors.Flatten().ToArray(); | |||
#endregion | |||
public void Deconstruct(out Tensor a, out Tensors? b) | |||
public static Tensors? FromNest(Nest<Tensor> nested) | |||
{ | |||
a = items[0]; | |||
b = Length == 1? null : new Tensors(items.Skip(1)); | |||
if(nested == Nest<Tensor>.Empty) | |||
{ | |||
return null; | |||
} | |||
return new Tensors(nested); | |||
} | |||
private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||
public void Deconstruct(out Tensor a, out Tensors? b) | |||
{ | |||
if(tensors.Length == 0) | |||
{ | |||
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||
} | |||
else if(tensors.Length > 1) | |||
{ | |||
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||
} | |||
a = this.First(); | |||
b = Length == 1? null : new Tensors(this.Skip(1)); | |||
} | |||
public override string ToString() | |||
{ | |||
if(items.Count == 1) | |||
if(Length == 1) | |||
{ | |||
return items[0].ToString(); | |||
return this.First().ToString(); | |||
} | |||
else | |||
{ | |||
StringBuilder sb = new StringBuilder(); | |||
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||
for(int i = 0; i < items.Count; i++) | |||
{ | |||
var tensor = items[i]; | |||
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||
} | |||
sb.Append("]\n"); | |||
return sb.ToString(); | |||
return $"Totally {Length} tensors: {base.ToString()}"; | |||
} | |||
} | |||
public void Dispose() | |||
{ | |||
foreach (var item in items) | |||
item.Dispose(); | |||
foreach (var tensor in this) | |||
tensor.Dispose(); | |||
} | |||
} | |||
} |
@@ -36,6 +36,7 @@ namespace Tensorflow.Util | |||
// (np.array([3, 4]), tf.constant([3, 4])))` | |||
// | |||
[Obsolete] | |||
public static class nest | |||
{ | |||
@@ -170,39 +171,6 @@ namespace Tensorflow.Util | |||
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||
} | |||
public static bool is_nested(object obj) | |||
{ | |||
// Refer to https://www.tensorflow.org/api_docs/python/tf/nest | |||
//if (obj is IList || obj is IDictionary || obj is ITuple) | |||
// return true; | |||
if (obj is IList || obj is IDictionary) | |||
return true; | |||
if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType | |||
|| obj is ISet<int> || obj is ISet<float> || obj is ISet<double>) | |||
return false; | |||
if (obj.GetType().IsNested) return true; | |||
// Check if the object is an IEnumerable | |||
if (obj is IEnumerable) | |||
{ | |||
// If it is, check if it is a nested structure | |||
foreach (object item in (IEnumerable)obj) | |||
{ | |||
if (is_nested(item)) | |||
{ | |||
return true; | |||
} | |||
} | |||
return true; | |||
} | |||
else | |||
{ | |||
// If it is not, return false | |||
return false; | |||
} | |||
} | |||
/// <summary> | |||
/// Yields the next value from the given iterable. | |||
/// </summary> | |||