@@ -22,5 +22,12 @@ namespace Tensorflow.Common.Extensions | |||||
{ | { | ||||
return new Tensors(tensors); | return new Tensors(tensors); | ||||
} | } | ||||
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third) | |||||
{ | |||||
first = values.Item1; | |||||
second = values.Item2; | |||||
third = values.Item3; | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,33 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow.Common.Extensions | |||||
{ | |||||
public static class NestExtensions | |||||
{ | |||||
public static Tensors ToTensors(this INestable<Tensor> tensors) | |||||
{ | |||||
return new Tensors(tensors.AsNest()); | |||||
} | |||||
public static Tensors? ToTensors(this Nest<Tensor> tensors) | |||||
{ | |||||
return Tensors.FromNest(tensors); | |||||
} | |||||
/// <summary> | |||||
/// If the nested object is already a nested type, this function could reduce it. | |||||
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`. | |||||
/// </summary> | |||||
/// <typeparam name="TIn"></typeparam> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="input"></param> | |||||
/// <returns></returns> | |||||
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut> | |||||
{ | |||||
return Nest<TOut>.ReduceFrom(input); | |||||
} | |||||
} | |||||
} |
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow.Common.Types | namespace Tensorflow.Common.Types | ||||
{ | { | ||||
public class GeneralizedTensorShape: IEnumerable<long?[]> | |||||
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||||
{ | { | ||||
public TensorShapeConfig[] Shapes { get; set; } | public TensorShapeConfig[] Shapes { get; set; } | ||||
/// <summary> | /// <summary> | ||||
@@ -63,6 +63,57 @@ namespace Tensorflow.Common.Types | |||||
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | ||||
} | } | ||||
public IEnumerable<long?> Flatten() | |||||
{ | |||||
List<long?> result = new List<long?>(); | |||||
foreach(var shapeConfig in Shapes) | |||||
{ | |||||
result.AddRange(shapeConfig.Items); | |||||
} | |||||
return result; | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||||
{ | |||||
List<Nest<TOut>> lists = new(); | |||||
foreach(var shapeConfig in Shapes) | |||||
{ | |||||
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||||
} | |||||
return new Nest<TOut>(lists); | |||||
} | |||||
public Nest<long?> AsNest() | |||||
{ | |||||
Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||||
{ | |||||
if (config.Items.Length == 0) | |||||
{ | |||||
return Nest<long?>.Empty; | |||||
} | |||||
else if (config.Items.Length == 1) | |||||
{ | |||||
return new Nest<long?>(config.Items[0]); | |||||
} | |||||
else | |||||
{ | |||||
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||||
} | |||||
} | |||||
if(Shapes.Length == 0) | |||||
{ | |||||
return Nest<long?>.Empty; | |||||
} | |||||
else if(Shapes.Length == 1) | |||||
{ | |||||
return DealWithSingleShape(Shapes[0]); | |||||
} | |||||
else | |||||
{ | |||||
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||||
} | |||||
} | |||||
public IEnumerator<long?[]> GetEnumerator() | public IEnumerator<long?[]> GetEnumerator() | ||||
{ | { | ||||
foreach (var shape in Shapes) | foreach (var shape in Shapes) | ||||
@@ -0,0 +1,27 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// This interface indicates that a class may have a nested structure and provide | |||||
/// methods to manipulate with the structure. | |||||
/// </summary> | |||||
public interface INestStructure<T>: INestable<T> | |||||
{ | |||||
/// <summary> | |||||
/// Flatten the Nestable object. Node that if the object contains only one value, | |||||
/// it will be flattened to an enumerable with one element. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
IEnumerable<T> Flatten(); | |||||
/// <summary> | |||||
/// Construct a new object with the same nested structure. | |||||
/// </summary> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="func"></param> | |||||
/// <returns></returns> | |||||
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func); | |||||
} | |||||
} |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public interface INestable<T> | |||||
{ | |||||
Nest<T> AsNest(); | |||||
} | |||||
} |
@@ -0,0 +1,62 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public static class Nest | |||||
{ | |||||
/// <summary> | |||||
/// Pack the flat items to a nested sequence by the template. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="template"></param> | |||||
/// <param name="flatItems"></param> | |||||
/// <returns></returns> | |||||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||||
{ | |||||
return template.AsNest().PackSequence(flatItems); | |||||
} | |||||
/// <summary> | |||||
/// Pack the flat items to a nested sequence by the template. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="template"></param> | |||||
/// <param name="flatItems"></param> | |||||
/// <returns></returns> | |||||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems) | |||||
{ | |||||
return template.AsNest().PackSequence(flatItems.ToArray()); | |||||
} | |||||
/// <summary> | |||||
/// Flatten the nested object. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
/// <param name="nestedObject"></param> | |||||
/// <returns></returns> | |||||
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject) | |||||
{ | |||||
return nestedObject.AsNest().Flatten(); | |||||
} | |||||
/// <summary> | |||||
/// Map the structure with specified function. | |||||
/// </summary> | |||||
/// <typeparam name="TIn"></typeparam> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="func"></param> | |||||
/// <param name="nestedObject"></param> | |||||
/// <returns></returns> | |||||
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject) | |||||
{ | |||||
return nestedObject.AsNest().MapStructure(func); | |||||
} | |||||
public static bool IsNested<T>(INestable<T> obj) | |||||
{ | |||||
return obj.AsNest().IsNested(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,458 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Common.Extensions; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public enum NestType | |||||
{ | |||||
Empty, | |||||
Node, | |||||
List, | |||||
Dictionary | |||||
} | |||||
/// <summary> | |||||
/// A nested structure which may inclulde value, list and dictionary. | |||||
/// Note that dictionary does not ensure the data order. When using it as IEnumerable, | |||||
/// its order is depth-first. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
public class Nest<T> : INestStructure<T>, IEnumerable<T> | |||||
{ | |||||
private static readonly Nest<T> _empty = new Nest<T>() | |||||
{ | |||||
NestType = NestType.Empty, | |||||
}; | |||||
public static Nest<T> Empty => _empty; | |||||
public NestType NestType { get; protected set; } | |||||
public string? Name { get; set; } | |||||
public T? Value { get; protected set; } | |||||
public List<Nest<T>>? ListValue { get; protected set; } | |||||
public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||||
protected Nest() { } | |||||
public Nest(T value, string? name = null) | |||||
{ | |||||
Value = value; | |||||
Name = name; | |||||
NestType = NestType.Node; | |||||
} | |||||
public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||||
{ | |||||
ListValue = values.ToList(); | |||||
Name = name; | |||||
NestType = NestType.List; | |||||
} | |||||
public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||||
{ | |||||
DictValue = value; | |||||
Name = name; | |||||
NestType = NestType.Dictionary; | |||||
} | |||||
public Nest(Nest<T> other) | |||||
{ | |||||
NestType = other.NestType; | |||||
Value = other.Value; | |||||
DictValue = other.DictValue; | |||||
ListValue = other.ListValue; | |||||
Name = other.Name; | |||||
} | |||||
public virtual IEnumerable<T> Flatten() | |||||
{ | |||||
return FlattenInternal(this); | |||||
} | |||||
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
{ | |||||
return MapStructureInternal(func); | |||||
} | |||||
/// <summary> | |||||
/// Pack the flat items to a nested sequence by the template. | |||||
/// </summary> | |||||
/// <param name="flatItems"></param> | |||||
/// <returns></returns> | |||||
public virtual Nest<T> PackSequence(T[] flatItems) | |||||
{ | |||||
if(flatItems.Length == 0) | |||||
{ | |||||
return Nest<T>.Empty; | |||||
} | |||||
int index = 0; | |||||
return PackSequenceInternal(this, flatItems, ref index); | |||||
} | |||||
private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||||
{ | |||||
if(template.NestType == NestType.Node) | |||||
{ | |||||
if(index >= flatItems.Length) | |||||
{ | |||||
throw new InvalidArgumentError("The template and flat items are not matched."); | |||||
} | |||||
return new Nest<T>(flatItems[index++]); | |||||
} | |||||
else if(template.NestType == NestType.List) | |||||
{ | |||||
List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||||
for (int i = 0; i < template.ListValue!.Count; i++) | |||||
{ | |||||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||||
} | |||||
return new Nest<T>(nestedObjects); | |||||
} | |||||
else if(template.NestType == NestType.Node) | |||||
{ | |||||
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||||
foreach(var (key, value) in template.DictValue!) | |||||
{ | |||||
dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||||
} | |||||
return new Nest<T>(dict); | |||||
} | |||||
// Consider Empty as invalid type. | |||||
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||||
} | |||||
public virtual Nest<T> AsNest() | |||||
{ | |||||
return this; | |||||
} | |||||
public virtual Nest<T> MergeWith(Nest<T>? other) | |||||
{ | |||||
if(other is null || other == Nest<T>.Empty) | |||||
{ | |||||
return this; | |||||
} | |||||
if(this == Nest<T>.Empty) | |||||
{ | |||||
return other; | |||||
} | |||||
if(NestType == NestType.Node && other.NestType == NestType.Node) | |||||
{ | |||||
return new Nest<T>(new Nest<T>[] { this, other }); | |||||
} | |||||
else if(NestType == NestType.List && other.NestType == NestType.List) | |||||
{ | |||||
return new Nest<T>(this.ListValue!.Concat(other.ListValue!)); | |||||
} | |||||
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary) | |||||
{ | |||||
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value)); | |||||
} | |||||
else | |||||
{ | |||||
return new Nest<T>(new Nest<T>[] { this, other }); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not | |||||
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public bool IsNested() | |||||
{ | |||||
if(NestType is NestType.Empty or NestType.Node) | |||||
{ | |||||
return false; | |||||
} | |||||
else if(NestType is NestType.List) | |||||
{ | |||||
foreach(var item in ListValue!) | |||||
{ | |||||
if(item.NestType is NestType.List or NestType.Dictionary) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
else | |||||
{ | |||||
foreach (var item in DictValue!.Values) | |||||
{ | |||||
if (item.NestType is NestType.List or NestType.Dictionary) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
} | |||||
[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")] | |||||
public T this[int index] | |||||
{ | |||||
get | |||||
{ | |||||
bool success = FindInternal(this, index, out var result); | |||||
if (success) | |||||
{ | |||||
return result; | |||||
} | |||||
else | |||||
{ | |||||
throw new IndexOutOfRangeException(); | |||||
} | |||||
} | |||||
set | |||||
{ | |||||
bool success = SetInternal(this, index, value); | |||||
if (!success) | |||||
{ | |||||
throw new IndexOutOfRangeException(); | |||||
} | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it | |||||
/// to `Nest[T]`. | |||||
/// </summary> | |||||
/// <typeparam name="TOut"></typeparam> | |||||
/// <param name="input"></param> | |||||
/// <returns></returns> | |||||
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||||
{ | |||||
var nested = input.AsNest(); | |||||
return ReduceInternal(nested); | |||||
} | |||||
private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||||
{ | |||||
if(node.NestType == NestType.Empty) | |||||
{ | |||||
return Nest<T>.Empty; | |||||
} | |||||
else if(node.NestType == NestType.Node) | |||||
{ | |||||
return node.Value!.AsNest(); | |||||
} | |||||
else if(node.NestType == NestType.List) | |||||
{ | |||||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||||
} | |||||
else // Dictionary type | |||||
{ | |||||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||||
} | |||||
} | |||||
private static bool FindInternal(Nest<T> node, int index, out T? result) | |||||
{ | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
if(index == 0) | |||||
{ | |||||
result = node.Value!; | |||||
return true; | |||||
} | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
foreach (var item in node.ListValue!) | |||||
{ | |||||
if(index == 0) | |||||
{ | |||||
return FindInternal(item, index, out result); | |||||
} | |||||
index--; | |||||
} | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
else if(node.NestType == NestType.Dictionary) | |||||
{ | |||||
foreach (var item in node.DictValue!.Values) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
return FindInternal(item, index, out result); | |||||
} | |||||
index--; | |||||
} | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
else | |||||
{ | |||||
result = default(T); | |||||
return false; | |||||
} | |||||
} | |||||
private static bool SetInternal(Nest<T> node, int index, T newValue) | |||||
{ | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
node.Value = newValue; | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
foreach (var item in node.ListValue!) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
return SetInternal(item, index, newValue); | |||||
} | |||||
index--; | |||||
} | |||||
return false; | |||||
} | |||||
else if (node.NestType == NestType.Dictionary) | |||||
{ | |||||
foreach (var item in node.DictValue!.Values) | |||||
{ | |||||
if (index == 0) | |||||
{ | |||||
return SetInternal(item, index, newValue); | |||||
} | |||||
index--; | |||||
} | |||||
return false; | |||||
} | |||||
else | |||||
{ | |||||
return false; | |||||
} | |||||
} | |||||
private static IEnumerable<T> FlattenInternal(Nest<T> node) | |||||
{ | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
yield return node.Value!; | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
foreach (var item in node.ListValue!) | |||||
{ | |||||
foreach(var val in FlattenInternal(item)) | |||||
{ | |||||
yield return val; | |||||
} | |||||
} | |||||
} | |||||
else if (node.NestType == NestType.Dictionary) | |||||
{ | |||||
foreach (var item in node.DictValue!.Values) | |||||
{ | |||||
foreach (var val in FlattenInternal(item)) | |||||
{ | |||||
yield return val; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func) | |||||
{ | |||||
if (NestType == NestType.Node) | |||||
{ | |||||
return new Nest<TOut>(func(Value!)); | |||||
} | |||||
else if (NestType == NestType.List) | |||||
{ | |||||
List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||||
foreach (var item in ListValue!) | |||||
{ | |||||
outs.Add(item.MapStructureInternal(func)); | |||||
} | |||||
return new Nest<TOut>(outs); | |||||
} | |||||
else if (NestType == NestType.Dictionary) | |||||
{ | |||||
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||||
foreach (var (key, value) in DictValue!) | |||||
{ | |||||
outs.Add(key, value.MapStructureInternal(func)); | |||||
} | |||||
return new Nest<TOut>(outs); | |||||
} | |||||
else | |||||
{ | |||||
return Nest<TOut>.Empty; | |||||
} | |||||
} | |||||
public IEnumerator<T> GetEnumerator() | |||||
{ | |||||
return Flatten().GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
public override string ToString() | |||||
{ | |||||
StringBuilder sb = new StringBuilder(); | |||||
sb.Append("("); | |||||
WriteString(this, sb); | |||||
sb.Append(")"); | |||||
return sb.ToString(); | |||||
} | |||||
private static void WriteString(Nest<T> node, StringBuilder sb) | |||||
{ | |||||
if (!string.IsNullOrEmpty(node.Name)) | |||||
{ | |||||
sb.Append($"{node.Name}: "); | |||||
} | |||||
if (node.NestType == NestType.Node) | |||||
{ | |||||
sb.Append(node.Value!.ToString()); | |||||
} | |||||
else if (node.NestType == NestType.List) | |||||
{ | |||||
sb.Append("["); | |||||
for(int i = 0; i < node.ListValue!.Count; i++) | |||||
{ | |||||
WriteString(node.ListValue![i], sb); | |||||
if(i != node.ListValue!.Count - 1) | |||||
{ | |||||
sb.Append(", "); | |||||
} | |||||
} | |||||
sb.Append("]"); | |||||
} | |||||
else if (node.NestType == NestType.Dictionary) | |||||
{ | |||||
sb.Append("{"); | |||||
int count = node.DictValue!.Count; | |||||
int i = 0; | |||||
foreach (var (key, value) in node.DictValue!) | |||||
{ | |||||
sb.Append($"{key}: "); | |||||
WriteString(value, sb); | |||||
if (i != count - 1) | |||||
{ | |||||
sb.Append(", "); | |||||
} | |||||
i++; | |||||
} | |||||
sb.Append("}"); | |||||
} | |||||
else | |||||
{ | |||||
sb.Append("<empty>"); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,99 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||||
{ | |||||
public IDictionary<TKey, TValue> Value { get; set; } | |||||
public NestDictionary(IDictionary<TKey, TValue> dict) | |||||
{ | |||||
Value = dict; | |||||
} | |||||
public IEnumerable<TValue> Flatten() | |||||
{ | |||||
return Value.Select(x => x.Value); | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func) | |||||
{ | |||||
return new NestList<TOut>(Value.Select(x => func(x.Value))); | |||||
} | |||||
public Nest<TValue> AsNest() | |||||
{ | |||||
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x))); | |||||
} | |||||
// Required IDictionary<TKey, TValue> members | |||||
public int Count => Value.Count; | |||||
public bool IsReadOnly => Value.IsReadOnly; | |||||
public ICollection<TKey> Keys => Value.Keys; | |||||
public ICollection<TValue> Values => Value.Values; | |||||
public void Add(TKey key, TValue value) | |||||
{ | |||||
Value.Add(key, value); | |||||
} | |||||
public void Add(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
Value.Add(item); | |||||
} | |||||
public void Clear() | |||||
{ | |||||
Value.Clear(); | |||||
} | |||||
public bool Contains(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
return Value.Contains(item); | |||||
} | |||||
public bool ContainsKey(TKey key) | |||||
{ | |||||
return Value.ContainsKey(key); | |||||
} | |||||
public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex) | |||||
{ | |||||
Value.CopyTo(array, arrayIndex); | |||||
} | |||||
public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator() | |||||
{ | |||||
return Value.GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
public bool Remove(TKey key) | |||||
{ | |||||
return Value.Remove(key); | |||||
} | |||||
public bool Remove(KeyValuePair<TKey, TValue> item) | |||||
{ | |||||
return Value.Remove(item); | |||||
} | |||||
public bool TryGetValue(TKey key, out TValue value) | |||||
{ | |||||
return Value.TryGetValue(key, out value); | |||||
} | |||||
// Optional IDictionary<TKey, TValue> members | |||||
public TValue this[TKey key] | |||||
{ | |||||
get => Value[key]; | |||||
set => Value[key] = value; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,43 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// The implementation of a list that support nest structure, in which the depth is 1. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||||
{ | |||||
public List<T> Value { get; set; } | |||||
public NestList(IEnumerable<T> values) | |||||
{ | |||||
Value = new List<T>(values); | |||||
} | |||||
public IEnumerable<T> Flatten() | |||||
{ | |||||
return Value; | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
{ | |||||
return new NestList<TOut>(Value.Select(x => func(x))); | |||||
} | |||||
public Nest<T> AsNest() | |||||
{ | |||||
return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||||
} | |||||
// Enumerator implementation | |||||
public IEnumerator<T> GetEnumerator() | |||||
{ | |||||
return Value.GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,32 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Common.Types | |||||
{ | |||||
/// <summary> | |||||
/// A nested structure with only one element. | |||||
/// </summary> | |||||
/// <typeparam name="T"></typeparam> | |||||
public class NestNode<T> : INestStructure<T> | |||||
{ | |||||
public T Value { get; set; } | |||||
public NestNode(T value) | |||||
{ | |||||
Value = value; | |||||
} | |||||
public IEnumerable<T> Flatten() | |||||
{ | |||||
yield return Value; | |||||
} | |||||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||||
{ | |||||
return new NestNode<TOut>(func(Value)); | |||||
} | |||||
public Nest<T> AsNest() | |||||
{ | |||||
return new Nest<T>(Value); | |||||
} | |||||
} | |||||
} |
@@ -3,6 +3,7 @@ using System; | |||||
using System.Collections; | using System.Collections; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Common.Types; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -13,16 +14,14 @@ namespace Tensorflow | |||||
/// and Tensor[] from Tensors implicitily. | /// and Tensor[] from Tensors implicitily. | ||||
/// It works for tuple and scalar as well. | /// It works for tuple and scalar as well. | ||||
/// </summary> | /// </summary> | ||||
public class Tensors : IEnumerable<Tensor>, IDisposable | |||||
public sealed class Tensors : Nest<Tensor>, IDisposable | |||||
{ | { | ||||
List<Tensor> items = new List<Tensor>(); | |||||
public TF_DataType dtype => items.First().dtype; | |||||
public Shape shape => items.First().shape; | |||||
public int rank => items.First().rank; | |||||
public Graph graph => items.First().graph; | |||||
public TF_DataType dtype => this.First().dtype; | |||||
public Shape shape => this.First().shape; | |||||
public int rank => this.First().rank; | |||||
public Graph graph => this.First().graph; | |||||
public bool IsList { get; set; } | public bool IsList { get; set; } | ||||
public int Length => items.Count(); | |||||
public int Length => this.Count(); | |||||
/// <summary> | /// <summary> | ||||
/// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | /// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception. | ||||
/// </summary> | /// </summary> | ||||
@@ -35,7 +34,7 @@ namespace Tensorflow | |||||
throw new ValueError("Tensors with more than one tensor cannot be " + | throw new ValueError("Tensors with more than one tensor cannot be " + | ||||
"implicitly converted to Tensor."); | "implicitly converted to Tensor."); | ||||
} | } | ||||
return items.First(); | |||||
return this.First(); | |||||
} | } | ||||
} | } | ||||
@@ -52,150 +51,194 @@ namespace Tensorflow | |||||
throw new ValueError($"Tensors with {Length} tensor cannot be " + | throw new ValueError($"Tensors with {Length} tensor cannot be " + | ||||
"implicitly converted to Tensor."); | "implicitly converted to Tensor."); | ||||
} | } | ||||
return items.FirstOrDefault(); | |||||
return this.FirstOrDefault(); | |||||
} | } | ||||
} | } | ||||
public Tensor this[int index] | |||||
public Tensor this[params string[] slices] | |||||
=> this.First()[slices]; | |||||
public Tensors(Tensor tensor) : base(tensor) | |||||
{ | { | ||||
get => items[index]; | |||||
set => items[index] = value; | |||||
} | } | ||||
public Tensor this[params string[] slices] | |||||
=> items.First()[slices]; | |||||
public Tensors(params Tensor[] tensors) | |||||
private Tensors(Nest<Tensor> nested) : base(nested) | |||||
{ | { | ||||
items.AddRange(tensors); | |||||
} | } | ||||
public Tensors(IEnumerable<Tensor> tensors) | |||||
public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||||
{ | { | ||||
items.AddRange(tensors); | |||||
} | } | ||||
public Tensors(NDArray nd) | |||||
public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||||
{ | { | ||||
items.Add(ops.convert_to_tensor(nd)); | |||||
} | } | ||||
public IEnumerator<Tensor> GetEnumerator() | |||||
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||||
{ | { | ||||
foreach (var tensor in items) | |||||
yield return tensor; | |||||
} | } | ||||
public bool IsSingle() | |||||
{ | |||||
return Length == 1; | |||||
} | |||||
public new Tensors MergeWith(Nest<Tensor>? other) | |||||
{ | |||||
return FromNest(base.MergeWith(other)); | |||||
} | |||||
[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||||
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||||
public void Add(Tensor tensor) | public void Add(Tensor tensor) | ||||
=> items.Add(tensor); | |||||
{ | |||||
if(NestType == NestType.Dictionary) | |||||
{ | |||||
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||||
} | |||||
else if(NestType == NestType.Node) | |||||
{ | |||||
NestType = NestType.List; | |||||
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||||
Value = null; | |||||
} | |||||
else | |||||
{ | |||||
ListValue.Add(new Nest<Tensor>(tensor)); | |||||
} | |||||
} | |||||
[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " + | |||||
"some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")] | |||||
public void AddRange(IEnumerable<Tensor> tensors) | public void AddRange(IEnumerable<Tensor> tensors) | ||||
=> items.AddRange(tensors); | |||||
{ | |||||
if (NestType == NestType.Dictionary) | |||||
{ | |||||
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||||
} | |||||
else if (NestType == NestType.Node) | |||||
{ | |||||
NestType = NestType.List; | |||||
ListValue = new() { new Nest<Tensor>(Value) }; | |||||
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||||
Value = null; | |||||
} | |||||
else | |||||
{ | |||||
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||||
} | |||||
} | |||||
[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " + | |||||
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")] | |||||
public void Insert(int index, Tensor tensor) | public void Insert(int index, Tensor tensor) | ||||
=> items.Insert(index, tensor); | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
=> GetEnumerator(); | |||||
{ | |||||
if (NestType == NestType.List) | |||||
{ | |||||
ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||||
} | |||||
else if(NestType == NestType.Node) | |||||
{ | |||||
NestType = NestType.List; | |||||
ListValue = new() { new Nest<Tensor>(Value) }; | |||||
ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||||
Value = null; | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors."); | |||||
} | |||||
} | |||||
public string[] StringData() | public string[] StringData() | ||||
{ | { | ||||
EnsureSingleTensor(this, "nnumpy"); | |||||
return this[0].StringData(); | |||||
return Single.StringData(); | |||||
} | } | ||||
public string StringData(int index) | public string StringData(int index) | ||||
{ | { | ||||
EnsureSingleTensor(this, "nnumpy"); | |||||
return this[0].StringData(index); | |||||
return Single.StringData(index); | |||||
} | } | ||||
public NDArray numpy() | public NDArray numpy() | ||||
{ | { | ||||
EnsureSingleTensor(this, "nnumpy"); | |||||
return this[0].numpy(); | |||||
return Single.numpy(); | |||||
} | } | ||||
[Obsolete] | |||||
public T[] ToArray<T>() where T: unmanaged | public T[] ToArray<T>() where T: unmanaged | ||||
{ | { | ||||
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||||
return this[0].ToArray<T>(); | |||||
return Single.ToArray<T>(); | |||||
} | } | ||||
#region Explicit Conversions | #region Explicit Conversions | ||||
public unsafe static explicit operator bool(Tensors tensor) | public unsafe static explicit operator bool(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||||
return (bool)tensor[0]; | |||||
return (bool)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator sbyte(Tensors tensor) | public unsafe static explicit operator sbyte(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||||
return (sbyte)tensor[0]; | |||||
return (sbyte)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator byte(Tensors tensor) | public unsafe static explicit operator byte(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
return (byte)tensor[0]; | |||||
return (byte)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator ushort(Tensors tensor) | public unsafe static explicit operator ushort(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||||
return (ushort)tensor[0]; | |||||
return (ushort)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator short(Tensors tensor) | public unsafe static explicit operator short(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to short"); | |||||
return (short)tensor[0]; | |||||
return (short)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator int(Tensors tensor) | public unsafe static explicit operator int(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to int"); | |||||
return (int)tensor[0]; | |||||
return (int)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator uint(Tensors tensor) | public unsafe static explicit operator uint(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||||
return (uint)tensor[0]; | |||||
return (uint)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator long(Tensors tensor) | public unsafe static explicit operator long(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to long"); | |||||
return (long)tensor[0]; | |||||
return (long)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator ulong(Tensors tensor) | public unsafe static explicit operator ulong(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||||
return (ulong)tensor[0]; | |||||
return (ulong)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator float(Tensors tensor) | public unsafe static explicit operator float(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
return (byte)tensor[0]; | |||||
return (byte)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator double(Tensors tensor) | public unsafe static explicit operator double(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to double"); | |||||
return (double)tensor[0]; | |||||
return (double)tensor.Single; | |||||
} | } | ||||
public unsafe static explicit operator string(Tensors tensor) | public unsafe static explicit operator string(Tensors tensor) | ||||
{ | { | ||||
EnsureSingleTensor(tensor, "explicit conversion to string"); | |||||
return (string)tensor[0]; | |||||
return (string)tensor.Single; | |||||
} | } | ||||
public static explicit operator object[](Tensors tensors) | public static explicit operator object[](Tensors tensors) | ||||
=> tensors.items.ToArray(); | |||||
=> tensors.Flatten().ToArray(); | |||||
#endregion | #endregion | ||||
#region Implicit Conversions | #region Implicit Conversions | ||||
@@ -219,52 +262,40 @@ namespace Tensorflow | |||||
=> tensors?.SingleOrNull; | => tensors?.SingleOrNull; | ||||
public static implicit operator Tensor[](Tensors tensors) | public static implicit operator Tensor[](Tensors tensors) | ||||
=> tensors.items.ToArray(); | |||||
=> tensors.Flatten().ToArray(); | |||||
#endregion | #endregion | ||||
public void Deconstruct(out Tensor a, out Tensors? b) | |||||
public static Tensors? FromNest(Nest<Tensor> nested) | |||||
{ | { | ||||
a = items[0]; | |||||
b = Length == 1? null : new Tensors(items.Skip(1)); | |||||
if(nested == Nest<Tensor>.Empty) | |||||
{ | |||||
return null; | |||||
} | |||||
return new Tensors(nested); | |||||
} | } | ||||
private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||||
public void Deconstruct(out Tensor a, out Tensors? b) | |||||
{ | { | ||||
if(tensors.Length == 0) | |||||
{ | |||||
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||||
} | |||||
else if(tensors.Length > 1) | |||||
{ | |||||
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||||
} | |||||
a = this.First(); | |||||
b = Length == 1? null : new Tensors(this.Skip(1)); | |||||
} | } | ||||
public override string ToString() | public override string ToString() | ||||
{ | { | ||||
if(items.Count == 1) | |||||
if(Length == 1) | |||||
{ | { | ||||
return items[0].ToString(); | |||||
return this.First().ToString(); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
StringBuilder sb = new StringBuilder(); | |||||
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n"); | |||||
for(int i = 0; i < items.Count; i++) | |||||
{ | |||||
var tensor = items[i]; | |||||
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n"); | |||||
} | |||||
sb.Append("]\n"); | |||||
return sb.ToString(); | |||||
return $"Totally {Length} tensors: {base.ToString()}"; | |||||
} | } | ||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
foreach (var item in items) | |||||
item.Dispose(); | |||||
foreach (var tensor in this) | |||||
tensor.Dispose(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -36,6 +36,7 @@ namespace Tensorflow.Util | |||||
// (np.array([3, 4]), tf.constant([3, 4])))` | // (np.array([3, 4]), tf.constant([3, 4])))` | ||||
// | // | ||||
[Obsolete] | |||||
public static class nest | public static class nest | ||||
{ | { | ||||
@@ -170,39 +171,6 @@ namespace Tensorflow.Util | |||||
throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | ||||
} | } | ||||
public static bool is_nested(object obj) | |||||
{ | |||||
// Refer to https://www.tensorflow.org/api_docs/python/tf/nest | |||||
//if (obj is IList || obj is IDictionary || obj is ITuple) | |||||
// return true; | |||||
if (obj is IList || obj is IDictionary) | |||||
return true; | |||||
if (obj is NDArray || obj is Tensor || obj is string || obj.GetType().IsGenericType | |||||
|| obj is ISet<int> || obj is ISet<float> || obj is ISet<double>) | |||||
return false; | |||||
if (obj.GetType().IsNested) return true; | |||||
// Check if the object is an IEnumerable | |||||
if (obj is IEnumerable) | |||||
{ | |||||
// If it is, check if it is a nested structure | |||||
foreach (object item in (IEnumerable)obj) | |||||
{ | |||||
if (is_nested(item)) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
else | |||||
{ | |||||
// If it is not, return false | |||||
return false; | |||||
} | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Yields the next value from the given iterable. | /// Yields the next value from the given iterable. | ||||
/// </summary> | /// </summary> | ||||