@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Runtime.InteropServices; | |||
using static Tensorflow.CppShapeInferenceResult.Types; | |||
namespace Tensorflow | |||
{ | |||
@@ -50,6 +51,19 @@ namespace Tensorflow | |||
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); | |||
} | |||
public unsafe static byte[] ByteStringPiece(IntPtr handle) | |||
{ | |||
byte* str_data = (byte*)handle.ToPointer(); | |||
List<byte> bytes = new List<byte>(); | |||
byte current = 255; | |||
while (current != ((byte)'\0')) | |||
{ | |||
current = *(str_data++); | |||
bytes.Add(current); | |||
} | |||
return bytes.Take(bytes.Count - 1).ToArray(); | |||
} | |||
[UnmanagedFunctionPointer(CallingConvention.Winapi)] | |||
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); | |||
@@ -46,10 +46,10 @@ namespace Tensorflow | |||
Tensor loop_vars, | |||
int parallel_iterations = 10) | |||
{ | |||
Func<Tensor[], Tensor> cond1 = x | |||
Func<Tensors, Tensor> cond1 = x | |||
=> cond(x[0]); | |||
Func<Tensor[], Tensor[]> body1 = x | |||
Func<Tensors, Tensors> body1 = x | |||
=> new[] { body(x[0]) }; | |||
var results = control_flow_ops.while_loop(cond1, | |||
@@ -58,9 +58,9 @@ namespace Tensorflow | |||
return results[0]; | |||
} | |||
public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
Func<Tensor[], Tensor[]> body, | |||
Tensor[] loop_vars, | |||
public Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||
Func<Tensors, Tensors> body, | |||
Tensors loop_vars, | |||
int parallel_iterations = 10, | |||
string name = null) | |||
=> control_flow_ops.while_loop(cond, body, loop_vars, | |||
@@ -18,7 +18,12 @@ namespace Tensorflow.Common.Extensions | |||
return sequence.Take(sequence.Count() - count); | |||
} | |||
#endif | |||
public static Tensors ToTensors(this IEnumerable<Tensor> tensors) | |||
public static Tensors ToTensors(this Tensor[] tensors) | |||
{ | |||
return new Tensors(tensors); | |||
} | |||
public static Tensors ToTensors(this IList<Tensor> tensors) | |||
{ | |||
return new Tensors(tensors); | |||
} | |||
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
/// <summary> | |||
/// This is a temp solution, which should be removed after refactoring `Tensors` | |||
/// </summary> | |||
[Obsolete] | |||
public class FakeTensorByTensorArray: Tensor | |||
{ | |||
public TensorArray TensorArray { get; set; } | |||
public FakeTensorByTensorArray(TensorArray array) | |||
{ | |||
TensorArray = array; | |||
} | |||
} | |||
} |
@@ -5,136 +5,80 @@ using System.Text; | |||
namespace Tensorflow.Common.Types | |||
{ | |||
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?> | |||
public class GeneralizedTensorShape: Nest<Shape> | |||
{ | |||
public TensorShapeConfig[] Shapes { get; set; } | |||
/// <summary> | |||
/// create a single-dim generalized Tensor shape. | |||
/// </summary> | |||
/// <param name="dim"></param> | |||
public GeneralizedTensorShape(int dim, int size = 1) | |||
{ | |||
var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | |||
Shapes = Enumerable.Repeat(elem, size).ToArray(); | |||
//Shapes = new TensorShapeConfig[size]; | |||
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); | |||
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); | |||
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||
} | |||
////public TensorShapeConfig[] Shapes { get; set; } | |||
///// <summary> | |||
///// create a single-dim generalized Tensor shape. | |||
///// </summary> | |||
///// <param name="dim"></param> | |||
//public GeneralizedTensorShape(int dim, int size = 1) | |||
//{ | |||
// var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; | |||
// Shapes = Enumerable.Repeat(elem, size).ToArray(); | |||
// //Shapes = new TensorShapeConfig[size]; | |||
// //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); | |||
// //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); | |||
// ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; | |||
//} | |||
public GeneralizedTensorShape(Shape shape) | |||
public GeneralizedTensorShape(Shape value, string? name = null) | |||
{ | |||
Shapes = new TensorShapeConfig[] { shape }; | |||
NodeValue = value; | |||
NestType = NestType.Node; | |||
} | |||
public GeneralizedTensorShape(TensorShapeConfig shape) | |||
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null) | |||
{ | |||
Shapes = new TensorShapeConfig[] { shape }; | |||
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList(); | |||
Name = name; | |||
NestType = NestType.List; | |||
} | |||
public GeneralizedTensorShape(TensorShapeConfig[] shapes) | |||
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null) | |||
{ | |||
Shapes = shapes; | |||
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>); | |||
Name = name; | |||
NestType = NestType.Dictionary; | |||
} | |||
public GeneralizedTensorShape(IEnumerable<Shape> shape) | |||
public GeneralizedTensorShape(Nest<Shape> other) | |||
{ | |||
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); | |||
NestType = other.NestType; | |||
NodeValue = other.NodeValue; | |||
DictValue = other.DictValue; | |||
ListValue = other.ListValue; | |||
Name = other.Name; | |||
} | |||
public Shape ToSingleShape() | |||
{ | |||
if (Shapes.Length != 1) | |||
var shapes = Flatten().ToList(); | |||
if (shapes.Count != 1) | |||
{ | |||
throw new ValueError("The generalized shape contains more than 1 dim."); | |||
} | |||
var shape_config = Shapes[0]; | |||
Debug.Assert(shape_config is not null); | |||
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); | |||
return shapes[0]; | |||
} | |||
public long ToNumber() | |||
{ | |||
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1) | |||
var shapes = Flatten().ToList(); | |||
if (shapes.Count != 1 || shapes[0].ndim != 1) | |||
{ | |||
throw new ValueError("The generalized shape contains more than 1 dim."); | |||
} | |||
var res = Shapes[0].Items[0]; | |||
return res is null ? -1 : res.Value; | |||
} | |||
public Shape[] ToShapeArray() | |||
{ | |||
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); | |||
} | |||
public IEnumerable<long?> Flatten() | |||
{ | |||
List<long?> result = new List<long?>(); | |||
foreach(var shapeConfig in Shapes) | |||
{ | |||
result.AddRange(shapeConfig.Items); | |||
} | |||
return result; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func) | |||
{ | |||
List<Nest<TOut>> lists = new(); | |||
foreach(var shapeConfig in Shapes) | |||
{ | |||
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x))))); | |||
} | |||
return new Nest<TOut>(lists); | |||
} | |||
public Nest<long?> AsNest() | |||
{ | |||
Nest<long?> DealWithSingleShape(TensorShapeConfig config) | |||
{ | |||
if (config.Items.Length == 0) | |||
{ | |||
return Nest<long?>.Empty; | |||
} | |||
else if (config.Items.Length == 1) | |||
{ | |||
return new Nest<long?>(config.Items[0]); | |||
} | |||
else | |||
{ | |||
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x))); | |||
} | |||
} | |||
if(Shapes.Length == 0) | |||
{ | |||
return Nest<long?>.Empty; | |||
} | |||
else if(Shapes.Length == 1) | |||
{ | |||
return DealWithSingleShape(Shapes[0]); | |||
} | |||
else | |||
{ | |||
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s))); | |||
} | |||
return shapes[0].dims[0]; | |||
} | |||
public static implicit operator GeneralizedTensorShape(int dims) | |||
=> new GeneralizedTensorShape(dims); | |||
public IEnumerator<long?[]> GetEnumerator() | |||
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs() | |||
{ | |||
foreach (var shape in Shapes) | |||
{ | |||
yield return shape.Items; | |||
} | |||
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
public static implicit operator GeneralizedTensorShape(Shape shape) | |||
{ | |||
return GetEnumerator(); | |||
return new GeneralizedTensorShape(shape); | |||
} | |||
} | |||
} |
@@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types | |||
/// </summary> | |||
public interface INestStructure<T>: INestable<T> | |||
{ | |||
NestType NestType { get; } | |||
/// <summary> | |||
/// The item count of depth 1 of the nested structure. | |||
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. | |||
/// </summary> | |||
int ShallowNestedCount { get; } | |||
/// <summary> | |||
/// The total item count of depth 1 of the nested structure. | |||
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. | |||
/// </summary> | |||
int TotalNestedCount { get; } | |||
/// <summary> | |||
/// Flatten the Nestable object. Node that if the object contains only one value, | |||
/// it will be flattened to an enumerable with one element. |
@@ -13,7 +13,7 @@ namespace Tensorflow.Common.Types | |||
/// <param name="template"></param> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems) | |||
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems) | |||
{ | |||
return template.AsNest().PackSequence(flatItems); | |||
} | |||
@@ -28,27 +28,58 @@ namespace Tensorflow.Common.Types | |||
public static Nest<T> Empty => _empty; | |||
public NestType NestType { get; protected set; } | |||
public string? Name { get; set; } | |||
public T? Value { get; protected set; } | |||
public List<Nest<T>>? ListValue { get; protected set; } | |||
public Dictionary<string, Nest<T>>? DictValue { get; protected set; } | |||
public T? NodeValue { get; protected set; } | |||
public List<INestStructure<T>>? ListValue { get; protected set; } | |||
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; } | |||
public int ShallowNestedCount | |||
{ | |||
get | |||
{ | |||
if (NestType == NestType.Empty) | |||
{ | |||
return 0; | |||
} | |||
else if (NestType == NestType.Node) | |||
{ | |||
return 1; | |||
} | |||
else if (NestType == NestType.List) | |||
{ | |||
return ListValue!.Count; | |||
} | |||
else // dict | |||
{ | |||
return DictValue!.Count; | |||
} | |||
} | |||
} | |||
public int TotalNestedCount | |||
{ | |||
get | |||
{ | |||
return Flatten().Count(); | |||
} | |||
} | |||
protected Nest() { } | |||
public Nest(T value, string? name = null) | |||
{ | |||
Value = value; | |||
NodeValue = value; | |||
Name = name; | |||
NestType = NestType.Node; | |||
} | |||
public Nest(IEnumerable<Nest<T>> values, string? name = null) | |||
public Nest(IEnumerable<INestStructure<T>> values, string? name = null) | |||
{ | |||
ListValue = values.ToList(); | |||
Name = name; | |||
NestType = NestType.List; | |||
} | |||
public Nest(Dictionary<string, Nest<T>> value, string? name = null) | |||
public Nest(Dictionary<string, INestStructure<T>> value, string? name = null) | |||
{ | |||
DictValue = value; | |||
Name = name; | |||
@@ -58,7 +89,7 @@ namespace Tensorflow.Common.Types | |||
public Nest(Nest<T> other) | |||
{ | |||
NestType = other.NestType; | |||
Value = other.Value; | |||
NodeValue = other.NodeValue; | |||
DictValue = other.DictValue; | |||
ListValue = other.ListValue; | |||
Name = other.Name; | |||
@@ -78,17 +109,17 @@ namespace Tensorflow.Common.Types | |||
/// </summary> | |||
/// <param name="flatItems"></param> | |||
/// <returns></returns> | |||
public virtual Nest<T> PackSequence(T[] flatItems) | |||
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems) | |||
{ | |||
if(flatItems.Length == 0) | |||
{ | |||
return Nest<T>.Empty; | |||
return Nest<TOut>.Empty; | |||
} | |||
int index = 0; | |||
return PackSequenceInternal(this, flatItems, ref index); | |||
} | |||
private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index) | |||
private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index) | |||
{ | |||
if(template.NestType == NestType.Node) | |||
{ | |||
@@ -96,25 +127,25 @@ namespace Tensorflow.Common.Types | |||
{ | |||
throw new InvalidArgumentError("The template and flat items are not matched."); | |||
} | |||
return new Nest<T>(flatItems[index++]); | |||
return new Nest<TOut>(flatItems[index++]); | |||
} | |||
else if(template.NestType == NestType.List) | |||
{ | |||
List<Nest<T>> nestedObjects = new List<Nest<T>>(); | |||
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>(); | |||
for (int i = 0; i < template.ListValue!.Count; i++) | |||
{ | |||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); | |||
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); | |||
} | |||
return new Nest<T>(nestedObjects); | |||
return new Nest<TOut>(nestedObjects); | |||
} | |||
else if(template.NestType == NestType.Node) | |||
{ | |||
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>(); | |||
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>(); | |||
foreach(var (key, value) in template.DictValue!) | |||
{ | |||
dict[key] = PackSequenceInternal(value, flatItems, ref index); | |||
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); | |||
} | |||
return new Nest<T>(dict); | |||
return new Nest<TOut>(dict); | |||
} | |||
// Consider Empty as invalid type. | |||
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); | |||
@@ -223,10 +254,10 @@ namespace Tensorflow.Common.Types | |||
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T> | |||
{ | |||
var nested = input.AsNest(); | |||
return ReduceInternal(nested); | |||
return ReduceInternal(nested).AsNest(); | |||
} | |||
private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T> | |||
{ | |||
if(node.NestType == NestType.Empty) | |||
{ | |||
@@ -234,15 +265,15 @@ namespace Tensorflow.Common.Types | |||
} | |||
else if(node.NestType == NestType.Node) | |||
{ | |||
return node.Value!.AsNest(); | |||
return node.NodeValue!.AsNest(); | |||
} | |||
else if(node.NestType == NestType.List) | |||
{ | |||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x))); | |||
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); | |||
} | |||
else // Dictionary type | |||
{ | |||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); | |||
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); | |||
} | |||
} | |||
@@ -252,7 +283,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if(index == 0) | |||
{ | |||
result = node.Value!; | |||
result = node.NodeValue!; | |||
return true; | |||
} | |||
result = default(T); | |||
@@ -264,7 +295,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if(index == 0) | |||
{ | |||
return FindInternal(item, index, out result); | |||
return FindInternal(item.AsNest(), index, out result); | |||
} | |||
index--; | |||
} | |||
@@ -277,7 +308,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if (index == 0) | |||
{ | |||
return FindInternal(item, index, out result); | |||
return FindInternal(item.AsNest(), index, out result); | |||
} | |||
index--; | |||
} | |||
@@ -297,7 +328,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if (index == 0) | |||
{ | |||
node.Value = newValue; | |||
node.NodeValue = newValue; | |||
return true; | |||
} | |||
return false; | |||
@@ -308,7 +339,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if (index == 0) | |||
{ | |||
return SetInternal(item, index, newValue); | |||
return SetInternal(item.AsNest(), index, newValue); | |||
} | |||
index--; | |||
} | |||
@@ -320,7 +351,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if (index == 0) | |||
{ | |||
return SetInternal(item, index, newValue); | |||
return SetInternal(item.AsNest(), index, newValue); | |||
} | |||
index--; | |||
} | |||
@@ -336,13 +367,13 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if (node.NestType == NestType.Node) | |||
{ | |||
yield return node.Value!; | |||
yield return node.NodeValue!; | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
foreach (var item in node.ListValue!) | |||
{ | |||
foreach(var val in FlattenInternal(item)) | |||
foreach(var val in FlattenInternal(item.AsNest())) | |||
{ | |||
yield return val; | |||
} | |||
@@ -352,7 +383,7 @@ namespace Tensorflow.Common.Types | |||
{ | |||
foreach (var item in node.DictValue!.Values) | |||
{ | |||
foreach (var val in FlattenInternal(item)) | |||
foreach (var val in FlattenInternal(item.AsNest())) | |||
{ | |||
yield return val; | |||
} | |||
@@ -364,23 +395,23 @@ namespace Tensorflow.Common.Types | |||
{ | |||
if (NestType == NestType.Node) | |||
{ | |||
return new Nest<TOut>(func(Value!)); | |||
return new Nest<TOut>(func(NodeValue!)); | |||
} | |||
else if (NestType == NestType.List) | |||
{ | |||
List<Nest<TOut>> outs = new List<Nest<TOut>>(); | |||
foreach (var item in ListValue!) | |||
{ | |||
outs.Add(item.MapStructureInternal(func)); | |||
outs.Add(item.AsNest().MapStructureInternal(func)); | |||
} | |||
return new Nest<TOut>(outs); | |||
} | |||
else if (NestType == NestType.Dictionary) | |||
{ | |||
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>(); | |||
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>(); | |||
foreach (var (key, value) in DictValue!) | |||
{ | |||
outs.Add(key, value.MapStructureInternal(func)); | |||
outs.Add(key, value.AsNest().MapStructureInternal(func)); | |||
} | |||
return new Nest<TOut>(outs); | |||
} | |||
@@ -417,14 +448,14 @@ namespace Tensorflow.Common.Types | |||
} | |||
if (node.NestType == NestType.Node) | |||
{ | |||
sb.Append(node.Value!.ToString()); | |||
sb.Append(node.NodeValue!.ToString()); | |||
} | |||
else if (node.NestType == NestType.List) | |||
{ | |||
sb.Append("["); | |||
for(int i = 0; i < node.ListValue!.Count; i++) | |||
{ | |||
WriteString(node.ListValue![i], sb); | |||
WriteString(node.ListValue![i].AsNest(), sb); | |||
if(i != node.ListValue!.Count - 1) | |||
{ | |||
sb.Append(", "); | |||
@@ -440,7 +471,7 @@ namespace Tensorflow.Common.Types | |||
foreach (var (key, value) in node.DictValue!) | |||
{ | |||
sb.Append($"{key}: "); | |||
WriteString(value, sb); | |||
WriteString(value.AsNest(), sb); | |||
if (i != count - 1) | |||
{ | |||
sb.Append(", "); | |||
@@ -454,5 +485,15 @@ namespace Tensorflow.Common.Types | |||
sb.Append("<empty>"); | |||
} | |||
} | |||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs) | |||
{ | |||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 }); | |||
} | |||
public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs) | |||
{ | |||
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 }); | |||
} | |||
} | |||
} |
@@ -6,7 +6,11 @@ namespace Tensorflow.Common.Types | |||
{ | |||
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull | |||
{ | |||
public NestType NestType => NestType.Dictionary; | |||
public IDictionary<TKey, TValue> Value { get; set; } | |||
public int ShallowNestedCount => Values.Count; | |||
public int TotalNestedCount => Values.Count; | |||
public NestDictionary(IDictionary<TKey, TValue> dict) | |||
{ | |||
Value = dict; | |||
@@ -10,29 +10,34 @@ namespace Tensorflow.Common.Types | |||
/// <typeparam name="T"></typeparam> | |||
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T> | |||
{ | |||
public List<T> Value { get; set; } | |||
public NestType NestType => NestType.List; | |||
public List<T> Values { get; set; } | |||
public int ShallowNestedCount => Values.Count; | |||
public int TotalNestedCount => Values.Count; | |||
public NestList(IEnumerable<T> values) | |||
{ | |||
Value = new List<T>(values); | |||
Values = new List<T>(values); | |||
} | |||
public IEnumerable<T> Flatten() | |||
{ | |||
return Value; | |||
return Values; | |||
} | |||
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func) | |||
{ | |||
return new NestList<TOut>(Value.Select(x => func(x))); | |||
return new NestList<TOut>(Values.Select(x => func(x))); | |||
} | |||
public Nest<T> AsNest() | |||
{ | |||
return new Nest<T>(Value.Select(x => new Nest<T>(x))); | |||
return new Nest<T>(Values.Select(x => new Nest<T>(x))); | |||
} | |||
// Enumerator implementation | |||
public IEnumerator<T> GetEnumerator() | |||
{ | |||
return Value.GetEnumerator(); | |||
return Values.GetEnumerator(); | |||
} | |||
IEnumerator IEnumerable.GetEnumerator() | |||
@@ -10,7 +10,11 @@ namespace Tensorflow.Common.Types | |||
/// <typeparam name="T"></typeparam> | |||
public class NestNode<T> : INestStructure<T> | |||
{ | |||
public NestType NestType => NestType.Node; | |||
public T Value { get; set; } | |||
public int ShallowNestedCount => 1; | |||
public int TotalNestedCount => 1; | |||
public NestNode(T value) | |||
{ | |||
Value = value; | |||
@@ -161,8 +161,8 @@ namespace Tensorflow | |||
break; | |||
} | |||
yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? | |||
null : new Tensors(results.Skip(FirstInputTensorCount))); | |||
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? | |||
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); | |||
} | |||
} | |||
@@ -359,6 +359,8 @@ namespace Tensorflow.Eager | |||
case TF_AttrType.TF_ATTR_FUNC: | |||
if (value is ConcreteFunction func) | |||
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||
else if(value is string str) | |||
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); | |||
else | |||
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | |||
break; | |||
@@ -1,4 +1,5 @@ | |||
using System.Linq; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Framework.Models | |||
{ | |||
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models | |||
shapes.Insert(0, dim); | |||
return new TensorSpec(shapes.ToArray(), _dtype); | |||
} | |||
public static TensorSpec FromTensor(Tensor tensor, string? name = null) | |||
{ | |||
if(tensor is EagerTensor) | |||
{ | |||
return new TensorSpec(tensor.shape, tensor.dtype, name); | |||
} | |||
else | |||
{ | |||
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,89 @@ | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow.Framework | |||
{ | |||
internal static class auto_control_deps_utils | |||
{ | |||
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; | |||
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph) | |||
{ | |||
List<int> result = new List<int>(); | |||
// A cache to store the read only resource inputs of an Op. | |||
// Operation -> ObjectIdentitySet of resource handles. | |||
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs = | |||
new Dictionary<Operation, HashSet<Tensor>>(); | |||
for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) | |||
{ | |||
Tensor t = func_graph.Inputs[inputIndex]; | |||
if (t.dtype != dtypes.resource) | |||
continue; | |||
bool readOnly = true; | |||
foreach (var op in t.consumers()) | |||
{ | |||
if (opReadOnlyResourceInputs.ContainsKey(op)) | |||
{ | |||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
{ | |||
readOnly = false; | |||
break; | |||
} | |||
} | |||
else | |||
{ | |||
List<int> indices = _get_read_only_resource_input_indices_op(op); | |||
opReadOnlyResourceInputs[op] = new HashSet<Tensor>( | |||
indices.Select(i => op.inputs[i])); | |||
if (!opReadOnlyResourceInputs[op].Contains(t)) | |||
{ | |||
readOnly = false; | |||
break; | |||
} | |||
} | |||
} | |||
if (readOnly) | |||
result.Add(inputIndex); | |||
} | |||
return result; | |||
} | |||
private static List<int> _get_read_only_resource_input_indices_op(Operation op) | |||
{ | |||
// ignore the RESOURCE_READ_OPS | |||
int[] read_only_input_indices; | |||
try | |||
{ | |||
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR); | |||
} | |||
catch (InvalidArgumentError) | |||
{ | |||
return new List<int>(); | |||
} | |||
int read_only_index = 0; | |||
List<int> result = new(); | |||
for (int i = 0; i < op.inputs.Length; i++) | |||
{ | |||
if (read_only_index >= read_only_input_indices.Length) | |||
{ | |||
break; | |||
} | |||
if (op.inputs[i].dtype != dtypes.resource) | |||
{ | |||
continue; | |||
} | |||
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) | |||
{ | |||
result.Add(i); | |||
read_only_index++; | |||
} | |||
} | |||
return result; | |||
} | |||
} | |||
} |
@@ -42,10 +42,10 @@ namespace Tensorflow.Framework | |||
func_graph.as_default(); | |||
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); | |||
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); | |||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); | |||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); | |||
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); | |||
// TODO(Rinne): func_graph.ControlOutputs | |||
_set_handle_data(func_graph, fdef); | |||
@@ -8,6 +8,7 @@ using Tensorflow.Gradients; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Train; | |||
using Tensorflow.Util; | |||
using Tensorflow.Common.Extensions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Functions | |||
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions | |||
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; | |||
public IEnumerable<IVariableV1> Variables => func_graph.Variables; | |||
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables; | |||
internal NameAttrList AsNameAttrList | |||
{ | |||
get | |||
{ | |||
NameAttrList ret = new() { Name = this.Name }; | |||
foreach (var (name, value) in _attrs) | |||
{ | |||
ret.Attr[name] = value; | |||
} | |||
return ret; | |||
} | |||
} | |||
public ConcreteFunction(string name) | |||
{ | |||
@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable | |||
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable); | |||
public Dictionary<string, AttrValue> Attrs { get; set; } | |||
Dictionary<long, (Tensor, Tensor)> _captures | |||
internal Dictionary<long, (Tensor, Tensor)> _captures | |||
= new Dictionary<long, (Tensor, Tensor)>(); | |||
public Tensor[] external_captures | |||
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable | |||
var flat_func_args = nest.flatten(func_args as object); | |||
var flat_func_kwargs = nest.flatten(func_kwargs as object); | |||
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) | |||
.Where(x => x is Tensor).Select(x => (Tensor)x)); | |||
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); | |||
//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); | |||
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); | |||
@@ -129,7 +129,7 @@ namespace Tensorflow | |||
} | |||
} | |||
protected Graph outer_graph; | |||
internal Graph outer_graph; | |||
public Graph OuterGraph => outer_graph; | |||
public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||
public SafeGraphHandle c_graph => _handle; | |||
@@ -7,13 +7,19 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public interface IRnnCell: ILayer | |||
{ | |||
GeneralizedTensorShape StateSize { get; } | |||
GeneralizedTensorShape OutputSize { get; } | |||
bool IsTFRnnCell { get; } | |||
/// <summary> | |||
/// If the derived class tends to not implement it, please return null. | |||
/// </summary> | |||
GeneralizedTensorShape? StateSize { get; } | |||
/// <summary> | |||
/// If the derived class tends to not implement it, please return null. | |||
/// </summary> | |||
GeneralizedTensorShape? OutputSize { get; } | |||
/// <summary> | |||
/// Whether the optional RNN args are supported when appying the layer. | |||
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. | |||
/// </summary> | |||
bool SupportOptionalArgs { get; } | |||
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); | |||
} | |||
} |
@@ -181,6 +181,10 @@ namespace Tensorflow | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public GeneralizedTensorShape StateSize => throw new NotImplementedException(); | |||
public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||
public bool IsTFRnnCell => throw new NotImplementedException(); | |||
@@ -15,9 +15,11 @@ | |||
******************************************************************************/ | |||
using Google.Protobuf; | |||
using Google.Protobuf.Collections; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.OpDef.Types; | |||
@@ -420,6 +422,12 @@ namespace Tensorflow | |||
case "list(shape)": | |||
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); | |||
break; | |||
case "func": | |||
attr_value.Func = _MakeFunc(value, attr_def.Name); | |||
break; | |||
case "list(func)": | |||
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | |||
break; | |||
default: | |||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | |||
} | |||
@@ -427,6 +435,47 @@ namespace Tensorflow | |||
return attr_value; | |||
} | |||
private NameAttrList _MakeFunc(object func, string arg_name) | |||
{ | |||
if(func is NameAttrList attrList) | |||
{ | |||
return attrList; | |||
} | |||
NameAttrList fn_attr; | |||
if(func is string funcStr) | |||
{ | |||
fn_attr = new NameAttrList() { Name = funcStr }; | |||
} | |||
else if(func is ConcreteFunction concrete) | |||
{ | |||
concrete.AddTograph(ops.get_default_graph()); | |||
fn_attr = concrete.AsNameAttrList; | |||
} | |||
else if(func is EagerDefinedFunction eager) | |||
{ | |||
eager.AddToGraph(ops.get_default_graph()); | |||
fn_attr = new NameAttrList() { Name = eager.Name }; | |||
} | |||
else | |||
{ | |||
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); | |||
} | |||
return fn_attr; | |||
} | |||
private List<NameAttrList> _MakeFuncList(object funcList, string arg_name) | |||
{ | |||
List<NameAttrList> res = new List<NameAttrList>(); | |||
if(funcList is IEnumerable enumerable) | |||
{ | |||
foreach(var func in enumerable) | |||
{ | |||
res.Add(_MakeFunc(func, arg_name)); | |||
} | |||
} | |||
return res; | |||
} | |||
private bool _IsListParameter(ArgDef arg) | |||
{ | |||
if (!String.IsNullOrEmpty(arg.NumberAttr)) | |||
@@ -34,7 +34,7 @@ namespace Tensorflow | |||
return num; | |||
} | |||
protected Tensor[] _outputs; | |||
internal Tensor[] _outputs; | |||
public virtual Tensor[] outputs => _outputs; | |||
public Tensor output => _outputs.FirstOrDefault(); | |||
@@ -46,9 +46,9 @@ namespace Tensorflow | |||
/// </summary> | |||
public partial class Operation : ITensorOrOperation | |||
{ | |||
private readonly IntPtr _handle; // _c_op in python | |||
protected IntPtr _handle; // _c_op in python | |||
private readonly Graph _graph; | |||
protected Graph _graph; | |||
internal Func<Operation, object[], Tensor[]> _gradient_function; | |||
@@ -69,6 +69,7 @@ namespace Tensorflow | |||
//private OperationDescription _op_desc; | |||
public NodeDef node_def => GetNodeDef(); | |||
protected Operation() { } | |||
public Operation(IntPtr handle, Graph g = null) | |||
{ | |||
@@ -17,6 +17,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
@@ -38,10 +39,6 @@ namespace Tensorflow.Operations | |||
bool _infer_shape; | |||
public override bool infer_shape => _infer_shape; | |||
public bool _dynamic_size; | |||
public Shape _element_shape; | |||
public List<Tensor> _colocate_with; | |||
Tensor _handle; | |||
public override Tensor handle => _handle; | |||
@@ -56,6 +53,7 @@ namespace Tensorflow.Operations | |||
bool infer_shape = true, Shape? element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
_size = size; | |||
_flow = constant_op.constant(0); | |||
_infer_shape = infer_shape; | |||
_element_shape = element_shape ?? Shape.Null; | |||
@@ -16,7 +16,9 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
@@ -33,18 +35,18 @@ namespace Tensorflow.Operations | |||
/// first tensor written to it. | |||
/// </summary> | |||
bool _colocate_with_first_write_call; | |||
public bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
bool _infer_shape; | |||
public bool infer_shape => _infer_shape; | |||
public bool _dynamic_size; | |||
public override bool infer_shape => _infer_shape; | |||
public List<Shape> _element_shape; | |||
public List<Tensor> _colocate_with; | |||
internal Tensor _handle; | |||
public Tensor handle => _handle; | |||
public override Tensor handle => _handle; | |||
internal Tensor _flow; | |||
public override Tensor flow => _flow; | |||
public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
@@ -55,6 +57,7 @@ namespace Tensorflow.Operations | |||
dynamic_size = dynamic_size ?? false; | |||
_dynamic_size = dynamic_size.Value; | |||
_dtype = dtype; | |||
_size = size; | |||
_colocate_with_first_write_call = colocate_with_first_write_call; | |||
if (colocate_with_first_write_call) | |||
@@ -235,4 +238,172 @@ namespace Tensorflow.Operations | |||
return value; | |||
} | |||
} | |||
public class _GraphTensorArrayV2 : TensorArray | |||
{ | |||
internal TF_DataType _dtype; | |||
public override TF_DataType dtype => _dtype; | |||
/// <summary> | |||
/// Used to keep track of what tensors the TensorArray should be | |||
/// colocated with. We choose to colocate the TensorArray with the | |||
/// first tensor written to it. | |||
/// </summary> | |||
bool _colocate_with_first_write_call; | |||
public override bool colocate_with_first_write_call => _colocate_with_first_write_call; | |||
bool _infer_shape; | |||
public override bool infer_shape => _infer_shape; | |||
public Shape _element_shape; | |||
public List<Tensor> _colocate_with; | |||
internal Tensor _handle; | |||
public override Tensor handle => _handle; | |||
internal Tensor _flow; | |||
public override Tensor flow => _flow; | |||
public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null, | |||
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, Shape? element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
Debug.Assert(handle is null); | |||
dynamic_size = dynamic_size ?? false; | |||
_dynamic_size = dynamic_size.Value; | |||
_size = size; | |||
if(flow is not null && flow.dtype != dtypes.variant) | |||
{ | |||
throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead"); | |||
} | |||
if(flow is null && size is null) | |||
{ | |||
throw new ValueError("Argument `size` must be provided if argument `flow` is not provided."); | |||
} | |||
if(flow is not null && size is not null) | |||
{ | |||
throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time."); | |||
} | |||
if(flow is not null && element_shape is not null) | |||
{ | |||
throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time."); | |||
} | |||
_dtype = dtype; | |||
_element_shape = element_shape; | |||
_infer_shape = infer_shape; | |||
tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope => | |||
{ | |||
if (flow is null) | |||
{ | |||
_flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name); | |||
} | |||
else | |||
{ | |||
_flow = flow; | |||
} | |||
}); | |||
_colocate_with_first_write_call = false; | |||
_colocate_with = null; | |||
} | |||
public override TensorArray unstack(Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
Debug.Assert(value.dtype == _dtype); | |||
var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray()); | |||
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
}); | |||
} | |||
public TensorArray scatter(Tensor indices, Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
Debug.Assert(value.dtype == _dtype); | |||
var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow); | |||
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
}); | |||
} | |||
public override Tensor read<T>(T index, string name = null) | |||
{ | |||
if(index is Tensor tensor) | |||
{ | |||
return read(tensor, name); | |||
} | |||
else | |||
{ | |||
throw new TypeError("Please use non-generic method instead."); | |||
} | |||
} | |||
public Tensor read(Tensor index, string name = null) | |||
{ | |||
return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope => | |||
{ | |||
return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name); | |||
}); | |||
} | |||
public override TensorArray write(Tensor index, Tensor value, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate | |||
{ | |||
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
Debug.Assert(value.dtype == _dtype); | |||
var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name); | |||
return tensor_array_ops.build_ta_with_new_flow(this, flow_out); | |||
}); | |||
} | |||
public override TensorArray write<T>(int index, T value, string name = null) | |||
{ | |||
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); | |||
var index_tensor = ops.convert_to_tensor(index, name: "index"); | |||
return write(index_tensor, value_tensor); | |||
} | |||
private Tensor size(string name = null) | |||
{ | |||
if(!_dynamic_size && _size is not null) | |||
{ | |||
return ops.convert_to_tensor(_size, dtypes.int32); | |||
} | |||
else | |||
{ | |||
return gen_list_ops.tensor_list_length(_flow, name); | |||
} | |||
} | |||
public override Tensor stack(string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate | |||
{ | |||
int ta_size; | |||
if(!_dynamic_size && (_size is not null)) | |||
{ | |||
ta_size = (int)tensor_util.constant_value(_size); | |||
} | |||
else | |||
{ | |||
ta_size = -1; | |||
} | |||
var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape); | |||
return value; | |||
}); | |||
} | |||
public override Tensor gather(Tensor indices, string name = null) | |||
{ | |||
return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name); | |||
} | |||
} | |||
} |
@@ -119,6 +119,27 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
{ | |||
dtype = dtype.as_base_dtype(); | |||
Tensor shapeTensor; | |||
if(shape.Length > 1) | |||
{ | |||
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | |||
if(shapeTensor.ndim > 1) | |||
{ | |||
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | |||
} | |||
} | |||
else | |||
{ | |||
shapeTensor = shape[0]; | |||
} | |||
var output = fill(shapeTensor, array_ops.constant(0, dtype), name); | |||
Debug.Assert(output.dtype.as_base_dtype() == dtype); | |||
return output; | |||
} | |||
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) | |||
{ | |||
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate | |||
@@ -307,6 +328,9 @@ namespace Tensorflow | |||
public static Tensor fill<T>(Shape dims, T value, string name = null) | |||
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||
public static Tensor fill<T>(Tensor dims, T value, string name = null) | |||
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||
/// <summary> | |||
/// Returns the rank of a tensor. | |||
/// </summary> | |||
@@ -675,16 +675,17 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor[] while_loop(Func<Tensor[], Tensor> cond, | |||
Func<Tensor[], Tensor[]> body, | |||
Tensor[] loop_vars, | |||
public static Tensors while_loop(Func<Tensors, Tensor> cond, | |||
Func<Tensors, Tensors> body, | |||
Tensors loop_vars, | |||
int parallel_iterations = 10, | |||
string name = null) | |||
{ | |||
var executing_eagerly = tf.Context.executing_eagerly(); | |||
if (!executing_eagerly) | |||
{ | |||
throw new NotImplementedException(""); | |||
return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations, | |||
name: name); | |||
} | |||
return tf_with(ops.name_scope("name", "while"), delegate | |||
@@ -16,12 +16,20 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Functions; | |||
using Tensorflow.Graphs; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class control_flow_util | |||
{ | |||
public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" || | |||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") || | |||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") || | |||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") || | |||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0"); | |||
/// <summary> | |||
/// Return true if `op` is an Exit. | |||
/// </summary> | |||
@@ -196,5 +204,74 @@ namespace Tensorflow | |||
} | |||
return null; | |||
} | |||
public static bool EnableControlFlowV2(Graph graph) | |||
{ | |||
return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0); | |||
} | |||
public static string create_new_tf_function(FuncGraph func_graph) | |||
{ | |||
var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary<string, AttrValue>()); | |||
func.AddToGraph(func_graph); | |||
return func_graph.Name; | |||
} | |||
public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs) | |||
{ | |||
if(inputs.Length == 0) | |||
{ | |||
return (null, new Tensor[0]); | |||
} | |||
else | |||
{ | |||
return (inputs[0], inputs); | |||
} | |||
} | |||
public static Tensor[] run_as_function_for_tape_gradients(Func<Tensor[], Tensor[]> make_op, Tensor[] inputs) | |||
{ | |||
if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER | |||
&& !(ops.get_default_graph().building_function)) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
else | |||
{ | |||
return make_op(inputs); | |||
} | |||
} | |||
public static string unique_fn_name(string scope, string name) | |||
{ | |||
return $"{scope}{name}_{ops.uid()}".Replace("/", "_"); | |||
} | |||
public static bool output_all_intermediates() | |||
{ | |||
if (in_defun()) | |||
{ | |||
return false; | |||
} | |||
if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR") | |||
{ | |||
return false; | |||
} | |||
// TODO(Rinne): check this after refactoring keras building. | |||
return false; | |||
} | |||
public static bool in_defun() | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return false; | |||
} | |||
var graph = ops.get_default_graph(); | |||
// TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph | |||
return graph is FuncGraph; | |||
} | |||
} | |||
} |
@@ -0,0 +1,111 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
namespace Tensorflow.Operations | |||
{ | |||
internal class list_ops | |||
{ | |||
private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype) | |||
{ | |||
if(list_handle is EagerTensor eagerTensor) | |||
{ | |||
var handle_data = new CppShapeInferenceResult.Types.HandleData(); | |||
handle_data.IsSet = true; | |||
handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType() | |||
{ | |||
Shape = element_shape.as_proto(), | |||
Dtype = element_dtype.as_datatype_enum(), | |||
Type = new FullTypeDef() { TypeId = FullTypeId.TftArray } | |||
}); | |||
list_handle.HandleData = handle_data; | |||
} | |||
} | |||
private static Tensor _build_element_shape(Shape? shape) | |||
{ | |||
if(shape is null || shape.IsNull) | |||
{ | |||
return ops.convert_to_tensor(-1); | |||
} | |||
else | |||
{ | |||
return ops.convert_to_tensor(shape); | |||
} | |||
} | |||
public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null) | |||
{ | |||
var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name); | |||
_set_handle_data(result, shape, element_dtype); | |||
return result; | |||
} | |||
public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null) | |||
{ | |||
var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name); | |||
_set_handle_data(result, tensor.shape, tensor.dtype); | |||
return result; | |||
} | |||
public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype, | |||
Shape? element_shape = null, string? name = null) | |||
{ | |||
return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape), | |||
element_dtype, name); | |||
} | |||
public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, | |||
bool resize_if_index_out_of_bounds = false, string? name = null) | |||
{ | |||
if (resize_if_index_out_of_bounds) | |||
{ | |||
var input_list_size = gen_list_ops.tensor_list_length(input_handle); | |||
input_handle = control_flow_ops.cond(index >= input_list_size, | |||
() => gen_list_ops.tensor_list_resize(input_handle, index + 1), | |||
() => input_handle); | |||
} | |||
var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name); | |||
handle_data_util.copy_handle_data(input_handle, output_handle); | |||
return output_handle; | |||
} | |||
public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1, | |||
Shape? element_shape = null, string? name = null) | |||
{ | |||
return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name); | |||
} | |||
public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype, | |||
Shape? element_shape = null, string? name = null) | |||
{ | |||
return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name); | |||
} | |||
public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null, | |||
string? name = null) | |||
{ | |||
if(input_handle is not null) | |||
{ | |||
var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name); | |||
handle_data_util.copy_handle_data(input_handle, output_handle); | |||
return output_handle; | |||
} | |||
else | |||
{ | |||
var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape), | |||
constant_op.constant(-1), name); | |||
_set_handle_data(output_handle, element_shape, tensor.dtype); | |||
return output_handle; | |||
} | |||
} | |||
public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1, | |||
string? name = null) | |||
{ | |||
return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype, | |||
max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name); | |||
} | |||
} | |||
} |
@@ -13,11 +13,23 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) | |||
{ | |||
var new_ta = tf.TensorArray( | |||
dtype: old_ta.dtype, | |||
infer_shape: old_ta.infer_shape, | |||
if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph())) | |||
{ | |||
throw new NotImplementedException("Attempting to build a graph-mode TF2-style " | |||
+ "TensorArray from either an eager-mode " | |||
+ "TensorArray or a TF1-style TensorArray. " | |||
+ "This is not currently supported. You may be " | |||
+ "attempting to capture a TensorArray " | |||
+ "inside a tf.function or tf.data map function. " | |||
+ "Instead, construct a new TensorArray inside " | |||
+ "the function."); | |||
} | |||
var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape, | |||
colocate_with_first_write_call: old_ta.colocate_with_first_write_call); | |||
new_ta._dynamic_size = old_ta._dynamic_size; | |||
new_ta._size = old_ta._size; | |||
new_ta._colocate_with = old_ta._colocate_with; | |||
new_ta._element_shape = old_ta._element_shape; | |||
return new_ta; | |||
} | |||
@@ -0,0 +1,401 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Text; | |||
using Tensorflow.Common.Extensions; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Framework.Models; | |||
using Tensorflow.Graphs; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
class _OperationWithOutputs : Operation | |||
{ | |||
public _OperationWithOutputs(IntPtr handle, Graph g = null) | |||
{ | |||
_handle = handle; | |||
_graph = g; | |||
_outputs = null; | |||
g._add_op(this); | |||
} | |||
} | |||
internal class while_v2 | |||
{ | |||
public static Tensor[] while_loop(Func<Tensors, Tensor> cond, | |||
Func<Tensors, Tensors> body, | |||
Tensors loop_vars, | |||
int maximum_iterations = -1, | |||
int parallel_iterations = 10, | |||
string name = null, | |||
bool back_prop = true, | |||
bool return_same_structure = true) | |||
{ | |||
var orig_loop_vars = loop_vars; | |||
var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray(); | |||
int len_orig_loop_vars = orig_loop_vars.Length; | |||
loop_vars = _tensor_array_to_flow(loop_vars); | |||
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); | |||
var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); | |||
var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); | |||
if(string.IsNullOrEmpty(name)) | |||
{ | |||
name = "while"; | |||
} | |||
return tf_with<ITensorFlowObject, Tensor[]>(ops.name_scope(name), nameScopeWhile => | |||
{ | |||
string scope = (nameScopeWhile as ops.NameScope).scope_name; | |||
string cond_name = control_flow_util.unique_fn_name(scope, "cond"); | |||
string body_name = control_flow_util.unique_fn_name(scope, "body"); | |||
var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations); | |||
var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype, | |||
name: "loop_counter"); | |||
loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray(); | |||
var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)} | |||
.Concat(loop_vars_signature.Flatten()).ToArray(); | |||
// TODO(Rinne): possible wrong implemenation here. | |||
var add_control_dependencies = false; | |||
object[] wrapped_cond(object[] inputs) | |||
{ | |||
Tensor loop_counter = (Tensor)inputs[0]; | |||
Tensor maximum_iterations_arg = (Tensor)inputs[1]; | |||
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); | |||
var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); | |||
if(pred.shape.IsNull || pred.shape.ndim > 0) | |||
{ | |||
pred = array_ops.squeeze(pred); | |||
} | |||
if(maximum_iterations == -1) | |||
{ | |||
return new object[] { pred }; | |||
} | |||
else | |||
{ | |||
return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) }; | |||
} | |||
} | |||
var cond_graph = FuncGraph.func_graph_from_func("cond", wrapped_cond, null, | |||
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); | |||
bool stateful_parallelism = false; | |||
object[] wrapped_body(object[] inputs) | |||
{ | |||
Tensor loop_counter = (Tensor)inputs[0]; | |||
Tensor maximum_iterations_arg = (Tensor)inputs[1]; | |||
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); | |||
_copy_handle_data(loop_vars.Flatten().Skip(2), args); | |||
foreach(var t in cond_graph.external_captures) | |||
{ | |||
var graph = (FuncGraph)(ops.get_default_graph()); | |||
graph.capture(t); | |||
} | |||
var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); | |||
outputs = _tensor_array_to_flow(outputs); | |||
return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); | |||
} | |||
var body_graph = FuncGraph.func_graph_from_func("body", wrapped_body, null, null, func_graph_signature, | |||
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); | |||
// TODO(Rinne): possible wrong implementation here. | |||
NestList<Tensors> loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() }); | |||
body_graph.Outputs.AddRange(body_graph.internal_captures); | |||
cond_graph.as_default(); | |||
int num_cond_captures = cond_graph.external_captures.Length; | |||
Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray())); | |||
_duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray()); | |||
cond_graph.Exit(); | |||
int first_loop_var_index = 2; | |||
int num_flattened_oututs = orig_loop_vars.Length; | |||
int num_original_outputs = body_graph.Outputs.Length; | |||
if (back_prop && control_flow_util.output_all_intermediates()) | |||
{ | |||
var intermediate_tensors = _get_intermediates(body_graph); | |||
foreach(var intermediate_tensor in intermediate_tensors) | |||
{ | |||
var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations); | |||
loop_vars_list.Values.Add(tensor_list); | |||
cond_graph.as_default(); | |||
cond_graph.capture(tensor_list); | |||
cond_graph.Exit(); | |||
body_graph.as_default(); | |||
var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor); | |||
body_graph.Outputs.Add(appended_tensor_list); | |||
body_graph.Exit(); | |||
} | |||
} | |||
List<Tensor> flattened_loop_vars = new(); | |||
foreach(var item in loop_vars_list.Values) | |||
{ | |||
flattened_loop_vars.AddRange(item.Flatten()); | |||
} | |||
// skip the check | |||
// TODO(Rinne): deal with control dependencies | |||
var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray(); | |||
var span = new Span<Shape>(output_shapes).Slice(first_loop_var_index, num_flattened_oututs); | |||
for(int i = 0; i < span.Length; i++) | |||
{ | |||
span[i] = flat_shape_invariants[i]; | |||
} | |||
Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations, | |||
(nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism); | |||
if (!ops.get_default_graph().building_function) | |||
{ | |||
outputs = outputs.Select(t => array_ops.identity(t)).ToArray(); | |||
} | |||
var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray(); | |||
if (!back_prop) | |||
{ | |||
output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray(); | |||
} | |||
outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars); | |||
return outputs; | |||
}); | |||
} | |||
private static Tensors _tensor_array_to_flow(Tensors loop_vars) | |||
{ | |||
if(loop_vars.NestType == NestType.Node) | |||
{ | |||
if(loop_vars.NodeValue is FakeTensorByTensorArray fake) | |||
{ | |||
return new Tensors(fake.TensorArray.flow); | |||
} | |||
else | |||
{ | |||
return new Tensors(loop_vars.NodeValue!); | |||
} | |||
} | |||
else if(loop_vars.NestType == NestType.List) | |||
{ | |||
List<INestStructure<Tensor>> list = new(); | |||
foreach(var item in loop_vars.ListValue!) | |||
{ | |||
if(item.NestType == NestType.Node) | |||
{ | |||
var nested = item.AsNest(); | |||
if (nested.NodeValue is FakeTensorByTensorArray fake) | |||
{ | |||
list.Add(new Nest<Tensor>(fake.TensorArray.flow)); | |||
} | |||
else | |||
{ | |||
list.Add(new Nest<Tensor>(nested.NodeValue!)); | |||
} | |||
} | |||
else | |||
{ | |||
list.Add(new Nest<Tensor>(item.AsNest())); | |||
} | |||
} | |||
return Tensors.FromNest(new Nest<Tensor>(list)); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph, | |||
Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism) | |||
{ | |||
var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op); | |||
var body_stateful_ops = body_graph.get_operations().Select(x => x.op); | |||
bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0; | |||
Tensor[] _make_op(Tensor[] inputs) | |||
{ | |||
Tensor[] outputs; | |||
if (is_stateful) | |||
{ | |||
outputs = gen_functional_ops._while( | |||
inputs, | |||
control_flow_util.create_new_tf_function(cond_graph), | |||
control_flow_util.create_new_tf_function(body_graph), | |||
output_shapes, | |||
parallel_iterations, | |||
name | |||
); | |||
} | |||
else | |||
{ | |||
outputs = gen_functional_ops.stateless_while( | |||
inputs, | |||
control_flow_util.create_new_tf_function(cond_graph), | |||
control_flow_util.create_new_tf_function(body_graph), | |||
output_shapes, | |||
parallel_iterations, | |||
name | |||
); | |||
} | |||
var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs); | |||
_copy_handle_data(body_graph.Outputs, tensors); | |||
_set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph}); | |||
while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs }); | |||
while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism }); | |||
cond_graph.outer_graph = ops.get_default_graph(); | |||
body_graph.outer_graph = ops.get_default_graph(); | |||
// TODO(Rinne): set the two graphs to while_op | |||
return tensors; | |||
} | |||
return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars); | |||
} | |||
/// <summary> | |||
/// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. | |||
/// </summary> | |||
/// <param name="op"></param> | |||
/// <param name="branch_graphs"></param> | |||
private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs) | |||
{ | |||
List<int> read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList(); | |||
foreach(var branch_graph in branch_graphs) | |||
{ | |||
if (read_only_indices.Count == 0) | |||
{ | |||
break; | |||
} | |||
var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph); | |||
read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList(); | |||
} | |||
AttrValue.Types.ListValue listValue = new(); | |||
listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x)); | |||
op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue() | |||
{ | |||
List = listValue | |||
}); | |||
} | |||
private static Tensors _pack_sequence_as<T>(INestStructure<T> loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars) | |||
{ | |||
var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item => | |||
{ | |||
var (flow, y) = item; | |||
if (y is FakeTensorByTensorArray ta) | |||
{ | |||
return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow)); | |||
} | |||
else | |||
{ | |||
return flow; | |||
} | |||
}).ToArray(); | |||
return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors(); | |||
} | |||
private static Tensor[] _get_intermediates(FuncGraph func_graph) | |||
{ | |||
List<Tensor> intermediates = new(); | |||
var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1); | |||
foreach(var op in func_graph.get_operations()) | |||
{ | |||
Debug.Assert(op is Operation); | |||
var oper = (Operation)op; | |||
if(oper.type == "Identity" || oper.type == "MutexLock") | |||
{ | |||
continue; | |||
} | |||
foreach(var o in op.outputs) | |||
{ | |||
if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o)) | |||
{ | |||
intermediates.Add(o); | |||
} | |||
} | |||
} | |||
return intermediates.ToArray(); | |||
} | |||
private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures) | |||
{ | |||
var types = body_graph_captures.Select(t => t.dtype).ToList(); | |||
var c_graph = cond_graph.c_graph; | |||
var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList(); | |||
var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList(); | |||
List<Tensor> tensors = new(); | |||
foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types)) | |||
{ | |||
var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph); | |||
op._outputs = new Tensor[] { tensor }; | |||
tensors.Add(tensor); | |||
} | |||
var tuples = zip(body_graph_captures, tensors).ToList(); | |||
var keys = body_graph_captures.Select(t => t.Id).ToList(); | |||
cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2)); | |||
cond_graph.Inputs.AddRange(tensors); | |||
} | |||
private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype) | |||
{ | |||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||
var op = c_api.TF_FinishOperation(desc, tf.Status); | |||
tf.Status.Check(true); | |||
var output = new TF_Output(); | |||
output.oper = op; | |||
output.index = 0; | |||
return output; | |||
} | |||
private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) | |||
{ | |||
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); | |||
} | |||
private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, | |||
string name) | |||
{ | |||
return ops.convert_to_tensor(value, dtype, name, false); | |||
} | |||
private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) | |||
{ | |||
return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations"); | |||
} | |||
private static void _copy_handle_data(IEnumerable<Tensor> src_tensors, IEnumerable<Tensor> dst_tensors) | |||
{ | |||
foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors)) | |||
{ | |||
handle_data_util.copy_handle_data(src_t, dst_t); | |||
} | |||
} | |||
} | |||
} |
@@ -105,6 +105,13 @@ namespace Tensorflow | |||
_id = ops.uid(); | |||
} | |||
internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output) | |||
{ | |||
Tensor ret = new Tensor(op, value_index, dtype); | |||
ret._tf_output = tf_output; | |||
return ret; | |||
} | |||
protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | |||
{ | |||
_handle = TF_NewTensor(shape, dtype, null); | |||
@@ -14,7 +14,9 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -44,5 +46,27 @@ namespace Tensorflow | |||
public abstract Tensor stack(string name = null); | |||
public abstract Tensor gather(Tensor indices, string name = null); | |||
internal bool _dynamic_size; | |||
internal Tensor _size; | |||
internal List<Tensor> _colocate_with; | |||
internal Shape _element_shape; | |||
public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false, | |||
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, | |||
bool infer_shape = true, Shape? element_shape = null, | |||
bool colocate_with_first_write_call = true, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant)) | |||
{ | |||
return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, | |||
infer_shape, element_shape, colocate_with_first_write_call, name); | |||
} | |||
else | |||
{ | |||
return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, | |||
infer_shape, element_shape, colocate_with_first_write_call, name); | |||
} | |||
} | |||
} | |||
} |
@@ -4,6 +4,8 @@ using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Common.Extensions; | |||
namespace Tensorflow | |||
{ | |||
@@ -58,7 +60,7 @@ namespace Tensorflow | |||
public Tensor this[params string[] slices] | |||
=> this.First()[slices]; | |||
private Tensors(Nest<Tensor> nested) : base(nested) | |||
internal Tensors(Nest<Tensor> nested) : base(nested) | |||
{ | |||
} | |||
@@ -68,9 +70,9 @@ namespace Tensorflow | |||
} | |||
public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x))) | |||
public Tensors(IList<Tensor> tensors) : base(tensors.Select(x => new Nest<Tensor>(x))) | |||
{ | |||
} | |||
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) | |||
@@ -78,6 +80,32 @@ namespace Tensorflow | |||
} | |||
/// <summary> | |||
/// Get the element in shallow level. For example, for ts = [1, [2, 3], 4], | |||
/// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3] | |||
/// </summary> | |||
/// <param name="index"></param> | |||
/// <returns></returns> | |||
public Tensors GetShallow(int index) | |||
{ | |||
if(NestType == NestType.Node) | |||
{ | |||
if(index > 0) | |||
{ | |||
throw new IndexOutOfRangeException(); | |||
} | |||
return this; | |||
} | |||
else if(NestType == NestType.List) | |||
{ | |||
return ListValue![index].AsNest().ToTensors(); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors) | |||
{ | |||
if (tensors.Length == 0) | |||
@@ -115,8 +143,8 @@ namespace Tensorflow | |||
else if(NestType == NestType.Node) | |||
{ | |||
NestType = NestType.List; | |||
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) }; | |||
Value = null; | |||
ListValue = new() { new Nest<Tensor>(NodeValue), new Nest<Tensor>(tensor) }; | |||
NodeValue = null; | |||
} | |||
else if(NestType == NestType.List) | |||
{ | |||
@@ -125,7 +153,7 @@ namespace Tensorflow | |||
else //Empty | |||
{ | |||
NestType = NestType.Node; | |||
Value = tensor; | |||
NodeValue = tensor; | |||
} | |||
} | |||
@@ -140,9 +168,9 @@ namespace Tensorflow | |||
else if (NestType == NestType.Node) | |||
{ | |||
NestType = NestType.List; | |||
ListValue = new() { new Nest<Tensor>(Value) }; | |||
ListValue = new() { new Nest<Tensor>(NodeValue) }; | |||
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x))); | |||
Value = null; | |||
NodeValue = null; | |||
} | |||
else if(NestType == NestType.List) | |||
{ | |||
@@ -151,7 +179,7 @@ namespace Tensorflow | |||
else // empty | |||
{ | |||
NestType = NestType.List; | |||
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList(); | |||
ListValue = tensors.Select(x => new Nest<Tensor>(x) as INestStructure<Tensor>).ToList(); | |||
} | |||
} | |||
@@ -166,9 +194,9 @@ namespace Tensorflow | |||
else if(NestType == NestType.Node) | |||
{ | |||
NestType = NestType.List; | |||
ListValue = new() { new Nest<Tensor>(Value) }; | |||
ListValue = new() { new Nest<Tensor>(NodeValue) }; | |||
ListValue.Insert(index, new Nest<Tensor>(tensor)); | |||
Value = null; | |||
NodeValue = null; | |||
} | |||
else | |||
{ | |||
@@ -283,7 +311,7 @@ namespace Tensorflow | |||
=> tensors?.SingleOrNull; | |||
public static implicit operator Tensor[](Tensors tensors) | |||
=> tensors.Flatten().ToArray(); | |||
=> tensors.Flatten().ToArray(); | |||
#endregion | |||
public static Tensors? FromNest(Nest<Tensor> nested) | |||
@@ -298,7 +326,7 @@ namespace Tensorflow | |||
public void Deconstruct(out Tensor a, out Tensors? b) | |||
{ | |||
a = this.First(); | |||
b = Length == 1? null : new Tensors(this.Skip(1)); | |||
b = Length == 1? null : new Tensors(this.Skip(1).ToArray()); | |||
} | |||
public override string ToString() | |||
@@ -576,7 +576,7 @@ namespace Tensorflow | |||
public static HandleData get_resource_handle_data(Tensor graph_op) | |||
{ | |||
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); | |||
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); | |||
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); | |||
} | |||
public static void dismantle_graph(Graph graph) | |||
@@ -25,6 +25,7 @@ using static Tensorflow.Binding; | |||
using static Tensorflow.Graphs.SubGraphUtility; | |||
using Tensorflow.Util; | |||
using Tensorflow.Common.Types; | |||
using System.Diagnostics; | |||
namespace Tensorflow.Keras | |||
{ | |||
@@ -485,7 +486,7 @@ namespace Tensorflow.Keras | |||
var first_flatted_input = flatted_inptus[0]; | |||
var time_steps = first_flatted_input.shape[0]; | |||
var batch = first_flatted_input.shape[1]; | |||
var time_steps_t = (int)first_flatted_input.shape[0]; | |||
var time_steps_t = tf.shape(first_flatted_input)[0]; | |||
foreach (var input_ in flatted_inptus) | |||
{ | |||
@@ -704,7 +705,7 @@ namespace Tensorflow.Keras | |||
var input_ta = new List<TensorArray>(); | |||
for (int i = 0; i < flatted_inptus.Count; i++) | |||
{ | |||
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); | |||
input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t)); | |||
} | |||
foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) | |||
@@ -730,18 +731,15 @@ namespace Tensorflow.Keras | |||
(output_time_zero, _) = step_function(input_time_zero, | |||
constants is null ? initial_states : initial_states.MergeWith(constants)); | |||
int output_ta_size = return_all_outputs ? time_steps_t : 1; | |||
Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1); | |||
var output_ta = new List<TensorArray>(); | |||
for (int i = 0; i < output_time_zero.ToList().Count; i++) | |||
foreach(var output in output_time_zero.Flatten()) | |||
{ | |||
var Out = output_time_zero.ToList()[i]; | |||
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); | |||
output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape)); | |||
} | |||
var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | |||
Func<Tensor, Tensor>? masking_fn; | |||
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | |||
if (mask != null) | |||
@@ -750,7 +748,7 @@ namespace Tensorflow.Keras | |||
{ | |||
mask = tf.reverse(mask, axis: new[] { 0 }); | |||
} | |||
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); | |||
var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t); | |||
mask_ta = mask_ta.unstack(mask); | |||
masking_fn = (time) => | |||
@@ -810,9 +808,9 @@ namespace Tensorflow.Keras | |||
masking_fn = null; | |||
} | |||
Func<Tensor, Tensor> cond = (time) => (time < time_steps_t); | |||
Func<Tensors, Tensor> cond = (time) => (time[0] < time_steps_t); | |||
int parallel_iterations = 32; | |||
new_states = states; | |||
Tensors final_outputs; | |||
if (masking_fn != null) | |||
{ | |||
// Mask for the T output will be base on the output of T - 1. In the | |||
@@ -825,7 +823,7 @@ namespace Tensorflow.Keras | |||
var prev_output = flat_zero_output; | |||
var output_ta_t = output_ta; | |||
Tensor _step(Tensor time) | |||
Tensors _step(Tensors tensors) | |||
{ | |||
/* | |||
RNN step function. | |||
@@ -838,23 +836,28 @@ namespace Tensorflow.Keras | |||
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | |||
*/ | |||
Tensor time = tensors[0]; | |||
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; | |||
Tensors prev_output = tensors.GetShallow(2); | |||
Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray()); | |||
var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | |||
// maybe set shape | |||
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | |||
var mask_t = masking_fn(time); | |||
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); | |||
var (output, new_states) = step_function(current_input, states.MergeWith(constants)); | |||
// mask output | |||
var flat_output = Nest.Flatten(output).ToList(); | |||
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); | |||
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList(); | |||
// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | |||
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | |||
// mask states | |||
var flat_state = states.ToList(); | |||
var flat_new_state = new_states_internal.ToList(); | |||
var flat_state = states.Flatten().ToList(); | |||
var flat_new_state = new_states.Flatten().ToList(); | |||
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||
{ | |||
@@ -865,38 +868,37 @@ namespace Tensorflow.Keras | |||
} | |||
var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | |||
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); | |||
new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors(); | |||
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||
output_ta_t = zip(output_ta_t, flat_new_output).Select(item => | |||
{ | |||
var (ta, out_) = item; | |||
return ta.write(ta_index_to_write, out_); | |||
}).ToList(); | |||
Debug.Assert(flat_output.Count() == 1); | |||
output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First()); | |||
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||
output_ta = output_ta_t; | |||
new_states = new_states_internal; | |||
return time + 1; | |||
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states) | |||
.ToArray().ToTensors(); | |||
} | |||
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); | |||
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) } | |||
.Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors(); | |||
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); | |||
new_states = final_outputs.Skip(3).ToList(); | |||
} | |||
else | |||
{ | |||
var output_ta_t = output_ta; | |||
new_states = states; | |||
Tensor _step(Tensor time) | |||
Tensors _step(Tensors tensors) | |||
{ | |||
Tensor time = tensors[0]; | |||
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; | |||
Tensors states = new Tensors(tensors.Skip(2).ToArray()); | |||
var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); | |||
// maybe set shape | |||
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); | |||
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); | |||
var (output, new_states) = step_function(current_input, states.MergeWith(constants)); | |||
var flat_state = new_states.Flatten().ToList(); | |||
var flat_new_state = new_states_internal.Flatten().ToList(); | |||
var flat_new_state = new_states.Flatten().ToList(); | |||
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||
{ | |||
if (new_state is Tensor) | |||
@@ -906,24 +908,23 @@ namespace Tensorflow.Keras | |||
} | |||
var flat_output = Nest.Flatten(output); | |||
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||
output_ta_t = zip(output_ta_t, flat_output).Select(item => | |||
{ | |||
var (ta, out_) = item; | |||
return ta.write(ta_index_to_write, out_); | |||
}).ToList(); | |||
new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||
output_ta = output_ta_t; | |||
new_states = new_states_internal; | |||
return time + 1; | |||
Debug.Assert(flat_output.Count() == 1); | |||
output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First()); | |||
new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); | |||
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors(); | |||
} | |||
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); | |||
Debug.Assert(output_ta.Count == 1); | |||
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors(); | |||
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); | |||
new_states = final_outputs.Skip(2).ToList(); | |||
} | |||
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors()); | |||
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors()); | |||
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); | |||
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); | |||
output_ta = new List<TensorArray> { (final_outputs[1] as FakeTensorByTensorArray).TensorArray }; | |||
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors()); | |||
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors()); | |||
outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors(); | |||
last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors(); | |||
} | |||
Func<Tensor, Tensor> set_shape; | |||
@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine | |||
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); | |||
graph.as_default(); | |||
var shapes = input_shape.ToShapeArray(); | |||
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); | |||
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray()); | |||
try | |||
{ | |||
Call(x, training: false); | |||
@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = new Tensors(x), | |||
X = new Tensors(x.ToArray()), | |||
Y = y, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
@@ -188,7 +188,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var data = iterator.next(); | |||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
@@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = new Tensors(train_x), | |||
X = new Tensors(train_x.ToArray()), | |||
Y = train_y, | |||
BatchSize = batch_size, | |||
InitialEpoch = initial_epoch, | |||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var data = iterator.next(); | |||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
@@ -4,10 +4,11 @@ using System.Text; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public abstract class DropoutRNNCellMixin: RnnCellBase | |||
public abstract class DropoutRNNCellMixin: Layer, IRnnCell | |||
{ | |||
public float dropout; | |||
public float recurrent_dropout; | |||
@@ -17,6 +18,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
} | |||
public abstract GeneralizedTensorShape StateSize { get; } | |||
public abstract GeneralizedTensorShape OutputSize { get; } | |||
public abstract bool SupportOptionalArgs { get; } | |||
public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||
{ | |||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||
} | |||
protected void _create_non_trackable_mask_cache() | |||
{ | |||
@@ -206,7 +206,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
// append bacth dim | |||
state_spec_shape = new int[] { -1 }.concat(state_spec_shape); | |||
return new InputSpec(shape: state_spec_shape); | |||
} | |||
// Check whether the input shape contains any nested shapes. It could be | |||
@@ -298,7 +297,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) | |||
Func<Tensors, Tensors, (Tensors, Tensors)> step; | |||
bool is_tf_rnn_cell = _cell.IsTFRnnCell; | |||
bool is_tf_rnn_cell = false; | |||
if (constants is not null) | |||
{ | |||
if (!_cell.SupportOptionalArgs) | |||
@@ -310,8 +309,8 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
step = (inputs, states) => | |||
{ | |||
constants = new Tensors(states.TakeLast(_num_constants)); | |||
states = new Tensors(states.SkipLast(_num_constants)); | |||
constants = new Tensors(states.TakeLast(_num_constants).ToArray()); | |||
states = new Tensors(states.SkipLast(_num_constants).ToArray()); | |||
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | |||
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||
return (output, new_states.Single); | |||
@@ -395,12 +394,12 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
if (_num_constants != 0) | |||
{ | |||
initial_state = new Tensors(inputs.Skip(1)); | |||
initial_state = new Tensors(inputs.Skip(1).ToArray()); | |||
} | |||
else | |||
{ | |||
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants)); | |||
constants = new Tensors(inputs.TakeLast(_num_constants)); | |||
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray()); | |||
constants = new Tensors(inputs.TakeLast(_num_constants).ToArray()); | |||
} | |||
if (len(initial_state) == 0) | |||
initial_state = null; | |||
@@ -558,36 +557,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
protected Tensors get_initial_state(Tensors inputs) | |||
{ | |||
var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state"); | |||
var input = inputs[0]; | |||
var input_shape = inputs.shape; | |||
var input_shape = array_ops.shape(inputs); | |||
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | |||
var dtype = input.dtype; | |||
Tensors init_state = new Tensors(); | |||
if(get_initial_state_fn != null) | |||
{ | |||
init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype }); | |||
} | |||
//if (_cell is RnnCellBase rnn_base_cell) | |||
//{ | |||
// init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); | |||
//} | |||
else | |||
{ | |||
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); | |||
} | |||
Tensors init_state = _cell.GetInitialState(null, batch_size, dtype); | |||
return init_state; | |||
} | |||
// Check whether the state_size contains multiple states. | |||
public static bool is_multiple_state(GeneralizedTensorShape state_size) | |||
{ | |||
return state_size.Shapes.Length > 1; | |||
} | |||
} | |||
} |
@@ -1,24 +0,0 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public abstract class RnnCellBase: Layer, IRnnCell | |||
{ | |||
public RnnCellBase(LayerArgs args) : base(args) { } | |||
public abstract GeneralizedTensorShape StateSize { get; } | |||
public abstract GeneralizedTensorShape OutputSize { get; } | |||
public abstract bool IsTFRnnCell { get; } | |||
public abstract bool SupportOptionalArgs { get; } | |||
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype) | |||
{ | |||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||
} | |||
} | |||
} |
@@ -7,6 +7,7 @@ using Tensorflow.Keras.Saving; | |||
using Tensorflow.Common.Types; | |||
using Tensorflow.Common.Extensions; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Graphs; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
@@ -28,7 +29,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
public override GeneralizedTensorShape StateSize => _state_size; | |||
public override GeneralizedTensorShape OutputSize => _output_size; | |||
public override bool IsTFRnnCell => true; | |||
public override bool SupportOptionalArgs => false; | |||
public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) | |||
@@ -98,7 +98,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
prev_output = math_ops.multiply(prev_output, rec_dp_mask); | |||
} | |||
var tmp = _recurrent_kernel.AsTensor(); | |||
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); | |||
if (_args.Activation != null) | |||
@@ -117,9 +116,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
} | |||
} | |||
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||
public Tensors get_initial_state(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); | |||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||
} | |||
} | |||
} |
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
public class StackedRNNCells : Layer, IRnnCell | |||
{ | |||
public IList<IRnnCell> Cells { get; set; } | |||
public bool reverse_state_order; | |||
public bool _reverse_state_order; | |||
public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | |||
{ | |||
@@ -23,22 +23,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
args.Kwargs = new Dictionary<string, object>(); | |||
} | |||
foreach (var cell in args.Cells) | |||
{ | |||
//Type type = cell.GetType(); | |||
//var CallMethodInfo = type.GetMethod("Call"); | |||
//if (CallMethodInfo == null) | |||
//{ | |||
// throw new ValueError( | |||
// "All cells must have a `Call` method. " + | |||
// $"Received cell without a `Call` method: {cell}"); | |||
//} | |||
} | |||
Cells = args.Cells; | |||
reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); | |||
_reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); | |||
if (reverse_state_order) | |||
if (_reverse_state_order) | |||
{ | |||
throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + | |||
"be deprecated. Please update the code to work with the " + | |||
@@ -47,49 +36,37 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
} | |||
} | |||
public bool SupportOptionalArgs => false; | |||
public GeneralizedTensorShape StateSize | |||
{ | |||
get | |||
{ | |||
GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count); | |||
if (reverse_state_order && Cells.Count > 0) | |||
if (_reverse_state_order) | |||
{ | |||
var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell)); | |||
foreach (var cell in idxAndCell) | |||
{ | |||
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); | |||
} | |||
var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); | |||
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s)))); | |||
} | |||
else | |||
{ | |||
//foreach (var cell in Cells) | |||
//{ | |||
// state_size.Shapes.add(cell.StateSize.Shapes.First()); | |||
//} | |||
var idxAndCell = Cells.Select((cell, idx) => (idx, cell)); | |||
foreach (var cell in idxAndCell) | |||
{ | |||
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); | |||
} | |||
var state_sizes = Cells.Select(cell => cell.StateSize); | |||
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s)))); | |||
} | |||
return state_size; | |||
} | |||
} | |||
public object output_size | |||
public GeneralizedTensorShape OutputSize | |||
{ | |||
get | |||
{ | |||
var lastCell = Cells.LastOrDefault(); | |||
if (lastCell.OutputSize.ToSingleShape() != -1) | |||
var lastCell = Cells.Last(); | |||
if(lastCell.OutputSize is not null) | |||
{ | |||
return lastCell.OutputSize; | |||
} | |||
else if (RNN.is_multiple_state(lastCell.StateSize)) | |||
else if (RnnUtils.is_multiple_state(lastCell.StateSize)) | |||
{ | |||
return lastCell.StateSize.First(); | |||
//throw new NotImplementedException(""); | |||
} | |||
else | |||
{ | |||
@@ -98,79 +75,65 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
} | |||
} | |||
public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) | |||
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
var cells = reverse_state_order ? Cells.Reverse() : Cells; | |||
Tensors initial_states = new Tensors(); | |||
var cells = _reverse_state_order ? Cells.Reverse() : Cells; | |||
List<Tensor> initial_states = new List<Tensor>(); | |||
foreach (var cell in cells) | |||
{ | |||
var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state"); | |||
if (get_initial_state_fn != null) | |||
{ | |||
var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype }); | |||
initial_states.Add(result); | |||
} | |||
else | |||
{ | |||
initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value)); | |||
} | |||
initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype)); | |||
} | |||
return initial_states; | |||
return new Tensors(initial_states); | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) | |||
{ | |||
// Recover per-cell states. | |||
var state_size = reverse_state_order ? StateSize.Reverse() : StateSize; | |||
var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten(); | |||
var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize; | |||
var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); | |||
var new_nest_states = new Tensors(); | |||
var new_nest_states = Nest<Tensor>.Empty; | |||
// Call the cells in order and store the returned states. | |||
foreach (var (cell, states) in zip(Cells, nested_states)) | |||
foreach (var (cell, internal_states) in zip(Cells, nested_states)) | |||
{ | |||
// states = states if tf.nest.is_nested(states) else [states] | |||
var type = cell.GetType(); | |||
bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null; | |||
state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state; | |||
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; | |||
Tensors? constants = rnn_optional_args?.Constants; | |||
Tensors new_states; | |||
(inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||
(inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants }); | |||
new_nest_states.Add(new_states); | |||
new_nest_states = new_nest_states.MergeWith(new_states); | |||
} | |||
new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray(); | |||
return new Nest<Tensor>(new List<Nest<Tensor>> { | |||
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(inputs.Single()) }), new Nest<Tensor>(new_nest_states) }) | |||
.ToTensors(); | |||
return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray()))); | |||
} | |||
public void build() | |||
public override void build(KerasShapesWrapper input_shape) | |||
{ | |||
built = true; | |||
// @tf_utils.shape_type_conversion | |||
// def build(self, input_shape) : | |||
// if isinstance(input_shape, list) : | |||
// input_shape = input_shape[0] | |||
// for cell in self.cells: | |||
// if isinstance(cell, Layer) and not cell.built: | |||
// with K.name_scope(cell.name): | |||
// cell.build(input_shape) | |||
// cell.built = True | |||
// if getattr(cell, 'output_size', None) is not None: | |||
// output_dim = cell.output_size | |||
// elif _is_multiple_state(cell.state_size) : | |||
// output_dim = cell.state_size[0] | |||
// else: | |||
// output_dim = cell.state_size | |||
// input_shape = tuple([input_shape[0]] + | |||
// tensor_shape.TensorShape(output_dim).as_list()) | |||
// self.built = True | |||
var shape = input_shape.ToSingleShape(); | |||
foreach(var cell in Cells) | |||
{ | |||
if(cell is Layer layer && !layer.Built) | |||
{ | |||
// ignored the name scope. | |||
layer.build(shape); | |||
layer.Built = true; | |||
} | |||
GeneralizedTensorShape output_dim; | |||
if(cell.OutputSize is not null) | |||
{ | |||
output_dim = cell.OutputSize; | |||
} | |||
else if (RnnUtils.is_multiple_state(cell.StateSize)) | |||
{ | |||
output_dim = cell.StateSize.First(); | |||
} | |||
else | |||
{ | |||
output_dim = cell.StateSize; | |||
} | |||
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray()); | |||
} | |||
this.Built = true; | |||
} | |||
public override IKerasConfig get_config() | |||
@@ -198,14 +161,5 @@ namespace Tensorflow.Keras.Layers.Rnn | |||
// deserialize_layer(cell_config, custom_objects = custom_objects)) | |||
// return cls(cells, **config) | |||
} | |||
public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); | |||
public bool IsTFRnnCell => true; | |||
public bool SupportOptionalArgs => throw new NotImplementedException(); | |||
} | |||
} |
@@ -10,20 +10,21 @@ namespace Tensorflow.Keras.Utils | |||
{ | |||
internal static class RnnUtils | |||
{ | |||
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) | |||
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) | |||
{ | |||
Func<GeneralizedTensorShape, Tensor> create_zeros; | |||
create_zeros = (GeneralizedTensorShape unnested_state_size) => | |||
{ | |||
var flat_dims = unnested_state_size.ToSingleShape().dims; | |||
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); | |||
return array_ops.zeros(new Shape(init_state_size), dtype: dtype); | |||
var init_state_size = new Tensor[] { batch_size_tensor }. | |||
Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); | |||
return array_ops.zeros(init_state_size, dtype: dtype); | |||
}; | |||
// TODO(Rinne): map structure with nested tensors. | |||
if(state_size.Shapes.Length > 1) | |||
if(state_size.TotalNestedCount > 1) | |||
{ | |||
return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s)))); | |||
return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray()); | |||
} | |||
else | |||
{ | |||
@@ -32,11 +33,11 @@ namespace Tensorflow.Keras.Utils | |||
} | |||
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) | |||
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) | |||
{ | |||
if (inputs != null) | |||
if (inputs is not null) | |||
{ | |||
batch_size = inputs.shape[0]; | |||
batch_size = array_ops.shape(inputs)[0]; | |||
dtype = inputs.dtype; | |||
} | |||
return generate_zero_filled_state(batch_size, cell.StateSize, dtype); | |||
@@ -77,17 +78,27 @@ namespace Tensorflow.Keras.Utils | |||
Debug.Assert(initial_state is null && constants is null); | |||
if(num_constants > 0) | |||
{ | |||
constants = inputs.TakeLast(num_constants).ToTensors(); | |||
inputs = inputs.SkipLast(num_constants).ToTensors(); | |||
constants = inputs.TakeLast(num_constants).ToArray().ToTensors(); | |||
inputs = inputs.SkipLast(num_constants).ToArray().ToTensors(); | |||
} | |||
if(inputs.Length > 1) | |||
{ | |||
initial_state = inputs.Skip(1).ToTensors(); | |||
inputs = inputs.Take(1).ToTensors(); | |||
initial_state = inputs.Skip(1).ToArray().ToTensors(); | |||
inputs = inputs.Take(1).ToArray().ToTensors(); | |||
} | |||
} | |||
return (inputs, initial_state, constants); | |||
} | |||
/// <summary> | |||
/// Check whether the state_size contains multiple states. | |||
/// </summary> | |||
/// <param name="state_size"></param> | |||
/// <returns></returns> | |||
public static bool is_multiple_state(GeneralizedTensorShape state_size) | |||
{ | |||
return state_size.TotalNestedCount > 1; | |||
} | |||
} | |||
} |
@@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var i = tf.constant(2); | |||
var j = tf.constant(3); | |||
Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10); | |||
Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; | |||
Func<Tensors, Tensor> c = (x) => tf.less(x[0] + x[1], 10); | |||
Func<Tensors, Tensors> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; | |||
var r = tf.while_loop(c, b, new[] { i, j }); | |||
Assert.AreEqual(5, (int)r[0]); | |||
Assert.AreEqual(6, (int)r[1]); | |||
@@ -21,7 +21,8 @@ namespace Tensorflow.CodeGen | |||
{ | |||
sb.Append("Operation "); | |||
} | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||
{ | |||
sb.Append("Tensor "); | |||
} | |||
@@ -70,7 +71,8 @@ namespace Tensorflow.CodeGen | |||
{ | |||
sb.AppendLine("return null;"); | |||
} | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||
{ | |||
sb.AppendLine("return _fast_path_result[0];"); | |||
} | |||
@@ -149,7 +151,8 @@ namespace Tensorflow.CodeGen | |||
{ | |||
sb.AppendLine("return _op;"); | |||
} | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||
{ | |||
sb.AppendLine("return _result[0];"); | |||
} | |||
@@ -174,7 +177,7 @@ namespace Tensorflow.CodeGen | |||
{ | |||
argName = $"{argName}_"; | |||
} | |||
if (!string.IsNullOrEmpty(arg.NumberAttr)) | |||
if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr)) | |||
{ | |||
sb.Append($"Tensors {argName}, "); | |||
} | |||
@@ -273,7 +276,8 @@ namespace Tensorflow.CodeGen | |||
{ | |||
sb.Append("Operation "); | |||
} | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||
{ | |||
sb.Append("Tensor "); | |||
} | |||
@@ -366,6 +370,13 @@ namespace Tensorflow.CodeGen | |||
sb.Append($"\"{attr.Name}\", {attrRealName}, "); | |||
} | |||
} | |||
else if(attr.Type == "list(type)") | |||
{ | |||
if (op.InputArg.Any(x => x.TypeListAttr == attr.Name)) | |||
{ | |||
continue; | |||
} | |||
} | |||
else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) | |||
{ | |||
bool found = false; | |||
@@ -408,7 +419,8 @@ namespace Tensorflow.CodeGen | |||
{ | |||
sb.AppendLine("return null;"); | |||
} | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) | |||
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) | |||
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) | |||
{ | |||
sb.AppendLine("return _result[0];"); | |||
} | |||
@@ -5,7 +5,7 @@ using System.Text; | |||
using System.Xml.Linq; | |||
using Tensorflow.CodeGen; | |||
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", | |||
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops_v2", | |||
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", | |||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", | |||
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); | |||
@@ -155,6 +155,10 @@ namespace Tensorflow.CodeGen | |||
} | |||
else if (attr.Type == "list(type)") | |||
{ | |||
if(op.InputArg.Any(x => x.TypeListAttr == attr.Name)) | |||
{ | |||
continue; | |||
} | |||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) | |||
{ | |||
List<TF_DataType> values = new(); | |||
@@ -231,11 +235,11 @@ namespace Tensorflow.CodeGen | |||
} | |||
else if (attr.Type == "func") | |||
{ | |||
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE")); | |||
res.Add((attr.Name, "object", "NOVALUE")); | |||
} | |||
else if (attr.Type == "list(func)") | |||
{ | |||
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE")); | |||
res.Add((attr.Name, "object[]", "NOVALUE")); | |||
} | |||
else if (attr.Type == "tensor") | |||
{ | |||