Browse Source

Merge pull request #1109 from AsakusaRinne/rnn-dev

feat: support training of RNN.
tags/v0.110.0-LSTM-Model
Rinne GitHub 2 years ago
parent
commit
edbf89bd8d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
90 changed files with 8353 additions and 1013 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +5
    -5
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  3. +3
    -3
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  5. +6
    -1
      src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
  6. +20
    -0
      src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
  7. +28
    -99
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  8. +13
    -0
      src/TensorFlowNET.Core/Common/Types/INestStructure.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
  10. +81
    -54
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  11. +4
    -0
      src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
  12. +16
    -6
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  13. +4
    -0
      src/TensorFlowNET.Core/Common/Types/NestNode.cs
  14. +2
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  15. +7
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  16. +7
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  17. +6
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
  18. +19
    -0
      src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs
  19. +13
    -0
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  20. +89
    -0
      src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs
  21. +2
    -2
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  22. +13
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  23. +2
    -3
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  24. +2
    -2
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  25. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  26. +0
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  27. +30
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  28. +3
    -23
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  29. +1
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs
  30. +1
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  31. +12
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  32. +9
    -3
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  33. +9
    -9
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  34. +23
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  35. +22
    -0
      src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs
  36. +1
    -2
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  37. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  38. +6
    -2
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  39. +57
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  40. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  41. +13
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  42. +2
    -4
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  43. +176
    -4
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  44. +78
    -22
      src/TensorFlowNET.Core/Operations/array_ops.cs
  45. +5
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  46. +77
    -0
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  47. +489
    -10
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  48. +1042
    -81
      src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
  49. +827
    -109
      src/TensorFlowNET.Core/Operations/gen_io_ops.cs
  50. +1308
    -0
      src/TensorFlowNET.Core/Operations/gen_list_ops.cs
  51. +585
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  52. +409
    -0
      src/TensorFlowNET.Core/Operations/gen_nn_ops.cs
  53. +1469
    -104
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  54. +3
    -3
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  55. +111
    -0
      src/TensorFlowNET.Core/Operations/list_ops.cs
  56. +16
    -4
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
  57. +401
    -0
      src/TensorFlowNET.Core/Operations/while_v2.cs
  58. +2
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  59. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  60. +7
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  61. +24
    -0
      src/TensorFlowNET.Core/Tensors/TensorArray.cs
  62. +41
    -13
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  63. +1
    -2
      src/TensorFlowNET.Core/Training/Trackable.cs
  64. +20
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  65. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  66. +48
    -47
      src/TensorFlowNET.Keras/BackendImpl.cs
  67. +2
    -0
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  68. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Build.cs
  69. +2
    -2
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  70. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  71. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  72. +4
    -0
      src/TensorFlowNET.Keras/IsExternalInit.cs
  73. +33
    -9
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  74. +12
    -3
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  75. +93
    -9
      src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
  76. +221
    -4
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  77. +84
    -97
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  78. +0
    -24
      src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
  79. +3
    -18
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  80. +9
    -15
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  81. +58
    -110
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  82. +26
    -16
      src/TensorFlowNET.Keras/Utils/RnnUtils.cs
  83. +58
    -36
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
  84. +2
    -2
      test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs
  85. +26
    -6
      tools/Tensorflow.CodeGen/FunctionGenerator.cs
  86. +1
    -0
      tools/Tensorflow.CodeGen/GenOpsWriter.cs
  87. +1
    -1
      tools/Tensorflow.CodeGen/OpClassifier.cs
  88. +1
    -1
      tools/Tensorflow.CodeGen/Program.cs
  89. +1
    -1
      tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
  90. +22
    -3
      tools/Tensorflow.CodeGen/Utils.cs

+ 14
- 0
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
@@ -50,6 +51,19 @@ namespace Tensorflow
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
{
byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);



+ 5
- 5
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -46,10 +46,10 @@ namespace Tensorflow
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ namespace Tensorflow
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,


+ 3
- 3
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -71,15 +71,15 @@ namespace Tensorflow
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
axis: axis,
num_or_size_splits: num_split,
axis: ops.convert_to_tensor(axis),
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -503,7 +503,7 @@ namespace Tensorflow
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:


+ 6
- 1
src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs View File

@@ -18,7 +18,12 @@ namespace Tensorflow.Common.Extensions
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}


+ 20
- 0
src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}

+ 28
- 99
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -5,136 +5,65 @@ using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
public class GeneralizedTensorShape: Nest<Shape>
{
public TensorShapeConfig[] Shapes { get; set; }
/// <summary>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim, int size = 1)
public GeneralizedTensorShape(Shape value, string? name = null)
{
var elem = new TensorShapeConfig() { Items = new long?[] { dim } };
Shapes = Enumerable.Repeat(elem, size).ToArray();
//Shapes = new TensorShapeConfig[size];
//Shapes.Initialize(new TensorShapeConfig() { Items = new long?[] { dim } });
//Array.Initialize(Shapes, new TensorShapeConfig() { Items = new long?[] { dim } });
////Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(Shape shape)
public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(TensorShapeConfig shape)
public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
Shapes = new TensorShapeConfig[] { shape };
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(TensorShapeConfig[] shapes)
public GeneralizedTensorShape(Nest<Shape> other)
{
Shapes = shapes;
}

public GeneralizedTensorShape(IEnumerable<Shape> shape)
{
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
if (Shapes.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var shape_config = Shapes[0];
Debug.Assert(shape_config is not null);
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
return shapes[0];
}

public long ToNumber()
{
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var res = Shapes[0].Items[0];
return res is null ? -1 : res.Value;
}

public Shape[] ToShapeArray()
{
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public IEnumerable<long?> Flatten()
{
List<long?> result = new List<long?>();
foreach(var shapeConfig in Shapes)
{
result.AddRange(shapeConfig.Items);
}
return result;
}
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
{
List<Nest<TOut>> lists = new();
foreach(var shapeConfig in Shapes)
{
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
}
return new Nest<TOut>(lists);
return shapes[0].dims[0];
}

public Nest<long?> AsNest()
public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
{
if (config.Items.Length == 0)
{
return Nest<long?>.Empty;
}
else if (config.Items.Length == 1)
{
return new Nest<long?>(config.Items[0]);
}
else
{
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
}
}

if(Shapes.Length == 0)
{
return Nest<long?>.Empty;
}
else if(Shapes.Length == 1)
{
return DealWithSingleShape(Shapes[0]);
}
else
{
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}


public static implicit operator GeneralizedTensorShape(int dims)
=> new GeneralizedTensorShape(dims);

public IEnumerator<long?[]> GetEnumerator()
{
foreach (var shape in Shapes)
{
yield return shape.Items;
}
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

IEnumerator IEnumerable.GetEnumerator()
public static implicit operator GeneralizedTensorShape(Shape shape)
{
return GetEnumerator();
return new GeneralizedTensorShape(shape);
}
}
}

src/TensorFlowNET.Core/Common/Types/INest.cs → src/TensorFlowNET.Core/Common/Types/INestStructure.cs View File

@@ -10,6 +10,19 @@ namespace Tensorflow.Common.Types
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.

+ 1
- 1
src/TensorFlowNET.Core/Common/Types/Nest.Static.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Common.Types
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}


+ 81
- 54
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -28,27 +28,58 @@ namespace Tensorflow.Common.Types
public static Nest<T> Empty => _empty;
public NestType NestType { get; protected set; }
public string? Name { get; set; }
public T? Value { get; protected set; }
public List<Nest<T>>? ListValue { get; protected set; }
public Dictionary<string, Nest<T>>? DictValue { get; protected set; }
public T? NodeValue { get; protected set; }
public List<INestStructure<T>>? ListValue { get; protected set; }
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; }

public int ShallowNestedCount
{
get
{
if (NestType == NestType.Empty)
{
return 0;
}
else if (NestType == NestType.Node)
{
return 1;
}
else if (NestType == NestType.List)
{
return ListValue!.Count;
}
else // dict
{
return DictValue!.Count;
}
}
}

public int TotalNestedCount
{
get
{
return Flatten().Count();
}
}

protected Nest() { }

public Nest(T value, string? name = null)
{
Value = value;
NodeValue = value;
Name = name;
NestType = NestType.Node;
}

public Nest(IEnumerable<Nest<T>> values, string? name = null)
public Nest(IEnumerable<INestStructure<T>> values, string? name = null)
{
ListValue = values.ToList();
Name = name;
NestType = NestType.List;
}

public Nest(Dictionary<string, Nest<T>> value, string? name = null)
public Nest(Dictionary<string, INestStructure<T>> value, string? name = null)
{
DictValue = value;
Name = name;
@@ -58,7 +89,7 @@ namespace Tensorflow.Common.Types
public Nest(Nest<T> other)
{
NestType = other.NestType;
Value = other.Value;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
@@ -78,17 +109,17 @@ namespace Tensorflow.Common.Types
/// </summary>
/// <param name="flatItems"></param>
/// <returns></returns>
public virtual Nest<T> PackSequence(T[] flatItems)
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems)
{
if(flatItems.Length == 0)
{
return Nest<T>.Empty;
return Nest<TOut>.Empty;
}
int index = 0;
return PackSequenceInternal(this, flatItems, ref index);
}

private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index)
private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index)
{
if(template.NestType == NestType.Node)
{
@@ -96,25 +127,25 @@ namespace Tensorflow.Common.Types
{
throw new InvalidArgumentError("The template and flat items are not matched.");
}
return new Nest<T>(flatItems[index++]);
return new Nest<TOut>(flatItems[index++]);
}
else if(template.NestType == NestType.List)
{
List<Nest<T>> nestedObjects = new List<Nest<T>>();
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>();
for (int i = 0; i < template.ListValue!.Count; i++)
{
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index));
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index));
}
return new Nest<T>(nestedObjects);
return new Nest<TOut>(nestedObjects);
}
else if(template.NestType == NestType.Node)
{
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>();
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>();
foreach(var (key, value) in template.DictValue!)
{
dict[key] = PackSequenceInternal(value, flatItems, ref index);
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index);
}
return new Nest<T>(dict);
return new Nest<TOut>(dict);
}
// Consider Empty as invalid type.
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
@@ -166,25 +197,11 @@ namespace Tensorflow.Common.Types
}
else if(NestType is NestType.List)
{
foreach(var item in ListValue!)
{
if(item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
return ListValue!.Count > 0;
}
else
{
foreach (var item in DictValue!.Values)
{
if (item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
return DictValue!.Count > 0;
}
}

@@ -223,10 +240,10 @@ namespace Tensorflow.Common.Types
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T>
{
var nested = input.AsNest();
return ReduceInternal(nested);
return ReduceInternal(nested).AsNest();
}

private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
{
if(node.NestType == NestType.Empty)
{
@@ -234,15 +251,15 @@ namespace Tensorflow.Common.Types
}
else if(node.NestType == NestType.Node)
{
return node.Value!.AsNest();
return node.NodeValue!.AsNest();
}
else if(node.NestType == NestType.List)
{
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x)));
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest())));
}
else // Dictionary type
{
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value)));
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest())));
}
}

@@ -252,7 +269,7 @@ namespace Tensorflow.Common.Types
{
if(index == 0)
{
result = node.Value!;
result = node.NodeValue!;
return true;
}
result = default(T);
@@ -264,7 +281,7 @@ namespace Tensorflow.Common.Types
{
if(index == 0)
{
return FindInternal(item, index, out result);
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
@@ -277,7 +294,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
return FindInternal(item, index, out result);
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
@@ -297,7 +314,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
node.Value = newValue;
node.NodeValue = newValue;
return true;
}
return false;
@@ -308,7 +325,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
return SetInternal(item, index, newValue);
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
@@ -320,7 +337,7 @@ namespace Tensorflow.Common.Types
{
if (index == 0)
{
return SetInternal(item, index, newValue);
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
@@ -336,13 +353,13 @@ namespace Tensorflow.Common.Types
{
if (node.NestType == NestType.Node)
{
yield return node.Value!;
yield return node.NodeValue!;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
foreach(var val in FlattenInternal(item))
foreach(var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
@@ -352,7 +369,7 @@ namespace Tensorflow.Common.Types
{
foreach (var item in node.DictValue!.Values)
{
foreach (var val in FlattenInternal(item))
foreach (var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
@@ -364,23 +381,23 @@ namespace Tensorflow.Common.Types
{
if (NestType == NestType.Node)
{
return new Nest<TOut>(func(Value!));
return new Nest<TOut>(func(NodeValue!));
}
else if (NestType == NestType.List)
{
List<Nest<TOut>> outs = new List<Nest<TOut>>();
foreach (var item in ListValue!)
{
outs.Add(item.MapStructureInternal(func));
outs.Add(item.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else if (NestType == NestType.Dictionary)
{
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>();
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>();
foreach (var (key, value) in DictValue!)
{
outs.Add(key, value.MapStructureInternal(func));
outs.Add(key, value.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
@@ -417,14 +434,14 @@ namespace Tensorflow.Common.Types
}
if (node.NestType == NestType.Node)
{
sb.Append(node.Value!.ToString());
sb.Append(node.NodeValue!.ToString());
}
else if (node.NestType == NestType.List)
{
sb.Append("[");
for(int i = 0; i < node.ListValue!.Count; i++)
{
WriteString(node.ListValue![i], sb);
WriteString(node.ListValue![i].AsNest(), sb);
if(i != node.ListValue!.Count - 1)
{
sb.Append(", ");
@@ -440,7 +457,7 @@ namespace Tensorflow.Common.Types
foreach (var (key, value) in node.DictValue!)
{
sb.Append($"{key}: ");
WriteString(value, sb);
WriteString(value.AsNest(), sb);
if (i != count - 1)
{
sb.Append(", ");
@@ -454,5 +471,15 @@ namespace Tensorflow.Common.Types
sb.Append("<empty>");
}
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 });
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 });
}
}
}

+ 4
- 0
src/TensorFlowNET.Core/Common/Types/NestDictionary.cs View File

@@ -6,7 +6,11 @@ namespace Tensorflow.Common.Types
{
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull
{
public NestType NestType => NestType.Dictionary;
public IDictionary<TKey, TValue> Value { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;
public NestDictionary(IDictionary<TKey, TValue> dict)
{
Value = dict;


+ 16
- 6
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -10,29 +10,39 @@ namespace Tensorflow.Common.Types
/// <typeparam name="T"></typeparam>
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
{
public List<T> Value { get; set; }
public NestType NestType => NestType.List;
public List<T> Values { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;

public NestList(params T[] values)
{
Values = new List<T>(values);
}

public NestList(IEnumerable<T> values)
{
Value = new List<T>(values);
Values = new List<T>(values);
}
public IEnumerable<T> Flatten()
{
return Value;
return Values;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x)));
return new NestList<TOut>(Values.Select(x => func(x)));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value.Select(x => new Nest<T>(x)));
return new Nest<T>(Values.Select(x => new Nest<T>(x)));
}

// Enumerator implementation
public IEnumerator<T> GetEnumerator()
{
return Value.GetEnumerator();
return Values.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()


+ 4
- 0
src/TensorFlowNET.Core/Common/Types/NestNode.cs View File

@@ -10,7 +10,11 @@ namespace Tensorflow.Common.Types
/// <typeparam name="T"></typeparam>
public class NestNode<T> : INestStructure<T>
{
public NestType NestType => NestType.Node;
public T Value { get; set; }
public int ShallowNestedCount => 1;

public int TotalNestedCount => 1;
public NestNode(T value)
{
Value = value;


+ 2
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -161,8 +161,8 @@ namespace Tensorflow
break;
}

yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray()));
}
}



+ 7
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -352,13 +352,19 @@ namespace Tensorflow.Eager
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
break;
case TF_AttrType.TF_ATTR_SHAPE:
var dims = (value as long[]).ToArray();
long[] dims;
if (value is Shape shape) dims = shape.dims.ToArray();
else if (value is long[] longs) dims = longs.ToArray();
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray();
else dims = ((long[])value).ToArray();
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
else if(value is string str)
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;


+ 7
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);


bool unconnected_gradients_zero = unconnected_gradients == "zero";
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
Shape tensor_shape = new(dims);

if(status.Code != TF_Code.TF_OK)
{
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager
}
else
{
Shape tensor_shape = new(dims);
return new TapeTensor(id, dtype, tensor_shape);
}
}
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager
return dtype == dtypes.variant || dtype == dtypes.resource;
}

bool ListContainNone(long[] list)
bool ListContainNone(long[]? list)
{
if(list is null)
{
return true;
}
int len = list.Length;
if(len == 0)
{


+ 6
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Eager
var str = NDArrayRender.ToString(nd);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
public string ToString(int maxLength)
{
var nd = new NDArray(this);
var str = NDArrayRender.ToString(nd, maxLength);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
}
}

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Exceptions
{
public class NotOkStatusException : TensorflowException
{
public NotOkStatusException() : base()
{

}

public NotOkStatusException(string message) : base(message)
{

}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -1,4 +1,5 @@
using System.Linq;
using Tensorflow.Eager;

namespace Tensorflow.Framework.Models
{
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models
shapes.Insert(0, dim);
return new TensorSpec(shapes.ToArray(), _dtype);
}

public static TensorSpec FromTensor(Tensor tensor, string? name = null)
{
if(tensor is EagerTensor)
{
return new TensorSpec(tensor.shape, tensor.dtype, name);
}
else
{
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name);
}
}
}
}

+ 89
- 0
src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs View File

@@ -0,0 +1,89 @@
using Tensorflow.Graphs;

namespace Tensorflow.Framework
{
internal static class auto_control_deps_utils
{
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs";
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph)
{
List<int> result = new List<int>();
// A cache to store the read only resource inputs of an Op.
// Operation -> ObjectIdentitySet of resource handles.
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs =
new Dictionary<Operation, HashSet<Tensor>>();

for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++)
{
Tensor t = func_graph.Inputs[inputIndex];
if (t.dtype != dtypes.resource)
continue;

bool readOnly = true;
foreach (var op in t.consumers())
{
if (opReadOnlyResourceInputs.ContainsKey(op))
{
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
else
{
List<int> indices = _get_read_only_resource_input_indices_op(op);
opReadOnlyResourceInputs[op] = new HashSet<Tensor>(
indices.Select(i => op.inputs[i]));
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
}

if (readOnly)
result.Add(inputIndex);
}

return result;
}

private static List<int> _get_read_only_resource_input_indices_op(Operation op)
{
// ignore the RESOURCE_READ_OPS

int[] read_only_input_indices;

try
{
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR);
}
catch (InvalidArgumentError)
{
return new List<int>();
}

int read_only_index = 0;
List<int> result = new();
for (int i = 0; i < op.inputs.Length; i++)
{
if (read_only_index >= read_only_input_indices.Length)
{
break;
}
if (op.inputs[i].dtype != dtypes.resource)
{
continue;
}
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index])
{
result.Add(i);
read_only_index++;
}
}
return result;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -42,10 +42,10 @@ namespace Tensorflow.Framework
func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());

var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);



+ 13
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable<IVariableV1> Variables => func_graph.Variables;
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables;
internal NameAttrList AsNameAttrList
{
get
{
NameAttrList ret = new() { Name = this.Name };
foreach (var (name, value) in _attrs)
{
ret.Attr[name] = value;
}
return ret;
}
}

public ConcreteFunction(string name)
{


+ 2
- 3
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients
? input_values[0].rank + dim_int
: dim_int % input_values[0].rank;
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
var sizes_tensor = constant_op.constant(sizes);
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList();
}
else if (constant_op.is_constant(concat_dim))
{
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList();
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList();
}
else
{


+ 2
- 2
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
public Dictionary<string, AttrValue> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
internal Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable
var flat_func_args = nest.flatten(func_args as object);
var flat_func_kwargs = nest.flatten(func_kwargs as object);
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
.Where(x => x is Tensor).Select(x => (Tensor)x));
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray());

//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -129,7 +129,7 @@ namespace Tensorflow
}
}

protected Graph outer_graph;
internal Graph outer_graph;
public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions;
public SafeGraphHandle c_graph => _handle;


+ 0
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -4,8 +4,6 @@
{
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
public int Implementation { get; set; }
}
}

+ 30
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -1,7 +1,35 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
using Newtonsoft.Json;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
// TODO: complete the implementation
public class LSTMCellArgs : LayerArgs
public class LSTMCellArgs : AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }
[JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")]
public int Implementation { get; set; } = 2;

}
}

+ 3
- 23
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -7,12 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
// TODO(Rinne): add regularizers.
public class RNNArgs : AutoSerializeLayerArgs
{
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnCell Cell { get; set; } = null;
[JsonProperty("cells")]
public IList<IRnnCell> Cells { get; set; } = null;

[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
@@ -25,8 +19,10 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public bool Unroll { get; set; } = false;
[JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false;

public int? InputDim { get; set; }
public int? InputLength { get; set; }
// TODO: Add `num_constants` and `zero_output_for_mask`.
public Dictionary<string, object> Kwargs { get; set; } = null;

public int Units { get; set; }
public Activation Activation { get; set; }
@@ -38,21 +34,5 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public float Dropout { get; set; } = .0f;
public bool ZeroOutputForMask { get; set; } = false;
public float RecurrentDropout { get; set; } = .0f;

// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
}
}

+ 1
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs View File

@@ -1,7 +1,4 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
@@ -25,5 +22,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }

}
}

+ 1
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -5,7 +5,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<IRnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
public bool ReverseStateOrder = false;
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -160,6 +160,18 @@ namespace Tensorflow.Keras.Layers
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
public ILayer LeakyReLU(float alpha = 0.3f);

public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2);

public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,


+ 9
- 3
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -7,13 +7,19 @@ namespace Tensorflow.Keras.Layers.Rnn
{
public interface IRnnCell: ILayer
{
GeneralizedTensorShape StateSize { get; }
GeneralizedTensorShape OutputSize { get; }
bool IsTFRnnCell { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
INestStructure<long>? StateSize { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
INestStructure<long>? OutputSize { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
/// </summary>
bool SupportOptionalArgs { get; }
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype);
}
}

+ 9
- 9
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy
{
public class NDArrayRender
{
public static string ToString(NDArray array)
public static string ToString(NDArray array, int maxLength = 10)
{
Shape shape = array.shape;
if (shape.IsScalar)
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy

var s = new StringBuilder();
s.Append("array(");
Build(s, array);
Build(s, array, maxLength);
s.Append(")");
return s.ToString();
}

static void Build(StringBuilder s, NDArray array)
static void Build(StringBuilder s, NDArray array, int maxLength)
{
var shape = array.shape;

@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy
var len = shape[0];
s.Append("[");

if (len <= 10)
if (len <= maxLength)
{
for (int i = 0; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy
}
else
{
for (int i = 0; i < 5; i++)
for (int i = 0; i < maxLength / 2; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy
s.Append(" ... ");
s.AppendLine();

for (int i = (int)len - 5; i < len; i++)
for (int i = (int)len - maxLength / 2; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");


+ 23
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -19,13 +19,14 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving.Common;
using Tensorflow.NumPy;

namespace Tensorflow
{
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape
public class Shape : INestStructure<long>
{
public int ndim => _dims == null ? -1 : _dims.Length;
long[] _dims;
@@ -41,6 +42,27 @@ namespace Tensorflow
}
}

public NestType NestType => NestType.List;

public int ShallowNestedCount => ndim;
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
public int TotalNestedCount => ndim;

public IEnumerable<long> Flatten() => dims.Select(x => x);

public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func)
{
return new NestList<TOut>(dims.Select(x => func(x)));
}

public Nest<long> AsNest()
{
return new NestList<long>(Flatten()).AsNest();
}

#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
public int Length => ndim;
public long[] Slice(int start, int length)


+ 22
- 0
src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Operations.Initializers
{
/// <summary>
/// An initializer specially used for debugging (to load weights from disk).
/// </summary>
class NpyLoadInitializer : IInitializer
{
string _path;
public NpyLoadInitializer(string path) { _path = path; }
public string ClassName => "";
public IDictionary<string, object> Config => new Dictionary<string, object>();
public Tensor Apply(InitializerArgs args)
{
return np.load(_path);
}
}
}

+ 1
- 2
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -58,8 +58,7 @@ public class Orthogonal : IInitializer

if (num_rows < num_cols)
{
// q = tf.linalg.matrix_transpose(q);
throw new NotImplementedException("");
q = array_ops.matrix_transpose(q);
}

return _gain * tf.reshape(q, shape);


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -89,7 +89,7 @@ namespace Tensorflow
gate_inputs = nn_ops.bias_add(gate_inputs, _bias);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);

var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);


+ 6
- 2
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -181,8 +181,12 @@ namespace Tensorflow
{
throw new NotImplementedException();
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
throw new NotImplementedException();
}
public INestStructure<long> StateSize => throw new NotImplementedException();
public INestStructure<long> OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}


+ 57
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -15,9 +15,11 @@
******************************************************************************/

using Google.Protobuf;
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using static Tensorflow.Binding;
using static Tensorflow.OpDef.Types;

@@ -387,9 +389,13 @@ namespace Tensorflow
case "list(type)":
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def)));
break;
case "list(float)":
if (value != null)
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray());
break;
case "list(int)":
if (value != null)
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x)));
break;
case "bool":
attr_value.B = (bool)value;
@@ -420,6 +426,15 @@ namespace Tensorflow
case "list(shape)":
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def)));
break;
case "func":
attr_value.Func = _MakeFunc(value, attr_def.Name);
break;
case "list(func)":
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name));
break;
case "list(string)":
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x)));
break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
}
@@ -427,6 +442,47 @@ namespace Tensorflow
return attr_value;
}

private NameAttrList _MakeFunc(object func, string arg_name)
{
if(func is NameAttrList attrList)
{
return attrList;
}
NameAttrList fn_attr;
if(func is string funcStr)
{
fn_attr = new NameAttrList() { Name = funcStr };
}
else if(func is ConcreteFunction concrete)
{
concrete.AddTograph(ops.get_default_graph());
fn_attr = concrete.AsNameAttrList;
}
else if(func is EagerDefinedFunction eager)
{
eager.AddToGraph(ops.get_default_graph());
fn_attr = new NameAttrList() { Name = eager.Name };
}
else
{
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}");
}
return fn_attr;
}

private List<NameAttrList> _MakeFuncList(object funcList, string arg_name)
{
List<NameAttrList> res = new List<NameAttrList>();
if(funcList is IEnumerable enumerable)
{
foreach(var func in enumerable)
{
res.Add(_MakeFunc(func, arg_name));
}
}
return res;
}

private bool _IsListParameter(ArgDef arg)
{
if (!String.IsNullOrEmpty(arg.NumberAttr))


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow
return num;
}

protected Tensor[] _outputs;
internal Tensor[] _outputs;
public virtual Tensor[] outputs => _outputs;
public Tensor output => _outputs.FirstOrDefault();



+ 13
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -46,9 +46,9 @@ namespace Tensorflow
/// </summary>
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
protected IntPtr _handle; // _c_op in python

private readonly Graph _graph;
protected Graph _graph;

internal Func<Operation, object[], Tensor[]> _gradient_function;

@@ -69,6 +69,7 @@ namespace Tensorflow
//private OperationDescription _op_desc;

public NodeDef node_def => GetNodeDef();
protected Operation() { }

public Operation(IntPtr handle, Graph g = null)
{
@@ -185,7 +186,16 @@ namespace Tensorflow
}

public virtual T get_attr<T>(string name)
=> (T)get_attr(name);
{
if (typeof(T).IsValueType)
{
return (T)Convert.ChangeType(get_attr(name), typeof(T));
}
else
{
return (T)get_attr(name);
}
}

internal unsafe TF_DataType _get_attr_type(string name)
{


+ 2
- 4
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -38,10 +39,6 @@ namespace Tensorflow.Operations

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public bool _dynamic_size;
public Shape _element_shape;

public List<Tensor> _colocate_with;

Tensor _handle;
public override Tensor handle => _handle;
@@ -56,6 +53,7 @@ namespace Tensorflow.Operations
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_size = size;
_flow = constant_op.constant(0);
_infer_shape = infer_shape;
_element_shape = element_shape ?? Shape.Null;


+ 176
- 4
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -16,7 +16,9 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using static Tensorflow.Binding;

@@ -33,18 +35,18 @@ namespace Tensorflow.Operations
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public bool colocate_with_first_write_call => _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public bool infer_shape => _infer_shape;
public bool _dynamic_size;
public override bool infer_shape => _infer_shape;
public List<Shape> _element_shape;

public List<Tensor> _colocate_with;

internal Tensor _handle;
public Tensor handle => _handle;
public override Tensor handle => _handle;
internal Tensor _flow;
public override Tensor flow => _flow;

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -55,6 +57,7 @@ namespace Tensorflow.Operations
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_dtype = dtype;
_size = size;

_colocate_with_first_write_call = colocate_with_first_write_call;
if (colocate_with_first_write_call)
@@ -235,4 +238,173 @@ namespace Tensorflow.Operations
return value;
}
}

public class _GraphTensorArrayV2 : TensorArray
{
internal TF_DataType _dtype;
public override TF_DataType dtype => _dtype;

/// <summary>
/// Used to keep track of what tensors the TensorArray should be
/// colocated with. We choose to colocate the TensorArray with the
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public Shape _element_shape;

public List<Tensor> _colocate_with;

internal Tensor _handle;
public override Tensor handle => _handle;
internal Tensor _flow;
public override Tensor flow => _flow;

public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
Debug.Assert(handle is null);
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_size = size;

if(flow is not null && flow.dtype != dtypes.variant)
{
throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead");
}
if(flow is null && size is null)
{
throw new ValueError("Argument `size` must be provided if argument `flow` is not provided.");
}
if(flow is not null && size is not null)
{
throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time.");
}
if(flow is not null && element_shape is not null)
{
throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time.");
}

_dtype = dtype;

_element_shape = element_shape;
_infer_shape = infer_shape;
tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope =>
{
if (flow is null)
{
_flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name);
}
else
{
_flow = flow;
}
});

_colocate_with_first_write_call = false;
_colocate_with = null;
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray());
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow);
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public override Tensor read<T>(T index, string name = null)
{
if(index is Tensor tensor)
{
return read(tensor, name);
}
else
{
throw new TypeError("Please use non-generic method instead.");
}
}

public Tensor read(Tensor index, string name = null)
{
return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope =>
{
return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name);
});
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name);

return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor);
}

private Tensor size(string name = null)
{
if(!_dynamic_size && _size is not null)
{
return ops.convert_to_tensor(_size, dtypes.int32);
}
else
{
return gen_list_ops.tensor_list_length(_flow, name);
}
}

public override Tensor stack(string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate
{
int ta_size;
if(!_dynamic_size && (_size is not null))
{
var size_tensor = tensor_util.constant_value(_size);
ta_size = size_tensor is null ? -1 : (int)size_tensor;
}
else
{
ta_size = -1;
}
var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape);
return value;
});
}

public override Tensor gather(Tensor indices, string name = null)
{
return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name);
}
}
}

+ 78
- 22
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -119,6 +119,27 @@ namespace Tensorflow
}
}

public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
Tensor shapeTensor;
if(shape.Length > 1)
{
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32);
if(shapeTensor.ndim > 1)
{
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1));
}
}
else
{
shapeTensor = shape[0];
}
var output = fill(shapeTensor, array_ops.constant(0, dtype), name);
Debug.Assert(output.dtype.as_base_dtype() == dtype);
return output;
}

public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
{
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate
@@ -307,6 +328,9 @@ namespace Tensorflow
public static Tensor fill<T>(Shape dims, T value, string name = null)
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);

public static Tensor fill<T>(Tensor dims, T value, string name = null)
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);

/// <summary>
/// Returns the rank of a tensor.
/// </summary>
@@ -947,38 +971,70 @@ namespace Tensorflow
});
}

public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1,
string name = "split")
/// <summary>
/// Transposes last two dimensions of tensor `a`.
/// For example:
/// <code> python
/// x = tf.constant([[1, 2, 3], [4, 5, 6]])
/// tf.matrix_transpose(x) # [[1, 4],
/// # [2, 5],
/// # [3, 6]]
/// </code>
/// Matrix with two batch dimensions.
/// x.shape is [1, 2, 3, 4]
/// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
/// </summary>
/// <param name="a"></param>
/// <param name="name"></param>
/// <param name="conjugate"></param>
/// <returns></returns>
/// <exception cref="ValueError"></exception>
public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false)
{
if (num == -1)
num = (int)size_splits.shape[0];

return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name);
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_shape = a.shape;
var ndims = a.shape.ndim;
Axis perm;
if(ndims != 0)
{
if (ndims < 2)
{
throw new ValueError("Argument `a` should be a (batch) matrix with rank " +
$">= 2. Received `a` = {a} with shape: {a_shape}");
}
perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray());
}
else
{
var a_rank = a.rank;
perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray());
}
return transpose(a, perm:perm, conjugate:conjugate);
});
}

public static Tensor[] split<T>(Tensor value, int num_split, T axis,
public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null,
string name = "split")
{
var size_splits = ops.convert_to_tensor(num_split);
return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name);
}

if (tf.Context.executing_eagerly())
public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1,
string name = "split")
{
if(num_or_size_splits.Length == 0)
{
return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.Context);
throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split.");
}
var size_splits = ops.convert_to_tensor(num_or_size_splits);

var _op = tf.OpDefLib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}

private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null)
{
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { value });
var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32);
var _inputs_flat = new List<Tensor> { axis_tensor };
_inputs_flat.AddRange(input);
var _attrs = new object[] { "num_split", num_split, "T", _attr_T };
if(num == -1)
{
num = (int)size_splits.shape[0];
}

return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name);
}

public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)


+ 5
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -675,16 +675,17 @@ namespace Tensorflow
}
}

public static Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public static Tensors while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
{
var executing_eagerly = tf.Context.executing_eagerly();
if (!executing_eagerly)
{
throw new NotImplementedException("");
return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations,
name: name);
}

return tf_with(ops.name_scope("name", "while"), delegate


+ 77
- 0
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -16,12 +16,20 @@

using System;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class control_flow_util
{
public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0");
/// <summary>
/// Return true if `op` is an Exit.
/// </summary>
@@ -196,5 +204,74 @@ namespace Tensorflow
}
return null;
}

public static bool EnableControlFlowV2(Graph graph)
{
return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0);
}

public static string create_new_tf_function(FuncGraph func_graph)
{
var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary<string, AttrValue>());
func.AddToGraph(func_graph);
return func_graph.Name;
}

public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs)
{
if(inputs.Length == 0)
{
return (null, new Tensor[0]);
}
else
{
return (inputs[0], inputs);
}
}

public static Tensor[] run_as_function_for_tape_gradients(Func<Tensor[], Tensor[]> make_op, Tensor[] inputs)
{
if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
&& !(ops.get_default_graph().building_function))
{
throw new NotImplementedException();
}
else
{
return make_op(inputs);
}
}

public static string unique_fn_name(string scope, string name)
{
return $"{scope}{name}_{ops.uid()}".Replace("/", "_");
}

public static bool output_all_intermediates()
{
if (in_defun())
{
return false;
}
if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR")
{
return false;
}
// TODO(Rinne): check this after refactoring keras building.
return false;
}

public static bool in_defun()
{
if (tf.Context.executing_eagerly())
{
return false;
}

var graph = ops.get_default_graph();
// TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph
return graph is FuncGraph;
}
}
}

+ 489
- 10
src/TensorFlowNET.Core/Operations/gen_array_ops.cs
File diff suppressed because it is too large
View File


+ 1042
- 81
src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
File diff suppressed because it is too large
View File


+ 827
- 109
src/TensorFlowNET.Core/Operations/gen_io_ops.cs
File diff suppressed because it is too large
View File


+ 1308
- 0
src/TensorFlowNET.Core/Operations/gen_list_ops.cs
File diff suppressed because it is too large
View File


+ 585
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs
File diff suppressed because it is too large
View File


+ 409
- 0
src/TensorFlowNET.Core/Operations/gen_nn_ops.cs
File diff suppressed because it is too large
View File


+ 1469
- 104
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
File diff suppressed because it is too large
View File


+ 3
- 3
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -1778,10 +1778,10 @@ new_height, new_width");
{
// a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3]
var a_xy_minmax = array_ops.split(
value: boxes_a, num_split: 4, axis: 2);
value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));
// b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3]
var b_xy_minmax = array_ops.split(
value: boxes_b, num_split: 4, axis: 2);
value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));

var i_xmin = math_ops.maximum(
a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 }));
@@ -1943,7 +1943,7 @@ new_height, new_width");
using (ops.name_scope("canonicalize_coordinates"))
{
// y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3]
var yx = array_ops.split(value: boxes, num_split: 4, axis: 2);
var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));
var y_1_is_min = math_ops.reduce_all(
gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0]));
var y_minmax = control_flow_ops.cond(


+ 111
- 0
src/TensorFlowNET.Core/Operations/list_ops.cs View File

@@ -0,0 +1,111 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow.Operations
{
internal class list_ops
{
private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype)
{
if(list_handle is EagerTensor eagerTensor)
{
var handle_data = new CppShapeInferenceResult.Types.HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType()
{
Shape = element_shape.as_proto(),
Dtype = element_dtype.as_datatype_enum(),
Type = new FullTypeDef() { TypeId = FullTypeId.TftArray }
});
list_handle.HandleData = handle_data;
}
}

private static Tensor _build_element_shape(Shape? shape)
{
if(shape is null || shape.IsNull)
{
return ops.convert_to_tensor(-1);
}
else
{
return ops.convert_to_tensor(shape);
}
}

public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null)
{
var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name);
_set_handle_data(result, shape, element_dtype);
return result;
}

public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null)
{
var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name);
_set_handle_data(result, tensor.shape, tensor.dtype);
return result;
}

public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape),
element_dtype, name);
}

public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item,
bool resize_if_index_out_of_bounds = false, string? name = null)
{
if (resize_if_index_out_of_bounds)
{
var input_list_size = gen_list_ops.tensor_list_length(input_handle);
input_handle = control_flow_ops.cond(index >= input_list_size,
() => gen_list_ops.tensor_list_resize(input_handle, index + 1),
() => input_handle);
}
var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name);
handle_data_util.copy_handle_data(input_handle, output_handle);
return output_handle;
}

public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name);
}

public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name);
}

public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null,
string? name = null)
{
if(input_handle is not null)
{
var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name);
handle_data_util.copy_handle_data(input_handle, output_handle);
return output_handle;
}
else
{
var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape),
constant_op.constant(-1), name);
_set_handle_data(output_handle, element_shape, tensor.dtype);
return output_handle;
}
}

public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1,
string? name = null)
{
return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype,
max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name);
}
}
}

+ 16
- 4
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -13,11 +13,23 @@ namespace Tensorflow
/// <returns></returns>
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
var new_ta = tf.TensorArray(
dtype: old_ta.dtype,
infer_shape: old_ta.infer_shape,
if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph()))
{
throw new NotImplementedException("Attempting to build a graph-mode TF2-style "
+ "TensorArray from either an eager-mode "
+ "TensorArray or a TF1-style TensorArray. "
+ "This is not currently supported. You may be "
+ "attempting to capture a TensorArray "
+ "inside a tf.function or tf.data map function. "
+ "Instead, construct a new TensorArray inside "
+ "the function.");
}
var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape,
colocate_with_first_write_call: old_ta.colocate_with_first_write_call);

new_ta._dynamic_size = old_ta._dynamic_size;
new_ta._size = old_ta._size;
new_ta._colocate_with = old_ta._colocate_with;
new_ta._element_shape = old_ta._element_shape;
return new_ta;
}



+ 401
- 0
src/TensorFlowNET.Core/Operations/while_v2.cs View File

@@ -0,0 +1,401 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
class _OperationWithOutputs : Operation
{
public _OperationWithOutputs(IntPtr handle, Graph g = null)
{
_handle = handle;
_graph = g;
_outputs = null;
g._add_op(this);
}
}
internal class while_v2
{
public static Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int maximum_iterations = -1,
int parallel_iterations = 10,
string name = null,
bool back_prop = true,
bool return_same_structure = true)
{
var orig_loop_vars = loop_vars;
var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray();
int len_orig_loop_vars = orig_loop_vars.Length;

loop_vars = _tensor_array_to_flow(loop_vars);
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors();

var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars));

var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray();

if(string.IsNullOrEmpty(name))
{
name = "while";
}

return tf_with<ITensorFlowObject, Tensor[]>(ops.name_scope(name), nameScopeWhile =>
{
string scope = (nameScopeWhile as ops.NameScope).scope_name;
string cond_name = control_flow_util.unique_fn_name(scope, "cond");
string body_name = control_flow_util.unique_fn_name(scope, "body");

var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations);
var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype,
name: "loop_counter");
loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray();

var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)}
.Concat(loop_vars_signature.Flatten()).ToArray();

// TODO(Rinne): possible wrong implemenation here.
var add_control_dependencies = false;

object[] wrapped_cond(object[] inputs)
{
Tensor loop_counter = (Tensor)inputs[0];
Tensor maximum_iterations_arg = (Tensor)inputs[1];
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray();
var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args));
if(pred.shape.IsNull || pred.shape.ndim > 0)
{
pred = array_ops.squeeze(pred);
}
if(maximum_iterations == -1)
{
return new object[] { pred };
}
else
{
return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) };
}
}

var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null,
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies);

bool stateful_parallelism = false;

object[] wrapped_body(object[] inputs)
{
Tensor loop_counter = (Tensor)inputs[0];
Tensor maximum_iterations_arg = (Tensor)inputs[1];
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray();

_copy_handle_data(loop_vars.Flatten().Skip(2), args);

foreach(var t in cond_graph.external_captures)
{
var graph = (FuncGraph)(ops.get_default_graph());
graph.capture(t);
}

var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args));
outputs = _tensor_array_to_flow(outputs);

return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray();
}

var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature,
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism);

// TODO(Rinne): possible wrong implementation here.
NestList<Tensors> loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() });
body_graph.Outputs.AddRange(body_graph.internal_captures);
cond_graph.as_default();
int num_cond_captures = cond_graph.external_captures.Length;
Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray()));
_duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray());
cond_graph.Exit();

int first_loop_var_index = 2;

int num_flattened_oututs = orig_loop_vars.Length;
int num_original_outputs = body_graph.Outputs.Length;
if (back_prop && control_flow_util.output_all_intermediates())
{
var intermediate_tensors = _get_intermediates(body_graph);

foreach(var intermediate_tensor in intermediate_tensors)
{
var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations);
loop_vars_list.Values.Add(tensor_list);

cond_graph.as_default();
cond_graph.capture(tensor_list);
cond_graph.Exit();

body_graph.as_default();
var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor);
body_graph.Outputs.Add(appended_tensor_list);
body_graph.Exit();
}
}

List<Tensor> flattened_loop_vars = new();
foreach(var item in loop_vars_list.Values)
{
flattened_loop_vars.AddRange(item.Flatten());
}
// skip the check

// TODO(Rinne): deal with control dependencies
var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray();
var span = new Span<Shape>(output_shapes).Slice(first_loop_var_index, num_flattened_oututs);
for(int i = 0; i < span.Length; i++)
{
span[i] = flat_shape_invariants[i];
}

Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations,
(nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism);

if (!ops.get_default_graph().building_function)
{
outputs = outputs.Select(t => array_ops.identity(t)).ToArray();
}

var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray();

if (!back_prop)
{
output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray();
}
outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars);

return outputs;
});
}

private static Tensors _tensor_array_to_flow(Tensors loop_vars)
{
if(loop_vars.NestType == NestType.Node)
{
if(loop_vars.NodeValue is FakeTensorByTensorArray fake)
{
return new Tensors(fake.TensorArray.flow);
}
else
{
return new Tensors(loop_vars.NodeValue!);
}
}
else if(loop_vars.NestType == NestType.List)
{
List<INestStructure<Tensor>> list = new();
foreach(var item in loop_vars.ListValue!)
{
if(item.NestType == NestType.Node)
{
var nested = item.AsNest();
if (nested.NodeValue is FakeTensorByTensorArray fake)
{
list.Add(new Nest<Tensor>(fake.TensorArray.flow));
}
else
{
list.Add(new Nest<Tensor>(nested.NodeValue!));
}
}
else
{
list.Add(new Nest<Tensor>(item.AsNest()));
}
}
return Tensors.FromNest(new Nest<Tensor>(list));
}
else
{
throw new NotImplementedException();
}
}

private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph,
Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism)
{
var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op);
var body_stateful_ops = body_graph.get_operations().Select(x => x.op);

bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0;

Tensor[] _make_op(Tensor[] inputs)
{
Tensor[] outputs;
if (is_stateful)
{
outputs = gen_functional_ops._while(
inputs,
control_flow_util.create_new_tf_function(cond_graph),
control_flow_util.create_new_tf_function(body_graph),
output_shapes,
parallel_iterations,
name
);
}
else
{
outputs = gen_functional_ops.stateless_while(
inputs,
control_flow_util.create_new_tf_function(cond_graph),
control_flow_util.create_new_tf_function(body_graph),
output_shapes,
parallel_iterations,
name
);
}
var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs);
_copy_handle_data(body_graph.Outputs, tensors);
_set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph});
while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs });
while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism });

cond_graph.outer_graph = ops.get_default_graph();
body_graph.outer_graph = ops.get_default_graph();
// TODO(Rinne): set the two graphs to while_op
return tensors;
}

return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars);
}

/// <summary>
/// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies.
/// </summary>
/// <param name="op"></param>
/// <param name="branch_graphs"></param>
private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs)
{
List<int> read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList();
foreach(var branch_graph in branch_graphs)
{
if (read_only_indices.Count == 0)
{
break;
}
var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph);
read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList();
}
AttrValue.Types.ListValue listValue = new();
listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x));
op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue()
{
List = listValue
});
}

private static Tensors _pack_sequence_as<T>(INestStructure<T> loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars)
{
var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item =>
{
var (flow, y) = item;
if (y is FakeTensorByTensorArray ta)
{
return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow));
}
else
{
return flow;
}
}).ToArray();
return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors();
}

private static Tensor[] _get_intermediates(FuncGraph func_graph)
{
List<Tensor> intermediates = new();
var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1);

foreach(var op in func_graph.get_operations())
{
Debug.Assert(op is Operation);
var oper = (Operation)op;
if(oper.type == "Identity" || oper.type == "MutexLock")
{
continue;
}
foreach(var o in op.outputs)
{
if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o))
{
intermediates.Add(o);
}
}
}
return intermediates.ToArray();
}

private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures)
{
var types = body_graph_captures.Select(t => t.dtype).ToList();
var c_graph = cond_graph.c_graph;
var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList();

var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList();

List<Tensor> tensors = new();
foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types))
{
var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph);
op._outputs = new Tensor[] { tensor };
tensors.Add(tensor);
}

var tuples = zip(body_graph_captures, tensors).ToList();
var keys = body_graph_captures.Select(t => t.Id).ToList();
cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2));
cond_graph.Inputs.AddRange(tensors);
}

private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
var op = c_api.TF_FinishOperation(desc, tf.Status);
tf.Status.Check(true);
var output = new TF_Output();
output.oper = op;
output.index = 0;
return output;
}

private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph)
{
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder");
}

private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype,
string name)
{
return ops.convert_to_tensor(value, dtype, name, false);
}

private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1)
{
return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations");
}

private static void _copy_handle_data(IEnumerable<Tensor> src_tensors, IEnumerable<Tensor> dst_tensors)
{
foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors))
{
handle_data_util.copy_handle_data(src_t, dst_t);
}
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Tensorflow.Exceptions;
using Tensorflow.Util;
using static Tensorflow.c_api;

@@ -88,7 +89,7 @@ namespace Tensorflow
case TF_Code.TF_INVALID_ARGUMENT:
throw new InvalidArgumentError(message);
default:
throw new TensorflowException(message);
throw new NotOkStatusException(message);
}
}
}


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -111,7 +111,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>



+ 7
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -105,6 +105,13 @@ namespace Tensorflow
_id = ops.uid();
}

internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output)
{
Tensor ret = new Tensor(op, value_index, dtype);
ret._tf_output = tf_output;
return ret;
}

protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, null);


+ 24
- 0
src/TensorFlowNET.Core/Tensors/TensorArray.cs View File

@@ -14,7 +14,9 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Common.Types;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -44,5 +46,27 @@ namespace Tensorflow

public abstract Tensor stack(string name = null);
public abstract Tensor gather(Tensor indices, string name = null);

internal bool _dynamic_size;
internal Tensor _size;
internal List<Tensor> _colocate_with;
internal Shape _element_shape;

public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant))
{
return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow,
infer_shape, element_shape, colocate_with_first_write_call, name);
}
else
{
return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow,
infer_shape, element_shape, colocate_with_first_write_call, name);
}
}
}
}

+ 41
- 13
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -4,6 +4,8 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Operations;
using Tensorflow.Common.Extensions;

namespace Tensorflow
{
@@ -58,7 +60,7 @@ namespace Tensorflow
public Tensor this[params string[] slices]
=> this.First()[slices];

private Tensors(Nest<Tensor> nested) : base(nested)
internal Tensors(Nest<Tensor> nested) : base(nested)
{

}
@@ -68,9 +70,9 @@ namespace Tensorflow
}

public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
public Tensors(IList<Tensor> tensors) : base(tensors.Select(x => new Nest<Tensor>(x)))
{
}

public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
@@ -78,6 +80,32 @@ namespace Tensorflow
}

/// <summary>
/// Get the element in shallow level. For example, for ts = [1, [2, 3], 4],
/// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3]
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public Tensors GetShallow(int index)
{
if(NestType == NestType.Node)
{
if(index > 0)
{
throw new IndexOutOfRangeException();
}
return this;
}
else if(NestType == NestType.List)
{
return ListValue![index].AsNest().ToTensors();
}
else
{
throw new NotImplementedException();
}
}

private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
{
if (tensors.Length == 0)
@@ -115,8 +143,8 @@ namespace Tensorflow
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
ListValue = new() { new Nest<Tensor>(NodeValue), new Nest<Tensor>(tensor) };
NodeValue = null;
}
else if(NestType == NestType.List)
{
@@ -125,7 +153,7 @@ namespace Tensorflow
else //Empty
{
NestType = NestType.Node;
Value = tensor;
NodeValue = tensor;
}
}

@@ -140,9 +168,9 @@ namespace Tensorflow
else if (NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue = new() { new Nest<Tensor>(NodeValue) };
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
NodeValue = null;
}
else if(NestType == NestType.List)
{
@@ -151,7 +179,7 @@ namespace Tensorflow
else // empty
{
NestType = NestType.List;
ListValue = tensors.Select(x => new Nest<Tensor>(x)).ToList();
ListValue = tensors.Select(x => new Nest<Tensor>(x) as INestStructure<Tensor>).ToList();
}
}

@@ -166,9 +194,9 @@ namespace Tensorflow
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue = new() { new Nest<Tensor>(NodeValue) };
ListValue.Insert(index, new Nest<Tensor>(tensor));
Value = null;
NodeValue = null;
}
else
{
@@ -283,7 +311,7 @@ namespace Tensorflow
=> tensors?.SingleOrNull;

public static implicit operator Tensor[](Tensors tensors)
=> tensors.Flatten().ToArray();
=> tensors.Flatten().ToArray();
#endregion

public static Tensors? FromNest(Nest<Tensor> nested)
@@ -298,7 +326,7 @@ namespace Tensorflow
public void Deconstruct(out Tensor a, out Tensors? b)
{
a = this.First();
b = Length == 1? null : new Tensors(this.Skip(1));
b = Length == 1? null : new Tensors(this.Skip(1).ToArray());
}

public override string ToString()


+ 1
- 2
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -179,8 +179,7 @@ namespace Tensorflow.Train
// handles slot variables.
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1);
return res as IVariableV1;
}


+ 20
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -170,11 +170,28 @@ namespace Tensorflow
public Tensor value()
=> GraphElement ?? _read_variable_op();

protected Tensor _read_variable_op()
protected Tensor _read_variable_op(bool no_copy = false)
{
variable_accessed(this);
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);

Tensor read_and_set_handle(bool no_copy)
{
if (no_copy)
{
gen_resource_variable_ops.disable_copy_on_read(handle);
}
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);
return result;
}

// TODO(Rinne): deal with caching device.
var result = read_and_set_handle(no_copy);
if (!tf.Context.executing_eagerly())
{
tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle },
backward_function: (x, _) => x);
}

// have to set shape when converting to substituent placeholder
if (result.shape.ndim == -1)


+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -576,7 +576,7 @@ namespace Tensorflow
public static HandleData get_resource_handle_data(Tensor graph_op)
{
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data));
}

public static void dismantle_graph(Graph graph)


+ 48
- 47
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -25,6 +25,7 @@ using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
using Tensorflow.Common.Types;
using System.Diagnostics;

namespace Tensorflow.Keras
{
@@ -485,7 +486,7 @@ namespace Tensorflow.Keras
var first_flatted_input = flatted_inptus[0];
var time_steps = first_flatted_input.shape[0];
var batch = first_flatted_input.shape[1];
var time_steps_t = (int)first_flatted_input.shape[0];
var time_steps_t = tf.shape(first_flatted_input)[0];

foreach (var input_ in flatted_inptus)
{
@@ -704,7 +705,7 @@ namespace Tensorflow.Keras
var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t));
input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t));
}

foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
@@ -730,18 +731,15 @@ namespace Tensorflow.Keras
(output_time_zero, _) = step_function(input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1);
var output_ta = new List<TensorArray>();
for (int i = 0; i < output_time_zero.ToList().Count; i++)
foreach(var output in output_time_zero.Flatten())
{
var Out = output_time_zero.ToList()[i];
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
@@ -750,7 +748,7 @@ namespace Tensorflow.Keras
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
@@ -810,9 +808,9 @@ namespace Tensorflow.Keras
masking_fn = null;
}

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
Func<Tensors, Tensor> cond = (time) => (time[0] < time_steps_t);
int parallel_iterations = 32;
new_states = states;
Tensors final_outputs;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
@@ -825,7 +823,7 @@ namespace Tensorflow.Keras

var prev_output = flat_zero_output;
var output_ta_t = output_ta;
Tensor _step(Tensor time)
Tensors _step(Tensors tensors)
{
/*
RNN step function.
@@ -838,23 +836,28 @@ namespace Tensorflow.Keras
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

Tensor time = tensors[0];
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray;
Tensors prev_output = tensors.GetShallow(2);
Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray());

var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
var (output, new_states) = step_function(current_input, states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.ToList();
var flat_new_state = new_states_internal.ToList();
var flat_state = states.Flatten().ToList();
var flat_new_state = new_states.Flatten().ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
@@ -865,38 +868,37 @@ namespace Tensorflow.Keras
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();
new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
output_ta_t = zip(output_ta_t, flat_new_output).Select(item =>
{
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();
Debug.Assert(flat_output.Count() == 1);
output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First());


new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states)
.ToArray().ToTensors();

}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }
.Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors();
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations);
new_states = final_outputs.Skip(3).ToList();
}
else
{
var output_ta_t = output_ta;
new_states = states;
Tensor _step(Tensor time)
Tensors _step(Tensors tensors)
{
Tensor time = tensors[0];
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray;
Tensors states = new Tensors(tensors.Skip(2).ToArray());
var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
var (output, new_states) = step_function(current_input, states.MergeWith(constants));
var flat_state = new_states.Flatten().ToList();
var flat_new_state = new_states_internal.Flatten().ToList();
var flat_new_state = new_states.Flatten().ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
@@ -906,24 +908,23 @@ namespace Tensorflow.Keras
}
var flat_output = Nest.Flatten(output);
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
output_ta_t = zip(output_ta_t, flat_output).Select(item =>
{
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();

new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;
Debug.Assert(flat_output.Count() == 1);
output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First());

new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors();
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
Debug.Assert(output_ta.Count == 1);
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors();
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations);
new_states = final_outputs.Skip(2).ToList();
}
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();

output_ta = new List<TensorArray> { (final_outputs[1] as FakeTensorByTensorArray).TensorArray };
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors();
}

Func<Tensor, Tensor> set_shape;


+ 2
- 0
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -38,6 +38,8 @@ namespace Tensorflow.Keras.Engine
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);

// TODO(Rinne): set save spec if null

scope.__exit__();

return outputs;


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Build.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
graph.as_default();
var shapes = input_shape.ToShapeArray();
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)));
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray());
try
{
Call(x, training: false);


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(x),
X = new Tensors(x.ToArray()),
Y = y,
Model = this,
StepsPerExecution = _steps_per_execution
@@ -188,7 +188,7 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine

var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(train_x),
X = new Tensors(train_x.ToArray()),
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}


+ 4
- 0
src/TensorFlowNET.Keras/IsExternalInit.cs View File

@@ -0,0 +1,4 @@
namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 33
- 9
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -702,16 +702,14 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
Dropout = dropout,
RecurrentDropout = recurrent_dropout
});

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells)
=> new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = cells.ToList()
});
=> new StackedRNNCells(cells.ToList(), new StackedRNNCellsArgs());

/// <summary>
///
@@ -756,9 +754,8 @@ namespace Tensorflow.Keras.Layers
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
=> new RNN(cell, new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
@@ -775,9 +772,8 @@ namespace Tensorflow.Keras.Layers
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
=> new RNN(cell, new RNNArgs
{
Cells = cell.ToList(),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
@@ -786,6 +782,33 @@ namespace Tensorflow.Keras.Layers
TimeMajor = time_major
});


public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal", // TODO(Wanglongzhi2001),glorot_uniform has not been developed.
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2)
=> new LSTMCell(new LSTMCellArgs
{
Units = uints,
Activation = keras.activations.GetActivationFromName(activation),
RecurrentActivation = keras.activations.GetActivationFromName(recurrent_activation),
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
UnitForgetBias = unit_forget_bias,
Dropout = dropout,
RecurrentDropout = recurrent_dropout,
Implementation = implementation
});

/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.
/// </summary>
@@ -846,7 +869,8 @@ namespace Tensorflow.Keras.Layers
GoBackwards = go_backwards,
Stateful = stateful,
TimeMajor = time_major,
Unroll = unroll
Unroll = unroll,
UnitForgetBias = unit_forget_bias
});

/// <summary>


+ 12
- 3
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -4,10 +4,11 @@ using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class DropoutRNNCellMixin: RnnCellBase
public abstract class DropoutRNNCellMixin: Layer, IRnnCell
{
public float dropout;
public float recurrent_dropout;
@@ -17,6 +18,14 @@ namespace Tensorflow.Keras.Layers.Rnn

}

public abstract INestStructure<long> StateSize { get; }
public abstract INestStructure<long> OutputSize { get; }
public abstract bool SupportOptionalArgs { get; }
public virtual Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}

protected void _create_non_trackable_mask_cache()
{
@@ -32,7 +41,7 @@ namespace Tensorflow.Keras.Layers.Rnn

}

public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;
@@ -44,7 +53,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// Get the recurrent dropout mask for RNN cell.
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
public Tensors? get_recurrent_dropout_mask_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;


+ 93
- 9
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -14,22 +15,105 @@ namespace Tensorflow.Keras.Layers.Rnn
public class LSTM : RNN
{
LSTMArgs args;
InputSpec[] state_spec;
int units => args.Units;
InputSpec[] _state_spec;
InputSpec _input_spec;
bool _could_use_gpu_kernel;

public LSTM(LSTMArgs args) :
base(args)
base(CreateCell(args), args)
{
this.args = args;
state_spec = new[] { units, units }
.Select(dim => new InputSpec(shape: (-1, dim)))
.ToArray();
_input_spec = new InputSpec(ndim: 3);
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray();
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh
&& args.RecurrentActivation == keras.activations.Sigmoid
&& args.RecurrentDropout == 0 && !args.Unroll && args.UseBias
&& ops.executing_eagerly_outside_functions();
}

private static IRnnCell CreateCell(LSTMArgs lstmArgs)
{
return new LSTMCell(new LSTMCellArgs()
{
Units = lstmArgs.Units,
Activation = lstmArgs.Activation,
RecurrentActivation = lstmArgs.RecurrentActivation,
UseBias = lstmArgs.UseBias,
KernelInitializer = lstmArgs.KernelInitializer,
RecurrentInitializer = lstmArgs.RecurrentInitializer,
UnitForgetBias = lstmArgs.UnitForgetBias,
BiasInitializer = lstmArgs.BiasInitializer,
// TODO(Rinne): kernel_regularizer
// TODO(Rinne): recurrent_regularizer
// TODO(Rinne): bias_regularizer
// TODO(Rinne): kernel_constriant
// TODO(Rinne): recurrent_constriant
// TODO(Rinne): bias_constriant
Dropout = lstmArgs.Dropout,
RecurrentDropout = lstmArgs.RecurrentDropout,
Implementation = lstmArgs.Implementation,
DType = lstmArgs.DType,
Trainable = lstmArgs.Trainable
});
}

protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return base.Call(inputs, initial_state: state, training: training);
// skip the condition of ragged input

(inputs, initial_state, _) = _process_inputs(inputs, initial_state, null);

Tensor mask = null;
if(optional_args is RnnOptionalArgs rnnArgs)
{
mask = rnnArgs.Mask;
}

var single_input = inputs.Single;
var input_shape = single_input.shape;
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];

_maybe_reset_cell_dropout_mask(Cell);

Func<Tensors, Tensors, (Tensors, Tensors)> step = (inputs, states) =>
{
var res = Cell.Apply(inputs, states, training is null ? true : training.Value);
var (output, state) = res;
return (output, state);
};

var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: null,
go_backwards: args.GoBackwards,
mask: mask,
unroll: args.Unroll,
input_length: ops.convert_to_tensor(timesteps),
time_major: args.TimeMajor,
zero_output_for_mask: args.ZeroOutputForMask,
return_all_outputs: args.ReturnSequences
);

Tensor output;
if (args.ReturnSequences)
{
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards);
}
else
{
output = last_output;
}

if (args.ReturnState)
{
return new Tensor[] { output }.Concat(states).ToArray().ToTensors();
}
else
{
return output;
}
}
}
}

+ 221
- 4
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -1,16 +1,233 @@
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Newtonsoft.Json;
using Serilog.Core;
using System.Diagnostics;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public class LSTMCell : Layer
/// <summary>
/// Cell class for the LSTM layer.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
/// This class processes one step within the whole time sequence input, whereas
/// `tf.keras.layer.LSTM` processes the whole sequence.
/// </summary>
public class LSTMCell : DropoutRNNCellMixin
{
LSTMCellArgs args;
LSTMCellArgs _args;
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IInitializer _bias_initializer;
IVariableV1 _bias;
INestStructure<long> _state_size;
INestStructure<long> _output_size;
public override INestStructure<long> StateSize => _state_size;

public override INestStructure<long> OutputSize => _output_size;

public override bool SupportOptionalArgs => false;
public LSTMCell(LSTMCellArgs args)
: base(args)
{
this.args = args;
_args = args;
if (args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
_args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
_args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
if (_args.RecurrentDropout != 0f && _args.Implementation != 1)
{
Debug.WriteLine("RNN `implementation=2` is not supported when `recurrent_dropout` is set." +
"Using `implementation=1`.");
_args.Implementation = 1;
}

_state_size = new NestList<long>(_args.Units, _args.Units);
_output_size = new NestNode<long>(_args.Units);
}

public override void build(KerasShapesWrapper input_shape)
{
base.build(input_shape);
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (input_dim, _args.Units * 4),
initializer: _args.KernelInitializer
);

_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units * 4),
initializer: _args.RecurrentInitializer
);

if (_args.UseBias)
{
if (_args.UnitForgetBias)
{
Tensor bias_initializer()
{
return keras.backend.concatenate(
new Tensors(
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units))),
tf.ones_initializer.Apply(new InitializerArgs(shape: (_args.Units))),
_args.BiasInitializer.Apply(new InitializerArgs(shape: (_args.Units)))), axis: 0);
}
}
else
{
_bias_initializer = _args.BiasInitializer;
}
_bias = add_weight("bias", (_args.Units * 4),
initializer: _bias_initializer
);
}
built = true;
}
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var h_tm1 = states[0]; // previous memory state
var c_tm1 = states[1]; // previous carry state

var dp_mask = get_dropout_mask_for_cell(inputs, training.Value, count: 4);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(
h_tm1, training.Value, count: 4);

Tensor c;
Tensor o;
if (_args.Implementation == 1)
{
Tensor inputs_i;
Tensor inputs_f;
Tensor inputs_c;
Tensor inputs_o;
if (0f < _args.Dropout && _args.Dropout < 1f)
{
inputs_i = inputs * dp_mask[0];
inputs_f = inputs * dp_mask[1];
inputs_c = inputs * dp_mask[2];
inputs_o = inputs * dp_mask[3];
}
else
{
inputs_i = inputs;
inputs_f = inputs;
inputs_c = inputs;
inputs_o = inputs;
}
var k = tf.split(_kernel.AsTensor(), num_split: 4, axis: 1);
Tensor k_i = k[0], k_f = k[1], k_c = k[2], k_o = k[3];
var x_i = math_ops.matmul(inputs_i, k_i);
var x_f = math_ops.matmul(inputs_f, k_f);
var x_c = math_ops.matmul(inputs_c, k_c);
var x_o = math_ops.matmul(inputs_o, k_o);
if (_args.UseBias)
{
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0);
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3];
x_i = gen_nn_ops.bias_add(x_i, b_i);
x_f = gen_nn_ops.bias_add(x_f, b_f);
x_c = gen_nn_ops.bias_add(x_c, b_c);
x_o = gen_nn_ops.bias_add(x_o, b_o);
}

Tensor h_tm1_i;
Tensor h_tm1_f;
Tensor h_tm1_c;
Tensor h_tm1_o;
if (0f < _args.RecurrentDropout && _args.RecurrentDropout < 1f)
{
h_tm1_i = h_tm1 * rec_dp_mask[0];
h_tm1_f = h_tm1 * rec_dp_mask[1];
h_tm1_c = h_tm1 * rec_dp_mask[2];
h_tm1_o = h_tm1 * rec_dp_mask[3];
}
else
{
h_tm1_i = h_tm1;
h_tm1_f = h_tm1;
h_tm1_c = h_tm1;
h_tm1_o = h_tm1;
}
var x = new Tensor[] { x_i, x_f, x_c, x_o };
var h_tm1_array = new Tensor[] { h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o };
(c, o) = _compute_carry_and_output(x, h_tm1_array, c_tm1);
}
else
{
if (0f < _args.Dropout && _args.Dropout < 1f)
inputs = inputs * dp_mask[0];
var z = math_ops.matmul(inputs, _kernel.AsTensor());
z += math_ops.matmul(h_tm1, _recurrent_kernel.AsTensor());
if (_args.UseBias)
{
z = tf.nn.bias_add(z, _bias);
}
var z_array = tf.split(z, num_split: 4, axis: 1);
(c, o) = _compute_carry_and_output_fused(z_array, c_tm1);
}
var h = o * _args.Activation.Apply(c);
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组
return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestList<Tensor>(h, c) }).ToTensors();
}

/// <summary>
/// Computes carry and output using split kernels.
/// </summary>
/// <param name="x"></param>
/// <param name="h_tm1"></param>
/// <param name="c_tm1"></param>
/// <returns></returns>
/// <exception cref="NotImplementedException"></exception>
public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm1)
{
Tensor x_i = x[0], x_f = x[1], x_c = x[2], x_o = x[3];
Tensor h_tm1_i = h_tm1[0], h_tm1_f = h_tm1[1], h_tm1_c = h_tm1[2],
h_tm1_o = h_tm1[3];

var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor();
int startIndex = (int)_recurrent_kernel_tensor.shape[0];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, 0 }, new[] { startIndex, _args.Units });
var i = _args.RecurrentActivation.Apply(
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units }, new[] { startIndex, _args.Units});
var f = _args.RecurrentActivation.Apply(
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units });
var c = f * c_tm1 + i * _args.Activation.Apply(
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units });
var o = _args.Activation.Apply(
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));

return new Tensors(c, o);
}

/// <summary>
/// Computes carry and output using fused kernels.
/// </summary>
/// <param name="z"></param>
/// <param name="c_tm1"></param>
/// <returns></returns>
public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
{
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
var i = _args.RecurrentActivation.Apply(z0);
var f = _args.RecurrentActivation.Apply(z1);
var c = f * c_tm1 + i * _args.Activation.Apply(z2);
var o = _args.RecurrentActivation.Apply(z3);
return new Tensors(c, o);
}
}

}

+ 84
- 97
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -11,6 +11,7 @@ using Tensorflow.Common.Extensions;
using System.Linq.Expressions;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
using System.Runtime.CompilerServices;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers.Rnn
@@ -30,25 +31,39 @@ namespace Tensorflow.Keras.Layers.Rnn
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
protected IRnnCell _cell;

public RNN(RNNArgs args) : base(PreConstruct(args))
private IRnnCell _cell;
protected IRnnCell Cell
{
_args = args;
SupportsMasking = true;

// if is StackedRnncell
if (args.Cells != null)
get
{
_cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cells
});
return _cell;
}
else
init
{
_cell = args.Cell;
_cell = value;
_self_tracked_trackables.Add(_cell);
}
}

public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args))
{
_args = args;
SupportsMasking = true;

Cell = cell;

// get input_shape
_args = PreConstruct(args);

_num_constants = 0;
}

public RNN(IEnumerable<IRnnCell> cells, RNNArgs args) : base(PreConstruct(args))
{
_args = args;
SupportsMasking = true;

Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs());

// get input_shape
_args = PreConstruct(args);
@@ -65,7 +80,7 @@ namespace Tensorflow.Keras.Layers.Rnn
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null);
var nested = Cell.StateSize.MapStructure<Tensor?>(x => null);
_states = nested.AsNest().ToTensors();
}
return _states;
@@ -73,7 +88,7 @@ namespace Tensorflow.Keras.Layers.Rnn
set { _states = value; }
}

private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
private INestStructure<Shape> compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
@@ -83,13 +98,15 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// state_size is a array of ints or a positive integer
var state_size = _cell.StateSize.ToSingleShape();
var state_size = Cell.StateSize;
if(state_size?.TotalNestedCount == 1)
{
state_size = new NestList<long>(state_size.Flatten().First());
}

// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func<Shape, Shape> _get_output_shape;
_get_output_shape = (flat_output_size) =>
Func<long, Shape> _get_output_shape = (flat_output_size) =>
{
var output_dim = flat_output_size.as_int_list();
var output_dim = new Shape(flat_output_size).as_int_list();
Shape output_shape;
if (_args.ReturnSequences)
{
@@ -110,33 +127,30 @@ namespace Tensorflow.Keras.Layers.Rnn
return output_shape;
};

Type type = _cell.GetType();
Type type = Cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
INestStructure<Shape> output_shape;
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape());
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
output_shape = Nest.MapStructure(_get_output_shape, Cell.OutputSize);
}
else
{
output_shape = _get_output_shape(state_size);
output_shape = new NestNode<Shape>(_get_output_shape(state_size.Flatten().First()));
}

if (_args.ReturnState)
{
Func<Shape, Shape> _get_state_shape;
_get_state_shape = (flat_state) =>
Func<long, Shape> _get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
var state_shape = new int[] { (int)batch }.concat(new Shape(flat_state).as_int_list());
return new Shape(state_shape);
};


var state_shape = _get_state_shape(state_size);
var state_shape = Nest.MapStructure(_get_state_shape, state_size);

return new List<Shape> { output_shape, state_shape };
return new Nest<Shape>(new[] { output_shape, state_shape } );
}
else
{
@@ -171,7 +185,9 @@ namespace Tensorflow.Keras.Layers.Rnn

public override void build(KerasShapesWrapper input_shape)
{
object get_input_spec(Shape shape)
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);

InputSpec get_input_spec(Shape shape)
{
var input_spec_shape = shape.as_int_list();

@@ -206,7 +222,6 @@ namespace Tensorflow.Keras.Layers.Rnn
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);

}

// Check whether the input shape contains any nested shapes. It could be
@@ -214,10 +229,13 @@ namespace Tensorflow.Keras.Layers.Rnn
// numpy inputs.


if (!_cell.Built)
if (Cell is Layer layer && !layer.Built)
{
_cell.build(input_shape);
layer.build(input_shape);
layer.Built = true;
}

this.built = true;
}

/// <summary>
@@ -248,10 +266,10 @@ namespace Tensorflow.Keras.Layers.Rnn

(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);

_maybe_reset_cell_dropout_mask(_cell);
if (_cell is StackedRNNCells)
_maybe_reset_cell_dropout_mask(Cell);
if (Cell is StackedRNNCells)
{
var stack_cell = _cell as StackedRNNCells;
var stack_cell = Cell as StackedRNNCells;
foreach (IRnnCell cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
@@ -298,23 +316,23 @@ namespace Tensorflow.Keras.Layers.Rnn

// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
Func<Tensors, Tensors, (Tensors, Tensors)> step;
bool is_tf_rnn_cell = _cell.IsTFRnnCell;
bool is_tf_rnn_cell = false;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
if (!Cell.SupportOptionalArgs)
{
throw new ValueError(
$"RNN cell {_cell} does not support constants." +
$"RNN cell {Cell} does not support constants." +
$"Received: constants={constants}");
}

step = (inputs, states) =>
{
constants = new Tensors(states.TakeLast(_num_constants));
states = new Tensors(states.SkipLast(_num_constants));
constants = new Tensors(states.TakeLast(_num_constants).ToArray());
states = new Tensors(states.SkipLast(_num_constants).ToArray());
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states.Single);
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states);
};
}
else
@@ -322,7 +340,7 @@ namespace Tensorflow.Keras.Layers.Rnn
step = (inputs, states) =>
{
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
var (output, new_states) = _cell.Apply(inputs, states);
var (output, new_states) = Cell.Apply(inputs, states);
return (output, new_states);
};
}
@@ -366,6 +384,11 @@ namespace Tensorflow.Keras.Layers.Rnn
}
else
{
//var tapeSet = tf.GetTapeSet();
//foreach(var tape in tapeSet)
//{
// tape.Watch(output);
//}
return output;
}
}
@@ -389,18 +412,18 @@ namespace Tensorflow.Keras.Layers.Rnn
throw new NotImplementedException();
}

private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
{
if (inputs.Length > 1)
{
if (_num_constants != 0)
{
initial_state = new Tensors(inputs.Skip(1));
initial_state = new Tensors(inputs.Skip(1).ToArray());
}
else
{
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants));
constants = new Tensors(inputs.TakeLast(_num_constants));
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants).ToArray());
constants = new Tensors(inputs.TakeLast(_num_constants).ToArray());
}
if (len(initial_state) == 0)
initial_state = null;
@@ -418,7 +441,7 @@ namespace Tensorflow.Keras.Layers.Rnn
tmp.add(tf.math.count_nonzero(s.Single()));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
initial_state = tf.cond(non_zero_count > 0, States, initial_state);
if ((int)non_zero_count.numpy() > 0)
{
initial_state = States;
@@ -428,16 +451,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
initial_state = States;
}
// TODO(Wanglongzhi2001),
// initial_state = tf.nest.map_structure(
//# When the layer has a inferred dtype, use the dtype from the
//# cell.
// lambda v: tf.cast(
// v, self.compute_dtype or self.cell.compute_dtype
// ),
// initial_state,
// )

//initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
}
else if (initial_state is null)
{
@@ -477,7 +491,7 @@ namespace Tensorflow.Keras.Layers.Rnn

}

void _maybe_reset_cell_dropout_mask(ILayer cell)
protected void _maybe_reset_cell_dropout_mask(ILayer cell)
{
if (cell is DropoutRNNCellMixin CellDRCMixin)
{
@@ -488,26 +502,21 @@ namespace Tensorflow.Keras.Layers.Rnn

private static RNNArgs PreConstruct(RNNArgs args)
{
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary<string, object>();
}

// If true, the output for masked timestep will be zeros, whereas in the
// false case, output from previous timestep is returned for masked timestep.
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
var zeroOutputForMask = args.ZeroOutputForMask;

Shape input_shape;
var propIS = (Shape)args.Kwargs.Get("input_shape", null);
var propID = (int?)args.Kwargs.Get("input_dim", null);
var propIL = (int?)args.Kwargs.Get("input_length", null);
var propIS = args.InputShape;
var propID = args.InputDim;
var propIL = args.InputLength;

if (propIS == null && (propID != null || propIL != null))
{
input_shape = new Shape(
propIL ?? -1,
propID ?? -1);
args.Kwargs["input_shape"] = input_shape;
args.InputShape = input_shape;
}

return args;
@@ -558,36 +567,14 @@ namespace Tensorflow.Keras.Layers.Rnn

protected Tensors get_initial_state(Tensors inputs)
{
var get_initial_state_fn = _cell.GetType().GetMethod("get_initial_state");

var input = inputs[0];
var input_shape = inputs.shape;
var input_shape = array_ops.shape(inputs);
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;

Tensors init_state = new Tensors();

if(get_initial_state_fn != null)
{
init_state = (Tensors)get_initial_state_fn.Invoke(_cell, new object[] { inputs, batch_size, dtype });
}
//if (_cell is RnnCellBase rnn_base_cell)
//{
// init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
//}
else
{
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype);
}
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);

return init_state;
}

// Check whether the state_size contains multiple states.
public static bool is_multiple_state(GeneralizedTensorShape state_size)
{
return state_size.Shapes.Length > 1;
}
}
}

+ 0
- 24
src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs View File

@@ -1,24 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class RnnCellBase: Layer, IRnnCell
{
public RnnCellBase(LayerArgs args) : base(args) { }
public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
public abstract bool IsTFRnnCell { get; }
public abstract bool SupportOptionalArgs { get; }
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}
}
}

+ 3
- 18
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -10,14 +10,14 @@ namespace Tensorflow.Keras.Layers.Rnn
public class SimpleRNN : RNN
{
SimpleRNNArgs args;
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args))
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args), args)
{
this.args = args;
}

private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args)
private static SimpleRNNCell CreateCellForArgs(SimpleRNNArgs args)
{
args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs()
return new SimpleRNNCell(new SimpleRNNCellArgs()
{
Units = args.Units,
Activation = args.Activation,
@@ -30,21 +30,6 @@ namespace Tensorflow.Keras.Layers.Rnn
DType = args.DType,
Trainable = args.Trainable,
});
return args;
}

public override void build(KerasShapesWrapper input_shape)
{
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_buildInputShape = input_shape;

_kernel = add_weight("kernel", (single_shape[-1], args.Units),
initializer: args.KernelInitializer
//regularizer = self.kernel_regularizer,
//constraint = self.kernel_constraint,
//caching_device = default_caching_device,
);
}
}
}

+ 9
- 15
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -7,6 +7,7 @@ using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;
using Tensorflow.Keras.Utils;
using Tensorflow.Graphs;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -23,12 +24,11 @@ namespace Tensorflow.Keras.Layers.Rnn
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;
INestStructure<long> _state_size;
INestStructure<long> _output_size;

public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override bool IsTFRnnCell => true;
public override INestStructure<long> StateSize => _state_size;
public override INestStructure<long> OutputSize => _output_size;
public override bool SupportOptionalArgs => false;

public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
@@ -41,8 +41,8 @@ namespace Tensorflow.Keras.Layers.Rnn
}
this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
_state_size = new GeneralizedTensorShape(args.Units);
_output_size = new GeneralizedTensorShape(args.Units);
_state_size = new NestNode<long>(args.Units);
_output_size = new NestNode<long>(args.Units);
}

public override void build(KerasShapesWrapper input_shape)
@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.Layers.Rnn
{
// TODO(Rinne): check if it will have multiple tensors when not nested.
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);
var dp_mask = get_dropout_mask_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_mask_for_cell(prev_output, training.Value);

Tensor h;
var ranks = inputs.rank;
@@ -98,7 +98,6 @@ namespace Tensorflow.Keras.Layers.Rnn
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}
var tmp = _recurrent_kernel.AsTensor();
Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());

if (_args.Activation != null)
@@ -116,10 +115,5 @@ namespace Tensorflow.Keras.Layers.Rnn
return new Tensors(output, output);
}
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size.Value, dtype.Value);
}
}
}

+ 58
- 110
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -1,10 +1,8 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -15,30 +13,15 @@ namespace Tensorflow.Keras.Layers.Rnn
public class StackedRNNCells : Layer, IRnnCell
{
public IList<IRnnCell> Cells { get; set; }
public bool reverse_state_order;
public bool _reverse_state_order;

public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
public StackedRNNCells(IEnumerable<IRnnCell> cells, StackedRNNCellsArgs args) : base(args)
{
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary<string, object>();
}
foreach (var cell in args.Cells)
{
//Type type = cell.GetType();
//var CallMethodInfo = type.GetMethod("Call");
//if (CallMethodInfo == null)
//{
// throw new ValueError(
// "All cells must have a `Call` method. " +
// $"Received cell without a `Call` method: {cell}");
//}
}
Cells = args.Cells;
reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false);
Cells = cells.ToList();

if (reverse_state_order)
_reverse_state_order = args.ReverseStateOrder;

if (_reverse_state_order)
{
throw new WarningException("reverse_state_order=True in StackedRNNCells will soon " +
"be deprecated. Please update the code to work with the " +
@@ -47,49 +30,37 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public GeneralizedTensorShape StateSize
public bool SupportOptionalArgs => false;

public INestStructure<long> StateSize
{
get
{
GeneralizedTensorShape state_size = new GeneralizedTensorShape(1, Cells.Count);
if (reverse_state_order && Cells.Count > 0)
if (_reverse_state_order)
{
var idxAndCell = Cells.Reverse().Select((cell, idx) => (idx, cell));
foreach (var cell in idxAndCell)
{
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
}
var state_sizes = Cells.Reverse().Select(cell => cell.StateSize);
return new Nest<long>(state_sizes);
}
else
{
//foreach (var cell in Cells)
//{
// state_size.Shapes.add(cell.StateSize.Shapes.First());

//}
var idxAndCell = Cells.Select((cell, idx) => (idx, cell));
foreach (var cell in idxAndCell)
{
state_size.Shapes[cell.idx] = cell.cell.StateSize.Shapes.First();
}
var state_sizes = Cells.Select(cell => cell.StateSize);
return new Nest<long>(state_sizes);
}
return state_size;
}
}

public object output_size
public INestStructure<long> OutputSize
{
get
{
var lastCell = Cells.LastOrDefault();
if (lastCell.OutputSize.ToSingleShape() != -1)
var lastCell = Cells.Last();
if(lastCell.OutputSize is not null)
{
return lastCell.OutputSize;
}
else if (RNN.is_multiple_state(lastCell.StateSize))
else if (RnnUtils.is_multiple_state(lastCell.StateSize))
{
return lastCell.StateSize.First();
//throw new NotImplementedException("");
return new NestNode<long>(lastCell.StateSize.Flatten().First());
}
else
{
@@ -98,79 +69,65 @@ namespace Tensorflow.Keras.Layers.Rnn
}
}

public Tensors get_initial_state(Tensors inputs = null, long? batch_size = null, TF_DataType? dtype = null)
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
var cells = reverse_state_order ? Cells.Reverse() : Cells;
Tensors initial_states = new Tensors();
var cells = _reverse_state_order ? Cells.Reverse() : Cells;
List<Tensor> initial_states = new List<Tensor>();
foreach (var cell in cells)
{
var get_initial_state_fn = cell.GetType().GetMethod("get_initial_state");
if (get_initial_state_fn != null)
{
var result = (Tensors)get_initial_state_fn.Invoke(cell, new object[] { inputs, batch_size, dtype });
initial_states.Add(result);
}
else
{
initial_states.Add(RnnUtils.generate_zero_filled_state_for_cell(cell, inputs, batch_size.Value, dtype.Value));
}
initial_states.Add(cell.GetInitialState(inputs, batch_size, dtype));
}
return initial_states;
return new Tensors(initial_states);
}

protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
// Recover per-cell states.
var state_size = reverse_state_order ? StateSize.Reverse() : StateSize;
var nested_states = reverse_state_order ? state.Flatten().Reverse() : state.Flatten();

var state_size = _reverse_state_order ? new NestList<long>(StateSize.Flatten().Reverse()) : StateSize;
var nested_states = Nest.PackSequenceAs(state_size, Nest.Flatten(states).ToArray());

var new_nest_states = new Tensors();
var new_nest_states = Nest<Tensor>.Empty;
// Call the cells in order and store the returned states.
foreach (var (cell, states) in zip(Cells, nested_states))
foreach (var (cell, internal_states) in zip(Cells, nested_states))
{
// states = states if tf.nest.is_nested(states) else [states]
var type = cell.GetType();
bool IsTFRnnCell = type.GetProperty("IsTFRnnCell") != null;
state = len(state) == 1 && IsTFRnnCell ? state.FirstOrDefault() : state;

RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
Tensors? constants = rnn_optional_args?.Constants;

Tensors new_states;
(inputs, new_states) = cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
(inputs, new_states) = cell.Apply(inputs, internal_states, optional_args: new RnnOptionalArgs() { Constants = constants });

new_nest_states.Add(new_states);
new_nest_states = new_nest_states.MergeWith(new_states);
}
new_nest_states = reverse_state_order ? new_nest_states.Reverse().ToArray() : new_nest_states.ToArray();
return new Nest<Tensor>(new List<Nest<Tensor>> {
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(inputs.Single()) }), new Nest<Tensor>(new_nest_states) })
.ToTensors();
return Tensors.FromNest((inputs, Nest.PackSequenceAs(state_size, Nest.Flatten(new_nest_states).ToArray())));
}

public void build()
public override void build(KerasShapesWrapper input_shape)
{
built = true;
// @tf_utils.shape_type_conversion
// def build(self, input_shape) :
// if isinstance(input_shape, list) :
// input_shape = input_shape[0]
// for cell in self.cells:
// if isinstance(cell, Layer) and not cell.built:
// with K.name_scope(cell.name):
// cell.build(input_shape)
// cell.built = True
// if getattr(cell, 'output_size', None) is not None:
// output_dim = cell.output_size
// elif _is_multiple_state(cell.state_size) :
// output_dim = cell.state_size[0]
// else:
// output_dim = cell.state_size
// input_shape = tuple([input_shape[0]] +
// tensor_shape.TensorShape(output_dim).as_list())
// self.built = True
var shape = input_shape.ToSingleShape();
foreach(var cell in Cells)
{
if(cell is Layer layer && !layer.Built)
{
// ignored the name scope.
layer.build(shape);
layer.Built = true;
}
INestStructure<long> output_dim;
if(cell.OutputSize is not null)
{
output_dim = cell.OutputSize;
}
else if (RnnUtils.is_multiple_state(cell.StateSize))
{
output_dim = new NestNode<long>(cell.StateSize.Flatten().First());
}
else
{
output_dim = cell.StateSize;
}
shape = new Shape(new long[] { shape.dims[0] }.Concat(output_dim.Flatten()).ToArray());
}
this.Built = true;
}

public override IKerasConfig get_config()
@@ -198,14 +155,5 @@ namespace Tensorflow.Keras.Layers.Rnn
// deserialize_layer(cell_config, custom_objects = custom_objects))
// return cls(cells, **config)
}

public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
{
throw new NotImplementedException();
}

public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => true;
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 26
- 16
src/TensorFlowNET.Keras/Utils/RnnUtils.cs View File

@@ -10,33 +10,33 @@ namespace Tensorflow.Keras.Utils
{
internal static class RnnUtils
{
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
internal static Tensors generate_zero_filled_state(Tensor batch_size_tensor, INestStructure<long> state_size, TF_DataType dtype)
{
Func<GeneralizedTensorShape, Tensor> create_zeros;
create_zeros = (GeneralizedTensorShape unnested_state_size) =>
Func<long, Tensor> create_zeros = (unnested_state_size) =>
{
var flat_dims = unnested_state_size.ToSingleShape().dims;
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray();
return array_ops.zeros(new Shape(init_state_size), dtype: dtype);
var flat_dims = new Shape(unnested_state_size).dims;
var init_state_size = new Tensor[] { batch_size_tensor }.
Concat(flat_dims.Select(x => tf.constant(x, dtypes.int32))).ToArray();
return array_ops.zeros(init_state_size, dtype: dtype);
};

// TODO(Rinne): map structure with nested tensors.
if(state_size.Shapes.Length > 1)
if(state_size.TotalNestedCount > 1)
{
return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s))));
return new Tensors(state_size.Flatten().Select(s => create_zeros(s)).ToArray());
}
else
{
return create_zeros(state_size);
return create_zeros(state_size.Flatten().First());
}

}

internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype)
internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, Tensor batch_size, TF_DataType dtype)
{
if (inputs != null)
if (inputs is not null)
{
batch_size = inputs.shape[0];
batch_size = array_ops.shape(inputs)[0];
dtype = inputs.dtype;
}
return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
@@ -77,17 +77,27 @@ namespace Tensorflow.Keras.Utils
Debug.Assert(initial_state is null && constants is null);
if(num_constants > 0)
{
constants = inputs.TakeLast(num_constants).ToTensors();
inputs = inputs.SkipLast(num_constants).ToTensors();
constants = inputs.TakeLast(num_constants).ToArray().ToTensors();
inputs = inputs.SkipLast(num_constants).ToArray().ToTensors();
}
if(inputs.Length > 1)
{
initial_state = inputs.Skip(1).ToTensors();
inputs = inputs.Take(1).ToTensors();
initial_state = inputs.Skip(1).ToArray().ToTensors();
inputs = inputs.Take(1).ToArray().ToTensors();
}
}

return (inputs, initial_state, constants);
}

/// <summary>
/// Check whether the state_size contains multiple states.
/// </summary>
/// <param name="state_size"></param>
/// <returns></returns>
public static bool is_multiple_state(INestStructure<long> state_size)
{
return state_size.TotalNestedCount > 1;
}
}
}

+ 58
- 36
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -21,21 +21,6 @@ namespace Tensorflow.Keras.UnitTest.Layers
[TestMethod]
public void SimpleRNNCell()
{
//var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
//var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
//var x = tf.random.normal((4, 100));
//var (y, h1) = cell.Apply(inputs: x, states: h0);
//var h2 = h1;
//Assert.AreEqual((4, 64), y.shape);
//Assert.AreEqual((4, 64), h2[0].shape);

//var model = keras.Sequential(new List<ILayer>
//{
// keras.layers.InputLayer(input_shape: (4,100)),
// keras.layers.SimpleRNNCell(64)
//});
//model.summary();

var cell = tf.keras.layers.SimpleRNNCell(64, dropout: 0.5f, recurrent_dropout: 0.5f);
var h0 = new Tensors { tf.zeros(new Shape(4, 64)) };
var x = tf.random.normal((4, 100));
@@ -60,24 +45,63 @@ namespace Tensorflow.Keras.UnitTest.Layers
}

[TestMethod]
public void SimpleRNN()
public void LSTMCell()
{
//var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
///*var simple_rnn = keras.layers.SimpleRNN(4);
//var output = simple_rnn.Apply(inputs);
//Assert.AreEqual((32, 4), output.shape);*/
var inputs = tf.ones((2, 100));
var states = new Tensors { tf.zeros((2, 4)), tf.zeros((2, 4)) };
var rnn = tf.keras.layers.LSTMCell(4);
var (output, new_states) = rnn.Apply(inputs, states);
Assert.AreEqual((2, 4), output.shape);
Assert.AreEqual((2, 4), new_states[0].shape);
}

//var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
//var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
//Assert.AreEqual((6, 10, 4), whole_sequence_output.shape);
//Assert.AreEqual((6, 4), final_state.shape);
[TestMethod]
public void TrainLSTMWithMnist()
{
var input = keras.Input((784));
var x = keras.layers.Reshape((28, 28)).Apply(input);
x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
x = keras.layers.LSTM(100).Apply(x);
var output = keras.layers.Dense(10, activation: "softmax").Apply(x);

var inputs = keras.Input(shape: (10, 8));
var x = keras.layers.SimpleRNN(4).Apply(inputs);
var output = keras.layers.Dense(10).Apply(x);
var model = keras.Model(inputs, output);
var model = keras.Model(input, output);
model.summary();
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = true,
ValidationSize = 55000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1);
}

[TestMethod]
public void SimpleRNN()
{
var input = keras.Input((784));
var x = keras.layers.Reshape((28, 28)).Apply(input);
x = keras.layers.SimpleRNN(10).Apply(x);
var output = keras.layers.Dense(10, activation: "softmax").Apply(x);

var model = keras.Model(input, output);
model.summary();
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2);
}

[TestMethod]
public void RNNForSimpleRNNCell()
{
@@ -100,15 +124,13 @@ namespace Tensorflow.Keras.UnitTest.Layers
}

[TestMethod]
public void WlzTest()
public void RNNForLSTMCell()
{
long[] b = { 1, 2, 3 };
Shape a = new Shape(Unknown).concatenate(b);
Console.WriteLine(a);
var inputs = tf.ones((5, 10, 8));
var rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(4));
var output = rnn.Apply(inputs);
Console.WriteLine($"output: {output}");
Assert.AreEqual((5, 4), output.shape);
}


}
}

+ 2
- 2
test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs View File

@@ -28,8 +28,8 @@ namespace TensorFlowNET.UnitTest.ManagedAPI

var i = tf.constant(2);
var j = tf.constant(3);
Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10);
Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
Func<Tensors, Tensor> c = (x) => tf.less(x[0] + x[1], 10);
Func<Tensors, Tensors> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
var r = tf.while_loop(c, b, new[] { i, j });
Assert.AreEqual(5, (int)r[0]);
Assert.AreEqual(6, (int)r[1]);


+ 26
- 6
tools/Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -21,7 +21,8 @@ namespace Tensorflow.CodeGen
{
sb.Append("Operation ");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.Append("Tensor ");
}
@@ -70,7 +71,8 @@ namespace Tensorflow.CodeGen
{
sb.AppendLine("return null;");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.AppendLine("return _fast_path_result[0];");
}
@@ -81,6 +83,14 @@ namespace Tensorflow.CodeGen

sb.AppendLine("}"); // try

sb.Append("catch(NotOkStatusException ex1)\n{\n");
sb.AppendLine("throw ex1;");
sb.AppendLine("}"); // catch

sb.Append("catch(InvalidArgumentError ex2)\n{\n");
sb.AppendLine("throw ex2;");
sb.AppendLine("}"); // catch

sb.Append("catch(Exception)\n{\n");
sb.AppendLine("}"); // catch

@@ -149,7 +159,8 @@ namespace Tensorflow.CodeGen
{
sb.AppendLine("return _op;");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.AppendLine("return _result[0];");
}
@@ -174,7 +185,7 @@ namespace Tensorflow.CodeGen
{
argName = $"{argName}_";
}
if (!string.IsNullOrEmpty(arg.NumberAttr))
if (!string.IsNullOrEmpty(arg.NumberAttr) || !string.IsNullOrEmpty(arg.TypeListAttr))
{
sb.Append($"Tensors {argName}, ");
}
@@ -273,7 +284,8 @@ namespace Tensorflow.CodeGen
{
sb.Append("Operation ");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.Append("Tensor ");
}
@@ -366,6 +378,13 @@ namespace Tensorflow.CodeGen
sb.Append($"\"{attr.Name}\", {attrRealName}, ");
}
}
else if(attr.Type == "list(type)")
{
if (op.InputArg.Any(x => x.TypeListAttr == attr.Name))
{
continue;
}
}
else if(attr.Type == "int" && op.InputArg.Any(x => x.NumberAttr == attr.Name))
{
bool found = false;
@@ -408,7 +427,8 @@ namespace Tensorflow.CodeGen
{
sb.AppendLine("return null;");
}
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr))
else if (outputArgsCount == 1 && string.IsNullOrEmpty(op.OutputArg[0].NumberAttr)
&& string.IsNullOrEmpty(op.OutputArg[0].TypeListAttr))
{
sb.AppendLine("return _result[0];");
}


+ 1
- 0
tools/Tensorflow.CodeGen/GenOpsWriter.cs View File

@@ -39,6 +39,7 @@ namespace Tensorflow.CodeGen
// Add commonly used namespaces.
sb.AppendLine("using Tensorflow.Eager;");
sb.AppendLine("using Tensorflow.Contexts;");
sb.AppendLine("using Tensorflow.Exceptions;");
sb.AppendLine("using static Tensorflow.Binding;");
sb.AppendLine();



+ 1
- 1
tools/Tensorflow.CodeGen/OpClassifier.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.CodeGen
{
public class OpClassifier
{
private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$";
private static readonly string _filenamePattern = @"^gen_[a-z_]*_ops.py$";
private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):";
private Dictionary<string, HashSet<string>> _opSet = new();
public Dictionary<string, HashSet<string>> OpSet => _opSet;


+ 1
- 1
tools/Tensorflow.CodeGen/Program.cs View File

@@ -5,7 +5,7 @@ using System.Text;
using System.Xml.Linq;
using Tensorflow.CodeGen;

GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops",
GenOpsWriter writer = new(@"D:\development\tf.net\gen_ops_v2",
@"D:\Apps\miniconda3\envs\tf2.11\Lib\site-packages\tensorflow\python\ops",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\api_def\base_api",
@"D:\development\tf.net\tensorflow-2.11.0\tensorflow\core\ops\ops.pbtxt");


+ 1
- 1
tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj View File

@@ -9,7 +9,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
</ItemGroup>

<ItemGroup>


+ 22
- 3
tools/Tensorflow.CodeGen/Utils.cs View File

@@ -155,6 +155,10 @@ namespace Tensorflow.CodeGen
}
else if (attr.Type == "list(type)")
{
if(op.InputArg.Any(x => x.TypeListAttr == attr.Name))
{
continue;
}
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.Type)
{
List<TF_DataType> values = new();
@@ -174,10 +178,25 @@ namespace Tensorflow.CodeGen
else if (attr.Type == "list(shape)")
{
res.Add((attr.Name, "Shape[]", "NOVALUE"));
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<string> exps = new();
foreach (var value in attr.DefaultValue.List.Shape)
{
exps.Add($"new Shape({string.Join(", ", value.Dim.Select(x => x.Size))})");
}
string expression = "new Shape[]{" + $"{string.Join(", ", exps)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "string[]", $"null"));
}
else
{
res.Add((attr.Name, "string[]", "NOVALUE"));
}
}
else if (attr.Type == "list(string)")
{
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<string> values = new();
foreach (var value in attr.DefaultValue.List.S)
@@ -231,11 +250,11 @@ namespace Tensorflow.CodeGen
}
else if (attr.Type == "func")
{
res.Add((attr.Name, "Func<Tensors, Tensors>", "NOVALUE"));
res.Add((attr.Name, "object", "NOVALUE"));
}
else if (attr.Type == "list(func)")
{
res.Add((attr.Name, "Func<Tensors, Tensors>[]", "NOVALUE"));
res.Add((attr.Name, "object[]", "NOVALUE"));
}
else if (attr.Type == "tensor")
{


Loading…
Cancel
Save