From f1fbcf20166fa1902e399998aaf1c738493f9785 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Fri, 16 Jun 2023 14:30:54 +0800 Subject: [PATCH] feat: support model building with RNN. --- src/TensorFlowNET.Core/APIs/c_api.cs | 14 + .../APIs/tf.control_flow.cs | 10 +- .../Common/Extensions/LinqExtensions.cs | 7 +- .../Common/Types/FakeTensorByTensorArray.cs | 20 + .../Common/Types/GeneralizedTensorShape.cs | 140 +- .../Types/{INest.cs => INestStructure.cs} | 13 + .../Common/Types/Nest.Static.cs | 2 +- src/TensorFlowNET.Core/Common/Types/Nest.cs | 117 +- .../Common/Types/NestDictionary.cs | 4 + .../Common/Types/NestList.cs | 17 +- .../Common/Types/NestNode.cs | 4 + src/TensorFlowNET.Core/Data/DatasetV2.cs | 4 +- .../Eager/EagerRunner.TFE_FastPathExecute.cs | 2 + .../Framework/Models/TensorSpec.cs | 13 + .../Framework/auto_control_deps_utils.cs | 89 ++ .../Framework/function_def_lib.cs | 4 +- .../Functions/ConcreteFunction.cs | 13 + src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 4 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 2 +- .../Keras/Layers/Rnn/IRnnCell.cs | 12 +- .../Operations/NnOps/RNNCell.cs | 4 + .../Operations/OpDefLibrary.cs | 49 + .../Operations/Operation.Output.cs | 2 +- .../Operations/Operation.cs | 5 +- .../Operations/_EagerTensorArray.cs | 6 +- .../Operations/_GraphTensorArray.cs | 179 ++- .../Operations/array_ops.cs | 24 + .../Operations/control_flow_ops.cs | 9 +- .../Operations/control_flow_util.py.cs | 77 ++ .../Operations/gen_functional_ops.cs | 1066 ++++++++++++-- .../Operations/gen_list_ops.cs | 1227 +++++++++++++++++ src/TensorFlowNET.Core/Operations/list_ops.cs | 111 ++ .../Operations/tensor_array_ops.cs | 20 +- src/TensorFlowNET.Core/Operations/while_v2.cs | 401 ++++++ .../Tensors/Tensor.Creation.cs | 7 + src/TensorFlowNET.Core/Tensors/TensorArray.cs | 24 + src/TensorFlowNET.Core/Tensors/Tensors.cs | 54 +- src/TensorFlowNET.Core/ops.cs | 2 +- src/TensorFlowNET.Keras/BackendImpl.cs | 95 +- src/TensorFlowNET.Keras/Engine/Model.Build.cs | 2 +- .../Engine/Model.Evaluate.cs | 4 +- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 2 +- .../Layers/Rnn/DropoutRNNCellMixin.cs | 11 +- src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 39 +- .../Layers/Rnn/RnnCellBase.cs | 24 - .../Layers/Rnn/SimpleRNNCell.cs | 7 +- .../Layers/Rnn/StackedRNNCells.cs | 152 +- src/TensorFlowNET.Keras/Utils/RnnUtils.cs | 35 +- .../ManagedAPI/ControlFlowApiTest.cs | 4 +- tools/Tensorflow.CodeGen/FunctionGenerator.cs | 24 +- tools/Tensorflow.CodeGen/Program.cs | 2 +- tools/Tensorflow.CodeGen/Utils.cs | 8 +- 53 files changed, 3662 insertions(+), 507 deletions(-) create mode 100644 src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs rename src/TensorFlowNET.Core/Common/Types/{INest.cs => INestStructure.cs} (65%) create mode 100644 src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs create mode 100644 src/TensorFlowNET.Core/Operations/gen_list_ops.cs create mode 100644 src/TensorFlowNET.Core/Operations/list_ops.cs create mode 100644 src/TensorFlowNET.Core/Operations/while_v2.cs delete mode 100644 src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 10f678e0..6049c95c 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -16,6 +16,7 @@ using System; using System.Runtime.InteropServices; +using static Tensorflow.CppShapeInferenceResult.Types; namespace Tensorflow { @@ -50,6 +51,19 @@ namespace Tensorflow return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle); } + public unsafe static byte[] ByteStringPiece(IntPtr handle) + { + byte* str_data = (byte*)handle.ToPointer(); + List bytes = new List(); + byte current = 255; + while (current != ((byte)'\0')) + { + current = *(str_data++); + bytes.Add(current); + } + return bytes.Take(bytes.Count - 1).ToArray(); + } + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args); diff --git a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs index 239487e0..cd5a71e5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.control_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.control_flow.cs @@ -46,10 +46,10 @@ namespace Tensorflow Tensor loop_vars, int parallel_iterations = 10) { - Func cond1 = x + Func cond1 = x => cond(x[0]); - Func body1 = x + Func body1 = x => new[] { body(x[0]) }; var results = control_flow_ops.while_loop(cond1, @@ -58,9 +58,9 @@ namespace Tensorflow return results[0]; } - public Tensor[] while_loop(Func cond, - Func body, - Tensor[] loop_vars, + public Tensor[] while_loop(Func cond, + Func body, + Tensors loop_vars, int parallel_iterations = 10, string name = null) => control_flow_ops.while_loop(cond, body, loop_vars, diff --git a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs index 6cf62e7b..287b48cc 100644 --- a/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs +++ b/src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs @@ -18,7 +18,12 @@ namespace Tensorflow.Common.Extensions return sequence.Take(sequence.Count() - count); } #endif - public static Tensors ToTensors(this IEnumerable tensors) + public static Tensors ToTensors(this Tensor[] tensors) + { + return new Tensors(tensors); + } + + public static Tensors ToTensors(this IList tensors) { return new Tensors(tensors); } diff --git a/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs b/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs new file mode 100644 index 00000000..d0c35ee7 --- /dev/null +++ b/src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Common.Types +{ + /// + /// This is a temp solution, which should be removed after refactoring `Tensors` + /// + [Obsolete] + public class FakeTensorByTensorArray: Tensor + { + public TensorArray TensorArray { get; set; } + + public FakeTensorByTensorArray(TensorArray array) + { + TensorArray = array; + } + } +} diff --git a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs index c61d04b2..40190315 100644 --- a/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs +++ b/src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs @@ -5,136 +5,80 @@ using System.Text; namespace Tensorflow.Common.Types { - public class GeneralizedTensorShape: IEnumerable, INestStructure, INestable + public class GeneralizedTensorShape: Nest { - public TensorShapeConfig[] Shapes { get; set; } - /// - /// create a single-dim generalized Tensor shape. - /// - /// - public GeneralizedTensorShape(int dim, int size = 1) - { - var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; - Shapes = Enumerable.Repeat(elem, size).ToArray(); - //Shapes = new TensorShapeConfig[size]; - //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); - //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); - ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; - } + ////public TensorShapeConfig[] Shapes { get; set; } + ///// + ///// create a single-dim generalized Tensor shape. + ///// + ///// + //public GeneralizedTensorShape(int dim, int size = 1) + //{ + // var elem = new TensorShapeConfig() { Items = new long?[] { dim } }; + // Shapes = Enumerable.Repeat(elem, size).ToArray(); + // //Shapes = new TensorShapeConfig[size]; + // //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } }); + // //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } }); + // ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } }; + //} - public GeneralizedTensorShape(Shape shape) + public GeneralizedTensorShape(Shape value, string? name = null) { - Shapes = new TensorShapeConfig[] { shape }; + NodeValue = value; + NestType = NestType.Node; } - public GeneralizedTensorShape(TensorShapeConfig shape) + public GeneralizedTensorShape(IEnumerable values, string? name = null) { - Shapes = new TensorShapeConfig[] { shape }; + ListValue = values.Select(s => new Nest(s) as INestStructure).ToList(); + Name = name; + NestType = NestType.List; } - public GeneralizedTensorShape(TensorShapeConfig[] shapes) + public GeneralizedTensorShape(Dictionary value, string? name = null) { - Shapes = shapes; + DictValue = value.ToDictionary(x => x.Key, x => new Nest(x.Value) as INestStructure); + Name = name; + NestType = NestType.Dictionary; } - public GeneralizedTensorShape(IEnumerable shape) + public GeneralizedTensorShape(Nest other) { - Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray(); + NestType = other.NestType; + NodeValue = other.NodeValue; + DictValue = other.DictValue; + ListValue = other.ListValue; + Name = other.Name; } public Shape ToSingleShape() { - if (Shapes.Length != 1) + var shapes = Flatten().ToList(); + if (shapes.Count != 1) { throw new ValueError("The generalized shape contains more than 1 dim."); } - var shape_config = Shapes[0]; - Debug.Assert(shape_config is not null); - return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray()); + return shapes[0]; } public long ToNumber() { - if(Shapes.Length != 1 || Shapes[0].Items.Length != 1) + var shapes = Flatten().ToList(); + if (shapes.Count != 1 || shapes[0].ndim != 1) { throw new ValueError("The generalized shape contains more than 1 dim."); } - var res = Shapes[0].Items[0]; - return res is null ? -1 : res.Value; - } - - public Shape[] ToShapeArray() - { - return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray(); - } - - public IEnumerable Flatten() - { - List result = new List(); - foreach(var shapeConfig in Shapes) - { - result.AddRange(shapeConfig.Items); - } - return result; - } - public INestStructure MapStructure(Func func) - { - List> lists = new(); - foreach(var shapeConfig in Shapes) - { - lists.Add(new Nest(shapeConfig.Items.Select(x => new Nest(func(x))))); - } - return new Nest(lists); - } - - public Nest AsNest() - { - Nest DealWithSingleShape(TensorShapeConfig config) - { - if (config.Items.Length == 0) - { - return Nest.Empty; - } - else if (config.Items.Length == 1) - { - return new Nest(config.Items[0]); - } - else - { - return new Nest(config.Items.Select(x => new Nest(x))); - } - } - - if(Shapes.Length == 0) - { - return Nest.Empty; - } - else if(Shapes.Length == 1) - { - return DealWithSingleShape(Shapes[0]); - } - else - { - return new Nest(Shapes.Select(s => DealWithSingleShape(s))); - } + return shapes[0].dims[0]; } - - - public static implicit operator GeneralizedTensorShape(int dims) - => new GeneralizedTensorShape(dims); - - public IEnumerator GetEnumerator() + public INestStructure ToTensorShapeConfigs() { - foreach (var shape in Shapes) - { - yield return shape.Items; - } + return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select(x => x == -1 ? null : x).ToArray() }); } - IEnumerator IEnumerable.GetEnumerator() + public static implicit operator GeneralizedTensorShape(Shape shape) { - return GetEnumerator(); + return new GeneralizedTensorShape(shape); } } } diff --git a/src/TensorFlowNET.Core/Common/Types/INest.cs b/src/TensorFlowNET.Core/Common/Types/INestStructure.cs similarity index 65% rename from src/TensorFlowNET.Core/Common/Types/INest.cs rename to src/TensorFlowNET.Core/Common/Types/INestStructure.cs index 001141dd..32b66293 100644 --- a/src/TensorFlowNET.Core/Common/Types/INest.cs +++ b/src/TensorFlowNET.Core/Common/Types/INestStructure.cs @@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types /// public interface INestStructure: INestable { + NestType NestType { get; } + + /// + /// The item count of depth 1 of the nested structure. + /// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3. + /// + int ShallowNestedCount { get; } + /// + /// The total item count of depth 1 of the nested structure. + /// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5. + /// + int TotalNestedCount { get; } + /// /// Flatten the Nestable object. Node that if the object contains only one value, /// it will be flattened to an enumerable with one element. diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs index b67d11f4..dc7fd3a1 100644 --- a/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs +++ b/src/TensorFlowNET.Core/Common/Types/Nest.Static.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Common.Types /// /// /// - public static Nest PackSequenceAs(INestable template, T[] flatItems) + public static Nest PackSequenceAs(INestable template, TOut[] flatItems) { return template.AsNest().PackSequence(flatItems); } diff --git a/src/TensorFlowNET.Core/Common/Types/Nest.cs b/src/TensorFlowNET.Core/Common/Types/Nest.cs index 84a60402..4de7d1fa 100644 --- a/src/TensorFlowNET.Core/Common/Types/Nest.cs +++ b/src/TensorFlowNET.Core/Common/Types/Nest.cs @@ -28,27 +28,58 @@ namespace Tensorflow.Common.Types public static Nest Empty => _empty; public NestType NestType { get; protected set; } public string? Name { get; set; } - public T? Value { get; protected set; } - public List>? ListValue { get; protected set; } - public Dictionary>? DictValue { get; protected set; } + public T? NodeValue { get; protected set; } + public List>? ListValue { get; protected set; } + public Dictionary>? DictValue { get; protected set; } + + public int ShallowNestedCount + { + get + { + if (NestType == NestType.Empty) + { + return 0; + } + else if (NestType == NestType.Node) + { + return 1; + } + else if (NestType == NestType.List) + { + return ListValue!.Count; + } + else // dict + { + return DictValue!.Count; + } + } + } + + public int TotalNestedCount + { + get + { + return Flatten().Count(); + } + } protected Nest() { } public Nest(T value, string? name = null) { - Value = value; + NodeValue = value; Name = name; NestType = NestType.Node; } - public Nest(IEnumerable> values, string? name = null) + public Nest(IEnumerable> values, string? name = null) { ListValue = values.ToList(); Name = name; NestType = NestType.List; } - public Nest(Dictionary> value, string? name = null) + public Nest(Dictionary> value, string? name = null) { DictValue = value; Name = name; @@ -58,7 +89,7 @@ namespace Tensorflow.Common.Types public Nest(Nest other) { NestType = other.NestType; - Value = other.Value; + NodeValue = other.NodeValue; DictValue = other.DictValue; ListValue = other.ListValue; Name = other.Name; @@ -78,17 +109,17 @@ namespace Tensorflow.Common.Types /// /// /// - public virtual Nest PackSequence(T[] flatItems) + public virtual Nest PackSequence(TOut[] flatItems) { if(flatItems.Length == 0) { - return Nest.Empty; + return Nest.Empty; } int index = 0; return PackSequenceInternal(this, flatItems, ref index); } - private static Nest PackSequenceInternal(Nest template, T[] flatItems, ref int index) + private static Nest PackSequenceInternal(Nest template, TOut[] flatItems, ref int index) { if(template.NestType == NestType.Node) { @@ -96,25 +127,25 @@ namespace Tensorflow.Common.Types { throw new InvalidArgumentError("The template and flat items are not matched."); } - return new Nest(flatItems[index++]); + return new Nest(flatItems[index++]); } else if(template.NestType == NestType.List) { - List> nestedObjects = new List>(); + List> nestedObjects = new List>(); for (int i = 0; i < template.ListValue!.Count; i++) { - nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index)); + nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index)); } - return new Nest(nestedObjects); + return new Nest(nestedObjects); } else if(template.NestType == NestType.Node) { - Dictionary> dict = new Dictionary>(); + Dictionary> dict = new Dictionary>(); foreach(var (key, value) in template.DictValue!) { - dict[key] = PackSequenceInternal(value, flatItems, ref index); + dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index); } - return new Nest(dict); + return new Nest(dict); } // Consider Empty as invalid type. throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node."); @@ -223,10 +254,10 @@ namespace Tensorflow.Common.Types public static Nest ReduceFrom(INestStructure input) where TOut: INestStructure { var nested = input.AsNest(); - return ReduceInternal(nested); + return ReduceInternal(nested).AsNest(); } - private static Nest ReduceInternal(Nest node) where TOut : INestStructure + private static INestStructure ReduceInternal(Nest node) where TOut : INestStructure { if(node.NestType == NestType.Empty) { @@ -234,15 +265,15 @@ namespace Tensorflow.Common.Types } else if(node.NestType == NestType.Node) { - return node.Value!.AsNest(); + return node.NodeValue!.AsNest(); } else if(node.NestType == NestType.List) { - return new Nest(node.ListValue!.Select(x => ReduceInternal(x))); + return new Nest(node.ListValue!.Select(x => ReduceInternal(x.AsNest()))); } else // Dictionary type { - return new Nest(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value))); + return new Nest(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest()))); } } @@ -252,7 +283,7 @@ namespace Tensorflow.Common.Types { if(index == 0) { - result = node.Value!; + result = node.NodeValue!; return true; } result = default(T); @@ -264,7 +295,7 @@ namespace Tensorflow.Common.Types { if(index == 0) { - return FindInternal(item, index, out result); + return FindInternal(item.AsNest(), index, out result); } index--; } @@ -277,7 +308,7 @@ namespace Tensorflow.Common.Types { if (index == 0) { - return FindInternal(item, index, out result); + return FindInternal(item.AsNest(), index, out result); } index--; } @@ -297,7 +328,7 @@ namespace Tensorflow.Common.Types { if (index == 0) { - node.Value = newValue; + node.NodeValue = newValue; return true; } return false; @@ -308,7 +339,7 @@ namespace Tensorflow.Common.Types { if (index == 0) { - return SetInternal(item, index, newValue); + return SetInternal(item.AsNest(), index, newValue); } index--; } @@ -320,7 +351,7 @@ namespace Tensorflow.Common.Types { if (index == 0) { - return SetInternal(item, index, newValue); + return SetInternal(item.AsNest(), index, newValue); } index--; } @@ -336,13 +367,13 @@ namespace Tensorflow.Common.Types { if (node.NestType == NestType.Node) { - yield return node.Value!; + yield return node.NodeValue!; } else if (node.NestType == NestType.List) { foreach (var item in node.ListValue!) { - foreach(var val in FlattenInternal(item)) + foreach(var val in FlattenInternal(item.AsNest())) { yield return val; } @@ -352,7 +383,7 @@ namespace Tensorflow.Common.Types { foreach (var item in node.DictValue!.Values) { - foreach (var val in FlattenInternal(item)) + foreach (var val in FlattenInternal(item.AsNest())) { yield return val; } @@ -364,23 +395,23 @@ namespace Tensorflow.Common.Types { if (NestType == NestType.Node) { - return new Nest(func(Value!)); + return new Nest(func(NodeValue!)); } else if (NestType == NestType.List) { List> outs = new List>(); foreach (var item in ListValue!) { - outs.Add(item.MapStructureInternal(func)); + outs.Add(item.AsNest().MapStructureInternal(func)); } return new Nest(outs); } else if (NestType == NestType.Dictionary) { - Dictionary> outs = new Dictionary>(); + Dictionary> outs = new Dictionary>(); foreach (var (key, value) in DictValue!) { - outs.Add(key, value.MapStructureInternal(func)); + outs.Add(key, value.AsNest().MapStructureInternal(func)); } return new Nest(outs); } @@ -417,14 +448,14 @@ namespace Tensorflow.Common.Types } if (node.NestType == NestType.Node) { - sb.Append(node.Value!.ToString()); + sb.Append(node.NodeValue!.ToString()); } else if (node.NestType == NestType.List) { sb.Append("["); for(int i = 0; i < node.ListValue!.Count; i++) { - WriteString(node.ListValue![i], sb); + WriteString(node.ListValue![i].AsNest(), sb); if(i != node.ListValue!.Count - 1) { sb.Append(", "); @@ -440,7 +471,7 @@ namespace Tensorflow.Common.Types foreach (var (key, value) in node.DictValue!) { sb.Append($"{key}: "); - WriteString(value, sb); + WriteString(value.AsNest(), sb); if (i != count - 1) { sb.Append(", "); @@ -454,5 +485,15 @@ namespace Tensorflow.Common.Types sb.Append(""); } } + + public static implicit operator Nest((INestStructure, INestStructure) inputs) + { + return new Nest(new INestStructure[] { inputs.Item1, inputs.Item2 }); + } + + public static implicit operator Nest((INestStructure, INestStructure, INestStructure) inputs) + { + return new Nest(new INestStructure[] { inputs.Item1, inputs.Item2, inputs.Item3 }); + } } } diff --git a/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs index 554ca526..cf199455 100644 --- a/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs +++ b/src/TensorFlowNET.Core/Common/Types/NestDictionary.cs @@ -6,7 +6,11 @@ namespace Tensorflow.Common.Types { public class NestDictionary : INestStructure, IDictionary where TKey : notnull { + public NestType NestType => NestType.Dictionary; public IDictionary Value { get; set; } + public int ShallowNestedCount => Values.Count; + + public int TotalNestedCount => Values.Count; public NestDictionary(IDictionary dict) { Value = dict; diff --git a/src/TensorFlowNET.Core/Common/Types/NestList.cs b/src/TensorFlowNET.Core/Common/Types/NestList.cs index 08218718..e38675da 100644 --- a/src/TensorFlowNET.Core/Common/Types/NestList.cs +++ b/src/TensorFlowNET.Core/Common/Types/NestList.cs @@ -10,29 +10,34 @@ namespace Tensorflow.Common.Types /// public sealed class NestList : INestStructure, IEnumerable { - public List Value { get; set; } + public NestType NestType => NestType.List; + public List Values { get; set; } + public int ShallowNestedCount => Values.Count; + + public int TotalNestedCount => Values.Count; + public NestList(IEnumerable values) { - Value = new List(values); + Values = new List(values); } public IEnumerable Flatten() { - return Value; + return Values; } public INestStructure MapStructure(Func func) { - return new NestList(Value.Select(x => func(x))); + return new NestList(Values.Select(x => func(x))); } public Nest AsNest() { - return new Nest(Value.Select(x => new Nest(x))); + return new Nest(Values.Select(x => new Nest(x))); } // Enumerator implementation public IEnumerator GetEnumerator() { - return Value.GetEnumerator(); + return Values.GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() diff --git a/src/TensorFlowNET.Core/Common/Types/NestNode.cs b/src/TensorFlowNET.Core/Common/Types/NestNode.cs index 1dad421d..701aade9 100644 --- a/src/TensorFlowNET.Core/Common/Types/NestNode.cs +++ b/src/TensorFlowNET.Core/Common/Types/NestNode.cs @@ -10,7 +10,11 @@ namespace Tensorflow.Common.Types /// public class NestNode : INestStructure { + public NestType NestType => NestType.Node; public T Value { get; set; } + public int ShallowNestedCount => 1; + + public int TotalNestedCount => 1; public NestNode(T value) { Value = value; diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 324d7e83..c1762d67 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -161,8 +161,8 @@ namespace Tensorflow break; } - yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ? - null : new Tensors(results.Skip(FirstInputTensorCount))); + yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ? + null : new Tensors(results.Skip(FirstInputTensorCount).ToArray())); } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index f1a09ed7..5f156fd9 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -359,6 +359,8 @@ namespace Tensorflow.Eager case TF_AttrType.TF_ATTR_FUNC: if (value is ConcreteFunction func) c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); + else if(value is string str) + c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length); else throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); break; diff --git a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs index 083d4813..ac099ae2 100644 --- a/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs +++ b/src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs @@ -1,4 +1,5 @@ using System.Linq; +using Tensorflow.Eager; namespace Tensorflow.Framework.Models { @@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models shapes.Insert(0, dim); return new TensorSpec(shapes.ToArray(), _dtype); } + + public static TensorSpec FromTensor(Tensor tensor, string? name = null) + { + if(tensor is EagerTensor) + { + return new TensorSpec(tensor.shape, tensor.dtype, name); + } + else + { + return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name); + } + } } } diff --git a/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs b/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs new file mode 100644 index 00000000..28d9e500 --- /dev/null +++ b/src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs @@ -0,0 +1,89 @@ +using Tensorflow.Graphs; + +namespace Tensorflow.Framework +{ + internal static class auto_control_deps_utils + { + public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"; + public static List get_read_only_resource_input_indices_graph(FuncGraph func_graph) + { + List result = new List(); + // A cache to store the read only resource inputs of an Op. + // Operation -> ObjectIdentitySet of resource handles. + Dictionary> opReadOnlyResourceInputs = + new Dictionary>(); + + for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++) + { + Tensor t = func_graph.Inputs[inputIndex]; + if (t.dtype != dtypes.resource) + continue; + + bool readOnly = true; + foreach (var op in t.consumers()) + { + if (opReadOnlyResourceInputs.ContainsKey(op)) + { + if (!opReadOnlyResourceInputs[op].Contains(t)) + { + readOnly = false; + break; + } + } + else + { + List indices = _get_read_only_resource_input_indices_op(op); + opReadOnlyResourceInputs[op] = new HashSet( + indices.Select(i => op.inputs[i])); + if (!opReadOnlyResourceInputs[op].Contains(t)) + { + readOnly = false; + break; + } + } + } + + if (readOnly) + result.Add(inputIndex); + } + + return result; + } + + private static List _get_read_only_resource_input_indices_op(Operation op) + { + // ignore the RESOURCE_READ_OPS + + int[] read_only_input_indices; + + try + { + read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR); + } + catch (InvalidArgumentError) + { + return new List(); + } + + int read_only_index = 0; + List result = new(); + for (int i = 0; i < op.inputs.Length; i++) + { + if (read_only_index >= read_only_input_indices.Length) + { + break; + } + if (op.inputs[i].dtype != dtypes.resource) + { + continue; + } + if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index]) + { + result.Add(i); + read_only_index++; + } + } + return result; + } + } +} diff --git a/src/TensorFlowNET.Core/Framework/function_def_lib.cs b/src/TensorFlowNET.Core/Framework/function_def_lib.cs index 67f8d324..488c6b65 100644 --- a/src/TensorFlowNET.Core/Framework/function_def_lib.cs +++ b/src/TensorFlowNET.Core/Framework/function_def_lib.cs @@ -42,10 +42,10 @@ namespace Tensorflow.Framework func_graph.as_default(); importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false); var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]); - func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); + func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]); - func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x))); + func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray()); // TODO(Rinne): func_graph.ControlOutputs _set_handle_data(func_graph, fdef); diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 88dce7d9..8742e453 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -8,6 +8,7 @@ using Tensorflow.Gradients; using Tensorflow.Graphs; using Tensorflow.Train; using Tensorflow.Util; +using Tensorflow.Common.Extensions; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -40,6 +41,18 @@ namespace Tensorflow.Functions public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs; public IEnumerable Variables => func_graph.Variables; public IEnumerable TrainableVariables => func_graph.TrainableVariables; + internal NameAttrList AsNameAttrList + { + get + { + NameAttrList ret = new() { Name = this.Name }; + foreach (var (name, value) in _attrs) + { + ret.Attr[name] = value; + } + return ret; + } + } public ConcreteFunction(string name) { diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 3bce52ea..ba7d7068 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable public IEnumerable TrainableVariables => Variables.Where(v => v.Trainable); public Dictionary Attrs { get; set; } - Dictionary _captures + internal Dictionary _captures = new Dictionary(); public Tensor[] external_captures @@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable var flat_func_args = nest.flatten(func_args as object); var flat_func_kwargs = nest.flatten(func_kwargs as object); func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs) - .Where(x => x is Tensor).Select(x => (Tensor)x)); + .Where(x => x is Tensor).Select(x => (Tensor)x).ToArray()); //var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true); //var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index eb8df581..9e879a0f 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -129,7 +129,7 @@ namespace Tensorflow } } - protected Graph outer_graph; + internal Graph outer_graph; public Graph OuterGraph => outer_graph; public Dictionary Functions => _functions; public SafeGraphHandle c_graph => _handle; diff --git a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs index d12ed1ad..8614391a 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs @@ -7,13 +7,19 @@ namespace Tensorflow.Keras.Layers.Rnn { public interface IRnnCell: ILayer { - GeneralizedTensorShape StateSize { get; } - GeneralizedTensorShape OutputSize { get; } - bool IsTFRnnCell { get; } + /// + /// If the derived class tends to not implement it, please return null. + /// + GeneralizedTensorShape? StateSize { get; } + /// + /// If the derived class tends to not implement it, please return null. + /// + GeneralizedTensorShape? OutputSize { get; } /// /// Whether the optional RNN args are supported when appying the layer. /// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`. /// bool SupportOptionalArgs { get; } + Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype); } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 26646b76..b651089a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -181,6 +181,10 @@ namespace Tensorflow { throw new NotImplementedException(); } + public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) + { + throw new NotImplementedException(); + } public GeneralizedTensorShape StateSize => throw new NotImplementedException(); public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); public bool IsTFRnnCell => throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 76a222ba..5ff5ccff 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -15,9 +15,11 @@ ******************************************************************************/ using Google.Protobuf; +using Google.Protobuf.Collections; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Functions; using static Tensorflow.Binding; using static Tensorflow.OpDef.Types; @@ -420,6 +422,12 @@ namespace Tensorflow case "list(shape)": attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def))); break; + case "func": + attr_value.Func = _MakeFunc(value, attr_def.Name); + break; + case "list(func)": + attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); + break; default: throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); } @@ -427,6 +435,47 @@ namespace Tensorflow return attr_value; } + private NameAttrList _MakeFunc(object func, string arg_name) + { + if(func is NameAttrList attrList) + { + return attrList; + } + NameAttrList fn_attr; + if(func is string funcStr) + { + fn_attr = new NameAttrList() { Name = funcStr }; + } + else if(func is ConcreteFunction concrete) + { + concrete.AddTograph(ops.get_default_graph()); + fn_attr = concrete.AsNameAttrList; + } + else if(func is EagerDefinedFunction eager) + { + eager.AddToGraph(ops.get_default_graph()); + fn_attr = new NameAttrList() { Name = eager.Name }; + } + else + { + throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}"); + } + return fn_attr; + } + + private List _MakeFuncList(object funcList, string arg_name) + { + List res = new List(); + if(funcList is IEnumerable enumerable) + { + foreach(var func in enumerable) + { + res.Add(_MakeFunc(func, arg_name)); + } + } + return res; + } + private bool _IsListParameter(ArgDef arg) { if (!String.IsNullOrEmpty(arg.NumberAttr)) diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 2955a13f..2329a478 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -34,7 +34,7 @@ namespace Tensorflow return num; } - protected Tensor[] _outputs; + internal Tensor[] _outputs; public virtual Tensor[] outputs => _outputs; public Tensor output => _outputs.FirstOrDefault(); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index a789c5f4..5e689c65 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -46,9 +46,9 @@ namespace Tensorflow /// public partial class Operation : ITensorOrOperation { - private readonly IntPtr _handle; // _c_op in python + protected IntPtr _handle; // _c_op in python - private readonly Graph _graph; + protected Graph _graph; internal Func _gradient_function; @@ -69,6 +69,7 @@ namespace Tensorflow //private OperationDescription _op_desc; public NodeDef node_def => GetNodeDef(); + protected Operation() { } public Operation(IntPtr handle, Graph g = null) { diff --git a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs index 08e73fe6..59176060 100644 --- a/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs @@ -17,6 +17,7 @@ using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -38,10 +39,6 @@ namespace Tensorflow.Operations bool _infer_shape; public override bool infer_shape => _infer_shape; - public bool _dynamic_size; - public Shape _element_shape; - - public List _colocate_with; Tensor _handle; public override Tensor handle => _handle; @@ -56,6 +53,7 @@ namespace Tensorflow.Operations bool infer_shape = true, Shape? element_shape = null, bool colocate_with_first_write_call = true, string name = null) { + _size = size; _flow = constant_op.constant(0); _infer_shape = infer_shape; _element_shape = element_shape ?? Shape.Null; diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs index dde2624a..4c3fde31 100644 --- a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs +++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs @@ -16,7 +16,9 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Common.Types; using Tensorflow.Eager; using static Tensorflow.Binding; @@ -33,18 +35,18 @@ namespace Tensorflow.Operations /// first tensor written to it. /// bool _colocate_with_first_write_call; - public bool colocate_with_first_write_call => _colocate_with_first_write_call; + public override bool colocate_with_first_write_call => _colocate_with_first_write_call; bool _infer_shape; - public bool infer_shape => _infer_shape; - public bool _dynamic_size; + public override bool infer_shape => _infer_shape; public List _element_shape; public List _colocate_with; internal Tensor _handle; - public Tensor handle => _handle; + public override Tensor handle => _handle; internal Tensor _flow; + public override Tensor flow => _flow; public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null, bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, @@ -55,6 +57,7 @@ namespace Tensorflow.Operations dynamic_size = dynamic_size ?? false; _dynamic_size = dynamic_size.Value; _dtype = dtype; + _size = size; _colocate_with_first_write_call = colocate_with_first_write_call; if (colocate_with_first_write_call) @@ -235,4 +238,172 @@ namespace Tensorflow.Operations return value; } } + + public class _GraphTensorArrayV2 : TensorArray + { + internal TF_DataType _dtype; + public override TF_DataType dtype => _dtype; + + /// + /// Used to keep track of what tensors the TensorArray should be + /// colocated with. We choose to colocate the TensorArray with the + /// first tensor written to it. + /// + bool _colocate_with_first_write_call; + public override bool colocate_with_first_write_call => _colocate_with_first_write_call; + + bool _infer_shape; + public override bool infer_shape => _infer_shape; + public Shape _element_shape; + + public List _colocate_with; + + internal Tensor _handle; + public override Tensor handle => _handle; + internal Tensor _flow; + public override Tensor flow => _flow; + + public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null, + bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + Debug.Assert(handle is null); + dynamic_size = dynamic_size ?? false; + _dynamic_size = dynamic_size.Value; + _size = size; + + if(flow is not null && flow.dtype != dtypes.variant) + { + throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead"); + } + if(flow is null && size is null) + { + throw new ValueError("Argument `size` must be provided if argument `flow` is not provided."); + } + if(flow is not null && size is not null) + { + throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time."); + } + if(flow is not null && element_shape is not null) + { + throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time."); + } + + _dtype = dtype; + + _element_shape = element_shape; + _infer_shape = infer_shape; + tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope => + { + if (flow is null) + { + _flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name); + } + else + { + _flow = flow; + } + }); + + _colocate_with_first_write_call = false; + _colocate_with = null; + } + + public override TensorArray unstack(Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + Debug.Assert(value.dtype == _dtype); + var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray()); + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public TensorArray scatter(Tensor indices, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + Debug.Assert(value.dtype == _dtype); + var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow); + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public override Tensor read(T index, string name = null) + { + if(index is Tensor tensor) + { + return read(tensor, name); + } + else + { + throw new TypeError("Please use non-generic method instead."); + } + } + + public Tensor read(Tensor index, string name = null) + { + return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope => + { + return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name); + }); + } + + public override TensorArray write(Tensor index, Tensor value, string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate + { + value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + Debug.Assert(value.dtype == _dtype); + var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name); + + return tensor_array_ops.build_ta_with_new_flow(this, flow_out); + }); + } + + public override TensorArray write(int index, T value, string name = null) + { + var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value"); + var index_tensor = ops.convert_to_tensor(index, name: "index"); + return write(index_tensor, value_tensor); + } + + private Tensor size(string name = null) + { + if(!_dynamic_size && _size is not null) + { + return ops.convert_to_tensor(_size, dtypes.int32); + } + else + { + return gen_list_ops.tensor_list_length(_flow, name); + } + } + + public override Tensor stack(string name = null) + { + return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate + { + int ta_size; + if(!_dynamic_size && (_size is not null)) + { + ta_size = (int)tensor_util.constant_value(_size); + } + else + { + ta_size = -1; + } + var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape); + return value; + }); + } + + public override Tensor gather(Tensor indices, string name = null) + { + return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name); + } + } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index a0b47aac..ca9e5fae 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -119,6 +119,27 @@ namespace Tensorflow } } + public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + Tensor shapeTensor; + if(shape.Length > 1) + { + shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); + if(shapeTensor.ndim > 1) + { + shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); + } + } + else + { + shapeTensor = shape[0]; + } + var output = fill(shapeTensor, array_ops.constant(0, dtype), name); + Debug.Assert(output.dtype.as_base_dtype() == dtype); + return output; + } + public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) { return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate @@ -307,6 +328,9 @@ namespace Tensorflow public static Tensor fill(Shape dims, T value, string name = null) => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); + public static Tensor fill(Tensor dims, T value, string name = null) + => gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); + /// /// Returns the rank of a tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs index 862b636f..efd9aba3 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.cs @@ -675,16 +675,17 @@ namespace Tensorflow } } - public static Tensor[] while_loop(Func cond, - Func body, - Tensor[] loop_vars, + public static Tensors while_loop(Func cond, + Func body, + Tensors loop_vars, int parallel_iterations = 10, string name = null) { var executing_eagerly = tf.Context.executing_eagerly(); if (!executing_eagerly) { - throw new NotImplementedException(""); + return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations, + name: name); } return tf_with(ops.name_scope("name", "while"), delegate diff --git a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs index c8891119..536d4e3c 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_util.py.cs @@ -16,12 +16,20 @@ using System; using System.Linq; +using Tensorflow.Functions; +using Tensorflow.Graphs; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { public class control_flow_util { + public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") || + (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0"); /// /// Return true if `op` is an Exit. /// @@ -196,5 +204,74 @@ namespace Tensorflow } return null; } + + public static bool EnableControlFlowV2(Graph graph) + { + return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0); + + } + + public static string create_new_tf_function(FuncGraph func_graph) + { + var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary()); + func.AddToGraph(func_graph); + return func_graph.Name; + } + + public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs) + { + if(inputs.Length == 0) + { + return (null, new Tensor[0]); + } + else + { + return (inputs[0], inputs); + } + } + + public static Tensor[] run_as_function_for_tape_gradients(Func make_op, Tensor[] inputs) + { + if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER + && !(ops.get_default_graph().building_function)) + { + throw new NotImplementedException(); + } + else + { + return make_op(inputs); + } + } + + public static string unique_fn_name(string scope, string name) + { + return $"{scope}{name}_{ops.uid()}".Replace("/", "_"); + } + + public static bool output_all_intermediates() + { + if (in_defun()) + { + return false; + } + if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR") + { + return false; + } + // TODO(Rinne): check this after refactoring keras building. + return false; + } + + public static bool in_defun() + { + if (tf.Context.executing_eagerly()) + { + return false; + } + + var graph = ops.get_default_graph(); + // TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph + return graph is FuncGraph; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs index 5663f9c9..e1cf1c13 100644 --- a/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_functional_ops.cs @@ -1,128 +1,1032 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Xml.Linq; -using Tensorflow.Contexts; +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + using Tensorflow.Eager; -using Tensorflow.Functions; +using Tensorflow.Contexts; using static Tensorflow.Binding; -namespace Tensorflow.Operations +namespace Tensorflow; + +public static class gen_functional_ops { - public class gen_functional_ops + /// + /// An n-way switch statement which calls a single branch function. + /// + /// + /// + /// An n-way switch statement, implementing the following: + /// ``` + /// switch (branch_index) { + /// case 0: + /// output = branches[0](input); + /// break; + /// case 1: + /// output = branches[1](input); + /// break; + /// ... + /// case [[nbranches-1]]: + /// default: + /// output = branches[nbranches-1](input); + /// break; + /// } + /// ``` + /// + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A list of functions each of which takes 'inputs' and returns a list of + /// tensors, whose types are the same as what every other branch returns. + /// + /// + /// + /// + public static Tensor[] _case(Tensor branch_index, Tensors input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "Case", name) { args = new object[] { branch_index, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["branches"] = branches, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return case_eager_fallback(branch_index, input, Tout: Tout, branches: branches, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["branch_index"] = branch_index; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["branches"] = branches; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("Case", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "branches", _op.get_attr("branches"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("Case", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] case_eager_fallback(Tensor branch_index, Tensor input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { branch_index, input }; + object[] _attrs = new object[] { "branches", branches, "output_shapes", output_shapes }; + var _result = _execute.execute("Case", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("Case", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Return the index of device the op runs. + /// + /// + /// + /// Given a list of device names, this operation returns the index of the device + /// this op runs. The length of the list is returned in two cases: + /// (1) Device does not exist in the given device list. + /// (2) It is in XLA compilation. + /// + /// + /// + /// + public static Tensor device_index(string[] device_names, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "DeviceIndex", name) { args = new object[] { }, attrs = new Dictionary() { ["device_names"] = device_names } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return device_index_eager_fallback(device_names: device_names, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["device_names"] = device_names; + var _op = tf.OpDefLib._apply_op_helper("DeviceIndex", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "device_names", _op.get_attr("device_names") }; + _execute.record_gradient("DeviceIndex", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor device_index_eager_fallback(string[] device_names, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "device_names", device_names }; + var _result = _execute.execute("DeviceIndex", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("DeviceIndex", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// ~~%~~ This op is used as a placeholder in If branch functions. It doesn't provide a~~%~~ valid output when run, so must either be removed (e.g. replaced with a~~%~~ function input) or guaranteed not to be used (e.g. if mirroring an~~%~~ intermediate output needed for the gradient computation of the other branch).~~%~~ + /// + /// + /// The type of the output. + /// + /// + /// + /// The purported shape of the output. This is only used for shape inference; + /// the output will not necessarily have this shape. Can be a partial shape. + /// + /// + /// + public static Tensor fake_param(TF_DataType dtype, Shape shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "FakeParam", name) { args = new object[] { }, attrs = new Dictionary() { ["dtype"] = dtype, ["shape"] = shape } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return fake_param_eager_fallback(dtype: dtype, shape: shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["dtype"] = dtype; + keywords["shape"] = shape; + var _op = tf.OpDefLib._apply_op_helper("FakeParam", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "dtype", _op._get_attr_type("dtype"), "shape", _op.get_attr("shape") }; + _execute.record_gradient("FakeParam", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor fake_param_eager_fallback(TF_DataType dtype, Shape shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { }; + object[] _attrs = new object[] { "dtype", dtype, "shape", shape }; + var _result = _execute.execute("FakeParam", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("FakeParam", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Applies a for loop. + /// + /// + /// + /// ```python + /// output = input; + /// for i in range(start, limit, delta) + /// output = body(i, output); + /// ``` + /// + /// + /// + /// + /// + /// + /// + /// + /// A function that takes a list of tensors (int32, T) and returns another + /// list of tensors (T). + /// + /// + /// + public static Tensor[] _for(Tensor start, Tensor limit, Tensor delta, Tensors input, object body, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "For", name) { args = new object[] { start, limit, delta, input }, attrs = new Dictionary() { ["body"] = body } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return for_eager_fallback(start, limit, delta, input, body: body, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["start"] = start; + keywords["limit"] = limit; + keywords["delta"] = delta; + keywords["input"] = input; + keywords["body"] = body; + var _op = tf.OpDefLib._apply_op_helper("For", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T"), "body", _op.get_attr("body") }; + _execute.record_gradient("For", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] for_eager_fallback(Tensor start, Tensor limit, Tensor delta, Tensor input, object body, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { start, limit, delta, input }; + object[] _attrs = new object[] { "body", body }; + var _result = _execute.execute("For", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("For", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// output = cond ? then_branch(input) : else_branch(input) + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what else_branch returns. + /// + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what then_branch returns. + /// + /// + /// + /// + public static Tensor[] _if(Tensor cond, Tensors input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string? name = null) { - public static Tensor[] partitioned_call(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, - string config = "", string config_proto = "", string executor_type = "", string name = null) + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) { - var ctx = tf.Context; - if (ctx.executing_eagerly()) + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "If", name) { args = new object[] { cond, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["then_branch"] = then_branch, ["else_branch"] = else_branch, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (Exception) { - try - { - return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "PartitionedCall", name, - args, tout, f, config, config_proto, executor_type)); - } - catch (Exception) - { + } + try + { + return if_eager_fallback(cond, input, Tout: Tout, then_branch: then_branch, else_branch: else_branch, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["cond"] = cond; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["then_branch"] = then_branch; + keywords["else_branch"] = else_branch; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("If", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tcond", _op._get_attr_type("Tcond"), "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "then_branch", _op.get_attr("then_branch"), "else_branch", _op.get_attr("else_branch"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("If", _op.inputs, _attrs, _result); + } + return _result; + } - } + public static Tensor[] if_eager_fallback(Tensor cond, Tensor input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { cond, input }; + object[] _attrs = new object[] { "Tcond", cond.dtype, "then_branch", then_branch, "else_branch", else_branch, "output_shapes", output_shapes }; + var _result = _execute.execute("If", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("If", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// returns `f(inputs)`, where `f`'s body is placed and partitioned. + /// + /// + /// + /// Asynchronously executes a function, potentially across multiple devices but + /// within a single process. The kernel places and partitions a given function's + /// underlying graph, and executes each of the partitioned subgraphs as a function. + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'args', a list of tensors, and returns 'output', + /// another list of tensors. Input and output types are specified by 'Tin' + /// and 'Tout'. The function body of f will be placed and partitioned across + /// devices, setting this op apart from the regular Call op. + /// + /// + /// + /// + /// + /// + public static Tensor[] partitioned_call(Tensors args, TF_DataType[] Tout, object f, string config = "", string config_proto = "", string executor_type = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "PartitionedCall", name) { args = new object[] { args }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f, ["config"] = config, ["config_proto"] = config_proto, ["executor_type"] = executor_type } }); + return _fast_path_result; } + catch (Exception) + { + } + try + { + return partitioned_call_eager_fallback(args, Tout: Tout, f: f, config: config, config_proto: config_proto, executor_type: executor_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + if (config is null) + { + config = ""; + } + if (config_proto is null) + { + config_proto = ""; + } + if (executor_type is null) + { + executor_type = ""; + } + Dictionary keywords = new(); + keywords["args"] = args; + keywords["Tout"] = Tout; + keywords["f"] = f; + keywords["config"] = config; + keywords["config_proto"] = config_proto; + keywords["executor_type"] = executor_type; + var _op = tf.OpDefLib._apply_op_helper("PartitionedCall", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f"), "config", _op.get_attr("config"), "config_proto", _op.get_attr("config_proto"), "executor_type", _op.get_attr("executor_type") }; + _execute.record_gradient("PartitionedCall", _op.inputs, _attrs, _result); + } + return _result; + } - if (config is null) + public static Tensor[] partitioned_call_eager_fallback(Tensor args, TF_DataType[] Tout, object f, string config, string config_proto, string executor_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { args }; + object[] _attrs = new object[] { "f", f, "config", config, "config_proto", config_proto, "executor_type", executor_type }; + var _result = _execute.execute("PartitionedCall", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("PartitionedCall", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Runs function `f` on a remote device indicated by `target`. + /// + /// + /// + /// + /// + /// The type list for the return values. + /// + /// + /// + /// + /// The function to run remotely. + /// + /// + /// + public static Tensor[] remote_call(Tensor target, Tensors args, TF_DataType[] Tout, object f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try { - config = ""; + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "RemoteCall", name) { args = new object[] { target, args }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f } }); + return _fast_path_result; } - if (config_proto is null) + catch (Exception) { - config_proto = ""; } - if (executor_type is null) + try { - executor_type = ""; + return remote_call_eager_fallback(target, args, Tout: Tout, f: f, name: name, ctx: _ctx); } - Dictionary kwargs = new(); - kwargs["args"] = args; - kwargs["Tout"] = tout; - kwargs["f"] = f; - kwargs["config"] = config; - kwargs["config_proto"] = config_proto; - kwargs["executor_type"] = executor_type; - var output = tf.OpDefLib._apply_op_helper("PartitionedCall", - name, kwargs); - var result = output.outputs; - if (_execute.must_record_gradient()) + catch (Exception) { - throw new NotImplementedException(); } - return result; } + Dictionary keywords = new(); + keywords["target"] = target; + keywords["args"] = args; + keywords["Tout"] = Tout; + keywords["f"] = f; + var _op = tf.OpDefLib._apply_op_helper("RemoteCall", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f") }; + _execute.record_gradient("RemoteCall", _op.inputs, _attrs, _result); + } + return _result; + } - public static Tensor[] partitioned_call_eager_fallback(Tensors args, TF_DataType[] tout, EagerDefinedFunction f, - string config, string config_proto, string executor_type, string name, Context ctx) + public static Tensor[] remote_call_eager_fallback(Tensor target, Tensor args, TF_DataType[] Tout, object f, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { target, args }; + object[] _attrs = new object[] { "f", f }; + var _result = _execute.execute("RemoteCall", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("RemoteCall", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// returns `f(inputs)`, where `f`'s body is placed and partitioned. + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'args', a list of tensors, and returns 'output', + /// another list of tensors. Input and output types are specified by 'Tin' + /// and 'Tout'. The function body of f will be placed and partitioned across + /// devices, setting this op apart from the regular Call op. This op is + /// stateful. + /// + /// + /// + /// + /// + /// + public static Tensor[] stateful_partitioned_call(Tensors args, TF_DataType[] Tout, object f, string config = "", string config_proto = "", string executor_type = "", string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) { - // TODO(Rinne): implement it. - throw new NotImplementedException(); - if(config is null) + try { - config = ""; + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatefulPartitionedCall", name) { args = new object[] { args }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f, ["config"] = config, ["config_proto"] = config_proto, ["executor_type"] = executor_type } }); + return _fast_path_result; } - if(config_proto is null) + catch (Exception) { - config_proto = ""; } - if(executor_type is null) + try { - executor_type = ""; + return stateful_partitioned_call_eager_fallback(args, Tout: Tout, f: f, config: config, config_proto: config_proto, executor_type: executor_type, name: name, ctx: _ctx); } - object[] attrs = new object[] + catch (Exception) { + } + } + if (config is null) + { + config = ""; + } + if (config_proto is null) + { + config_proto = ""; + } + if (executor_type is null) + { + executor_type = ""; + } + Dictionary keywords = new(); + keywords["args"] = args; + keywords["Tout"] = Tout; + keywords["f"] = f; + keywords["config"] = config; + keywords["config_proto"] = config_proto; + keywords["executor_type"] = executor_type; + var _op = tf.OpDefLib._apply_op_helper("StatefulPartitionedCall", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f"), "config", _op.get_attr("config"), "config_proto", _op.get_attr("config_proto"), "executor_type", _op.get_attr("executor_type") }; + _execute.record_gradient("StatefulPartitionedCall", _op.inputs, _attrs, _result); + } + return _result; + } - }; + public static Tensor[] stateful_partitioned_call_eager_fallback(Tensor args, TF_DataType[] Tout, object f, string config, string config_proto, string executor_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { args }; + object[] _attrs = new object[] { "f", f, "config", config, "config_proto", config_proto, "executor_type", executor_type }; + var _result = _execute.execute("StatefulPartitionedCall", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatefulPartitionedCall", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// An n-way switch statement which calls a single branch function. + /// + /// + /// + /// An n-way switch statement, implementing the following: + /// ``` + /// switch (branch_index) { + /// case 0: + /// output = branches[0](input); + /// break; + /// case 1: + /// output = branches[1](input); + /// break; + /// ... + /// case [[nbranches-1]]: + /// default: + /// output = branches[nbranches-1](input); + /// break; + /// } + /// ``` + /// + /// This should only be used when the none of branches has stateful ops. + /// + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A list of functions each of which takes 'inputs' and returns a list of + /// tensors, whose types are the same as what every other branch returns. + /// + /// + /// + /// + public static Tensor[] stateless_case(Tensor branch_index, Tensors input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatelessCase", name) { args = new object[] { branch_index, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["branches"] = branches, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return stateless_case_eager_fallback(branch_index, input, Tout: Tout, branches: branches, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["branch_index"] = branch_index; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["branches"] = branches; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("StatelessCase", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "branches", _op.get_attr("branches"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("StatelessCase", _op.inputs, _attrs, _result); } + return _result; + } - public static Tensor[] symbolic_gradient(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name = null) + public static Tensor[] stateless_case_eager_fallback(Tensor branch_index, Tensor input, TF_DataType[] Tout, object[] branches, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { branch_index, input }; + object[] _attrs = new object[] { "branches", branches, "output_shapes", output_shapes }; + var _result = _execute.execute("StatelessCase", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) { - var ctx = tf.Context; - if (ctx.executing_eagerly()) + _execute.record_gradient("StatelessCase", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// output = cond ? then_branch(input) : else_branch(input) + /// + /// + /// + /// + /// A list of output types. + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what else_branch returns. + /// + /// + /// + /// + /// A function that takes 'inputs' and returns a list of tensors, whose + /// types are the same as what then_branch returns. + /// + /// + /// + /// + public static Tensor[] stateless_if(Tensor cond, Tensors input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatelessIf", name) { args = new object[] { cond, input }, attrs = new Dictionary() { ["Tout"] = Tout, ["then_branch"] = then_branch, ["else_branch"] = else_branch, ["output_shapes"] = output_shapes } }); + return _fast_path_result; + } + catch (Exception) + { + } + try { - try - { - var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( - tf.Context, "SymbolicGradient", name, input, Tout, f)); - return _result; - } - catch (Exception) - { + return stateless_if_eager_fallback(cond, input, Tout: Tout, then_branch: then_branch, else_branch: else_branch, output_shapes: output_shapes, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["cond"] = cond; + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["then_branch"] = then_branch; + keywords["else_branch"] = else_branch; + keywords["output_shapes"] = output_shapes; + var _op = tf.OpDefLib._apply_op_helper("StatelessIf", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tcond", _op._get_attr_type("Tcond"), "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "then_branch", _op.get_attr("then_branch"), "else_branch", _op.get_attr("else_branch"), "output_shapes", _op.get_attr("output_shapes") }; + _execute.record_gradient("StatelessIf", _op.inputs, _attrs, _result); + } + return _result; + } - } + public static Tensor[] stateless_if_eager_fallback(Tensor cond, Tensor input, TF_DataType[] Tout, object then_branch, object else_branch, Shape[] output_shapes, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { cond, input }; + object[] _attrs = new object[] { "Tcond", cond.dtype, "then_branch", then_branch, "else_branch", else_branch, "output_shapes", output_shapes }; + var _result = _execute.execute("StatelessIf", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatelessIf", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// output = input; While (Cond(output)) { output = Body(output) } + /// + /// + /// + /// + /// A function takes 'input' and returns a tensor. If the tensor is + /// a scalar of non-boolean, the scalar is converted to a boolean + /// according to the following rule: if the scalar is a numerical + /// value, non-zero means True and zero means False; if the scalar is + /// a string, non-empty means True and empty means False. If the + /// tensor is not a scalar, non-emptiness means True and False + /// otherwise. + /// + /// This should only be used when the while condition and body functions + /// do not have stateful ops. + /// + /// + /// + /// + /// A function that takes a list of tensors and returns another + /// list of tensors. Both lists have the same types as specified + /// by T. + /// + /// + /// + /// + /// + public static Tensor[] stateless_while(Tensors input, object cond, object body, Shape[] output_shapes, int parallel_iterations = 10, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "StatelessWhile", name) { args = new object[] { input }, attrs = new Dictionary() { ["cond"] = cond, ["body"] = body, ["output_shapes"] = output_shapes, ["parallel_iterations"] = parallel_iterations } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return stateless_while_eager_fallback(input, cond: cond, body: body, output_shapes: output_shapes, parallel_iterations: parallel_iterations, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["cond"] = cond; + keywords["body"] = body; + keywords["output_shapes"] = output_shapes; + keywords["parallel_iterations"] = parallel_iterations; + var _op = tf.OpDefLib._apply_op_helper("StatelessWhile", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body", _op.get_attr("body"), "output_shapes", _op.get_attr("output_shapes"), "parallel_iterations", _op._get_attr_int("parallel_iterations") }; + _execute.record_gradient("StatelessWhile", _op.inputs, _attrs, _result); + } + return _result; + } - try - { - return symbolic_gradient_eager_fallback(input, Tout, f, name, ctx); - } - catch (Exception) - { + public static Tensor[] stateless_while_eager_fallback(Tensor input, object cond, object body, Shape[] output_shapes, int parallel_iterations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "cond", cond, "body", body, "output_shapes", output_shapes, "parallel_iterations", parallel_iterations }; + var _result = _execute.execute("StatelessWhile", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("StatelessWhile", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Computes the gradient function for function f via backpropagation. + /// + /// + /// + /// + /// the type list for the input list. + /// + /// + /// + /// + /// The function we want to compute the gradient for. + /// + /// The function 'f' must be a numerical function which takes N inputs and + /// produces M outputs. Its gradient function 'g', which is computed by + /// this SymbolicGradient op is a function taking N + M inputs and + /// produces N outputs. + /// + /// I.e. if we have + /// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), + /// then, g is + /// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, + /// dL/dy1, dL/dy2, ..., dL/dy_M), + /// + /// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the + /// loss function). dL/dx_i is the partial derivative of L with respect + /// to x_i. + /// + /// (Needs some math expert to say the comment above better.) + /// + /// + /// + public static Tensor[] symbolic_gradient(Tensors input, TF_DataType[] Tout, object f, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "SymbolicGradient", name) { args = new object[] { input }, attrs = new Dictionary() { ["Tout"] = Tout, ["f"] = f } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return symbolic_gradient_eager_fallback(input, Tout: Tout, f: f, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["Tout"] = Tout; + keywords["f"] = f; + var _op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f", _op.get_attr("f") }; + _execute.record_gradient("SymbolicGradient", _op.inputs, _attrs, _result); + } + return _result; + } - } + public static Tensor[] symbolic_gradient_eager_fallback(Tensor input, TF_DataType[] Tout, object f, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "f", f }; + var _result = _execute.execute("SymbolicGradient", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("SymbolicGradient", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Converts a tensor to a scalar predicate. + /// + /// + /// + /// Converts a tensor to a scalar predicate with the following rules: + /// + /// - For 0D tensors, truthiness is determined by comparing against a "zero" + /// value. For numerical types it is the obvious zero. For strings it is the + /// empty string. + /// + /// - For >0D tensors, truthiness is determined by looking at the number of + /// elements. If has zero elements, then the result is false. Otherwise the + /// result is true. + /// + /// This matches the behavior of If and While for determining if a tensor counts + /// as true/false for a branch condition. + /// + /// + /// + /// + public static Tensor to_bool(Tensor input, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "ToBool", name) { args = new object[] { input }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { } - var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); - var result = op.outputs; - if (_execute.must_record_gradient()) + try { - throw new NotImplementedException(); + return to_bool_eager_fallback(input, name: name, ctx: _ctx); } - return result; + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input"] = input; + var _op = tf.OpDefLib._apply_op_helper("ToBool", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op._get_attr_type("T") }; + _execute.record_gradient("ToBool", _op.inputs, _attrs, _result); } + return _result[0]; + } - public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) + public static Tensor to_bool_eager_fallback(Tensor input, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "T", input.dtype }; + var _result = _execute.execute("ToBool", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("ToBool", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// output = input; While (Cond(output)) { output = Body(output) } + /// + /// + /// + /// + /// A function takes 'input' and returns a tensor. If the tensor is + /// a scalar of non-boolean, the scalar is converted to a boolean + /// according to the following rule: if the scalar is a numerical + /// value, non-zero means True and zero means False; if the scalar is + /// a string, non-empty means True and empty means False. If the + /// tensor is not a scalar, non-emptiness means True and False + /// otherwise. + /// + /// + /// + /// + /// A function that takes a list of tensors and returns another + /// list of tensors. Both lists have the same types as specified + /// by T. + /// + /// + /// + /// + /// + public static Tensor[] _while(Tensors input, object cond, object body, Shape[] output_shapes, int parallel_iterations = 10, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) { - object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; - var result = _execute.execute("SymbolicGradient", Tout.Length, input, attrs, ctx, name); - if (_execute.must_record_gradient()) + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "While", name) { args = new object[] { input }, attrs = new Dictionary() { ["cond"] = cond, ["body"] = body, ["output_shapes"] = output_shapes, ["parallel_iterations"] = parallel_iterations } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return while_eager_fallback(input, cond: cond, body: body, output_shapes: output_shapes, parallel_iterations: parallel_iterations, name: name, ctx: _ctx); + } + catch (Exception) { - throw new NotImplementedException(); } - return result; } + Dictionary keywords = new(); + keywords["input"] = input; + keywords["cond"] = cond; + keywords["body"] = body; + keywords["output_shapes"] = output_shapes; + keywords["parallel_iterations"] = parallel_iterations; + var _op = tf.OpDefLib._apply_op_helper("While", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body", _op.get_attr("body"), "output_shapes", _op.get_attr("output_shapes"), "parallel_iterations", _op._get_attr_int("parallel_iterations") }; + _execute.record_gradient("While", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] while_eager_fallback(Tensor input, object cond, object body, Shape[] output_shapes, int parallel_iterations, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input }; + object[] _attrs = new object[] { "cond", cond, "body", body, "output_shapes", output_shapes, "parallel_iterations", parallel_iterations }; + var _result = _execute.execute("While", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("While", _inputs_flat, _attrs, _result); + } + return _result; } } diff --git a/src/TensorFlowNET.Core/Operations/gen_list_ops.cs b/src/TensorFlowNET.Core/Operations/gen_list_ops.cs new file mode 100644 index 00000000..e7253986 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/gen_list_ops.cs @@ -0,0 +1,1227 @@ +/*Wrappers around TensorFlow ops. This file is MACHINE GENERATED! Do not edit.*/ + +using Tensorflow.Eager; +using Tensorflow.Contexts; +using static Tensorflow.Binding; + +namespace Tensorflow; + +public static class gen_list_ops +{ + /// + /// Creates and returns an empty tensor list. + /// + /// + /// + /// All list elements must be tensors of dtype element_dtype and shape compatible + /// with element_shape. + /// + /// handle: an empty tensor list. + /// element_dtype: the type of elements in the list. + /// element_shape: a shape compatible with that of elements in the list. + /// + /// + /// + /// + /// + /// + public static Tensor empty_tensor_list(Tensor element_shape, Tensor max_num_elements, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "EmptyTensorList", name) { args = new object[] { element_shape, max_num_elements }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return empty_tensor_list_eager_fallback(element_shape, max_num_elements, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["element_shape"] = element_shape; + keywords["max_num_elements"] = max_num_elements; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("EmptyTensorList", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("EmptyTensorList", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor empty_tensor_list_eager_fallback(Tensor element_shape, Tensor max_num_elements, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { element_shape, max_num_elements }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("EmptyTensorList", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("EmptyTensorList", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Concats all tensors in the list along the 0th dimension. + /// + /// + /// + /// Requires that all tensors have the same shape except the first dimension. + /// + /// input_handle: The input list. + /// tensor: The concated result. + /// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] tensor_list_concat(Tensor input_handle, TF_DataType element_dtype, Shape element_shape = null, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListConcat", name) { args = new object[] { input_handle }, attrs = new Dictionary() { ["element_dtype"] = element_dtype, ["element_shape"] = element_shape } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return tensor_list_concat_eager_fallback(input_handle, element_dtype: element_dtype, element_shape: element_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_dtype"] = element_dtype; + keywords["element_shape"] = element_shape; + var _op = tf.OpDefLib._apply_op_helper("TensorListConcat", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "element_shape", _op.get_attr("element_shape") }; + _execute.record_gradient("TensorListConcat", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] tensor_list_concat_eager_fallback(Tensor input_handle, TF_DataType element_dtype, Shape element_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "element_shape", element_shape }; + var _result = _execute.execute("TensorListConcat", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListConcat", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_concat_lists(Tensor input_a, Tensor input_b, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListConcatLists", name) { args = new object[] { input_a, input_b }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_concat_lists_eager_fallback(input_a, input_b, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_a"] = input_a; + keywords["input_b"] = input_b; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListConcatLists", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListConcatLists", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_concat_lists_eager_fallback(Tensor input_a, Tensor input_b, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_a, input_b }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListConcatLists", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListConcatLists", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Concats all tensors in the list along the 0th dimension. + /// + /// + /// + /// Requires that all tensors have the same shape except the first dimension. + /// + /// input_handle: The input list. + /// element_shape: The shape of the uninitialized elements in the list. If the first + /// dimension is not -1, it is assumed that all list elements have the same + /// leading dim. + /// leading_dims: The list of leading dims of uninitialized list elements. Used if + /// the leading dim of input_handle.element_shape or the element_shape input arg + /// is not already set. + /// tensor: The concated result. + /// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor[] tensor_list_concat_v2(Tensor input_handle, Tensor element_shape, Tensor leading_dims, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListConcatV2", name) { args = new object[] { input_handle, element_shape, leading_dims }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return tensor_list_concat_v2_eager_fallback(input_handle, element_shape, leading_dims, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_shape"] = element_shape; + keywords["leading_dims"] = leading_dims; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListConcatV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListConcatV2", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] tensor_list_concat_v2_eager_fallback(Tensor input_handle, Tensor element_shape, Tensor leading_dims, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, element_shape, leading_dims }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListConcatV2", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListConcatV2", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// The shape of the elements of the given list, as a tensor. + /// + /// + /// + /// input_handle: the list + /// element_shape: the shape of elements of the list + /// + /// + /// + /// + /// + public static Tensor tensor_list_element_shape(Tensor input_handle, TF_DataType shape_type, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListElementShape", name) { args = new object[] { input_handle }, attrs = new Dictionary() { ["shape_type"] = shape_type } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_element_shape_eager_fallback(input_handle, shape_type: shape_type, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["shape_type"] = shape_type; + var _op = tf.OpDefLib._apply_op_helper("TensorListElementShape", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListElementShape", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_element_shape_eager_fallback(Tensor input_handle, TF_DataType shape_type, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle }; + object[] _attrs = new object[] { "shape_type", shape_type }; + var _result = _execute.execute("TensorListElementShape", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListElementShape", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a TensorList which, when stacked, has the value of `tensor`. + /// + /// + /// + /// Each tensor in the result list corresponds to one row of the input tensor. + /// + /// tensor: The input tensor. + /// output_handle: The list. + /// + /// + /// + /// + /// + public static Tensor tensor_list_from_tensor(Tensor tensor, Tensor element_shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListFromTensor", name) { args = new object[] { tensor, element_shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_from_tensor_eager_fallback(tensor, element_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["element_shape"] = element_shape; + var _op = tf.OpDefLib._apply_op_helper("TensorListFromTensor", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListFromTensor", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_from_tensor_eager_fallback(Tensor tensor, Tensor element_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, element_shape }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListFromTensor", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListFromTensor", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a Tensor by indexing into the TensorList. + /// + /// + /// + /// Each row in the produced Tensor corresponds to the element in the TensorList + /// specified by the given index (see `tf.gather`). + /// + /// input_handle: The input tensor list. + /// indices: The indices used to index into the list. + /// values: The tensor. + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, Tensor element_shape, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListGather", name) { args = new object[] { input_handle, indices, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_gather_eager_fallback(input_handle, indices, element_shape, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["indices"] = indices; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListGather", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListGather", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_gather_eager_fallback(Tensor input_handle, Tensor indices, Tensor element_shape, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, indices, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListGather", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListGather", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, Tensor element_shape, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListGetItem", name) { args = new object[] { input_handle, index, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_get_item_eager_fallback(input_handle, index, element_shape, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["index"] = index; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListGetItem", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListGetItem", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_get_item_eager_fallback(Tensor input_handle, Tensor index, Tensor element_shape, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, index, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListGetItem", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListGetItem", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the number of tensors in the input tensor list. + /// + /// + /// + /// input_handle: the input list + /// length: the number of tensors in the list + /// + /// + /// + /// + public static Tensor tensor_list_length(Tensor input_handle, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListLength", name) { args = new object[] { input_handle }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_length_eager_fallback(input_handle, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + var _op = tf.OpDefLib._apply_op_helper("TensorListLength", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("TensorListLength", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_length_eager_fallback(Tensor input_handle, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("TensorListLength", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListLength", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Returns the last element of the input list as well as a list with all but that element. + /// + /// + /// + /// Fails if the list is empty. + /// + /// input_handle: the input list + /// tensor: the withdrawn last element of the list + /// element_dtype: the type of elements in the list + /// element_shape: the shape of the output tensor + /// + /// + /// + /// + /// + /// + public static Tensor[] tensor_list_pop_back(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListPopBack", name) { args = new object[] { input_handle, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result; + } + catch (Exception) + { + } + try + { + return tensor_list_pop_back_eager_fallback(input_handle, element_shape, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListPopBack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListPopBack", _op.inputs, _attrs, _result); + } + return _result; + } + + public static Tensor[] tensor_list_pop_back_eager_fallback(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype }; + var _result = _execute.execute("TensorListPopBack", 2, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListPopBack", _inputs_flat, _attrs, _result); + } + return _result; + } + /// + /// Returns a list which has the passed-in `Tensor` as last element and the other elements of the given list in `input_handle`. + /// + /// + /// + /// tensor: The tensor to put on the list. + /// input_handle: The old list. + /// output_handle: A list with the elements of the old list followed by tensor. + /// element_dtype: the type of elements in the list. + /// element_shape: a shape compatible with that of elements in the list. + /// + /// + /// + /// + /// + public static Tensor tensor_list_push_back(Tensor input_handle, Tensor tensor, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListPushBack", name) { args = new object[] { input_handle, tensor }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_push_back_eager_fallback(input_handle, tensor, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["tensor"] = tensor; + var _op = tf.OpDefLib._apply_op_helper("TensorListPushBack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListPushBack", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_push_back_eager_fallback(Tensor input_handle, Tensor tensor, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, tensor }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype }; + var _result = _execute.execute("TensorListPushBack", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListPushBack", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_push_back_batch(Tensor input_handles, Tensor tensor, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListPushBackBatch", name) { args = new object[] { input_handles, tensor }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_push_back_batch_eager_fallback(input_handles, tensor, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handles"] = input_handles; + keywords["tensor"] = tensor; + var _op = tf.OpDefLib._apply_op_helper("TensorListPushBackBatch", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListPushBackBatch", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_push_back_batch_eager_fallback(Tensor input_handles, Tensor tensor, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handles, tensor }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype }; + var _result = _execute.execute("TensorListPushBackBatch", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListPushBackBatch", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// List of the given size with empty elements. + /// + /// + /// + /// element_shape: the shape of the future elements of the list + /// num_elements: the number of elements to reserve + /// handle: the output list + /// element_dtype: the desired type of elements in the list. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_reserve(Tensor element_shape, Tensor num_elements, TF_DataType element_dtype, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListReserve", name) { args = new object[] { element_shape, num_elements }, attrs = new Dictionary() { ["element_dtype"] = element_dtype } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_reserve_eager_fallback(element_shape, num_elements, element_dtype: element_dtype, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["element_shape"] = element_shape; + keywords["num_elements"] = num_elements; + keywords["element_dtype"] = element_dtype; + var _op = tf.OpDefLib._apply_op_helper("TensorListReserve", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListReserve", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_reserve_eager_fallback(Tensor element_shape, Tensor num_elements, TF_DataType element_dtype, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { element_shape, num_elements }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListReserve", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListReserve", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Resizes the list. + /// + /// + /// + /// + /// input_handle: the input list + /// size: size of the output list + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_resize(Tensor input_handle, Tensor size, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListResize", name) { args = new object[] { input_handle, size }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_resize_eager_fallback(input_handle, size, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["size"] = size; + var _op = tf.OpDefLib._apply_op_helper("TensorListResize", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { }; + _execute.record_gradient("TensorListResize", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_resize_eager_fallback(Tensor input_handle, Tensor size, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, size }; + object[] _attrs = new object[] { }; + var _result = _execute.execute("TensorListResize", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListResize", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a TensorList by indexing into a Tensor. + /// + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see `tf.gather`). + /// + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// element_shape: The shape of the elements in the list (can be less specified than + /// the shape of the tensor). + /// output_handle: The TensorList. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Tensor element_shape, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListScatter", name) { args = new object[] { tensor, indices, element_shape }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_scatter_eager_fallback(tensor, indices, element_shape, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["element_shape"] = element_shape; + var _op = tf.OpDefLib._apply_op_helper("TensorListScatter", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListScatter", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_scatter_eager_fallback(Tensor tensor, Tensor indices, Tensor element_shape, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, element_shape }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListScatter", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListScatter", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Scatters tensor at indices in an input list. + /// + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see `tf.gather`). + /// + /// input_handle: The list to scatter into. + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// output_handle: The TensorList. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_scatter_into_existing_list(Tensor input_handle, Tensor tensor, Tensor indices, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListScatterIntoExistingList", name) { args = new object[] { input_handle, tensor, indices }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_scatter_into_existing_list_eager_fallback(input_handle, tensor, indices, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["tensor"] = tensor; + keywords["indices"] = indices; + var _op = tf.OpDefLib._apply_op_helper("TensorListScatterIntoExistingList", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListScatterIntoExistingList", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_scatter_into_existing_list_eager_fallback(Tensor input_handle, Tensor tensor, Tensor indices, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, tensor, indices }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype }; + var _result = _execute.execute("TensorListScatterIntoExistingList", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListScatterIntoExistingList", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Creates a TensorList by indexing into a Tensor. + /// + /// + /// + /// Each member of the TensorList corresponds to one row of the input tensor, + /// specified by the given index (see `tf.gather`). + /// + /// tensor: The input tensor. + /// indices: The indices used to index into the list. + /// element_shape: The shape of the elements in the list (can be less specified than + /// the shape of the tensor). + /// num_elements: The size of the output list. Must be large enough to accommodate + /// the largest index in indices. If -1, the list is just large enough to include + /// the largest index in indices. + /// output_handle: The TensorList. + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_scatter_v2(Tensor tensor, Tensor indices, Tensor element_shape, Tensor num_elements, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListScatterV2", name) { args = new object[] { tensor, indices, element_shape, num_elements }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_scatter_v2_eager_fallback(tensor, indices, element_shape, num_elements, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["indices"] = indices; + keywords["element_shape"] = element_shape; + keywords["num_elements"] = num_elements; + var _op = tf.OpDefLib._apply_op_helper("TensorListScatterV2", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListScatterV2", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_scatter_v2_eager_fallback(Tensor tensor, Tensor indices, Tensor element_shape, Tensor num_elements, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, indices, element_shape, num_elements }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListScatterV2", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListScatterV2", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListSetItem", name) { args = new object[] { input_handle, index, item }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_set_item_eager_fallback(input_handle, index, item, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["index"] = index; + keywords["item"] = item; + var _op = tf.OpDefLib._apply_op_helper("TensorListSetItem", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype") }; + _execute.record_gradient("TensorListSetItem", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_set_item_eager_fallback(Tensor input_handle, Tensor index, Tensor item, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, index, item }; + object[] _attrs = new object[] { "element_dtype", item.dtype }; + var _result = _execute.execute("TensorListSetItem", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListSetItem", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Splits a tensor into a list. + /// + /// + /// + /// list[i] corresponds to lengths[i] tensors from the input tensor. + /// The tensor must have rank at least 1 and contain exactly sum(lengths) elements. + /// + /// tensor: The input tensor. + /// element_shape: A shape compatible with that of elements in the tensor. + /// lengths: Vector of sizes of the 0th dimension of tensors in the list. + /// output_handle: The list. + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_split(Tensor tensor, Tensor element_shape, Tensor lengths, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListSplit", name) { args = new object[] { tensor, element_shape, lengths }, attrs = new Dictionary() { } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_split_eager_fallback(tensor, element_shape, lengths, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["tensor"] = tensor; + keywords["element_shape"] = element_shape; + keywords["lengths"] = lengths; + var _op = tf.OpDefLib._apply_op_helper("TensorListSplit", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "shape_type", _op._get_attr_type("shape_type") }; + _execute.record_gradient("TensorListSplit", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_split_eager_fallback(Tensor tensor, Tensor element_shape, Tensor lengths, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { tensor, element_shape, lengths }; + object[] _attrs = new object[] { "element_dtype", tensor.dtype, "shape_type", element_shape.dtype }; + var _result = _execute.execute("TensorListSplit", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListSplit", _inputs_flat, _attrs, _result); + } + return _result[0]; + } + /// + /// Stacks all tensors in the list. + /// + /// + /// + /// Requires that all tensors have the same shape. + /// + /// input_handle: the input list + /// tensor: the gathered result + /// num_elements: optional. If not -1, the number of elements in the list. + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor tensor_list_stack(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, int num_elements = -1, string? name = null) + { + var _ctx = tf.Context; + if (_ctx.executing_eagerly()) + { + try + { + var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "TensorListStack", name) { args = new object[] { input_handle, element_shape }, attrs = new Dictionary() { ["element_dtype"] = element_dtype, ["num_elements"] = num_elements } }); + return _fast_path_result[0]; + } + catch (Exception) + { + } + try + { + return tensor_list_stack_eager_fallback(input_handle, element_shape, element_dtype: element_dtype, num_elements: num_elements, name: name, ctx: _ctx); + } + catch (Exception) + { + } + } + Dictionary keywords = new(); + keywords["input_handle"] = input_handle; + keywords["element_shape"] = element_shape; + keywords["element_dtype"] = element_dtype; + keywords["num_elements"] = num_elements; + var _op = tf.OpDefLib._apply_op_helper("TensorListStack", name, keywords); + var _result = _op.outputs; + if (_execute.must_record_gradient()) + { + object[] _attrs = new object[] { "element_dtype", _op._get_attr_type("element_dtype"), "num_elements", _op._get_attr_int("num_elements") }; + _execute.record_gradient("TensorListStack", _op.inputs, _attrs, _result); + } + return _result[0]; + } + + public static Tensor tensor_list_stack_eager_fallback(Tensor input_handle, Tensor element_shape, TF_DataType element_dtype, int num_elements, string name, Context ctx) + { + Tensor[] _inputs_flat = new Tensor[] { input_handle, element_shape }; + object[] _attrs = new object[] { "element_dtype", element_dtype, "num_elements", num_elements }; + var _result = _execute.execute("TensorListStack", 1, inputs: _inputs_flat, attrs: _attrs, ctx: ctx, name: name); + if (_execute.must_record_gradient()) + { + _execute.record_gradient("TensorListStack", _inputs_flat, _attrs, _result); + } + return _result[0]; + } +} diff --git a/src/TensorFlowNET.Core/Operations/list_ops.cs b/src/TensorFlowNET.Core/Operations/list_ops.cs new file mode 100644 index 00000000..c5e83ee4 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/list_ops.cs @@ -0,0 +1,111 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow.Operations +{ + internal class list_ops + { + private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype) + { + if(list_handle is EagerTensor eagerTensor) + { + var handle_data = new CppShapeInferenceResult.Types.HandleData(); + handle_data.IsSet = true; + handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType() + { + Shape = element_shape.as_proto(), + Dtype = element_dtype.as_datatype_enum(), + Type = new FullTypeDef() { TypeId = FullTypeId.TftArray } + }); + list_handle.HandleData = handle_data; + } + } + + private static Tensor _build_element_shape(Shape? shape) + { + if(shape is null || shape.IsNull) + { + return ops.convert_to_tensor(-1); + } + else + { + return ops.convert_to_tensor(shape); + } + } + + public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null) + { + var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name); + _set_handle_data(result, shape, element_dtype); + return result; + } + + public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null) + { + var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name); + _set_handle_data(result, tensor.shape, tensor.dtype); + return result; + } + + public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype, + Shape? element_shape = null, string? name = null) + { + return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape), + element_dtype, name); + } + + public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item, + bool resize_if_index_out_of_bounds = false, string? name = null) + { + if (resize_if_index_out_of_bounds) + { + var input_list_size = gen_list_ops.tensor_list_length(input_handle); + input_handle = control_flow_ops.cond(index >= input_list_size, + () => gen_list_ops.tensor_list_resize(input_handle, index + 1), + () => input_handle); + } + var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name); + handle_data_util.copy_handle_data(input_handle, output_handle); + return output_handle; + } + + public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1, + Shape? element_shape = null, string? name = null) + { + return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name); + } + + public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype, + Shape? element_shape = null, string? name = null) + { + return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name); + } + + public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null, + string? name = null) + { + if(input_handle is not null) + { + var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name); + handle_data_util.copy_handle_data(input_handle, output_handle); + return output_handle; + } + else + { + var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape), + constant_op.constant(-1), name); + _set_handle_data(output_handle, element_shape, tensor.dtype); + return output_handle; + } + } + + public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1, + string? name = null) + { + return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype, + max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs index 7d2da544..6be0706c 100644 --- a/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/tensor_array_ops.cs @@ -13,11 +13,23 @@ namespace Tensorflow /// public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow) { - var new_ta = tf.TensorArray( - dtype: old_ta.dtype, - infer_shape: old_ta.infer_shape, + if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph())) + { + throw new NotImplementedException("Attempting to build a graph-mode TF2-style " + + "TensorArray from either an eager-mode " + + "TensorArray or a TF1-style TensorArray. " + + "This is not currently supported. You may be " + + "attempting to capture a TensorArray " + + "inside a tf.function or tf.data map function. " + + "Instead, construct a new TensorArray inside " + + "the function."); + } + var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape, colocate_with_first_write_call: old_ta.colocate_with_first_write_call); - + new_ta._dynamic_size = old_ta._dynamic_size; + new_ta._size = old_ta._size; + new_ta._colocate_with = old_ta._colocate_with; + new_ta._element_shape = old_ta._element_shape; return new_ta; } diff --git a/src/TensorFlowNET.Core/Operations/while_v2.cs b/src/TensorFlowNET.Core/Operations/while_v2.cs new file mode 100644 index 00000000..7ee3e9e8 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/while_v2.cs @@ -0,0 +1,401 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using Tensorflow.Common.Extensions; +using Tensorflow.Common.Types; +using Tensorflow.Eager; +using Tensorflow.Framework; +using Tensorflow.Framework.Models; +using Tensorflow.Graphs; +using static Tensorflow.Binding; + +namespace Tensorflow.Operations +{ + class _OperationWithOutputs : Operation + { + public _OperationWithOutputs(IntPtr handle, Graph g = null) + { + _handle = handle; + _graph = g; + _outputs = null; + g._add_op(this); + } + } + internal class while_v2 + { + public static Tensor[] while_loop(Func cond, + Func body, + Tensors loop_vars, + int maximum_iterations = -1, + int parallel_iterations = 10, + string name = null, + bool back_prop = true, + bool return_same_structure = true) + { + var orig_loop_vars = loop_vars; + var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray(); + int len_orig_loop_vars = orig_loop_vars.Length; + + loop_vars = _tensor_array_to_flow(loop_vars); + loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors(); + + var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars)); + + var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray(); + + if(string.IsNullOrEmpty(name)) + { + name = "while"; + } + + return tf_with(ops.name_scope(name), nameScopeWhile => + { + string scope = (nameScopeWhile as ops.NameScope).scope_name; + string cond_name = control_flow_util.unique_fn_name(scope, "cond"); + string body_name = control_flow_util.unique_fn_name(scope, "body"); + + var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations); + var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype, + name: "loop_counter"); + loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray(); + + var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)} + .Concat(loop_vars_signature.Flatten()).ToArray(); + + // TODO(Rinne): possible wrong implemenation here. + var add_control_dependencies = false; + + object[] wrapped_cond(object[] inputs) + { + Tensor loop_counter = (Tensor)inputs[0]; + Tensor maximum_iterations_arg = (Tensor)inputs[1]; + Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); + var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); + if(pred.shape.IsNull || pred.shape.ndim > 0) + { + pred = array_ops.squeeze(pred); + } + if(maximum_iterations == -1) + { + return new object[] { pred }; + } + else + { + return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) }; + } + } + + var cond_graph = FuncGraph.func_graph_from_func("cond", wrapped_cond, null, + null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); + + bool stateful_parallelism = false; + + object[] wrapped_body(object[] inputs) + { + Tensor loop_counter = (Tensor)inputs[0]; + Tensor maximum_iterations_arg = (Tensor)inputs[1]; + Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray(); + + _copy_handle_data(loop_vars.Flatten().Skip(2), args); + + foreach(var t in cond_graph.external_captures) + { + var graph = (FuncGraph)(ops.get_default_graph()); + graph.capture(t); + } + + var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args)); + outputs = _tensor_array_to_flow(outputs); + + return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); + } + + var body_graph = FuncGraph.func_graph_from_func("body", wrapped_body, null, null, func_graph_signature, + add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); + + // TODO(Rinne): possible wrong implementation here. + NestList loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() }); + body_graph.Outputs.AddRange(body_graph.internal_captures); + + cond_graph.as_default(); + int num_cond_captures = cond_graph.external_captures.Length; + Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray())); + _duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray()); + cond_graph.Exit(); + + int first_loop_var_index = 2; + + int num_flattened_oututs = orig_loop_vars.Length; + int num_original_outputs = body_graph.Outputs.Length; + if (back_prop && control_flow_util.output_all_intermediates()) + { + var intermediate_tensors = _get_intermediates(body_graph); + + foreach(var intermediate_tensor in intermediate_tensors) + { + var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations); + loop_vars_list.Values.Add(tensor_list); + + cond_graph.as_default(); + cond_graph.capture(tensor_list); + cond_graph.Exit(); + + body_graph.as_default(); + var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor); + body_graph.Outputs.Add(appended_tensor_list); + body_graph.Exit(); + } + } + + List flattened_loop_vars = new(); + foreach(var item in loop_vars_list.Values) + { + flattened_loop_vars.AddRange(item.Flatten()); + } + // skip the check + + // TODO(Rinne): deal with control dependencies + var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray(); + var span = new Span(output_shapes).Slice(first_loop_var_index, num_flattened_oututs); + for(int i = 0; i < span.Length; i++) + { + span[i] = flat_shape_invariants[i]; + } + + Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations, + (nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism); + + if (!ops.get_default_graph().building_function) + { + outputs = outputs.Select(t => array_ops.identity(t)).ToArray(); + } + + var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray(); + + if (!back_prop) + { + output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray(); + } + outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars); + + return outputs; + }); + } + + private static Tensors _tensor_array_to_flow(Tensors loop_vars) + { + if(loop_vars.NestType == NestType.Node) + { + if(loop_vars.NodeValue is FakeTensorByTensorArray fake) + { + return new Tensors(fake.TensorArray.flow); + } + else + { + return new Tensors(loop_vars.NodeValue!); + } + } + else if(loop_vars.NestType == NestType.List) + { + List> list = new(); + foreach(var item in loop_vars.ListValue!) + { + if(item.NestType == NestType.Node) + { + var nested = item.AsNest(); + if (nested.NodeValue is FakeTensorByTensorArray fake) + { + list.Add(new Nest(fake.TensorArray.flow)); + } + else + { + list.Add(new Nest(nested.NodeValue!)); + } + } + else + { + list.Add(new Nest(item.AsNest())); + } + } + return Tensors.FromNest(new Nest(list)); + } + else + { + throw new NotImplementedException(); + } + } + + private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph, + Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism) + { + var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op); + var body_stateful_ops = body_graph.get_operations().Select(x => x.op); + + bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0; + + Tensor[] _make_op(Tensor[] inputs) + { + Tensor[] outputs; + if (is_stateful) + { + outputs = gen_functional_ops._while( + inputs, + control_flow_util.create_new_tf_function(cond_graph), + control_flow_util.create_new_tf_function(body_graph), + output_shapes, + parallel_iterations, + name + ); + } + else + { + outputs = gen_functional_ops.stateless_while( + inputs, + control_flow_util.create_new_tf_function(cond_graph), + control_flow_util.create_new_tf_function(body_graph), + output_shapes, + parallel_iterations, + name + ); + } + var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs); + _copy_handle_data(body_graph.Outputs, tensors); + _set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph}); + while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs }); + while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism }); + + cond_graph.outer_graph = ops.get_default_graph(); + body_graph.outer_graph = ops.get_default_graph(); + // TODO(Rinne): set the two graphs to while_op + return tensors; + } + + return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars); + } + + /// + /// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies. + /// + /// + /// + private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs) + { + List read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList(); + foreach(var branch_graph in branch_graphs) + { + if (read_only_indices.Count == 0) + { + break; + } + var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph); + read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList(); + } + AttrValue.Types.ListValue listValue = new(); + listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x)); + op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue() + { + List = listValue + }); + } + + private static Tensors _pack_sequence_as(INestStructure loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars) + { + var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item => + { + var (flow, y) = item; + if (y is FakeTensorByTensorArray ta) + { + return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow)); + } + else + { + return flow; + } + }).ToArray(); + return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors(); + } + + private static Tensor[] _get_intermediates(FuncGraph func_graph) + { + List intermediates = new(); + var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1); + + foreach(var op in func_graph.get_operations()) + { + Debug.Assert(op is Operation); + var oper = (Operation)op; + if(oper.type == "Identity" || oper.type == "MutexLock") + { + continue; + } + foreach(var o in op.outputs) + { + if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o)) + { + intermediates.Add(o); + } + } + } + return intermediates.ToArray(); + } + + private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures) + { + var types = body_graph_captures.Select(t => t.dtype).ToList(); + var c_graph = cond_graph.c_graph; + var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList(); + + var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList(); + + List tensors = new(); + foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types)) + { + var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph); + op._outputs = new Tensor[] { tensor }; + tensors.Add(tensor); + } + + var tuples = zip(body_graph_captures, tensors).ToList(); + var keys = body_graph_captures.Select(t => t.Id).ToList(); + cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2)); + cond_graph.Inputs.AddRange(tensors); + } + + private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype) + { + var desc = c_api.TF_NewOperation(graph, "Placeholder", name); + c_api.TF_SetAttrType(desc, "dtype", dtype); + var op = c_api.TF_FinishOperation(desc, tf.Status); + tf.Status.Check(true); + var output = new TF_Output(); + output.oper = op; + output.index = 0; + return output; + } + + private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph) + { + return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder"); + } + + private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype, + string name) + { + return ops.convert_to_tensor(value, dtype, name, false); + } + + private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1) + { + return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations"); + } + + private static void _copy_handle_data(IEnumerable src_tensors, IEnumerable dst_tensors) + { + foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors)) + { + handle_data_util.copy_handle_data(src_t, dst_t); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 498ffda7..e7ff9f74 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -105,6 +105,13 @@ namespace Tensorflow _id = ops.uid(); } + internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output) + { + Tensor ret = new Tensor(op, value_index, dtype); + ret._tf_output = tf_output; + return ret; + } + protected unsafe void InitTensor(Shape shape, TF_DataType dtype) { _handle = TF_NewTensor(shape, dtype, null); diff --git a/src/TensorFlowNET.Core/Tensors/TensorArray.cs b/src/TensorFlowNET.Core/Tensors/TensorArray.cs index fb59593c..ff74956a 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorArray.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorArray.cs @@ -14,7 +14,9 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Common.Types; using Tensorflow.Operations; +using static Tensorflow.Binding; namespace Tensorflow { @@ -44,5 +46,27 @@ namespace Tensorflow public abstract Tensor stack(string name = null); public abstract Tensor gather(Tensor indices, string name = null); + + internal bool _dynamic_size; + internal Tensor _size; + internal List _colocate_with; + internal Shape _element_shape; + + public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false, + bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null, + bool infer_shape = true, Shape? element_shape = null, + bool colocate_with_first_write_call = true, string name = null) + { + if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant)) + { + return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, + infer_shape, element_shape, colocate_with_first_write_call, name); + } + else + { + return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow, + infer_shape, element_shape, colocate_with_first_write_call, name); + } + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 259b1eec..38a3e5dc 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -4,6 +4,8 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using Tensorflow.Common.Types; +using Tensorflow.Operations; +using Tensorflow.Common.Extensions; namespace Tensorflow { @@ -58,7 +60,7 @@ namespace Tensorflow public Tensor this[params string[] slices] => this.First()[slices]; - private Tensors(Nest nested) : base(nested) + internal Tensors(Nest nested) : base(nested) { } @@ -68,9 +70,9 @@ namespace Tensorflow } - public Tensors(IEnumerable tensors): base(tensors.Select(x => new Nest(x))) + public Tensors(IList tensors) : base(tensors.Select(x => new Nest(x))) { - + } public Tensors(NDArray nd): base(ops.convert_to_tensor(nd)) @@ -78,6 +80,32 @@ namespace Tensorflow } + /// + /// Get the element in shallow level. For example, for ts = [1, [2, 3], 4], + /// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3] + /// + /// + /// + public Tensors GetShallow(int index) + { + if(NestType == NestType.Node) + { + if(index > 0) + { + throw new IndexOutOfRangeException(); + } + return this; + } + else if(NestType == NestType.List) + { + return ListValue![index].AsNest().ToTensors(); + } + else + { + throw new NotImplementedException(); + } + } + private static Nest DealWithConstructorArrayInput(Tensor[] tensors) { if (tensors.Length == 0) @@ -115,8 +143,8 @@ namespace Tensorflow else if(NestType == NestType.Node) { NestType = NestType.List; - ListValue = new() { new Nest(Value), new Nest(tensor) }; - Value = null; + ListValue = new() { new Nest(NodeValue), new Nest(tensor) }; + NodeValue = null; } else if(NestType == NestType.List) { @@ -125,7 +153,7 @@ namespace Tensorflow else //Empty { NestType = NestType.Node; - Value = tensor; + NodeValue = tensor; } } @@ -140,9 +168,9 @@ namespace Tensorflow else if (NestType == NestType.Node) { NestType = NestType.List; - ListValue = new() { new Nest(Value) }; + ListValue = new() { new Nest(NodeValue) }; ListValue.AddRange(tensors.Select(x => new Nest(x))); - Value = null; + NodeValue = null; } else if(NestType == NestType.List) { @@ -151,7 +179,7 @@ namespace Tensorflow else // empty { NestType = NestType.List; - ListValue = tensors.Select(x => new Nest(x)).ToList(); + ListValue = tensors.Select(x => new Nest(x) as INestStructure).ToList(); } } @@ -166,9 +194,9 @@ namespace Tensorflow else if(NestType == NestType.Node) { NestType = NestType.List; - ListValue = new() { new Nest(Value) }; + ListValue = new() { new Nest(NodeValue) }; ListValue.Insert(index, new Nest(tensor)); - Value = null; + NodeValue = null; } else { @@ -283,7 +311,7 @@ namespace Tensorflow => tensors?.SingleOrNull; public static implicit operator Tensor[](Tensors tensors) - => tensors.Flatten().ToArray(); + => tensors.Flatten().ToArray(); #endregion public static Tensors? FromNest(Nest nested) @@ -298,7 +326,7 @@ namespace Tensorflow public void Deconstruct(out Tensor a, out Tensors? b) { a = this.First(); - b = Length == 1? null : new Tensors(this.Skip(1)); + b = Length == 1? null : new Tensors(this.Skip(1).ToArray()); } public override string ToString() diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 6d1385ca..fb9bccf3 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -576,7 +576,7 @@ namespace Tensorflow public static HandleData get_resource_handle_data(Tensor graph_op) { var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output()); - return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data))); + return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data)); } public static void dismantle_graph(Graph graph) diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 1336e9af..8dbcf90d 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -25,6 +25,7 @@ using static Tensorflow.Binding; using static Tensorflow.Graphs.SubGraphUtility; using Tensorflow.Util; using Tensorflow.Common.Types; +using System.Diagnostics; namespace Tensorflow.Keras { @@ -485,7 +486,7 @@ namespace Tensorflow.Keras var first_flatted_input = flatted_inptus[0]; var time_steps = first_flatted_input.shape[0]; var batch = first_flatted_input.shape[1]; - var time_steps_t = (int)first_flatted_input.shape[0]; + var time_steps_t = tf.shape(first_flatted_input)[0]; foreach (var input_ in flatted_inptus) { @@ -704,7 +705,7 @@ namespace Tensorflow.Keras var input_ta = new List(); for (int i = 0; i < flatted_inptus.Count; i++) { - input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t)); + input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t)); } foreach(var (ta, input_) in zip(input_ta, flatted_inptus)) @@ -730,18 +731,15 @@ namespace Tensorflow.Keras (output_time_zero, _) = step_function(input_time_zero, constants is null ? initial_states : initial_states.MergeWith(constants)); - int output_ta_size = return_all_outputs ? time_steps_t : 1; + Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1); var output_ta = new List(); - for (int i = 0; i < output_time_zero.ToList().Count; i++) + foreach(var output in output_time_zero.Flatten()) { - var Out = output_time_zero.ToList()[i]; - output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); + output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape)); } var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); - - Func? masking_fn; Func? compute_masked_output = null; if (mask != null) @@ -750,7 +748,7 @@ namespace Tensorflow.Keras { mask = tf.reverse(mask, axis: new[] { 0 }); } - var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t); + var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t); mask_ta = mask_ta.unstack(mask); masking_fn = (time) => @@ -810,9 +808,9 @@ namespace Tensorflow.Keras masking_fn = null; } - Func cond = (time) => (time < time_steps_t); + Func cond = (time) => (time[0] < time_steps_t); int parallel_iterations = 32; - new_states = states; + Tensors final_outputs; if (masking_fn != null) { // Mask for the T output will be base on the output of T - 1. In the @@ -825,7 +823,7 @@ namespace Tensorflow.Keras var prev_output = flat_zero_output; var output_ta_t = output_ta; - Tensor _step(Tensor time) + Tensors _step(Tensors tensors) { /* RNN step function. @@ -838,23 +836,28 @@ namespace Tensorflow.Keras Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` */ + Tensor time = tensors[0]; + TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; + Tensors prev_output = tensors.GetShallow(2); + Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray()); + var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); // maybe set shape // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); var mask_t = masking_fn(time); - var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); + var (output, new_states) = step_function(current_input, states.MergeWith(constants)); // mask output var flat_output = Nest.Flatten(output).ToList(); - var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); + var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList(); // TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); // mask states - var flat_state = states.ToList(); - var flat_new_state = new_states_internal.ToList(); + var flat_state = states.Flatten().ToList(); + var flat_new_state = new_states.Flatten().ToList(); foreach (var (state, new_state) in zip(flat_state, flat_new_state)) { @@ -865,38 +868,37 @@ namespace Tensorflow.Keras } var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); - new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors(); + new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors(); var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - output_ta_t = zip(output_ta_t, flat_new_output).Select(item => - { - var (ta, out_) = item; - return ta.write(ta_index_to_write, out_); - }).ToList(); + Debug.Assert(flat_output.Count() == 1); + output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First()); - - new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); - - output_ta = output_ta_t; - new_states = new_states_internal; - return time + 1; + return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states) + .ToArray().ToTensors(); } - var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); + var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) } + .Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors(); + final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); + new_states = final_outputs.Skip(3).ToList(); } else { var output_ta_t = output_ta; new_states = states; - Tensor _step(Tensor time) + Tensors _step(Tensors tensors) { + Tensor time = tensors[0]; + TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray; + Tensors states = new Tensors(tensors.Skip(2).ToArray()); var flat_current_input = input_ta.Select(x => x.read(time)).ToList(); // maybe set shape // TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors(); - var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants)); + var (output, new_states) = step_function(current_input, states.MergeWith(constants)); var flat_state = new_states.Flatten().ToList(); - var flat_new_state = new_states_internal.Flatten().ToList(); + var flat_new_state = new_states.Flatten().ToList(); foreach (var (state, new_state) in zip(flat_state, flat_new_state)) { if (new_state is Tensor) @@ -906,24 +908,23 @@ namespace Tensorflow.Keras } var flat_output = Nest.Flatten(output); var ta_index_to_write = return_all_outputs ? time : tf.constant(0); - output_ta_t = zip(output_ta_t, flat_output).Select(item => - { - var (ta, out_) = item; - return ta.write(ta_index_to_write, out_); - }).ToList(); - - new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); - output_ta = output_ta_t; - new_states = new_states_internal; - return time + 1; + Debug.Assert(flat_output.Count() == 1); + output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First()); + + new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors(); + return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors(); } - var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations); + Debug.Assert(output_ta.Count == 1); + var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors(); + final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations); + new_states = final_outputs.Skip(2).ToList(); } - outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors()); - last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors()); - outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors(); - last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors(); + output_ta = new List { (final_outputs[1] as FakeTensorByTensorArray).TensorArray }; + outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors()); + last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors()); + outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors(); + last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors(); } Func set_shape; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Build.cs b/src/TensorFlowNET.Keras/Engine/Model.Build.cs index 69afdef9..23336383 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Build.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Build.cs @@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); graph.as_default(); var shapes = input_shape.ToShapeArray(); - var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x))); + var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray()); try { Call(x, training: false); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 185de4f4..d807b204 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine { var data_handler = new DataHandler(new DataHandlerArgs { - X = new Tensors(x), + X = new Tensors(x.ToArray()), Y = y, Model = this, StepsPerExecution = _steps_per_execution @@ -188,7 +188,7 @@ namespace Tensorflow.Keras.Engine { var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); + var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index bb8e18cc..76c592ad 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine var data_handler = new DataHandler(new DataHandlerArgs { - X = new Tensors(train_x), + X = new Tensors(train_x.ToArray()), Y = train_y, BatchSize = batch_size, InitialEpoch = initial_epoch, diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 905ea453..48c16e18 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine { var data = iterator.next(); var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; - var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); + var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs index 78d3dac9..d2669ccc 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs @@ -4,10 +4,11 @@ using System.Text; using Tensorflow.Common.Types; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Layers.Rnn { - public abstract class DropoutRNNCellMixin: RnnCellBase + public abstract class DropoutRNNCellMixin: Layer, IRnnCell { public float dropout; public float recurrent_dropout; @@ -17,6 +18,14 @@ namespace Tensorflow.Keras.Layers.Rnn } + public abstract GeneralizedTensorShape StateSize { get; } + public abstract GeneralizedTensorShape OutputSize { get; } + public abstract bool SupportOptionalArgs { get; } + public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype) + { + return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); + } + protected void _create_non_trackable_mask_cache() { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 0ebd7362..77f7d927 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -206,7 +206,6 @@ namespace Tensorflow.Keras.Layers.Rnn // append bacth dim state_spec_shape = new int[] { -1 }.concat(state_spec_shape); return new InputSpec(shape: state_spec_shape); - } // Check whether the input shape contains any nested shapes. It could be @@ -298,7 +297,7 @@ namespace Tensorflow.Keras.Layers.Rnn // cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) Func step; - bool is_tf_rnn_cell = _cell.IsTFRnnCell; + bool is_tf_rnn_cell = false; if (constants is not null) { if (!_cell.SupportOptionalArgs) @@ -310,8 +309,8 @@ namespace Tensorflow.Keras.Layers.Rnn step = (inputs, states) => { - constants = new Tensors(states.TakeLast(_num_constants)); - states = new Tensors(states.SkipLast(_num_constants)); + constants = new Tensors(states.TakeLast(_num_constants).ToArray()); + states = new Tensors(states.SkipLast(_num_constants).ToArray()); states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); return (output, new_states.Single); @@ -395,12 +394,12 @@ namespace Tensorflow.Keras.Layers.Rnn { if (_num_constants != 0) { - initial_state = new Tensors(inputs.Skip(1)); + initial_state = new Tensors(inputs.Skip(1).ToArray()); } else { - initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants)); - constants = new Tensors(inputs.TakeLast(_num_constants)); + initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray()); + constants = new Tensors(inputs.TakeLast(_num_constants).ToArray()); } if (len(initial_state) == 0) initial_state = null; @@ -558,36 +557,14 @@ namespace Tensorflow.Keras.Layers.Rnn protected Tensors get_initial_state(Tensors inputs) { - var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state"); - var input = inputs[0]; - var input_shape = inputs.shape; + var input_shape = array_ops.shape(inputs); var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; var dtype = input.dtype; - Tensors init_state = new Tensors(); - - if(get_initial_state_fn != null) - { - init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype }); - - } - //if (_cell is RnnCellBase rnn_base_cell) - //{ - // init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype); - //} - else - { - init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype); - } + Tensors init_state = _cell.GetInitialState(null, batch_size, dtype); return init_state; } - - // Check whether the state_size contains multiple states. - public static bool is_multiple_state(GeneralizedTensorShape state_size) - { - return state_size.Shapes.Length > 1; - } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs deleted file mode 100644 index 751312e5..00000000 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs +++ /dev/null @@ -1,24 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; -using Tensorflow.Common.Types; -using Tensorflow.Keras.ArgsDefinition; -using Tensorflow.Keras.ArgsDefinition.Rnn; -using Tensorflow.Keras.Engine; -using Tensorflow.Keras.Utils; - -namespace Tensorflow.Keras.Layers.Rnn -{ - public abstract class RnnCellBase: Layer, IRnnCell - { - public RnnCellBase(LayerArgs args) : base(args) { } - public abstract GeneralizedTensorShape StateSize { get; } - public abstract GeneralizedTensorShape OutputSize { get; } - public abstract bool IsTFRnnCell { get; } - public abstract bool SupportOptionalArgs { get; } - public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype) - { - return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); - } - } -} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index 39610ff5..3b4b9419 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -7,6 +7,7 @@ using Tensorflow.Keras.Saving; using Tensorflow.Common.Types; using Tensorflow.Common.Extensions; using Tensorflow.Keras.Utils; +using Tensorflow.Graphs; namespace Tensorflow.Keras.Layers.Rnn { @@ -28,7 +29,6 @@ namespace Tensorflow.Keras.Layers.Rnn public override GeneralizedTensorShape StateSize => _state_size; public override GeneralizedTensorShape OutputSize => _output_size; - public override bool IsTFRnnCell => true; public override bool SupportOptionalArgs => false; public SimpleRNNCell(SimpleRNNCellArgs args) : base(args) @@ -98,7 +98,6 @@ namespace Tensorflow.Keras.Layers.Rnn { prev_output = math_ops.multiply(prev_output, rec_dp_mask); } - var tmp = _recurrent_kernel.AsTensor(); Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor()); if (_args.Activation != null) @@ -117,9 +116,9 @@ namespace Tensorflow.Keras.Layers.Rnn } } - public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) + public Tensors get_initial_state(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) { - return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value); + return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); } } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index 56634853..fb74d6d2 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Rnn public class StackedRNNCells : Layer, IRnnCell { public IList Cells { get; set; } - public bool reverse_state_order; + public bool _reverse_state_order; public StackedRNNCells(StackedRNNCellsArgs args) : base(args) { @@ -23,22 +23,11 @@ namespace Tensorflow.Keras.Layers.Rnn { args.Kwargs = new Dictionary(); } - foreach (var cell in args.Cells) - { - //Type type = cell.GetType(); - //var CallMethodInfo = type.GetMethod("Call"); - //if (CallMethodInfo == null) - //{ - // throw new ValueError( - // "All cells must have a `Call` method. " + - // $"Received cell without a `Call` method: {cell}"); - //} - } Cells = args.Cells; - reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); + _reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); - if (reverse_state_order) + if (_reverse_state_order) { throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " + "be deprecated. Please update the code to work with the " + @@ -47,49 +36,37 @@ namespace Tensorflow.Keras.Layers.Rnn } } + public bool SupportOptionalArgs => false; + public GeneralizedTensorShape StateSize { get { - GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count); - if (reverse_state_order && Cells.Count > 0) + if (_reverse_state_order) { - var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell)); - foreach (var cell in idxAndCell) - { - state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); - } + var state_sizes = Cells.Reverse().Select(cell => cell.StateSize); + return new GeneralizedTensorShape(new Nest(state_sizes.Select(s => new Nest(s)))); } else { - //foreach (var cell in Cells) - //{ - // state_size.Shapes.add(cell.StateSize.Shapes.First()); - - //} - var idxAndCell = Cells.Select((cell, idx) => (idx, cell)); - foreach (var cell in idxAndCell) - { - state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First(); - } + var state_sizes = Cells.Select(cell => cell.StateSize); + return new GeneralizedTensorShape(new Nest(state_sizes.Select(s => new Nest(s)))); } - return state_size; } } - public object output_size + public GeneralizedTensorShape OutputSize { get { - var lastCell = Cells.LastOrDefault(); - if (lastCell.OutputSize.ToSingleShape() != -1) + var lastCell = Cells.Last(); + if(lastCell.OutputSize is not null) { return lastCell.OutputSize; } - else if (RNN.is_multiple_state(lastCell.StateSize)) + else if (RnnUtils.is_multiple_state(lastCell.StateSize)) { return lastCell.StateSize.First(); - //throw new NotImplementedException(""); } else { @@ -98,79 +75,65 @@ namespace Tensorflow.Keras.Layers.Rnn } } - public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null) + public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) { - var cells = reverse_state_order ? Cells.Reverse() : Cells; - Tensors initial_states = new Tensors(); + var cells = _reverse_state_order ? Cells.Reverse() : Cells; + List initial_states = new List(); foreach (var cell in cells) { - var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state"); - if (get_initial_state_fn != null) - { - var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype }); - initial_states.Add(result); - } - else - { - initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value)); - } + initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype)); } - return initial_states; + return new Tensors(initial_states); } - protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) + protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null) { // Recover per-cell states. - var state_size = reverse_state_order ? StateSize.Reverse() : StateSize; - var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten(); + var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize; + var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray()); - - var new_nest_states = new Tensors(); + var new_nest_states = Nest.Empty; // Call the cells in order and store the returned states. - foreach (var (cell, states) in zip(Cells, nested_states)) + foreach (var (cell, internal_states) in zip(Cells, nested_states)) { - // states = states if tf.nest.is_nested(states) else [states] - var type = cell.GetType(); - bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null; - state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state; - RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs; Tensors? constants = rnn_optional_args?.Constants; Tensors new_states; - (inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); + (inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants }); - new_nest_states.Add(new_states); + new_nest_states = new_nest_states.MergeWith(new_states); } - new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray(); - return new Nest(new List> { - new Nest(new List> { new Nest(inputs.Single()) }), new Nest(new_nest_states) }) - .ToTensors(); + return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray()))); } - - - public void build() + public override void build(KerasShapesWrapper input_shape) { - built = true; - // @tf_utils.shape_type_conversion - // def build(self, input_shape) : - // if isinstance(input_shape, list) : - // input_shape = input_shape[0] - // for cell in self.cells: - // if isinstance(cell, Layer) and not cell.built: - // with K.name_scope(cell.name): - // cell.build(input_shape) - // cell.built = True - // if getattr(cell, 'output_size', None) is not None: - // output_dim = cell.output_size - // elif _is_multiple_state(cell.state_size) : - // output_dim = cell.state_size[0] - // else: - // output_dim = cell.state_size - // input_shape = tuple([input_shape[0]] + - // tensor_shape.TensorShape(output_dim).as_list()) - // self.built = True + var shape = input_shape.ToSingleShape(); + foreach(var cell in Cells) + { + if(cell is Layer layer && !layer.Built) + { + // ignored the name scope. + layer.build(shape); + layer.Built = true; + } + GeneralizedTensorShape output_dim; + if(cell.OutputSize is not null) + { + output_dim = cell.OutputSize; + } + else if (RnnUtils.is_multiple_state(cell.StateSize)) + { + output_dim = cell.StateSize.First(); + } + else + { + output_dim = cell.StateSize; + } + shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray()); + } + this.Built = true; } public override IKerasConfig get_config() @@ -198,14 +161,5 @@ namespace Tensorflow.Keras.Layers.Rnn // deserialize_layer(cell_config, custom_objects = custom_objects)) // return cls(cells, **config) } - - public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null) - { - throw new NotImplementedException(); - } - - public GeneralizedTensorShape OutputSize => throw new NotImplementedException(); - public bool IsTFRnnCell => true; - public bool SupportOptionalArgs => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs index 3109eb77..7ff3f9fb 100644 --- a/src/TensorFlowNET.Keras/Utils/RnnUtils.cs +++ b/src/TensorFlowNET.Keras/Utils/RnnUtils.cs @@ -10,20 +10,21 @@ namespace Tensorflow.Keras.Utils { internal static class RnnUtils { - internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) + internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype) { Func create_zeros; create_zeros = (GeneralizedTensorShape unnested_state_size) => { var flat_dims = unnested_state_size.ToSingleShape().dims; - var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray(); - return array_ops.zeros(new Shape(init_state_size), dtype: dtype); + var init_state_size = new Tensor[] { batch_size_tensor }. + Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray(); + return array_ops.zeros(init_state_size, dtype: dtype); }; // TODO(Rinne): map structure with nested tensors. - if(state_size.Shapes.Length > 1) + if(state_size.TotalNestedCount > 1) { - return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s)))); + return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray()); } else { @@ -32,11 +33,11 @@ namespace Tensorflow.Keras.Utils } - internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype) + internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype) { - if (inputs != null) + if (inputs is not null) { - batch_size = inputs.shape[0]; + batch_size = array_ops.shape(inputs)[0]; dtype = inputs.dtype; } return generate_zero_filled_state(batch_size, cell.StateSize, dtype); @@ -77,17 +78,27 @@ namespace Tensorflow.Keras.Utils Debug.Assert(initial_state is null && constants is null); if(num_constants > 0) { - constants = inputs.TakeLast(num_constants).ToTensors(); - inputs = inputs.SkipLast(num_constants).ToTensors(); + constants = inputs.TakeLast(num_constants).ToArray().ToTensors(); + inputs = inputs.SkipLast(num_constants).ToArray().ToTensors(); } if(inputs.Length > 1) { - initial_state = inputs.Skip(1).ToTensors(); - inputs = inputs.Take(1).ToTensors(); + initial_state = inputs.Skip(1).ToArray().ToTensors(); + inputs = inputs.Take(1).ToArray().ToTensors(); } } return (inputs, initial_state, constants); } + + /// + /// Check whether the state_size contains multiple states. + /// + /// + /// + public static bool is_multiple_state(GeneralizedTensorShape state_size) + { + return state_size.TotalNestedCount > 1; + } } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs index 6d7182e0..23dc1d44 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs @@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var i = tf.constant(2); var j = tf.constant(3); - Func c = (x) => tf.less(x[0] + x[1], 10); - Func b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; + Func c = (x) => tf.less(x[0] + x[1], 10); + Func b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) }; var r = tf.while_loop(c, b, new[] { i, j }); Assert.AreEqual(5, (int)r[0]); Assert.AreEqual(6, (int)r[1]); diff --git a/tools/Tensorflow.CodeGen/FunctionGenerator.cs b/tools/Tensorflow.CodeGen/FunctionGenerator.cs index 93f9ea4e..186e6a27 100644 --- a/tools/Tensorflow.CodeGen/FunctionGenerator.cs +++ b/tools/Tensorflow.CodeGen/FunctionGenerator.cs @@ -21,7 +21,8 @@ namespace Tensorflow.CodeGen { sb.Append("Operation "); } - else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) { sb.Append("Tensor "); } @@ -70,7 +71,8 @@ namespace Tensorflow.CodeGen { sb.AppendLine("return null;"); } - else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) { sb.AppendLine("return _fast_path_result[0];"); } @@ -149,7 +151,8 @@ namespace Tensorflow.CodeGen { sb.AppendLine("return _op;"); } - else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) { sb.AppendLine("return _result[0];"); } @@ -174,7 +177,7 @@ namespace Tensorflow.CodeGen { argName = $"{argName}_"; } - if (!string.IsNullOrEmpty(arg.NumberAttr)) + if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr)) { sb.Append($"Tensors {argName}, "); } @@ -273,7 +276,8 @@ namespace Tensorflow.CodeGen { sb.Append("Operation "); } - else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) { sb.Append("Tensor "); } @@ -366,6 +370,13 @@ namespace Tensorflow.CodeGen sb.Append($"\"{attr.Name}\", {attrRealName}, "); } } + else if(attr.Type == "list(type)") + { + if (op.InputArg.Any(x => x.TypeListAttr == attr.Name)) + { + continue; + } + } else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name)) { bool found = false; @@ -408,7 +419,8 @@ namespace Tensorflow.CodeGen { sb.AppendLine("return null;"); } - else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)) + else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr) + && string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr)) { sb.AppendLine("return _result[0];"); } diff --git a/tools/Tensorflow.CodeGen/Program.cs b/tools/Tensorflow.CodeGen/Program.cs index f9d44ce8..cea52e0b 100644 --- a/tools/Tensorflow.CodeGen/Program.cs +++ b/tools/Tensorflow.CodeGen/Program.cs @@ -5,7 +5,7 @@ using System.Text; using System.Xml.Linq; using Tensorflow.CodeGen; -GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops", +GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops_v2", @"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops", @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api", @"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt"); diff --git a/tools/Tensorflow.CodeGen/Utils.cs b/tools/Tensorflow.CodeGen/Utils.cs index d3f30d9f..19de6c0e 100644 --- a/tools/Tensorflow.CodeGen/Utils.cs +++ b/tools/Tensorflow.CodeGen/Utils.cs @@ -155,6 +155,10 @@ namespace Tensorflow.CodeGen } else if (attr.Type == "list(type)") { + if(op.InputArg.Any(x => x.TypeListAttr == attr.Name)) + { + continue; + } if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type) { List values = new(); @@ -231,11 +235,11 @@ namespace Tensorflow.CodeGen } else if (attr.Type == "func") { - res.Add((attr.Name, "Func", "NOVALUE")); + res.Add((attr.Name, "object", "NOVALUE")); } else if (attr.Type == "list(func)") { - res.Add((attr.Name, "Func[]", "NOVALUE")); + res.Add((attr.Name, "object[]", "NOVALUE")); } else if (attr.Type == "tensor") {