Browse Source

feat: support model building with RNN.

tags/v0.110.0-LSTM-Model
Yaohui Liu 2 years ago
parent
commit
f1fbcf2016
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
53 changed files with 3662 additions and 507 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +5
    -5
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  3. +6
    -1
      src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
  4. +20
    -0
      src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
  5. +42
    -98
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  6. +13
    -0
      src/TensorFlowNET.Core/Common/Types/INestStructure.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
  8. +79
    -38
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  9. +4
    -0
      src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
  10. +11
    -6
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  11. +4
    -0
      src/TensorFlowNET.Core/Common/Types/NestNode.cs
  12. +2
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  13. +2
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  14. +13
    -0
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  15. +89
    -0
      src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs
  16. +2
    -2
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  17. +13
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  18. +2
    -2
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  19. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  20. +9
    -3
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  21. +4
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  22. +49
    -0
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  23. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  24. +3
    -2
      src/TensorFlowNET.Core/Operations/Operation.cs
  25. +2
    -4
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  26. +175
    -4
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  27. +24
    -0
      src/TensorFlowNET.Core/Operations/array_ops.cs
  28. +5
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  29. +77
    -0
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  30. +985
    -81
      src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
  31. +1227
    -0
      src/TensorFlowNET.Core/Operations/gen_list_ops.cs
  32. +111
    -0
      src/TensorFlowNET.Core/Operations/list_ops.cs
  33. +16
    -4
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
  34. +401
    -0
      src/TensorFlowNET.Core/Operations/while_v2.cs
  35. +7
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  36. +24
    -0
      src/TensorFlowNET.Core/Tensors/TensorArray.cs
  37. +41
    -13
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  38. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  39. +48
    -47
      src/TensorFlowNET.Keras/BackendImpl.cs
  40. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Build.cs
  41. +2
    -2
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  42. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  43. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  44. +10
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  45. +8
    -31
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  46. +0
    -24
      src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
  47. +3
    -4
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  48. +53
    -99
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  49. +23
    -12
      src/TensorFlowNET.Keras/Utils/RnnUtils.cs
  50. +2
    -2
      test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs
  51. +18
    -6
      tools/Tensorflow.CodeGen/FunctionGenerator.cs
  52. +1
    -1
      tools/Tensorflow.CodeGen/Program.cs
  53. +6
    -2
      tools/Tensorflow.CodeGen/Utils.cs

+ 14
- 0
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
@@ -50,6 +51,19 @@ namespace Tensorflow
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
{
byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);



+ 5
- 5
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -46,10 +46,10 @@ namespace Tensorflow
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ namespace Tensorflow
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,


+ 6
- 1
src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs View File

@@ -18,7 +18,12 @@ namespace Tensorflow.Common.Extensions
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}


+ 20
- 0
src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}

+ 42
- 98
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -5,136 +5,80 @@ using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
public class GeneralizedTensorShape: Nest<Shape>
{
public TensorShapeConfig[] Shapes { get; set; }
/// <summary>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim, int size = 1)
{
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
}
////public TensorShapeConfig[] Shapes { get; set; }
///// <summary>
///// create a single-dim generalized Tensor shape.
///// </summary>
///// <param name="dim"></param>
//public GeneralizedTensorShape(int dim, int size = 1)
//{
// var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
// Shapes = Enumerable.Repeat(elem, size).ToArray();
// //Shapes = new TensorShapeConfig[size];
// //Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
// //Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
// ////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
//}

public GeneralizedTensorShape(Shape shape)
public GeneralizedTensorShape(Shape value, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(TensorShapeConfig shape)
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(TensorShapeConfig[] shapes)
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
Shapes = shapes;
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(IEnumerable<Shape> shape)
public GeneralizedTensorShape(Nest<Shape> other)
{
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
if (Shapes.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var shape_config = Shapes[0];
Debug.Assert(shape_config is not null);
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
return shapes[0];
}

public long ToNumber()
{
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var res = Shapes[0].Items[0];
return res is null ? -1 : res.Value;
}

public Shape[] ToShapeArray()
{
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public IEnumerable<long?> Flatten()
{
List<long?> result = new List<long?>();
foreach(var shapeConfig in Shapes)
{
result.AddRange(shapeConfig.Items);
}
return result;
}
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
{
List<Nest<TOut>> lists = new();
foreach(var shapeConfig in Shapes)
{
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
}
return new Nest<TOut>(lists);
}

public Nest<long?> AsNest()
{
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
{
if (config.Items.Length == 0)
{
return Nest<long?>.Empty;
}
else if (config.Items.Length == 1)
{
return new Nest<long?>(config.Items[0]);
}
else
{
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
}
}

if(Shapes.Length == 0)
{
return Nest<long?>.Empty;
}
else if(Shapes.Length == 1)
{
return DealWithSingleShape(Shapes[0]);
}
else
{
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
return shapes[0].dims[0];
}


public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
foreach (var shape in Shapes)
{
yield return shape.Items;
}
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

IEnumerator IEnumerable.GetEnumerator()
public static implicit operator GeneralizedTensorShape(Shape shape)
{
return GetEnumerator();
return new GeneralizedTensorShape(shape);
}
}
}

src/TensorFlowNET.Core/Common/Types/INest.cs → src/TensorFlowNET.Core/Common/Types/INestStructure.cs View File

@@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.

+ 1
- 1
src/TensorFlowNET.Core/Common/Types/Nest.Static.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Common.Types
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}


+ 79
- 38
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -28,27 +28,58 @@ namespace Tensorflow.Common.Types
public static Nest<T> Empty => _empty;
public NestType NestType { get; protected set; }
public string? Name { get; set; }
public T? Value { get; protected set; }
public List<Nest<T>>? ListValue { get; protected set; }
public Dictionary<string, Nest<T>>? DictValue { get; protected set; }
public T? NodeValue { get; protected set; }
public List<INestStructure<T>>? ListValue { get; protected set; }
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; }

public int ShallowNestedCount
{
get
{
if (NestType == NestType.Empty)
{
return 0;
}
else if (NestType == NestType.Node)
{
return 1;
}
else if (NestType == NestType.List)
{
return ListValue!.Count;
}
else // dict
{
return DictValue!.Count;
}
}
}

public int TotalNestedCount
{
get
{
return Flatten().Count();
}
}

protected Nest() { }

public Nest(T value, string? name = null)
{
Value = value;
NodeValue = value;
Name = name;
NestType = NestType.Node;
}

public Nest(IEnumerable<Nest<T>> values, string? name = null)
public Nest(IEnumerable<INestStructure<T>> values, string? name = null)
{
ListValue = values.ToList();
Name = name;
NestType = NestType.List;
}

public Nest(Dictionary<string, Nest<T>> value, string? name = null)
public Nest(Dictionary<string, INestStructure<T>> value, string? name = null)
{
DictValue = value;
Name = name;
@@ -58,7 +89,7 @@ namespace Tensorflow.Common.Types
public Nest(Nest<T> other)
{
NestType = other.NestType;
Value = other.Value;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
@@ -78,17 +109,17 @@ namespace Tensorflow.Common.Types
/// </summary>
/// <param name="flatItems"></param>
/// <returns></returns>
public virtual Nest<T> PackSequence(T[] flatItems)
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems)
{
if(flatItems.Length == 0)
{
return Nest<T>.Empty;
return Nest<TOut>.Empty;
}
int index = 0;
return PackSequenceInternal(this, flatItems, ref index);
}

private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index)
private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index)
{
if(template.NestType == NestType.Node)
{
@@ -96,25 +127,25 @@ namespace Tensorflow.Common.Types
{
throw new InvalidArgumentError("The template and flat items are not matched.");
}
return new Nest<T>(flatItems[index++]);
return new Nest<TOut>(flatItems[index++]);
}
else if(template.NestType == NestType.List)
{
List<Nest<T>> nestedObjects = new List<Nest<T>>();
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>();
for (int i = 0; i < template.ListValue!.Count; i++)
{
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index));
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index));
}
return new Nest<T>(nestedObjects);
return new Nest<TOut>(nestedObjects);
}
else if(template.NestType == NestType.Node)
{
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>();
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>();
foreach(var (key, value) in template.DictValue!)
{
dict[key] = PackSequenceInternal(value, flatItems, ref index);
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index);
}
return new Nest<T>(dict);
return new Nest<TOut>(dict);
}
// Consider Empty as invalid type.
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
@@ -223,10 +254,10 @@ namespace Tensorflow.Common.Types
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T>
{
var nested = input.AsNest();
return ReduceInternal(nested);
return ReduceInternal(nested).AsNest();
}

private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
{
if(node.NestType == NestType.Empty)
{
@@ -234,15 +265,15 @@ namespace Tensorflow.Common.Types
}
else if(node.NestType == NestType.Node)
{
return node.Value!.AsNest();
return node.NodeValue!.AsNest();
}
else if(node.NestType == NestType.List)
{
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x)));
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest())));
}
else // Dictionary type
{
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value)));
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest())));
}
}

@@ -252,7 +283,7 @@ namespace Tensorflow.Common.Types
{
if(index == 0)
{
result = node.Value!;
result = node.NodeValue!;
return true;
}
result = default(T);
@@ -264,7 +295,7 @@ namespace Tensorflow.Common.Types
{
if(index == 0)
{
return FindInternal(item, index, out result);
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
@@ -277,7 +308,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
return FindInternal(item, index, out result);
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
@@ -297,7 +328,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
node.Value = newValue;
node.NodeValue = newValue;
return true;
}
return false;
@@ -308,7 +339,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
return SetInternal(item, index, newValue);
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
@@ -320,7 +351,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
return SetInternal(item, index, newValue);
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
@@ -336,13 +367,13 @@ namespace Tensorflow.Common.Types
{
if (node.NestType == NestType.Node)
{
yield return node.Value!;
yield return node.NodeValue!;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
foreach(var val in FlattenInternal(item))
foreach(var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
@@ -352,7 +383,7 @@ namespace Tensorflow.Common.Types
{
foreach (var item in node.DictValue!.Values)
{
foreach (var val in FlattenInternal(item))
foreach (var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
@@ -364,23 +395,23 @@ namespace Tensorflow.Common.Types
{
if (NestType == NestType.Node)
{
return new Nest<TOut>(func(Value!));
return new Nest<TOut>(func(NodeValue!));
}
else if (NestType == NestType.List)
{
List<Nest<TOut>> outs = new List<Nest<TOut>>();
foreach (var item in ListValue!)
{
outs.Add(item.MapStructureInternal(func));
outs.Add(item.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else if (NestType == NestType.Dictionary)
{
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>();
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>();
foreach (var (key, value) in DictValue!)
{
outs.Add(key, value.MapStructureInternal(func));
outs.Add(key, value.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
@@ -417,14 +448,14 @@ namespace Tensorflow.Common.Types
}
if (node.NestType == NestType.Node)
{
sb.Append(node.Value!.ToString());
sb.Append(node.NodeValue!.ToString());
}
else if (node.NestType == NestType.List)
{
sb.Append("[");
for(int i = 0; i < node.ListValue!.Count; i++)
{
WriteString(node.ListValue![i], sb);
WriteString(node.ListValue![i].AsNest(), sb);
if(i != node.ListValue!.Count - 1)
{
sb.Append(", ");
@@ -440,7 +471,7 @@ namespace Tensorflow.Common.Types
foreach (var (key, value) in node.DictValue!)
{
sb.Append($"{key}: ");
WriteString(value, sb);
WriteString(value.AsNest(), sb);
if (i != count - 1)
{
sb.Append(", ");
@@ -454,5 +485,15 @@ namespace Tensorflow.Common.Types
sb.Append("<empty>");
}
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 });
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 });
}
}
}

+ 4
- 0
src/TensorFlowNET.Core/Common/Types/NestDictionary.cs View File

@@ -6,7 +6,11 @@ namespace Tensorflow.Common.Types
{
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull
{
public NestType NestType => NestType.Dictionary;
public IDictionary<TKey, TValue> Value { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;
public NestDictionary(IDictionary<TKey, TValue> dict)
{
Value = dict;


+ 11
- 6
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -10,29 +10,34 @@ namespace Tensorflow.Common.Types
/// <typeparam name="T"></typeparam>
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
{
public List<T> Value { get; set; }
public NestType NestType => NestType.List;
public List<T> Values { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;
public NestList(IEnumerable<T> values)
{
Value = new List<T>(values);
Values = new List<T>(values);
}
public IEnumerable<T> Flatten()
{
return Value;
return Values;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x)));
return new NestList<TOut>(Values.Select(x => func(x)));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value.Select(x => new Nest<T>(x)));
return new Nest<T>(Values.Select(x => new Nest<T>(x)));
}

// Enumerator implementation
public IEnumerator<T> GetEnumerator()
{
return Value.GetEnumerator();
return Values.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()


+ 4
- 0
src/TensorFlowNET.Core/Common/Types/NestNode.cs View File

@@ -10,7 +10,11 @@ namespace Tensorflow.Common.Types
/// <typeparam name="T"></typeparam>
public class NestNode<T> : INestStructure<T>
{
public NestType NestType => NestType.Node;
public T Value { get; set; }
public int ShallowNestedCount => 1;

public int TotalNestedCount => 1;
public NestNode(T value)
{
Value = value;


+ 2
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -161,8 +161,8 @@ namespace Tensorflow
break;
}

yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray()));
}
}



+ 2
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -359,6 +359,8 @@ namespace Tensorflow.Eager
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
else if(value is string str)
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;


+ 13
- 0
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -1,4 +1,5 @@
using System.Linq;
using Tensorflow.Eager;

namespace Tensorflow.Framework.Models
{
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models
shapes.Insert(0, dim);
return new TensorSpec(shapes.ToArray(), _dtype);
}

public static TensorSpec FromTensor(Tensor tensor, string? name = null)
{
if(tensor is EagerTensor)
{
return new TensorSpec(tensor.shape, tensor.dtype, name);
}
else
{
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name);
}
}
}
}

+ 89
- 0
src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs View File

@@ -0,0 +1,89 @@
using Tensorflow.Graphs;

namespace Tensorflow.Framework
{
internal static class auto_control_deps_utils
{
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs";
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph)
{
List<int> result = new List<int>();
// A cache to store the read only resource inputs of an Op.
// Operation -> ObjectIdentitySet of resource handles.
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs =
new Dictionary<Operation, HashSet<Tensor>>();

for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++)
{
Tensor t = func_graph.Inputs[inputIndex];
if (t.dtype != dtypes.resource)
continue;

bool readOnly = true;
foreach (var op in t.consumers())
{
if (opReadOnlyResourceInputs.ContainsKey(op))
{
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
else
{
List<int> indices = _get_read_only_resource_input_indices_op(op);
opReadOnlyResourceInputs[op] = new HashSet<Tensor>(
indices.Select(i => op.inputs[i]));
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
}

if (readOnly)
result.Add(inputIndex);
}

return result;
}

private static List<int> _get_read_only_resource_input_indices_op(Operation op)
{
// ignore the RESOURCE_READ_OPS

int[] read_only_input_indices;

try
{
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR);
}
catch (InvalidArgumentError)
{
return new List<int>();
}

int read_only_index = 0;
List<int> result = new();
for (int i = 0; i < op.inputs.Length; i++)
{
if (read_only_index >= read_only_input_indices.Length)
{
break;
}
if (op.inputs[i].dtype != dtypes.resource)
{
continue;
}
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index])
{
result.Add(i);
read_only_index++;
}
}
return result;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -42,10 +42,10 @@ namespace Tensorflow.Framework
func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());

var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);



+ 13
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable<IVariableV1> Variables => func_graph.Variables;
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables;
internal NameAttrList AsNameAttrList
{
get
{
NameAttrList ret = new() { Name = this.Name };
foreach (var (name, value) in _attrs)
{
ret.Attr[name] = value;
}
return ret;
}
}

public ConcreteFunction(string name)
{


+ 2
- 2
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
public Dictionary<string, AttrValue> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
internal Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable
var flat_func_args = nest.flatten(func_args as object);
var flat_func_kwargs = nest.flatten(func_kwargs as object);
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
.Where(x => x is Tensor).Select(x => (Tensor)x));
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray());

//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -129,7 +129,7 @@ namespace Tensorflow
}
}

protected Graph outer_graph;
internal Graph outer_graph;
public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions;
public SafeGraphHandle c_graph => _handle;


+ 9
- 3
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -7,13 +7,19 @@ namespace Tensorflow.Keras.Layers.Rnn
{
public interface IRnnCell: ILayer
{
GeneralizedTensorShape StateSize { get; }
GeneralizedTensorShape OutputSize { get; }
bool IsTFRnnCell { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
GeneralizedTensorShape? StateSize { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
GeneralizedTensorShape? OutputSize { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
/// </summary>
bool SupportOptionalArgs { get; }
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype);
}
}

+ 4
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -181,6 +181,10 @@ namespace Tensorflow
{
throw new NotImplementedException();
}
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
throw new NotImplementedException();
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();


+ 49
- 0
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -15,9 +15,11 @@
******************************************************************************/

using Google.Protobuf;
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using static Tensorflow.Binding;
using static Tensorflow.OpDef.Types;

@@ -420,6 +422,12 @@ namespace Tensorflow
case "list(shape)":
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def)));
break;
case "func":
attr_value.Func = _MakeFunc(value, attr_def.Name);
break;
case "list(func)":
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name));
break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
}
@@ -427,6 +435,47 @@ namespace Tensorflow
return attr_value;
}

private NameAttrList _MakeFunc(object func, string arg_name)
{
if(func is NameAttrList attrList)
{
return attrList;
}
NameAttrList fn_attr;
if(func is string funcStr)
{
fn_attr = new NameAttrList() { Name = funcStr };
}
else if(func is ConcreteFunction concrete)
{
concrete.AddTograph(ops.get_default_graph());
fn_attr = concrete.AsNameAttrList;
}
else if(func is EagerDefinedFunction eager)
{
eager.AddToGraph(ops.get_default_graph());
fn_attr = new NameAttrList() { Name = eager.Name };
}
else
{
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}");
}
return fn_attr;
}

private List<NameAttrList> _MakeFuncList(object funcList, string arg_name)
{
List<NameAttrList> res = new List<NameAttrList>();
if(funcList is IEnumerable enumerable)
{
foreach(var func in enumerable)
{
res.Add(_MakeFunc(func, arg_name));
}
}
return res;
}

private bool _IsListParameter(ArgDef arg)
{
if (!String.IsNullOrEmpty(arg.NumberAttr))


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow
return num;
}

protected Tensor[] _outputs;
internal Tensor[] _outputs;
public virtual Tensor[] outputs => _outputs;
public Tensor output => _outputs.FirstOrDefault();



+ 3
- 2
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -46,9 +46,9 @@ namespace Tensorflow
/// </summary>
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
protected IntPtr _handle; // _c_op in python

private readonly Graph _graph;
protected Graph _graph;

internal Func<Operation, object[], Tensor[]> _gradient_function;

@@ -69,6 +69,7 @@ namespace Tensorflow
//private OperationDescription _op_desc;

public NodeDef node_def => GetNodeDef();
protected Operation() { }

public Operation(IntPtr handle, Graph g = null)
{


+ 2
- 4
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -38,10 +39,6 @@ namespace Tensorflow.Operations

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public bool _dynamic_size;
public Shape _element_shape;

public List<Tensor> _colocate_with;

Tensor _handle;
public override Tensor handle => _handle;
@@ -56,6 +53,7 @@ namespace Tensorflow.Operations
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_size = size;
_flow = constant_op.constant(0);
_infer_shape = infer_shape;
_element_shape = element_shape ?? Shape.Null;


+ 175
- 4
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -16,7 +16,9 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using static Tensorflow.Binding;

@@ -33,18 +35,18 @@ namespace Tensorflow.Operations
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public bool colocate_with_first_write_call => _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public bool infer_shape => _infer_shape;
public bool _dynamic_size;
public override bool infer_shape => _infer_shape;
public List<Shape> _element_shape;

public List<Tensor> _colocate_with;

internal Tensor _handle;
public Tensor handle => _handle;
public override Tensor handle => _handle;
internal Tensor _flow;
public override Tensor flow => _flow;

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -55,6 +57,7 @@ namespace Tensorflow.Operations
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_dtype = dtype;
_size = size;

_colocate_with_first_write_call = colocate_with_first_write_call;
if (colocate_with_first_write_call)
@@ -235,4 +238,172 @@ namespace Tensorflow.Operations
return value;
}
}

public class _GraphTensorArrayV2 : TensorArray
{
internal TF_DataType _dtype;
public override TF_DataType dtype => _dtype;

/// <summary>
/// Used to keep track of what tensors the TensorArray should be
/// colocated with. We choose to colocate the TensorArray with the
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public Shape _element_shape;

public List<Tensor> _colocate_with;

internal Tensor _handle;
public override Tensor handle => _handle;
internal Tensor _flow;
public override Tensor flow => _flow;

public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
Debug.Assert(handle is null);
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_size = size;

if(flow is not null && flow.dtype != dtypes.variant)
{
throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead");
}
if(flow is null && size is null)
{
throw new ValueError("Argument `size` must be provided if argument `flow` is not provided.");
}
if(flow is not null && size is not null)
{
throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time.");
}
if(flow is not null && element_shape is not null)
{
throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time.");
}

_dtype = dtype;

_element_shape = element_shape;
_infer_shape = infer_shape;
tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope =>
{
if (flow is null)
{
_flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name);
}
else
{
_flow = flow;
}
});

_colocate_with_first_write_call = false;
_colocate_with = null;
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray());
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow);
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public override Tensor read<T>(T index, string name = null)
{
if(index is Tensor tensor)
{
return read(tensor, name);
}
else
{
throw new TypeError("Please use non-generic method instead.");
}
}

public Tensor read(Tensor index, string name = null)
{
return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope =>
{
return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name);
});
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name);

return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor);
}

private Tensor size(string name = null)
{
if(!_dynamic_size && _size is not null)
{
return ops.convert_to_tensor(_size, dtypes.int32);
}
else
{
return gen_list_ops.tensor_list_length(_flow, name);
}
}

public override Tensor stack(string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate
{
int ta_size;
if(!_dynamic_size && (_size is not null))
{
ta_size = (int)tensor_util.constant_value(_size);
}
else
{
ta_size = -1;
}
var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape);
return value;
});
}

public override Tensor gather(Tensor indices, string name = null)
{
return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name);
}
}
}

+ 24
- 0
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -119,6 +119,27 @@ namespace Tensorflow
}
}

public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
Tensor shapeTensor;
if(shape.Length > 1)
{
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32);
if(shapeTensor.ndim > 1)
{
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1));
}
}
else
{
shapeTensor = shape[0];
}
var output = fill(shapeTensor, array_ops.constant(0, dtype), name);
Debug.Assert(output.dtype.as_base_dtype() == dtype);
return output;
}

public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
{
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate
@@ -307,6 +328,9 @@ namespace Tensorflow
public static Tensor fill<T>(Shape dims, T value, string name = null)
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);

public static Tensor fill<T>(Tensor dims, T value, string name = null)
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);

/// <summary>
/// Returns the rank of a tensor.
/// </summary>


+ 5
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -675,16 +675,17 @@ namespace Tensorflow
}
}

public static Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public static Tensors while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
{
var executing_eagerly = tf.Context.executing_eagerly();
if (!executing_eagerly)
{
throw new NotImplementedException("");
return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations,
name: name);
}

return tf_with(ops.name_scope("name", "while"), delegate


+ 77
- 0
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -16,12 +16,20 @@

using System;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class control_flow_util
{
public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0");
/// <summary>
/// Return true if `op` is an Exit.
/// </summary>
@@ -196,5 +204,74 @@ namespace Tensorflow
}
return null;
}

public static bool EnableControlFlowV2(Graph graph)
{
return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0);
}

public static string create_new_tf_function(FuncGraph func_graph)
{
var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary<string, AttrValue>());
func.AddToGraph(func_graph);
return func_graph.Name;
}

public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs)
{
if(inputs.Length == 0)
{
return (null, new Tensor[0]);
}
else
{
return (inputs[0], inputs);
}
}

public static Tensor[] run_as_function_for_tape_gradients(Func<Tensor[], Tensor[]> make_op, Tensor[] inputs)
{
if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
&& !(ops.get_default_graph().building_function))
{
throw new NotImplementedException();
}
else
{
return make_op(inputs);
}
}

public static string unique_fn_name(string scope, string name)
{
return $"{scope}{name}_{ops.uid()}".Replace("/", "_");
}

public static bool output_all_intermediates()
{
if (in_defun())
{
return false;
}
if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR")
{
return false;
}
// TODO(Rinne): check this after refactoring keras building.
return false;
}

public static bool in_defun()
{
if (tf.Context.executing_eagerly())
{
return false;
}

var graph = ops.get_default_graph();
// TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph
return graph is FuncGraph;
}
}
}

+ 985
- 81
src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
File diff suppressed because it is too large
View File


+ 1227
- 0
src/TensorFlowNET.Core/Operations/gen_list_ops.cs
File diff suppressed because it is too large
View File


+ 111
- 0
src/TensorFlowNET.Core/Operations/list_ops.cs View File

@@ -0,0 +1,111 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow.Operations
{
internal class list_ops
{
private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype)
{
if(list_handle is EagerTensor eagerTensor)
{
var handle_data = new CppShapeInferenceResult.Types.HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType()
{
Shape = element_shape.as_proto(),
Dtype = element_dtype.as_datatype_enum(),
Type = new FullTypeDef() { TypeId = FullTypeId.TftArray }
});
list_handle.HandleData = handle_data;
}
}

private static Tensor _build_element_shape(Shape? shape)
{
if(shape is null || shape.IsNull)
{
return ops.convert_to_tensor(-1);
}
else
{
return ops.convert_to_tensor(shape);
}
}

public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null)
{
var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name);
_set_handle_data(result, shape, element_dtype);
return result;
}

public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null)
{
var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name);
_set_handle_data(result, tensor.shape, tensor.dtype);
return result;
}

public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape),
element_dtype, name);
}

public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item,
bool resize_if_index_out_of_bounds = false, string? name = null)
{
if (resize_if_index_out_of_bounds)
{
var input_list_size = gen_list_ops.tensor_list_length(input_handle);
input_handle = control_flow_ops.cond(index >= input_list_size,
() => gen_list_ops.tensor_list_resize(input_handle, index + 1),
() => input_handle);
}
var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name);
handle_data_util.copy_handle_data(input_handle, output_handle);
return output_handle;
}

public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name);
}

public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name);
}

public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null,
string? name = null)
{
if(input_handle is not null)
{
var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name);
handle_data_util.copy_handle_data(input_handle, output_handle);
return output_handle;
}
else
{
var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape),
constant_op.constant(-1), name);
_set_handle_data(output_handle, element_shape, tensor.dtype);
return output_handle;
}
}

public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1,
string? name = null)
{
return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype,
max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name);
}
}
}

+ 16
- 4
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -13,11 +13,23 @@ namespace Tensorflow
/// <returns></returns>
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
var new_ta = tf.TensorArray(
dtype: old_ta.dtype,
infer_shape: old_ta.infer_shape,
if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph()))
{
throw new NotImplementedException("Attempting to build a graph-mode TF2-style "
+ "TensorArray from either an eager-mode "
+ "TensorArray or a TF1-style TensorArray. "
+ "This is not currently supported. You may be "
+ "attempting to capture a TensorArray "
+ "inside a tf.function or tf.data map function. "
+ "Instead, construct a new TensorArray inside "
+ "the function.");
}
var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape,
colocate_with_first_write_call: old_ta.colocate_with_first_write_call);

new_ta._dynamic_size = old_ta._dynamic_size;
new_ta._size = old_ta._size;
new_ta._colocate_with = old_ta._colocate_with;
new_ta._element_shape = old_ta._element_shape;
return new_ta;
}



+ 401
- 0
src/TensorFlowNET.Core/Operations/while_v2.cs View File

@@ -0,0 +1,401 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
class _OperationWithOutputs : Operation
{
public _OperationWithOutputs(IntPtr handle, Graph g = null)
{
_handle = handle;
_graph = g;
_outputs = null;
g._add_op(this);
}
}
internal class while_v2
{
public static Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int maximum_iterations = -1,
int parallel_iterations = 10,
string name = null,
bool back_prop = true,
bool return_same_structure = true)
{
var orig_loop_vars = loop_vars;
var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray();
int len_orig_loop_vars = orig_loop_vars.Length;

loop_vars = _tensor_array_to_flow(loop_vars);
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors();

var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars));

var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray();

if(string.IsNullOrEmpty(name))
{
name = "while";
}

return tf_with<ITensorFlowObject, Tensor[]>(ops.name_scope(name), nameScopeWhile =>
{
string scope = (nameScopeWhile as ops.NameScope).scope_name;
string cond_name = control_flow_util.unique_fn_name(scope, "cond");
string body_name = control_flow_util.unique_fn_name(scope, "body");

var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations);
var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype,
name: "loop_counter");
loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray();

var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)}
.Concat(loop_vars_signature.Flatten()).ToArray();

// TODO(Rinne): possible wrong implemenation here.
var add_control_dependencies = false;

object[] wrapped_cond(object[] inputs)
{
Tensor loop_counter = (Tensor)inputs[0];
Tensor maximum_iterations_arg = (Tensor)inputs[1];
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray();
var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args));
if(pred.shape.IsNull || pred.shape.ndim > 0)
{
pred = array_ops.squeeze(pred);
}
if(maximum_iterations == -1)
{
return new object[] { pred };
}
else
{
return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) };
}
}

var cond_graph = FuncGraph.func_graph_from_func("cond", wrapped_cond, null,
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies);

bool stateful_parallelism = false;

object[] wrapped_body(object[] inputs)
{
Tensor loop_counter = (Tensor)inputs[0];
Tensor maximum_iterations_arg = (Tensor)inputs[1];
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray();

_copy_handle_data(loop_vars.Flatten().Skip(2), args);

foreach(var t in cond_graph.external_captures)
{
var graph = (FuncGraph)(ops.get_default_graph());
graph.capture(t);
}

var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args));
outputs = _tensor_array_to_flow(outputs);

return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray();
}

var body_graph = FuncGraph.func_graph_from_func("body", wrapped_body, null, null, func_graph_signature,
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism);

// TODO(Rinne): possible wrong implementation here.
NestList<Tensors> loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() });
body_graph.Outputs.AddRange(body_graph.internal_captures);
cond_graph.as_default();
int num_cond_captures = cond_graph.external_captures.Length;
Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray()));
_duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray());
cond_graph.Exit();

int first_loop_var_index = 2;

int num_flattened_oututs = orig_loop_vars.Length;
int num_original_outputs = body_graph.Outputs.Length;
if (back_prop && control_flow_util.output_all_intermediates())
{
var intermediate_tensors = _get_intermediates(body_graph);

foreach(var intermediate_tensor in intermediate_tensors)
{
var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations);
loop_vars_list.Values.Add(tensor_list);

cond_graph.as_default();
cond_graph.capture(tensor_list);
cond_graph.Exit();

body_graph.as_default();
var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor);
body_graph.Outputs.Add(appended_tensor_list);
body_graph.Exit();
}
}

List<Tensor> flattened_loop_vars = new();
foreach(var item in loop_vars_list.Values)
{
flattened_loop_vars.AddRange(item.Flatten());
}
// skip the check

// TODO(Rinne): deal with control dependencies
var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray();
var span = new Span<Shape>(output_shapes).Slice(first_loop_var_index, num_flattened_oututs);
for(int i = 0; i < span.Length; i++)
{
span[i] = flat_shape_invariants[i];
}

Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations,
(nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism);

if (!ops.get_default_graph().building_function)
{
outputs = outputs.Select(t => array_ops.identity(t)).ToArray();
}

var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray();

if (!back_prop)
{
output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray();
}
outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars);

return outputs;
});
}

private static Tensors _tensor_array_to_flow(Tensors loop_vars)
{
if(loop_vars.NestType == NestType.Node)
{
if(loop_vars.NodeValue is FakeTensorByTensorArray fake)
{
return new Tensors(fake.TensorArray.flow);
}
else
{
return new Tensors(loop_vars.NodeValue!);
}
}
else if(loop_vars.NestType == NestType.List)
{
List<INestStructure<Tensor>> list = new();
foreach(var item in loop_vars.ListValue!)
{
if(item.NestType == NestType.Node)
{
var nested = item.AsNest();
if (nested.NodeValue is FakeTensorByTensorArray fake)
{
list.Add(new Nest<Tensor>(fake.TensorArray.flow));
}
else
{
list.Add(new Nest<Tensor>(nested.NodeValue!));
}
}
else
{
list.Add(new Nest<Tensor>(item.AsNest()));
}
}
return Tensors.FromNest(new Nest<Tensor>(list));
}
else
{
throw new NotImplementedException();
}
}

private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph,
Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism)
{
var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op);
var body_stateful_ops = body_graph.get_operations().Select(x => x.op);

bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0;

Tensor[] _make_op(Tensor[] inputs)
{
Tensor[] outputs;
if (is_stateful)
{
outputs = gen_functional_ops._while(
inputs,
control_flow_util.create_new_tf_function(cond_graph),
control_flow_util.create_new_tf_function(body_graph),
output_shapes,
parallel_iterations,
name
);
}
else
{
outputs = gen_functional_ops.stateless_while(
inputs,
control_flow_util.create_new_tf_function(cond_graph),
control_flow_util.create_new_tf_function(body_graph),
output_shapes,
parallel_iterations,
name
);
}
var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs);
_copy_handle_data(body_graph.Outputs, tensors);
_set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph});
while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs });
while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism });

cond_graph.outer_graph = ops.get_default_graph();
body_graph.outer_graph = ops.get_default_graph();
// TODO(Rinne): set the two graphs to while_op
return tensors;
}

return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars);
}

/// <summary>
/// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies.
/// </summary>
/// <param name="op"></param>
/// <param name="branch_graphs"></param>
private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs)
{
List<int> read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList();
foreach(var branch_graph in branch_graphs)
{
if (read_only_indices.Count == 0)
{
break;
}
var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph);
read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList();
}
AttrValue.Types.ListValue listValue = new();
listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x));
op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue()
{
List = listValue
});
}

private static Tensors _pack_sequence_as<T>(INestStructure<T> loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars)
{
var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item =>
{
var (flow, y) = item;
if (y is FakeTensorByTensorArray ta)
{
return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow));
}
else
{
return flow;
}
}).ToArray();
return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors();
}

private static Tensor[] _get_intermediates(FuncGraph func_graph)
{
List<Tensor> intermediates = new();
var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1);

foreach(var op in func_graph.get_operations())
{
Debug.Assert(op is Operation);
var oper = (Operation)op;
if(oper.type == "Identity" || oper.type == "MutexLock")
{
continue;
}
foreach(var o in op.outputs)
{
if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o))
{
intermediates.Add(o);
}
}
}
return intermediates.ToArray();
}

private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures)
{
var types = body_graph_captures.Select(t => t.dtype).ToList();
var c_graph = cond_graph.c_graph;
var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList();

var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList();

List<Tensor> tensors = new();
foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types))
{
var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph);
op._outputs = new Tensor[] { tensor };
tensors.Add(tensor);
}

var tuples = zip(body_graph_captures, tensors).ToList();
var keys = body_graph_captures.Select(t => t.Id).ToList();
cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2));
cond_graph.Inputs.AddRange(tensors);
}

private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
var op = c_api.TF_FinishOperation(desc, tf.Status);
tf.Status.Check(true);
var output = new TF_Output();
output.oper = op;
output.index = 0;
return output;
}

private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph)
{
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder");
}

private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype,
string name)
{
return ops.convert_to_tensor(value, dtype, name, false);
}

private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1)
{
return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations");
}

private static void _copy_handle_data(IEnumerable<Tensor> src_tensors, IEnumerable<Tensor> dst_tensors)
{
foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors))
{
handle_data_util.copy_handle_data(src_t, dst_t);
}
}
}
}

+ 7
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -105,6 +105,13 @@ namespace Tensorflow
_id = ops.uid();
}

internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output)
{
Tensor ret = new Tensor(op, value_index, dtype);
ret._tf_output = tf_output;
return ret;
}

protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, null);


+ 24
- 0
src/TensorFlowNET.Core/Tensors/TensorArray.cs View File

@@ -14,7 +14,9 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Common.Types;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -44,5 +46,27 @@ namespace Tensorflow

public abstract Tensor stack(string name = null);
public abstract Tensor gather(Tensor indices, string name = null);

internal bool _dynamic_size;
internal Tensor _size;
internal List<Tensor> _colocate_with;
internal Shape _element_shape;

public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant))
{
return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow,
infer_shape, element_shape, colocate_with_first_write_call, name);
}
else
{
return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow,
infer_shape, element_shape, colocate_with_first_write_call, name);
}
}
}
}

+ 41
- 13
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -4,6 +4,8 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Operations;
using Tensorflow.Common.Extensions;

namespace Tensorflow
{
@@ -58,7 +60,7 @@ namespace Tensorflow
public Tensor this[params string[] slices]
=> this.First()[slices];

private Tensors(Nest<Tensor> nested) : base(nested)
internal Tensors(Nest<Tensor> nested) : base(nested)
{

}
@@ -68,9 +70,9 @@ namespace Tensorflow
}

public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
public Tensors(IList<Tensor> tensors) : base(tensors.Select(x => new Nest<Tensor>(x)))
{
}

public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
@@ -78,6 +80,32 @@ namespace Tensorflow
}

/// <summary>
/// Get the element in shallow level. For example, for ts = [1, [2, 3], 4],
/// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3]
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public Tensors GetShallow(int index)
{
if(NestType == NestType.Node)
{
if(index > 0)
{
throw new IndexOutOfRangeException();
}
return this;
}
else if(NestType == NestType.List)
{
return ListValue![index].AsNest().ToTensors();
}
else
{
throw new NotImplementedException();
}
}

private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
{
if (tensors.Length == 0)
@@ -115,8 +143,8 @@ namespace Tensorflow
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
ListValue = new() { new Nest<Tensor>(NodeValue), new Nest<Tensor>(tensor) };
NodeValue = null;
}
else if(NestType == NestType.List)
{
@@ -125,7 +153,7 @@ namespace Tensorflow
else //Empty
{
NestType = NestType.Node;
Value = tensor;
NodeValue = tensor;
}
}

@@ -140,9 +168,9 @@ namespace Tensorflow
else if (NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue = new() { new Nest<Tensor>(NodeValue) };
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
NodeValue = null;
}
else if(NestType == NestType.List)
{
@@ -151,7 +179,7 @@ namespace Tensorflow
else // empty
{
NestType = NestType.List;
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
ListValue = tensors.Select(x => new Nest<Tensor>(x) as INestStructure<Tensor>).ToList();
}
}

@@ -166,9 +194,9 @@ namespace Tensorflow
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue = new() { new Nest<Tensor>(NodeValue) };
ListValue.Insert(index, new Nest<Tensor>(tensor));
Value = null;
NodeValue = null;
}
else
{
@@ -283,7 +311,7 @@ namespace Tensorflow
=> tensors?.SingleOrNull;

public static implicit operator Tensor[](Tensors tensors)
=> tensors.Flatten().ToArray();
=> tensors.Flatten().ToArray();
#endregion

public static Tensors? FromNest(Nest<Tensor> nested)
@@ -298,7 +326,7 @@ namespace Tensorflow
public void Deconstruct(out Tensor a, out Tensors? b)
{
a = this.First();
b = Length == 1? null : new Tensors(this.Skip(1));
b = Length == 1? null : new Tensors(this.Skip(1).ToArray());
}

public override string ToString()


+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -576,7 +576,7 @@ namespace Tensorflow
public static HandleData get_resource_handle_data(Tensor graph_op)
{
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data));
}

public static void dismantle_graph(Graph graph)


+ 48
- 47
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -25,6 +25,7 @@ using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
using Tensorflow.Common.Types;
using System.Diagnostics;

namespace Tensorflow.Keras
{
@@ -485,7 +486,7 @@ namespace Tensorflow.Keras
var first_flatted_input = flatted_inptus[0];
var time_steps = first_flatted_input.shape[0];
var batch = first_flatted_input.shape[1];
var time_steps_t = (int)first_flatted_input.shape[0];
var time_steps_t = tf.shape(first_flatted_input)[0];

foreach (var input_ in flatted_inptus)
{
@@ -704,7 +705,7 @@ namespace Tensorflow.Keras
var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t));
input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t));
}

foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
@@ -730,18 +731,15 @@ namespace Tensorflow.Keras
(output_time_zero, _) = step_function(input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1);
var output_ta = new List<TensorArray>();
for (int i = 0; i < output_time_zero.ToList().Count; i++)
foreach(var output in output_time_zero.Flatten())
{
var Out = output_time_zero.ToList()[i];
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
@@ -750,7 +748,7 @@ namespace Tensorflow.Keras
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
@@ -810,9 +808,9 @@ namespace Tensorflow.Keras
masking_fn = null;
}

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
Func<Tensors, Tensor> cond = (time) => (time[0] < time_steps_t);
int parallel_iterations = 32;
new_states = states;
Tensors final_outputs;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
@@ -825,7 +823,7 @@ namespace Tensorflow.Keras

var prev_output = flat_zero_output;
var output_ta_t = output_ta;
Tensor _step(Tensor time)
Tensors _step(Tensors tensors)
{
/*
RNN step function.
@@ -838,23 +836,28 @@ namespace Tensorflow.Keras
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

Tensor time = tensors[0];
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray;
Tensors prev_output = tensors.GetShallow(2);
Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray());

var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
var (output, new_states) = step_function(current_input, states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.ToList();
var flat_new_state = new_states_internal.ToList();
var flat_state = states.Flatten().ToList();
var flat_new_state = new_states.Flatten().ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
@@ -865,38 +868,37 @@ namespace Tensorflow.Keras
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();
new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
{
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();
Debug.Assert(flat_output.Count() == 1);
output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First());


new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states)
.ToArray().ToTensors();

}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }
.Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors();
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations);
new_states = final_outputs.Skip(3).ToList();
}
else
{
var output_ta_t = output_ta;
new_states = states;
Tensor _step(Tensor time)
Tensors _step(Tensors tensors)
{
Tensor time = tensors[0];
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray;
Tensors states = new Tensors(tensors.Skip(2).ToArray());
var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
var (output, new_states) = step_function(current_input, states.MergeWith(constants));
var flat_state = new_states.Flatten().ToList();
var flat_new_state = new_states_internal.Flatten().ToList();
var flat_new_state = new_states.Flatten().ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
@@ -906,24 +908,23 @@ namespace Tensorflow.Keras
}
var flat_output = Nest.Flatten(output);
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
output_ta_t = zip(output_ta_t, flat_output).Select(item =>
{
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();

new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;
Debug.Assert(flat_output.Count() == 1);
output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First());

new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors();
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
Debug.Assert(output_ta.Count == 1);
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors();
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations);
new_states = final_outputs.Skip(2).ToList();
}
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();

output_ta = new List<TensorArray> { (final_outputs[1] as FakeTensorByTensorArray).TensorArray };
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors();
}

Func<Tensor, Tensor> set_shape;


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Build.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
graph.as_default();
var shapes = input_shape.ToShapeArray();
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)));
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray());
try
{
Call(x, training: false);


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(x),
X = new Tensors(x.ToArray()),
Y = y,
Model = this,
StepsPerExecution = _steps_per_execution
@@ -188,7 +188,7 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine

var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(train_x),
X = new Tensors(train_x.ToArray()),
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}


+ 10
- 1
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -4,10 +4,11 @@ using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class DropoutRNNCellMixin: RnnCellBase
public abstract class DropoutRNNCellMixin: Layer, IRnnCell
{
public float dropout;
public float recurrent_dropout;
@@ -17,6 +18,14 @@ namespace Tensorflow.Keras.Layers.Rnn

}

public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
public abstract bool SupportOptionalArgs { get; }
public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}

protected void _create_non_trackable_mask_cache()
{


+ 8
- 31
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -206,7 +206,6 @@ namespace Tensorflow.Keras.Layers.Rnn
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);

}

// Check whether the input shape contains any nested shapes. It could be
@@ -298,7 +297,7 @@ namespace Tensorflow.Keras.Layers.Rnn

// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
Func<Tensors, Tensors, (Tensors, Tensors)> step;
bool is_tf_rnn_cell = _cell.IsTFRnnCell;
bool is_tf_rnn_cell = false;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
@@ -310,8 +309,8 @@ namespace Tensorflow.Keras.Layers.Rnn

step = (inputs, states) =>
{
constants = new Tensors(states.TakeLast(_num_constants));
states = new Tensors(states.SkipLast(_num_constants));
constants = new Tensors(states.TakeLast(_num_constants).ToArray());
states = new Tensors(states.SkipLast(_num_constants).ToArray());
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states.Single);
@@ -395,12 +394,12 @@ namespace Tensorflow.Keras.Layers.Rnn
{
if (_num_constants != 0)
{
initial_state = new Tensors(inputs.Skip(1));
initial_state = new Tensors(inputs.Skip(1).ToArray());
}
else
{
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants));
constants = new Tensors(inputs.TakeLast(_num_constants));
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray());
constants = new Tensors(inputs.TakeLast(_num_constants).ToArray());
}
if (len(initial_state) == 0)
initial_state = null;
@@ -558,36 +557,14 @@ namespace Tensorflow.Keras.Layers.Rnn

protected Tensors get_initial_state(Tensors inputs)
{
var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state");

var input = inputs[0];
var input_shape = inputs.shape;
var input_shape = array_ops.shape(inputs);
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;

Tensors init_state = new Tensors();

if(get_initial_state_fn != null)
{
init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype });
}
//if (_cell is RnnCellBase rnn_base_cell)
//{
// init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
//}
else
{
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype);
}
Tensors init_state = _cell.GetInitialState(null, batch_size, dtype);

return init_state;
}

// Check whether the state_size contains multiple states.
public static bool is_multiple_state(GeneralizedTensorShape state_size)
{
return state_size.Shapes.Length > 1;
}
}
}

+ 0
- 24
src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs View File

@@ -1,24 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class RnnCellBase: Layer, IRnnCell
{
public RnnCellBase(LayerArgs args) : base(args) { }
public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
public abstract bool IsTFRnnCell { get; }
public abstract bool SupportOptionalArgs { get; }
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}
}
}

+ 3
- 4
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -7,6 +7,7 @@ using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
using Tensorflow.Keras.Utils;
using Tensorflow.Graphs;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -28,7 +29,6 @@ namespace Tensorflow.Keras.Layers.Rnn

public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override bool IsTFRnnCell => true;
public override bool SupportOptionalArgs => false;

public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
@@ -98,7 +98,6 @@ namespace Tensorflow.Keras.Layers.Rnn
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
var tmp = _recurrent_kernel.AsTensor();
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());

if (_args.Activation != null)
@@ -117,9 +116,9 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
public Tensors get_initial_state(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}
}
}

+ 53
- 99
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers.Rnn
public class StackedRNNCells : Layer, IRnnCell
{
public IList<IRnnCell> Cells { get; set; }
public bool reverse_state_order;
public bool _reverse_state_order;

public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
{
@@ -23,22 +23,11 @@ namespace Tensorflow.Keras.Layers.Rnn
{
args.Kwargs = new Dictionary<string, object>();
}
foreach (var cell in args.Cells)
{
//Type type = cell.GetType();
//var CallMethodInfo = type.GetMethod("Call");
//if (CallMethodInfo == null)
//{
// throw new ValueError(
// "All cells must have a `Call` method. " +
// $"Received cell without a `Call` method: {cell}");
//}
}
Cells = args.Cells;
reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false);
_reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false);

if (reverse_state_order)
if (_reverse_state_order)
{
throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " +
"be deprecated. Please update the code to work with the " +
@@ -47,49 +36,37 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public bool SupportOptionalArgs => false;

public GeneralizedTensorShape StateSize
{
get
{
GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count);
if (reverse_state_order && Cells.Count > 0)
if (_reverse_state_order)
{
var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell));
foreach (var cell in idxAndCell)
{
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
}
var state_sizes = Cells.Reverse().Select(cell => cell.StateSize);
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s))));
}
else
{
//foreach (var cell in Cells)
//{
// state_size.Shapes.add(cell.StateSize.Shapes.First());

//}
var idxAndCell = Cells.Select((cell, idx) => (idx, cell));
foreach (var cell in idxAndCell)
{
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
}
var state_sizes = Cells.Select(cell => cell.StateSize);
return new GeneralizedTensorShape(new Nest<Shape>(state_sizes.Select(s => new Nest<Shape>(s))));
}
return state_size;
}
}

public object output_size
public GeneralizedTensorShape OutputSize
{
get
{
var lastCell = Cells.LastOrDefault();
if (lastCell.OutputSize.ToSingleShape() != -1)
var lastCell = Cells.Last();
if(lastCell.OutputSize is not null)
{
return lastCell.OutputSize;
}
else if (RNN.is_multiple_state(lastCell.StateSize))
else if (RnnUtils.is_multiple_state(lastCell.StateSize))
{
return lastCell.StateSize.First();
//throw new NotImplementedException("");
}
else
{
@@ -98,79 +75,65 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
var cells = reverse_state_order ? Cells.Reverse() : Cells;
Tensors initial_states = new Tensors();
var cells = _reverse_state_order ? Cells.Reverse() : Cells;
List<Tensor> initial_states = new List<Tensor>();
foreach (var cell in cells)
{
var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state");
if (get_initial_state_fn != null)
{
var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype });
initial_states.Add(result);
}
else
{
initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value));
}
initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype));
}
return initial_states;
return new Tensors(initial_states);
}

protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// Recover per-cell states.
var state_size = reverse_state_order ? StateSize.Reverse() : StateSize;
var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten();
var state_size = _reverse_state_order ? new GeneralizedTensorShape(StateSize.Reverse()) : StateSize;
var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray());


var new_nest_states = new Tensors();
var new_nest_states = Nest<Tensor>.Empty;
// Call the cells in order and store the returned states.
foreach (var (cell, states) in zip(Cells, nested_states))
foreach (var (cell, internal_states) in zip(Cells, nested_states))
{
// states = states if tf.nest.is_nested(states) else [states]
var type = cell.GetType();
bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null;
state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state;

RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
Tensors? constants = rnn_optional_args?.Constants;

Tensors new_states;
(inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
(inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants });

new_nest_states.Add(new_states);
new_nest_states = new_nest_states.MergeWith(new_states);
}
new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray();
return new Nest<Tensor>(new List<Nest<Tensor>> {
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(inputs.Single()) }), new Nest<Tensor>(new_nest_states) })
.ToTensors();
return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray())));
}

public void build()
public override void build(KerasShapesWrapper input_shape)
{
built = true;
// @tf_utils.shape_type_conversion
// def build(self, input_shape) :
// if isinstance(input_shape, list) :
// input_shape = input_shape[0]
// for cell in self.cells:
// if isinstance(cell, Layer) and not cell.built:
// with K.name_scope(cell.name):
// cell.build(input_shape)
// cell.built = True
// if getattr(cell, 'output_size', None) is not None:
// output_dim = cell.output_size
// elif _is_multiple_state(cell.state_size) :
// output_dim = cell.state_size[0]
// else:
// output_dim = cell.state_size
// input_shape = tuple([input_shape[0]] +
// tensor_shape.TensorShape(output_dim).as_list())
// self.built = True
var shape = input_shape.ToSingleShape();
foreach(var cell in Cells)
{
if(cell is Layer layer && !layer.Built)
{
// ignored the name scope.
layer.build(shape);
layer.Built = true;
}
GeneralizedTensorShape output_dim;
if(cell.OutputSize is not null)
{
output_dim = cell.OutputSize;
}
else if (RnnUtils.is_multiple_state(cell.StateSize))
{
output_dim = cell.StateSize.First();
}
else
{
output_dim = cell.StateSize;
}
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.ToSingleShape().dims).ToArray());
}
this.Built = true;
}

public override IKerasConfig get_config()
@@ -198,14 +161,5 @@ namespace Tensorflow.Keras.Layers.Rnn
// deserialize_layer(cell_config, custom_objects = custom_objects))
// return cls(cells, **config)
}

public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
{
throw new NotImplementedException();
}

public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => true;
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 23
- 12
src/TensorFlowNET.Keras/Utils/RnnUtils.cs View File

@@ -10,20 +10,21 @@ namespace Tensorflow.Keras.Utils
{
internal static class RnnUtils
{
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
{
Func<GeneralizedTensorShape, Tensor> create_zeros;
create_zeros = (GeneralizedTensorShape unnested_state_size) =>
{
var flat_dims = unnested_state_size.ToSingleShape().dims;
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray();
return array_ops.zeros(new Shape(init_state_size), dtype: dtype);
var init_state_size = new Tensor[] { batch_size_tensor }.
Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray();
return array_ops.zeros(init_state_size, dtype: dtype);
};

// TODO(Rinne): map structure with nested tensors.
if(state_size.Shapes.Length > 1)
if(state_size.TotalNestedCount > 1)
{
return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s))));
return new Tensors(state_size.Flatten().Select(s => create_zeros(new GeneralizedTensorShape(s))).ToArray());
}
else
{
@@ -32,11 +33,11 @@ namespace Tensorflow.Keras.Utils

}

internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype)
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
if (inputs != null)
if (inputs is not null)
{
batch_size = inputs.shape[0];
batch_size = array_ops.shape(inputs)[0];
dtype = inputs.dtype;
}
return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
@@ -77,17 +78,27 @@ namespace Tensorflow.Keras.Utils
Debug.Assert(initial_state is null && constants is null);
if(num_constants > 0)
{
constants = inputs.TakeLast(num_constants).ToTensors();
inputs = inputs.SkipLast(num_constants).ToTensors();
constants = inputs.TakeLast(num_constants).ToArray().ToTensors();
inputs = inputs.SkipLast(num_constants).ToArray().ToTensors();
}
if(inputs.Length > 1)
{
initial_state = inputs.Skip(1).ToTensors();
inputs = inputs.Take(1).ToTensors();
initial_state = inputs.Skip(1).ToArray().ToTensors();
inputs = inputs.Take(1).ToArray().ToTensors();
}
}

return (inputs, initial_state, constants);
}

/// <summary>
/// Check whether the state_size contains multiple states.
/// </summary>
/// <param name="state_size"></param>
/// <returns></returns>
public static bool is_multiple_state(GeneralizedTensorShape state_size)
{
return state_size.TotalNestedCount > 1;
}
}
}

+ 2
- 2
test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs View File

@@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest.ManagedAPI

var i = tf.constant(2);
var j = tf.constant(3);
Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10);
Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
Func<Tensors, Tensor> c = (x) => tf.less(x[0] + x[1], 10);
Func<Tensors, Tensors> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
var r = tf.while_loop(c, b, new[] { i, j });
Assert.AreEqual(5, (int)r[0]);
Assert.AreEqual(6, (int)r[1]);


+ 18
- 6
tools/Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -21,7 +21,8 @@ namespace Tensorflow.CodeGen
{
sb.Append("Operation ");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.Append("Tensor ");
}
@@ -70,7 +71,8 @@ namespace Tensorflow.CodeGen
{
sb.AppendLine("return null;");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.AppendLine("return _fast_path_result[0];");
}
@@ -149,7 +151,8 @@ namespace Tensorflow.CodeGen
{
sb.AppendLine("return _op;");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.AppendLine("return _result[0];");
}
@@ -174,7 +177,7 @@ namespace Tensorflow.CodeGen
{
argName = $"{argName}_";
}
if (!string.IsNullOrEmpty(arg.NumberAttr))
if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr))
{
sb.Append($"Tensors {argName}, ");
}
@@ -273,7 +276,8 @@ namespace Tensorflow.CodeGen
{
sb.Append("Operation ");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.Append("Tensor ");
}
@@ -366,6 +370,13 @@ namespace Tensorflow.CodeGen
sb.Append($"\"{attr.Name}\", {attrRealName}, ");
}
}
else if(attr.Type == "list(type)")
{
if (op.InputArg.Any(x => x.TypeListAttr == attr.Name))
{
continue;
}
}
else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name))
{
bool found = false;
@@ -408,7 +419,8 @@ namespace Tensorflow.CodeGen
{
sb.AppendLine("return null;");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.AppendLine("return _result[0];");
}


+ 1
- 1
tools/Tensorflow.CodeGen/Program.cs View File

@@ -5,7 +5,7 @@ using System.Text;
using System.Xml.Linq;
using Tensorflow.CodeGen;

GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops",
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops_v2",
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt");


+ 6
- 2
tools/Tensorflow.CodeGen/Utils.cs View File

@@ -155,6 +155,10 @@ namespace Tensorflow.CodeGen
}
else if (attr.Type == "list(type)")
{
if(op.InputArg.Any(x => x.TypeListAttr == attr.Name))
{
continue;
}
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
List<TF_DataType> values = new();
@@ -231,11 +235,11 @@ namespace Tensorflow.CodeGen
}
else if (attr.Type == "func")
{
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE"));
res.Add((attr.Name, "object", "NOVALUE"));
}
else if (attr.Type == "list(func)")
{
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE"));
res.Add((attr.Name, "object[]", "NOVALUE"));
}
else if (attr.Type == "tensor")
{


Loading…
Cancel
Save