Browse Source

Merge pull request #1097 from AsakusaRinne/rnn-dev

feat: add rnn basic modules
tags/v0.110.0-LSTM-Model
Rinne GitHub 2 years ago
parent
commit
9da157f5b8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
98 changed files with 2809 additions and 293 deletions
  1. +0
    -0
      src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs
  3. +33
    -0
      src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
  4. +33
    -0
      src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
  5. +0
    -0
      src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs
  6. +130
    -0
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  7. +27
    -0
      src/TensorFlowNET.Core/Common/Types/INest.cs
  8. +11
    -0
      src/TensorFlowNET.Core/Common/Types/INestable.cs
  9. +21
    -0
      src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs
  10. +0
    -0
      src/TensorFlowNET.Core/Common/Types/NamedTuple.cs
  11. +62
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
  12. +458
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  13. +99
    -0
      src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
  14. +43
    -0
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  15. +32
    -0
      src/TensorFlowNET.Core/Common/Types/NestNode.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs
  17. +6
    -5
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  18. +14
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs
  19. +29
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs
  20. +3
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  21. +19
    -0
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  22. +12
    -0
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs
  23. +1
    -0
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs
  24. +1
    -0
      src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs
  25. +0
    -5
      src/TensorFlowNET.Core/NumPy/Axis.cs
  26. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  27. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  28. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  29. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs
  30. +14
    -2
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  31. +98
    -19
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  32. +1
    -1
      src/TensorFlowNET.Core/Operations/logging_ops.cs
  33. +1
    -1
      src/TensorFlowNET.Core/Operations/sort_ops.cs
  34. +5
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  35. +152
    -89
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  36. +1
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs
  37. +533
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  38. +3
    -2
      src/TensorFlowNET.Keras/Engine/Functional.cs
  39. +4
    -3
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  40. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  41. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.cs
  42. +2
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  43. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  44. +2
    -2
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  45. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs
  46. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs
  47. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  48. +3
    -2
      src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs
  49. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs
  50. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs
  51. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Swish.cs
  52. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs
  53. +2
    -1
      src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs
  54. +3
    -2
      src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs
  55. +2
    -1
      src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
  56. +2
    -1
      src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
  57. +2
    -1
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  58. +2
    -1
      src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
  59. +2
    -1
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  60. +2
    -1
      src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
  61. +2
    -1
      src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
  62. +2
    -1
      src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
  63. +2
    -1
      src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs
  64. +2
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs
  65. +2
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
  66. +2
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs
  67. +2
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs
  68. +2
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs
  69. +2
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs
  70. +2
    -2
      src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs
  71. +2
    -1
      src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs
  72. +2
    -1
      src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs
  73. +3
    -2
      src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs
  74. +3
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs
  75. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs
  76. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs
  77. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
  78. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
  79. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
  80. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs
  81. +2
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs
  82. +85
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  83. +3
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
  84. +499
    -72
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  85. +13
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs
  86. +24
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs
  87. +20
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  88. +83
    -16
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  89. +12
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  90. +2
    -1
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs
  91. +1
    -1
      src/TensorFlowNET.Keras/Metrics/metrics_utils.cs
  92. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  93. +1
    -1
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  94. +93
    -0
      src/TensorFlowNET.Keras/Utils/RnnUtils.cs
  95. +2
    -1
      src/TensorflowNET.Hub/KerasLayer.cs
  96. +0
    -11
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  97. +28
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
  98. +1
    -1
      tools/TensorFlowNET.Console/SimpleRnnTest.cs

src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs → src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs View File


src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs → src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs View File

@@ -3,16 +3,16 @@ using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Extensions
namespace Tensorflow.Common.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
{
var res = obj[key];
if(res is null)
if (res is null)
{
return default(T);
return default;
}
else
{

+ 33
- 0
src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class LinqExtensions
{
#if NETSTANDARD2_0
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Skip(sequence.Count() - count);
}

public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
{
return new Tensors(tensors);
}

public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
{
first = values.Item1;
second = values.Item2;
third = values.Item3;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Common.Extensions
{
public static class NestExtensions
{
public static Tensors ToTensors(this INestable<Tensor> tensors)
{
return new Tensors(tensors.AsNest());
}

public static Tensors? ToTensors(this Nest<Tensor> tensors)
{
return Tensors.FromNest(tensors);
}

/// <summary>
/// If the nested object is already a nested type, this function could reduce it.
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
{
return Nest<TOut>.ReduceFrom(input);
}
}
}

src/TensorFlowNET.Core/Extensions/OneofExtension.cs → src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs View File


+ 130
- 0
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -0,0 +1,130 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
{
public TensorShapeConfig[] Shapes { get; set; }
/// <summary>
/// create a single-dim generalized Tensor shape.
/// </summary>
/// <param name="dim"></param>
public GeneralizedTensorShape(int dim)
{
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
}

public GeneralizedTensorShape(Shape shape)
{
Shapes = new TensorShapeConfig[] { shape };
}

public GeneralizedTensorShape(TensorShapeConfig shape)
{
Shapes = new TensorShapeConfig[] { shape };
}

public GeneralizedTensorShape(TensorShapeConfig[] shapes)
{
Shapes = shapes;
}

public GeneralizedTensorShape(IEnumerable<Shape> shape)
{
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
}

public Shape ToSingleShape()
{
if (Shapes.Length != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var shape_config = Shapes[0];
Debug.Assert(shape_config is not null);
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
}

public long ToNumber()
{
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
var res = Shapes[0].Items[0];
return res is null ? -1 : res.Value;
}

public Shape[] ToShapeArray()
{
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
}

public IEnumerable<long?> Flatten()
{
List<long?> result = new List<long?>();
foreach(var shapeConfig in Shapes)
{
result.AddRange(shapeConfig.Items);
}
return result;
}
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
{
List<Nest<TOut>> lists = new();
foreach(var shapeConfig in Shapes)
{
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
}
return new Nest<TOut>(lists);
}

public Nest<long?> AsNest()
{
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
{
if (config.Items.Length == 0)
{
return Nest<long?>.Empty;
}
else if (config.Items.Length == 1)
{
return new Nest<long?>(config.Items[0]);
}
else
{
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
}
}

if(Shapes.Length == 0)
{
return Nest<long?>.Empty;
}
else if(Shapes.Length == 1)
{
return DealWithSingleShape(Shapes[0]);
}
else
{
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
}
}

public IEnumerator<long?[]> GetEnumerator()
{
foreach (var shape in Shapes)
{
yield return shape.Items;
}
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

+ 27
- 0
src/TensorFlowNET.Core/Common/Types/INest.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface indicates that a class may have a nested structure and provide
/// methods to manipulate with the structure.
/// </summary>
public interface INestStructure<T>: INestable<T>
{
/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
/// </summary>
/// <returns></returns>
IEnumerable<T> Flatten();
/// <summary>
/// Construct a new object with the same nested structure.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <returns></returns>
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
}
}

+ 11
- 0
src/TensorFlowNET.Core/Common/Types/INestable.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public interface INestable<T>
{
Nest<T> AsNest();
}
}

+ 21
- 0
src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface is used when some corresponding python methods have optional args.
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
/// as the parameter of the method.
/// </summary>
public interface IOptionalArgs
{
/// <summary>
/// The identifier of the class. It is not an argument but only something to
/// separate different OptionalArgs.
/// </summary>
string Identifier { get; }
}
}

src/TensorFlowNET.Core/Extensions/NamedTuple.cs → src/TensorFlowNET.Core/Common/Types/NamedTuple.cs View File


+ 62
- 0
src/TensorFlowNET.Core/Common/Types/Nest.Static.cs View File

@@ -0,0 +1,62 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public static class Nest
{
/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems)
{
return template.AsNest().PackSequence(flatItems.ToArray());
}

/// <summary>
/// Flatten the nested object.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject)
{
return nestedObject.AsNest().Flatten();
}

/// <summary>
/// Map the structure with specified function.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject)
{
return nestedObject.AsNest().MapStructure(func);
}

public static bool IsNested<T>(INestable<T> obj)
{
return obj.AsNest().IsNested();
}
}
}

+ 458
- 0
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -0,0 +1,458 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Common.Types
{
public enum NestType
{
Empty,
Node,
List,
Dictionary
}

/// <summary>
/// A nested structure which may inclulde value, list and dictionary.
/// Note that dictionary does not ensure the data order. When using it as IEnumerable,
/// its order is depth-first.
/// </summary>
/// <typeparam name="T"></typeparam>
public class Nest<T> : INestStructure<T>, IEnumerable<T>
{
private static readonly Nest<T> _empty = new Nest<T>()
{
NestType = NestType.Empty,
};
public static Nest<T> Empty => _empty;
public NestType NestType { get; protected set; }
public string? Name { get; set; }
public T? Value { get; protected set; }
public List<Nest<T>>? ListValue { get; protected set; }
public Dictionary<string, Nest<T>>? DictValue { get; protected set; }

protected Nest() { }

public Nest(T value, string? name = null)
{
Value = value;
Name = name;
NestType = NestType.Node;
}

public Nest(IEnumerable<Nest<T>> values, string? name = null)
{
ListValue = values.ToList();
Name = name;
NestType = NestType.List;
}

public Nest(Dictionary<string, Nest<T>> value, string? name = null)
{
DictValue = value;
Name = name;
NestType = NestType.Dictionary;
}

public Nest(Nest<T> other)
{
NestType = other.NestType;
Value = other.Value;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public virtual IEnumerable<T> Flatten()
{
return FlattenInternal(this);
}
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return MapStructureInternal(func);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <param name="flatItems"></param>
/// <returns></returns>
public virtual Nest<T> PackSequence(T[] flatItems)
{
if(flatItems.Length == 0)
{
return Nest<T>.Empty;
}
int index = 0;
return PackSequenceInternal(this, flatItems, ref index);
}

private static Nest<T> PackSequenceInternal(Nest<T> template, T[] flatItems, ref int index)
{
if(template.NestType == NestType.Node)
{
if(index >= flatItems.Length)
{
throw new InvalidArgumentError("The template and flat items are not matched.");
}
return new Nest<T>(flatItems[index++]);
}
else if(template.NestType == NestType.List)
{
List<Nest<T>> nestedObjects = new List<Nest<T>>();
for (int i = 0; i < template.ListValue!.Count; i++)
{
nestedObjects.Add(PackSequenceInternal(template.ListValue![i], flatItems, ref index));
}
return new Nest<T>(nestedObjects);
}
else if(template.NestType == NestType.Node)
{
Dictionary<string, Nest<T>> dict = new Dictionary<string, Nest<T>>();
foreach(var (key, value) in template.DictValue!)
{
dict[key] = PackSequenceInternal(value, flatItems, ref index);
}
return new Nest<T>(dict);
}
// Consider Empty as invalid type.
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
}

public virtual Nest<T> AsNest()
{
return this;
}

public virtual Nest<T> MergeWith(Nest<T>? other)
{
if(other is null || other == Nest<T>.Empty)
{
return this;
}
if(this == Nest<T>.Empty)
{
return other;
}
if(NestType == NestType.Node && other.NestType == NestType.Node)
{
return new Nest<T>(new Nest<T>[] { this, other });
}
else if(NestType == NestType.List && other.NestType == NestType.List)
{
return new Nest<T>(this.ListValue!.Concat(other.ListValue!));
}
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary)
{
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value));
}
else
{
return new Nest<T>(new Nest<T>[] { this, other });
}
}

/// <summary>
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested.
/// </summary>
/// <returns></returns>
public bool IsNested()
{
if(NestType is NestType.Empty or NestType.Node)
{
return false;
}
else if(NestType is NestType.List)
{
foreach(var item in ListValue!)
{
if(item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
}
else
{
foreach (var item in DictValue!.Values)
{
if (item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
}
}

[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")]
public T this[int index]
{
get
{
bool success = FindInternal(this, index, out var result);
if (success)
{
return result;
}
else
{
throw new IndexOutOfRangeException();
}
}
set
{
bool success = SetInternal(this, index, value);
if (!success)
{
throw new IndexOutOfRangeException();
}
}
}

/// <summary>
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it
/// to `Nest[T]`.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T>
{
var nested = input.AsNest();
return ReduceInternal(nested);
}

private static Nest<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
{
if(node.NestType == NestType.Empty)
{
return Nest<T>.Empty;
}
else if(node.NestType == NestType.Node)
{
return node.Value!.AsNest();
}
else if(node.NestType == NestType.List)
{
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x)));
}
else // Dictionary type
{
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value)));
}
}

private static bool FindInternal(Nest<T> node, int index, out T? result)
{
if (node.NestType == NestType.Node)
{
if(index == 0)
{
result = node.Value!;
return true;
}
result = default(T);
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if(index == 0)
{
return FindInternal(item, index, out result);
}
index--;
}
result = default(T);
return false;
}
else if(node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return FindInternal(item, index, out result);
}
index--;
}
result = default(T);
return false;
}
else
{
result = default(T);
return false;
}
}

private static bool SetInternal(Nest<T> node, int index, T newValue)
{
if (node.NestType == NestType.Node)
{
if (index == 0)
{
node.Value = newValue;
return true;
}
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if (index == 0)
{
return SetInternal(item, index, newValue);
}
index--;
}
return false;
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return SetInternal(item, index, newValue);
}
index--;
}
return false;
}
else
{
return false;
}
}

private static IEnumerable<T> FlattenInternal(Nest<T> node)
{
if (node.NestType == NestType.Node)
{
yield return node.Value!;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
foreach(var val in FlattenInternal(item))
{
yield return val;
}
}
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
foreach (var val in FlattenInternal(item))
{
yield return val;
}
}
}
}

private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func)
{
if (NestType == NestType.Node)
{
return new Nest<TOut>(func(Value!));
}
else if (NestType == NestType.List)
{
List<Nest<TOut>> outs = new List<Nest<TOut>>();
foreach (var item in ListValue!)
{
outs.Add(item.MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else if (NestType == NestType.Dictionary)
{
Dictionary<string, Nest<TOut>> outs = new Dictionary<string, Nest<TOut>>();
foreach (var (key, value) in DictValue!)
{
outs.Add(key, value.MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else
{
return Nest<TOut>.Empty;
}
}

public IEnumerator<T> GetEnumerator()
{
return Flatten().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public override string ToString()
{
StringBuilder sb = new StringBuilder();
sb.Append("(");
WriteString(this, sb);
sb.Append(")");
return sb.ToString();
}

private static void WriteString(Nest<T> node, StringBuilder sb)
{
if (!string.IsNullOrEmpty(node.Name))
{
sb.Append($"{node.Name}: ");
}
if (node.NestType == NestType.Node)
{
sb.Append(node.Value!.ToString());
}
else if (node.NestType == NestType.List)
{
sb.Append("[");
for(int i = 0; i < node.ListValue!.Count; i++)
{
WriteString(node.ListValue![i], sb);
if(i != node.ListValue!.Count - 1)
{
sb.Append(", ");
}
}
sb.Append("]");
}
else if (node.NestType == NestType.Dictionary)
{
sb.Append("{");
int count = node.DictValue!.Count;
int i = 0;
foreach (var (key, value) in node.DictValue!)
{
sb.Append($"{key}: ");
WriteString(value, sb);
if (i != count - 1)
{
sb.Append(", ");
}
i++;
}
sb.Append("}");
}
else
{
sb.Append("<empty>");
}
}
}
}

+ 99
- 0
src/TensorFlowNET.Core/Common/Types/NestDictionary.cs View File

@@ -0,0 +1,99 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull
{
public IDictionary<TKey, TValue> Value { get; set; }
public NestDictionary(IDictionary<TKey, TValue> dict)
{
Value = dict;
}
public IEnumerable<TValue> Flatten()
{
return Value.Select(x => x.Value);
}
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x.Value)));
}

public Nest<TValue> AsNest()
{
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x)));
}

// Required IDictionary<TKey, TValue> members
public int Count => Value.Count;

public bool IsReadOnly => Value.IsReadOnly;

public ICollection<TKey> Keys => Value.Keys;

public ICollection<TValue> Values => Value.Values;

public void Add(TKey key, TValue value)
{
Value.Add(key, value);
}

public void Add(KeyValuePair<TKey, TValue> item)
{
Value.Add(item);
}

public void Clear()
{
Value.Clear();
}

public bool Contains(KeyValuePair<TKey, TValue> item)
{
return Value.Contains(item);
}

public bool ContainsKey(TKey key)
{
return Value.ContainsKey(key);
}

public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
Value.CopyTo(array, arrayIndex);
}

public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
return Value.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public bool Remove(TKey key)
{
return Value.Remove(key);
}

public bool Remove(KeyValuePair<TKey, TValue> item)
{
return Value.Remove(item);
}

public bool TryGetValue(TKey key, out TValue value)
{
return Value.TryGetValue(key, out value);
}

// Optional IDictionary<TKey, TValue> members
public TValue this[TKey key]
{
get => Value[key];
set => Value[key] = value;
}
}
}

+ 43
- 0
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -0,0 +1,43 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// The implementation of a list that support nest structure, in which the depth is 1.
/// </summary>
/// <typeparam name="T"></typeparam>
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
{
public List<T> Value { get; set; }
public NestList(IEnumerable<T> values)
{
Value = new List<T>(values);
}
public IEnumerable<T> Flatten()
{
return Value;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x)));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value.Select(x => new Nest<T>(x)));
}

// Enumerator implementation
public IEnumerator<T> GetEnumerator()
{
return Value.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

+ 32
- 0
src/TensorFlowNET.Core/Common/Types/NestNode.cs View File

@@ -0,0 +1,32 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// A nested structure with only one element.
/// </summary>
/// <typeparam name="T"></typeparam>
public class NestNode<T> : INestStructure<T>
{
public T Value { get; set; }
public NestNode(T value)
{
Value = value;
}
public IEnumerable<T> Flatten()
{
yield return Value;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestNode<TOut>(func(Value));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value);
}
}
}

src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs → src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs View File

@@ -3,7 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
namespace Tensorflow.Common.Types
{
public class TensorShapeConfig
{

+ 6
- 5
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -1,17 +1,15 @@
using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
// TODO(Rinne): add regularizers.
public class RNNArgs : AutoSerializeLayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnArgCell Cell { get; set; } = null;
public IRnnCell Cell { get; set; } = null;
[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
@@ -34,6 +32,9 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
public float Dropout { get; set; } = .0f;
public bool ZeroOutputForMask { get; set; } = false;
public float RecurrentDropout { get; set; } = .0f;

// kernel_regularizer=None,
// recurrent_regularizer=None,


+ 14
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class RnnOptionalArgs: IOptionalArgs
{
public string Identifier => "Rnn";
public Tensor Mask { get; set; } = null;
public Tensors Constants { get; set; } = null;
}
}

+ 29
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs View File

@@ -0,0 +1,29 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }
}
}

+ 3
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Training;
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false);
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null);
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }


+ 19
- 0
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Rnn
{
public interface IRnnCell: ILayer
{
GeneralizedTensorShape StateSize { get; }
GeneralizedTensorShape OutputSize { get; }
bool IsTFRnnCell { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
/// </summary>
bool SupportOptionalArgs { get; }
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Layers.Rnn
{
public interface IStackedRnnCells : IRnnCell
{
int Count { get; }
IRnnCell this[int idx] { get; }
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs View File

@@ -3,6 +3,7 @@ using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Saving.Json
{


+ 1
- 0
src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs View File

@@ -6,6 +6,7 @@ using System.Text;
using System.Diagnostics;
using OneOf.Types;
using Tensorflow.Keras.Saving.Json;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Saving
{


+ 0
- 5
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -74,8 +74,3 @@ namespace Tensorflow
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
}
}

namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -53,7 +53,7 @@ public class Orthogonal : IInitializer
// Compute the qr factorization
var (q, r) = tf.linalg.qr(a, full_matrices: false);
// Make Q uniform
var d = tf.linalg.tensor_diag_part(r);
var d = tf.linalg.tensor_diag_part(r.Single);
q *= tf.sign(d);

if (num_rows < num_cols)


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -11,6 +11,7 @@ namespace Tensorflow
/// Basic LSTM recurrent network cell.
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
/// </summary>
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class BasicLstmCell : LayerRnnCell
{
int _num_units;


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -20,6 +20,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class BasicRnnCell : LayerRnnCell
{
int _num_units;


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs View File

@@ -19,6 +19,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class LayerRnnCell : RnnCell
{
protected InputSpec inputSpec;


+ 14
- 2
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -16,10 +16,12 @@

using System;
using System.Collections.Generic;
using Tensorflow.Common.Types;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;
@@ -50,7 +52,8 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`.
/// </summary>
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public abstract class RnnCell : ILayer, IRnnCell
{
/// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight
@@ -142,7 +145,7 @@ namespace Tensorflow
throw new NotImplementedException("_zero_state_tensors");
}

public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null)
{
throw new NotImplementedException();
}
@@ -173,5 +176,14 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
{
throw new NotImplementedException();
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 98
- 19
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -48,6 +49,7 @@ namespace Tensorflow.Operations
public override Tensor flow => _flow;
bool _clear_after_read;
List<Tensor> _tensor_array;
List<int> _previous_read_indices;

public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -61,16 +63,20 @@ namespace Tensorflow.Operations
_dtype = dtype.as_base_dtype();
_dynamic_size = dynamic_size;
_clear_after_read = clear_after_read;
_tensor_array = new List<Tensor>();
_tensor_array = Enumerable.Repeat<Tensor>(null, size.numpy()).ToList();
_previous_read_indices = new();
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
var tensors = array_ops.unstack(value, name: name);
if(tensors.Length > _tensor_array.Count && !_dynamic_size)
{
var num_elements = array_ops.shape(value)[0];
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
});
throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}");
}
_tensor_array = tensors.ToList();
// TODO(Rinne): revise the implementation. Here we should return `parent()`.
return this;
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
@@ -116,9 +122,19 @@ namespace Tensorflow.Operations
_colocate_with.Add(value);
}

private Tensor _maybe_zero(int ix)
{
var val = _tensor_array[ix];
if(val is null)
{
val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype);
}
return val;
}

public override Tensor read<T>(T index, string name = null)
{
int index_int = -1;
int index_int;
if (index is int int_index)
index_int = int_index;
else if (index is Tensor tensor_index)
@@ -126,27 +142,75 @@ namespace Tensorflow.Operations
else
throw new ValueError("");

if(index_int >= _tensor_array.Count)
{
throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} ");
}

var res = _tensor_array[index_int];
if(res is null)
{
if (_previous_read_indices.Contains(index_int))
{
throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " +
$"a previous read (perhaps try setting clear_after_read = false?)");
}
else
{
res = _maybe_zero(index_int);
}
}

if (_clear_after_read)
{
_tensor_array[index_int] = null;
_previous_read_indices.Add(index_int);
}

return _tensor_array[index_int];
return res;
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
if (_infer_shape)
_element_shape = _element_shape.merge_with(value.shape);
_tensor_array.add(value);
return this;
int index_int;
if(index is EagerTensor eager)
{
return write<Tensor>(eager.numpy(), value, name);
}
throw new InvalidArgumentError("The index is supposed to be an EagerTensor");
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor, name: name);
int size = _tensor_array.Count;
if(index >= size)
{
if (!_dynamic_size)
{
throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " +
$"is: {size} ");
}
_tensor_array.AddRange(Enumerable.Repeat<Tensor>(null, index - size + 1));
}

Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if(_dtype != tensor.dtype)
{
throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " +
$"trying to write dtype {tensor.dtype.as_python_name()} ");
}

if (!_element_shape.is_compatible_with(tensor.shape))
{
throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})");
}

if (_infer_shape)
{
_element_shape = _element_shape.merge_with(tensor.shape);
}
_tensor_array[index] = tensor;
return this;
}

private Tensor size(string name = null)
@@ -156,11 +220,26 @@ namespace Tensorflow.Operations

public override Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
if(_tensor_array.Count > 0)
{
return gather(math_ops.range(0, size()), name: name);
});
for(int i = 0; i < _tensor_array.Count; i++)
{
_maybe_zero(i);
}
}
if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined)
{
return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype);
}
else
{
return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype);
}
//ops.colocate_with(_handle);
//return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
//{
// return gather(math_ops.range(0, size()), name: name);
//});
}

public override Tensor gather(Tensor indices, string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Operations/logging_ops.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
name: name);

return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string)
.SetAttributes(new { output_stream, end }));
.SetAttributes(new { output_stream, end })).SingleOrNull;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/sort_ops.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow
{
sorted = true
}));
return indices;
return indices.Single;
}

public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null)


+ 5
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -114,4 +114,9 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" />
<PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" />
</ItemGroup>
</Project>

+ 152
- 89
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;

namespace Tensorflow
{
@@ -13,157 +14,231 @@ namespace Tensorflow
/// and Tensor[] from Tensors implicitily.
/// It works for tuple and scalar as well.
/// </summary>
public class Tensors : IEnumerable<Tensor>, IDisposable
public sealed class Tensors : Nest<Tensor>, IDisposable
{
List<Tensor> items = new List<Tensor>();

public TF_DataType dtype => items.First().dtype;
public Shape shape => items.First().shape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public TF_DataType dtype => this.First().dtype;
public Shape shape => this.First().shape;
public int rank => this.First().rank;
public Graph graph => this.First().graph;
public bool IsList { get; set; }
public int Length => items.Count();
public int Length => this.Count();
/// <summary>
/// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception.
/// </summary>
public Tensor Single
{
get
{
if (Length != 1)
{
throw new ValueError("Tensors with more than one tensor cannot be " +
"implicitly converted to Tensor.");
}
return this.First();
}
}

public Tensor this[int index]
/// <summary>
/// Return a Tensor if `Tensors` has only one tensor, and return null when `Tensors` is empty,
/// otherwise throw an exception.
/// </summary>
public Tensor? SingleOrNull
{
get => items[index];
set => items[index] = value;
get
{
if (Length > 1)
{
throw new ValueError($"Tensors with {Length} tensor cannot be " +
"implicitly converted to Tensor.");
}
return this.FirstOrDefault();
}
}

public Tensor this[params string[] slices]
=> items.First()[slices];
public Tensors(params Tensor[] tensors)
=> this.First()[slices];

public Tensors(Tensor tensor) : base(tensor)
{

}

private Tensors(Nest<Tensor> nested) : base(nested)
{

}

public Tensors(params Tensor[] tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
{
}

public Tensors(IEnumerable<Tensor> tensors): base(tensors.Select(x => new Nest<Tensor>(x)))
{
items.AddRange(tensors);
}

public Tensors(IEnumerable<Tensor> tensors)
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
{
items.AddRange(tensors);
}

public Tensors(NDArray nd)
public bool IsSingle()
{
items.Add(ops.convert_to_tensor(nd));
return Length == 1;
}

public IEnumerator<Tensor> GetEnumerator()
public new Tensors MergeWith(Nest<Tensor>? other)
{
foreach (var tensor in items)
yield return tensor;
return FromNest(base.MergeWith(other));
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " +
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")]
public void Add(Tensor tensor)
=> items.Add(tensor);
{
if(NestType == NestType.Dictionary)
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value), new Nest<Tensor>(tensor) };
Value = null;
}
else
{
ListValue.Add(new Nest<Tensor>(tensor));
}
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " +
"some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")]
public void AddRange(IEnumerable<Tensor> tensors)
=> items.AddRange(tensors);
{
if (NestType == NestType.Dictionary)
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
else if (NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
Value = null;
}
else
{
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
}
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " +
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")]
public void Insert(int index, Tensor tensor)
=> items.Insert(index, tensor);

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
{
if (NestType == NestType.List)
{
ListValue.Insert(index, new Nest<Tensor>(tensor));
}
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(Value) };
ListValue.Insert(index, new Nest<Tensor>(tensor));
Value = null;
}
else
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
}

public string[] StringData()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData();
return Single.StringData();
}

public string StringData(int index)
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData(index);
return Single.StringData(index);
}

public NDArray numpy()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].numpy();
return Single.numpy();
}

[Obsolete]
public T[] ToArray<T>() where T: unmanaged
{
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
return this[0].ToArray<T>();
return Single.ToArray<T>();
}

#region Explicit Conversions
public unsafe static explicit operator bool(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to bool");
return (bool)tensor[0];
return (bool)tensor.Single;
}

public unsafe static explicit operator sbyte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
return (sbyte)tensor[0];
return (sbyte)tensor.Single;
}

public unsafe static explicit operator byte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
return (byte)tensor.Single;
}

public unsafe static explicit operator ushort(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ushort");
return (ushort)tensor[0];
return (ushort)tensor.Single;
}

public unsafe static explicit operator short(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to short");
return (short)tensor[0];
return (short)tensor.Single;
}

public unsafe static explicit operator int(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to int");
return (int)tensor[0];
return (int)tensor.Single;
}

public unsafe static explicit operator uint(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to uint");
return (uint)tensor[0];
return (uint)tensor.Single;
}

public unsafe static explicit operator long(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to long");
return (long)tensor[0];
return (long)tensor.Single;
}

public unsafe static explicit operator ulong(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ulong");
return (ulong)tensor[0];
return (ulong)tensor.Single;
}

public unsafe static explicit operator float(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
return (byte)tensor.Single;
}

public unsafe static explicit operator double(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to double");
return (double)tensor[0];
return (double)tensor.Single;
}

public unsafe static explicit operator string(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to string");
return (string)tensor[0];
return (string)tensor.Single;
}

public static explicit operator object[](Tensors tensors)
=> tensors.items.ToArray();
=> tensors.Flatten().ToArray();
#endregion

#region Implicit Conversions
@@ -183,56 +258,44 @@ namespace Tensorflow
public static implicit operator Tensors(List<Tensor> tensors)
=> new Tensors(tensors.ToArray());

public static implicit operator Tensor(Tensors tensors)
=> tensors.FirstOrDefault();
public static implicit operator Tensor(Tensors? tensors)
=> tensors?.SingleOrNull;

public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();

=> tensors.Flatten().ToArray();
#endregion

public void Deconstruct(out Tensor a, out Tensor b)
public static Tensors? FromNest(Nest<Tensor> nested)
{
a = items[0];
b = items[1];
if(nested == Nest<Tensor>.Empty)
{
return null;
}
return new Tensors(nested);
}

private static void EnsureSingleTensor(Tensors tensors, string methodnName)
public void Deconstruct(out Tensor a, out Tensors? b)
{
if(tensors.Length == 0)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
}
else if(tensors.Length > 1)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
}
a = this.First();
b = Length == 1? null : new Tensors(this.Skip(1));
}

public override string ToString()
{
if(items.Count == 1)
if(Length == 1)
{
return items[0].ToString();
return this.First().ToString();
}
else
{
StringBuilder sb = new StringBuilder();
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n");
for(int i = 0; i < items.Count; i++)
{
var tensor = items[i];
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n");
}
sb.Append("]\n");
return sb.ToString();
return $"Totally {Length} tensors: {base.ToString()}";
}
}

public void Dispose()
{
foreach (var item in items)
item.Dispose();
foreach (var tensor in this)
tensor.Dispose();
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -36,6 +36,7 @@ namespace Tensorflow.Util
// (np.array([3, 4]), tf.constant([3, 4])))`
//

[Obsolete]
public static class nest
{



+ 533
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -20,8 +20,11 @@ using System.Linq;
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras
{
@@ -450,5 +453,535 @@ namespace Tensorflow.Keras

return x;
}

public (Tensors, Tensors, Tensors) rnn(
Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states
Tensors inputs, // inputs is a tuple of tensors (one per input sequence)
Tensors initial_states,
bool go_backwards = false,
Tensor? mask = null,
Tensors? constants = null,
bool unroll = false,
Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not
bool time_major = false,
bool zero_output_for_mask = false,
bool return_all_outputs = true)
{

Tensor swap_batch_timestep(Tensor input_t)
{
var axes = Enumerable.Range(0, input_t.rank).ToArray();
axes[0] = 1;
axes[1] = 0;
return tf.transpose(input_t, axes);
}

if (!time_major)
{
inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors();
}

var flatted_inptus = Nest.Flatten(inputs).ToList();
var first_flatted_input = flatted_inptus[0];
var time_steps = first_flatted_input.shape[0];
var batch = first_flatted_input.shape[1];
var time_steps_t = (int)first_flatted_input.shape[0];

foreach (var input_ in flatted_inptus)
{
input_.shape.with_rank_at_least(3);
}

if (mask != null)
{
if (mask.dtype != TF_DataType.TF_BOOL)
{
mask = tf.cast(mask, TF_DataType.TF_BOOL);
}

if (mask.rank == 2)
{
mask = tf.expand_dims(mask, -1);
}

if (!time_major)
{
mask = swap_batch_timestep(mask);
}

}

// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
// So we need to broadcast the mask to match the shape of inputs.
// That's what the tile call does, it just repeats the mask along its
// second dimension n times.

Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
if (!mask_t.IsSingle())
{
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}");
}

if (!input_t.IsSingle())
{
throw new ValueError($"input_t is expected to be tensor, but got {input_t}");
}

var rank_diff = input_t.rank - mask_t.rank;
for (int i = 0; i < rank_diff; i++)
{
mask_t = tf.expand_dims(mask_t, -1);
}
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank));
return tf.tile(mask_t, multiples);
}

Tensors outputs = new Tensors();
Tensors output_time_zero = new Tensors();
Tensors last_output = new Tensors();
Tensors new_states = new Tensors();
if (unroll)
{
if (time_steps == 0)
{
throw new ValueError("Unrolling requires a fixed number of timesteps.");
}

// Process the input tensors. The input tensor need to be split on the
// time_step dim, and reverse if go_backwards is True. In the case of
// nested input, the input is flattened and then transformed
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)


// TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple
//var states = Tuple.Create(initial_states);
var states = initial_states;

var successive_states = new Tensors();
var successive_outputs = new Tensors();

// Process the input tensors. The input tensor need to be split on the
// time_step dim, and reverse if go_backwards is True. In the case of
// nested input, the input is flattened and then transformed
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)




Tensors _process_single_input_t(Tensor input_t)
{
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
if (go_backwards)
{
unstaked_input_t = unstaked_input_t.Reverse().ToArray();
}
return unstaked_input_t;
}

// TODO(Wanglongzhi2001)
Tensors processed_input;
if (!inputs.IsSingle())
{
processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors();
}
else
{
processed_input = _process_single_input_t(inputs);
}

object _get_input_tensor(int time)
{
List<Tensor> inp = new List<Tensor>();
foreach (var t_ in processed_input)
{
inp.Add(t_[time]);
}
return Nest.PackSequenceAs(inputs, inp);
}

if (mask != null)
{
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse();
}

for (int i = 0; i < time_steps; i++)
{
// TODO(Wanglongzhi2001),deal with _get_input_tensor
var inp = _get_input_tensor(i);
var mask_t = mask_list[i];
// TODO
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));

var tiled_mask_t = _expand_mask(mask_t, output);

Tensors prev_output;
if (successive_outputs == null)
{
prev_output = tf.zeros_like(output);
}
else
{
prev_output = successive_outputs[successive_outputs.Length - 1];
}

output = tf.where(tiled_mask_t, output, prev_output);

var flat_states = Nest.Flatten(states).ToList();
var flat_new_states = Nest.Flatten(newStates).ToList();

var tiledMaskT = flat_states
.Select(s => _expand_mask(mask_t, s))
.ToArray();
var tuple = Tuple.Create(tiledMaskT);

List<Tensor> flat_final_states = new List<Tensor>();
foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states))
{
flat_final_states.Add(tf.where(m, s, ps));
}

states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(states);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { states };
}

}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
{
for (int i = 0; i < time_steps; i++)
{
var inp = _get_input_tensor(i);
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
states = newStates;

if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(newStates);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs[successive_outputs.Length - 1];
new_states = successive_states[successive_states.Length - 1];
outputs = tf.stack(successive_outputs);
}
}
}
else // unroll == false
{
var states = initial_states;
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.
var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_steps_t));
}

foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
{
if (!go_backwards)
{
ta.unstack(input_);
}
else
{
ta.unstack(reverse(input_, 0));
}
}

// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var inps = new Tensors();
foreach (var inp in flatted_inptus)
{
inps.Add(inp[0]);
}
var input_time_zero = Nest.PackSequenceAs(inputs, inps).ToTensors();

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function((Tensor)input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

int output_ta_size = return_all_outputs ? time_steps_t : 1;
var output_ta = new List<TensorArray>();
for (int i = 0; i < output_time_zero.ToList().Count; i++)
{
var Out = output_time_zero.ToList()[i];
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");



Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
{
if (go_backwards)
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
{
return mask_ta.read(time);
};

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var tiled_mask_t = new Tensors();
foreach (var o in flat_out)
{
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
}

Tensors res = new Tensors();
foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList()))
{
res.Add(tf.where(m, o, fm));
}
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
else if (input_length is Tensor)
{
if (go_backwards)
{
var max_len = tf.reduce_max(input_length, axis: 0);
var rev_input_length = tf.subtract(max_len - 1, input_length);

masking_fn = (time) =>
{
return tf.less(rev_input_length, time);
};
}
else
{
masking_fn = (time) =>
{
return tf.greater(input_length, time);
};
}

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var res = new List<Tensor>();
foreach (var (o, zo) in zip(flat_out, flat_mask))
{
res.Add(tf.where(mask_t, o, zo));
}
return res;
};
}
else
{
masking_fn = null;
}

Func<Tensor, Tensor> cond = (time) => (time < time_steps_t);
int parallel_iterations = 32;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
// case T = 0, a zero filled tensor will be used.
var flat_zero_output = new Tensors();
foreach (var o in Nest.Flatten(output_time_zero))
{
flat_zero_output.Add(tf.zeros_like(o));
}

var prev_output = flat_zero_output;
var output_ta_t = output_ta;
Tensor _step(Tensor time)
{
/*
RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states_internal) = step_function(current_input, states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.ToList();
var flat_new_state = new_states_internal.ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.shape = state.shape;
}
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states_internal = Nest.PackSequenceAs(new_states, flat_final_state).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
// TODO(Wanglongzhi2001),deal with zip output_ta_t
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output))
{
output_ta_t.Add(ta.write(ta_index_to_write, Out));
}

new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();

output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;

}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
else
{
var output_ta_t = output_ta;
new_states = states;
Tensor _step(Tensor time)
{
var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var (output, new_states_internal) = step_function(current_input, new_states.MergeWith(constants));
var flat_state = new_states.Flatten().ToList();
var flat_new_state = new_states_internal.Flatten().ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.shape = state.shape;
}
}
var flat_output = Nest.Flatten(output);
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
output_ta_t = zip(output_ta_t, flat_output).Select(item =>
{
var (ta, out_) = item;
return ta.write(ta_index_to_write, out_);
}).ToList();

new_states_internal = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
output_ta = output_ta_t;
new_states = new_states_internal;
return time + 1;
}
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: time, parallel_iterations: parallel_iterations);
}
//Tensors outputs = new Tensors();
foreach (var o in output_ta)
{
outputs.Add(o.stack());
}
foreach (var o in outputs)
{
last_output.Add(o[-1]);
}
outputs = Nest.PackSequenceAs(output_time_zero, outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, last_output).ToTensors();

}

Func<Tensor, Tensor> set_shape;
set_shape = (output_) =>
{
if (output_ is Tensor)
{
var shape = output_.shape.as_int_list();
if (return_all_outputs)
{
shape[0] = (int)time_steps;
}
else
{
shape[0] = 1;
}
shape[1] = (int)batch;
output_.shape = shape;
}
return output_;
};

outputs = Nest.MapStructure(set_shape, outputs).ToTensors();
if (!time_major)
{
outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors();
}
return (last_output, outputs, new_states);

}

public Tensor reverse(Tensor input, int axis)
{
return reverse(input, new int[] { axis });
}

public Tensor reverse(Tensor input, int[] axes)
{
return tf.reverse(input, axes);
}

public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false)
{
if (!is_ragged_output)
{
return output;
}

throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}

+ 3
- 2
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
@@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Engine
}
else
{
_buildInputShape = new Saving.TensorShapeConfig();
_buildInputShape = new TensorShapeConfig();
}

if (outputs.Any(x => x.KerasHistory == null))
@@ -325,7 +326,7 @@ namespace Tensorflow.Keras.Engine
nodes_in_decreasing_depth.append(node);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
// map input values


+ 4
- 3
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -1,4 +1,5 @@
using System.Threading;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
@@ -8,11 +9,11 @@ namespace Tensorflow.Keras.Engine
/// <summary>
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="inputs"></param>
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null)
{
if (callContext.Value == null)
callContext.Value = new CallContext();
@@ -30,7 +31,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

var outputs = Call(inputs, state: state, training: training);
var outputs = Call(inputs, state: states, training: training);

// memory leak
// _set_connectivity_metadata_(inputs, outputs);


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -32,7 +32,7 @@ using Tensorflow.Util;
using static Tensorflow.Binding;
using Tensorflow.Framework;
using Tensorflow.Sessions;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Engine
{
@@ -332,7 +332,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected virtual Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if(ReplacedCall is not null)
{


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -1,8 +1,8 @@
using System.Diagnostics;
using Tensorflow.Common.Types;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;


+ 2
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Engine
{
@@ -143,7 +144,7 @@ namespace Tensorflow.Keras.Engine
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (!_has_explicit_input_shape)
{


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -29,7 +30,7 @@ namespace Tensorflow.Keras.Layers {
base.build(input_shape);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
output = tf.where(output > 0f, output,


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -4,7 +4,7 @@ using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers {
public class Exponential : Layer
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers {
{
base.build(input_shape);
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
return tf.exp(output);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers {
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
public HardSigmoid ( LayerArgs args ) : base(args) {
// hard sigmoid has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null ) {
Tensor x = inputs;
return tf.clip_by_value(
tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return tf.nn.leaky_relu(inputs, alpha: alpha);
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Layers {
}
base.build(input_shape);
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) {
Tensor output = inputs;
return tf.where(output > 0f,
tf.multiply(scale, output),


+ 3
- 2
src/TensorFlowNET.Keras/Layers/Activation/Softmax.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
@@ -11,8 +12,8 @@ namespace Tensorflow.Keras.Layers {
public Softmax ( SoftmaxArgs args ) : base(args) {
axis = args.axis;
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9)
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) {
Tensor x = inputs.Length == 2 ? inputs[0] + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9)
: inputs;
Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true)));
Tensor s = tf.reduce_sum(e, axis: this.axis, keepdims: true);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Softplus.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
public Softplus ( LayerArgs args ) : base(args) {
// Softplus has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) {
Tensor x = inputs;
return tf.log(
tf.add(tf.exp(x), 1f));


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Softsign.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
public Softsign ( LayerArgs args ) : base(args) {
// Softsign has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) {
Tensor x = inputs;
// x / (abs(x) + 1)
return tf.div(x, tf.add(1f, tf.abs(x)));


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Swish.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
public Swish ( LayerArgs args ) : base(args) {
// Swish has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) {
Tensor x = inputs;

// x / (1 + exp(-x))


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/Tanh.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -13,7 +14,7 @@ namespace Tensorflow.Keras.Layers
{
// Tanh has no arguments
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor x = inputs;



+ 2
- 1
src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs View File

@@ -6,6 +6,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

/// <summary>
/// Base class for attention layers that can be used in sequence DNN/CNN models.
@@ -114,7 +115,7 @@ namespace Tensorflow.Keras.Layers
return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensors _inp;
Tensors _mask = null;


+ 3
- 2
src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs View File

@@ -6,6 +6,7 @@ using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using System;
using System.Linq;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -252,7 +253,7 @@ namespace Tensorflow.Keras.Layers
return (attention_output, attention_scores);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensors _inp;
Tensor _mask = null;
@@ -349,7 +350,7 @@ namespace Tensorflow.Keras.Layers
//}

if (return_attention_scores)
return (attention_output, attention_scores);
return (attention_output, attention_scores.Single);
return attention_output;
}
}

+ 2
- 1
src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs View File

@@ -20,6 +20,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -83,7 +84,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var inputs_shape = array_ops.shape(inputs);
var batch_size = inputs_shape[0];


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -103,7 +104,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = false, IOptionalArgs? optional_args = null)
{
var outputs = _convolution_op.Apply(inputs, kernel.AsTensor());
if (use_bias)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -18,6 +18,7 @@ using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -69,7 +70,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs View File

@@ -7,6 +7,7 @@ using System.Text.RegularExpressions;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -189,7 +190,7 @@ namespace Tensorflow.Keras.Layers
// return new dict(base_config.items().ToList() + config.items().ToList());
//}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor()));
if (this.bias != null)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -66,7 +67,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Merging/Merge.cs View File

@@ -5,6 +5,7 @@ using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Layers
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return _merge_function(inputs);
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers
return false;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor outputs = null;
var training_tensor = training == null


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -101,7 +102,7 @@ namespace Tensorflow.Keras.Layers
return input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor outputs = null;
var inputs_dtype = inputs.dtype.as_base_dtype();


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Normalization/Normalization.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;

@@ -157,7 +158,7 @@ namespace Tensorflow.Keras.Layers
base.adapt(data, batch_size: batch_size, steps: steps);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (_args.Invert)
{


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling1D.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, 1, false);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, (1, 2), false);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling1D.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (data_format == "channels_last")
return math_ops.reduce_max(inputs, 1, false);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalMaxPooling2D.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -12,7 +13,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (data_format == "channels_last")
return math_ops.reduce_max(inputs, (1, 2), false);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Pooling/Pooling1D.cs View File

@@ -18,6 +18,7 @@ using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 3);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
int pad_axis = args.DataFormat == "channels_first" ? 2 : 3;
inputs = tf.expand_dims(inputs, pad_axis);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs View File

@@ -17,6 +17,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -36,7 +37,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
int[] pool_shape;
int[] strides;


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Preprocessing/CategoryEncoding.cs View File

@@ -1,6 +1,6 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
namespace Tensorflow.Keras.Layers
{
/// <summary>
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var depth = args.NumTokens;
var max_value = tf.reduce_max(inputs);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -17,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
scale = constant_op.constant(args.Scale, args.DType);
offset = constant_op.constant(args.Offset, args.DType);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs View File

@@ -4,6 +4,7 @@ using System;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation);
}


+ 3
- 2
src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
@@ -15,7 +16,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (training == null)
training = false;


+ 3
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs View File

@@ -1,6 +1,8 @@
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Reshaping
{
@@ -27,7 +29,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
if (output.rank != 3)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs View File

@@ -1,6 +1,7 @@
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Reshaping
{
@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
built = true;
_buildInputShape = input_shape;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
if (output.rank != 4)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs View File

@@ -1,6 +1,7 @@
using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Reshaping
{
@@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Layers.Reshaping
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
if (output.rank != 5)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Framework;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
@@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Layers
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (_channels_first)
{


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs View File

@@ -6,6 +6,7 @@ using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers {
public class Permute : Layer
@@ -28,7 +29,7 @@ namespace Tensorflow.Keras.Layers {
built = true;
_buildInputShape = input_shape;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor outputs = inputs;
return tf.transpose(outputs, new Axis(permute));


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -4,6 +4,7 @@ using static Tensorflow.Binding;
using System.Collections.Generic;
using System;
using System.Linq;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var shapes = new List<Tensor>();
shapes.Add(array_ops.shape(inputs)[0]);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs View File

@@ -6,6 +6,7 @@ using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -24,7 +25,7 @@ namespace Tensorflow.Keras.Layers
inputSpec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return keras.backend.resize_images(inputs,
size[0], size[1],


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
@@ -26,7 +27,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return keras.backend.spatial_2d_padding(inputs,
padding: padding,


+ 85
- 0
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -0,0 +1,85 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class DropoutRNNCellMixin: RnnCellBase
{
public float dropout;
public float recurrent_dropout;
// TODO(Rinne): deal with cache.
public DropoutRNNCellMixin(LayerArgs args): base(args)
{

}

public Tensors? get_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;
return _generate_dropout_mask(
tf.ones_like(input),
dropout,
training,
count);
}

// Get the recurrent dropout mask for RNN cell.
public Tensors? get_recurrent_dropout_maskcell_for_cell(Tensors input, bool training, int count = 1)
{
if (dropout == 0f)
return null;
return _generate_dropout_mask(
tf.ones_like(input),
recurrent_dropout,
training,
count);
}

public Tensors _create_dropout_mask(Tensors input, bool training, int count = 1)
{
return _generate_dropout_mask(
tf.ones_like(input),
dropout,
training,
count);
}

public Tensors _create_recurrent_dropout_mask(Tensors input, bool training, int count = 1)
{
return _generate_dropout_mask(
tf.ones_like(input),
recurrent_dropout,
training,
count);
}

public Tensors _generate_dropout_mask(Tensor ones, float rate, bool training, int count = 1)
{
Tensors dropped_inputs()
{
DropoutArgs args = new DropoutArgs();
args.Rate = rate;
var DropoutLayer = new Dropout(args);
var mask = DropoutLayer.Apply(ones, training: training);
return mask;
}

if (count > 1)
{
Tensors results = new Tensors();
for (int i = 0; i < count; i++)
{
results.Add(dropped_inputs());
}
return results;
}

return dropped_inputs();
}
}
}

+ 3
- 2
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs View File

@@ -1,6 +1,7 @@
using System.Linq;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Rnn
{
@@ -26,9 +27,9 @@ namespace Tensorflow.Keras.Layers.Rnn
.ToArray();
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return base.Call(inputs, state: state, training: training);
return base.Call(inputs, initial_state: state, training: training);
}
}
}

+ 499
- 72
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -1,53 +1,468 @@
using System;
using OneOf;
using System;
using System.Collections.Generic;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using System.Linq.Expressions;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers.Rnn
{
public class RNN : Layer
/// <summary>
/// Base class for recurrent layers.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
/// </summary>
public class RNN : RnnBase
{
private RNNArgs args;
private object input_spec = null; // or NoneValue??
private object state_spec = null;
private object _states = null;
private object constants_spec = null;
private int _num_constants = 0;
protected IVariableV1 kernel;
protected IVariableV1 bias;
protected ILayer cell;
private RNNArgs _args;
private object _input_spec = null; // or NoneValue??
private object _state_spec = null;
private Tensors _states = null;
private object _constants_spec = null;
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
protected IRnnCell _cell;

public RNN(RNNArgs args) : base(PreConstruct(args))
{
this.args = args;
_args = args;
SupportsMasking = true;

// The input shape is unknown yet, it could have nested tensor inputs, and
// the input spec will be the list of specs for nested inputs, the structure
// of the input_spec will be the same as the input.
// if is StackedRnncell
_cell = args.Cell;

// get input_shape
_args = PreConstruct(args);

_num_constants = 0;
}

// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...)
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape
public Tensors States
{
get
{
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null);
_states = nested.AsNest().ToTensors();
}
return _states;
}
set { _states = value; }
}

private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
{
var batch = input_shape[0];
var time_step = input_shape[1];
if (_args.TimeMajor)
{
(batch, time_step) = (time_step, batch);
}

// state_size is a array of ints or a positive integer
var state_size = _cell.StateSize.ToSingleShape();

// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func<Shape, Shape> _get_output_shape;
_get_output_shape = (flat_output_size) =>
{
var output_dim = flat_output_size.as_int_list();
Shape output_shape;
if (_args.ReturnSequences)
{
if (_args.TimeMajor)
{
output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim));
}
else
{
output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim));

}
}
else
{
output_shape = new Shape(new int[] { (int)batch }.concat(output_dim));
}
return output_shape;
};

Type type = _cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape());
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
}
else
{
output_shape = _get_output_shape(state_size);
}

if (_args.ReturnState)
{
Func<Shape, Shape> _get_state_shape;
_get_state_shape = (flat_state) =>
{
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list());
return new Shape(state_shape);
};
var state_shape = _get_state_shape(state_size);

return new List<Shape> { output_shape, state_shape };
}
else
{
return output_shape;
}

//if(stateful)
//{
// if (ds_context.has_strategy()) // ds_context????
// {
// throw new Exception("RNNs with stateful=True not yet supported with tf.distribute.Strategy");
// }
//}
}

private Tensors compute_mask(Tensors inputs, Tensors mask)
{
// Time step masks must be the same for each input.
// This is because the mask for an RNN is of size [batch, time_steps, 1],
// and specifies which time steps should be skipped, and a time step
// must be skipped for all inputs.

mask = nest.flatten(mask)[0];
var output_mask = _args.ReturnSequences ? mask : null;
if (_args.ReturnState)
{
var state_mask = new List<Tensor>();
for (int i = 0; i < len(States); i++)
{
state_mask.Add(null);
}
return new List<Tensor> { output_mask }.concat(state_mask);
}
else
{
return output_mask;
}
}

public override void build(KerasShapesWrapper input_shape)
{
if (!cell.Built)
object get_input_spec(Shape shape)
{
var input_spec_shape = shape.as_int_list();

var (batch_index, time_step_index) = _args.TimeMajor ? (1, 0) : (0, 1);
if (!_args.Stateful)
{
input_spec_shape[batch_index] = -1;
}
input_spec_shape[time_step_index] = -1;
return new InputSpec(shape: input_spec_shape);
}

Shape get_step_input_shape(Shape shape)
{

// return shape[1:] if self.time_major else (shape[0],) + shape[2:]
if (_args.TimeMajor)
{
return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray();
}
else
{
return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray());
}


}

object get_state_spec(Shape shape)
{
var state_spec_shape = shape.as_int_list();
// append bacth dim
state_spec_shape = new int[] { -1 }.concat(state_spec_shape);
return new InputSpec(shape: state_spec_shape);

}

// Check whether the input shape contains any nested shapes. It could be
// (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from
// numpy inputs.


if (!_cell.Built)
{
_cell.build(input_shape);
}
}

/// <summary>
///
/// </summary>
/// <param name="inputs"></param>
/// <param name="mask">Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked</param>
/// <param name="training"></param>
/// <param name="initial_state">List of initial state tensors to be passed to the first call of the cell</param>
/// <param name="constants">List of constant tensors to be passed to the cell at each timestep</param>
/// <returns></returns>
/// <exception cref="ValueError"></exception>
/// <exception cref="NotImplementedException"></exception>
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if(optional_args is not null && rnn_optional_args is null)
{
throw new ArgumentException("The optional args shhould be of type `RnnOptionalArgs`");
}
Tensors? constants = rnn_optional_args?.Constants;
Tensors? mask = rnn_optional_args?.Mask;
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs);
// 暂时先不接受ragged tensor
int row_length = 0; // TODO(Rinne): support this param.
bool is_ragged_input = false;
_validate_args_if_ragged(is_ragged_input, mask);

(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);

_maybe_reset_cell_dropout_mask(_cell);
if (_cell is StackedRNNCells)
{
var stack_cell = _cell as StackedRNNCells;
foreach (var cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
}
}

if (mask != null)
{
// Time step masks must be the same for each input.
mask = mask.Flatten().First();
}

Shape input_shape;
if (!inputs.IsSingle())
{
// In the case of nested input, use the first element for shape check
// input_shape = nest.flatten(inputs)[0].shape;
// TODO(Wanglongzhi2001)
input_shape = inputs.Flatten().First().shape;
}
else
{
input_shape = inputs.shape;
}

var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1];

if (_args.Unroll && timesteps != null)
{
throw new ValueError(
"Cannot unroll a RNN if the " +
"time dimension is undefined. \n" +
"- If using a Sequential model, " +
"specify the time dimension by passing " +
"an `input_shape` or `batch_input_shape` " +
"argument to your first layer. If your " +
"first layer is an Embedding, you can " +
"also use the `input_length` argument.\n" +
"- If using the functional API, specify " +
"the time dimension by passing a `shape` " +
"or `batch_shape` argument to your Input layer."
);
}

// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call)
Func<Tensors, Tensors, (Tensors, Tensors)> step;
bool is_tf_rnn_cell = _cell.IsTFRnnCell;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
{
throw new ValueError(
$"RNN cell {_cell} does not support constants." +
$"Received: constants={constants}");
}

step = (inputs, states) =>
{
constants = new Tensors(states.TakeLast(_num_constants));
states = new Tensors(states.SkipLast(_num_constants));
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
// TODO(Wanglongzhi2001),should cell_call_fn's return value be Tensors, Tensors?
return (output, new_states.Single);
};
}
else
{
step = (inputs, states) =>
{
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states);
return (output, new_states.Single);
};
}

var (last_output, outputs, states) = keras.backend.rnn(step,
inputs,
initial_state,
constants: constants,
go_backwards: _args.GoBackwards,
mask: mask,
unroll: _args.Unroll,
input_length: row_length != null ? new Tensor(row_length) : new Tensor(timesteps),
time_major: _args.TimeMajor,
zero_output_for_mask: _args.ZeroOutputForMask,
return_all_outputs: _args.ReturnSequences);

if (_args.Stateful)
{
throw new NotImplementedException("this argument havn't been developed.");
}

Tensors output = new Tensors();
if (_args.ReturnSequences)
{
// TODO(Rinne): add go_backwards parameter and revise the `row_length` param
output = keras.backend.maybe_convert_to_ragged(is_ragged_input, outputs, row_length, false);
}
else
{
output = last_output;
}

if (_args.ReturnState)
{
foreach (var state in states)
{
output.Add(state);
}
return output;
}
else
{
return output;
}
}

public override Tensors Apply(Tensors inputs, Tensors initial_states = null, bool training = false, IOptionalArgs? optional_args = null)
{
RnnOptionalArgs? rnn_optional_args = optional_args as RnnOptionalArgs;
if (optional_args is not null && rnn_optional_args is null)
{
cell.build(input_shape);
throw new ArgumentException("The type of optional args should be `RnnOptionalArgs`.");
}
Tensors? constants = rnn_optional_args?.Constants;
(inputs, initial_states, constants) = RnnUtils.standardize_args(inputs, initial_states, constants, _num_constants);

if(initial_states is null && constants is null)
{
return base.Apply(inputs);
}

// TODO(Rinne): implement it.
throw new NotImplementedException();
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
{
return base.Call(inputs, state, training);
if (inputs.Length > 1)
{
if (_num_constants != 0)
{
initial_state = new Tensors(inputs.Skip(1));
}
else
{
initial_state = new Tensors(inputs.Skip(1).SkipLast(_num_constants));
constants = new Tensors(inputs.TakeLast(_num_constants));
}
if (len(initial_state) == 0)
initial_state = null;
inputs = inputs[0];
}

if (_args.Stateful)
{
if (initial_state != null)
{
var tmp = new Tensor[] { };
foreach (var s in nest.flatten(States))
{
tmp.add(tf.math.count_nonzero((Tensor)s));
}
var non_zero_count = tf.add_n(tmp);
//initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
if ((int)non_zero_count.numpy() > 0)
{
initial_state = States;
}
}
else
{
initial_state = States;
}

}
else if (initial_state is null)
{
initial_state = get_initial_state(inputs);
}

if (initial_state.Length != States.Length)
{
throw new ValueError(
$"Layer {this} expects {States.Length} state(s), " +
$"but it received {initial_state.Length} " +
$"initial state(s). Input received: {inputs}");
}

return (inputs, initial_state, constants);
}

private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask)
{
if (!is_ragged_input)
{
return;
}

if (_args.Unroll)
{
throw new ValueError("The input received contains RaggedTensors and does " +
"not support unrolling. Disable unrolling by passing " +
"`unroll=False` in the RNN Layer constructor.");
}
if (mask != null)
{
throw new ValueError($"The mask that was passed in was {mask}, which " +
"cannot be applied to RaggedTensor inputs. Please " +
"make sure that there is no mask injected by upstream " +
"layers.");
}

}

void _maybe_reset_cell_dropout_mask(ILayer cell)
{
//if (cell is DropoutRNNCellMixin)
//{
// cell.reset_dropout_mask();
// cell.reset_recurrent_dropout_mask();
//}
}

private static RNNArgs PreConstruct(RNNArgs args)
@@ -77,60 +492,72 @@ namespace Tensorflow.Keras.Layers.Rnn
return args;
}

public RNN New(LayerRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public RNN New(IList<RnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});


protected Tensor get_initial_state(Tensor inputs)
public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null)
{
return _generate_zero_filled_state_for_cell(null, null);
throw new NotImplementedException();
}

Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size)
// 好像不能cell不能传接口类型
//public RNN New(IRnnArgCell cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });

//public RNN New(List<IRnnArgCell> cell,
// bool return_sequences = false,
// bool return_state = false,
// bool go_backwards = false,
// bool stateful = false,
// bool unroll = false,
// bool time_major = false)
// => new RNN(new RNNArgs
// {
// Cell = cell,
// ReturnSequences = return_sequences,
// ReturnState = return_state,
// GoBackwards = go_backwards,
// Stateful = stateful,
// Unroll = unroll,
// TimeMajor = time_major
// });


protected Tensors get_initial_state(Tensors inputs)
{
throw new NotImplementedException("");
var input = inputs[0];
var input_shape = input.shape;
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;
Tensors init_state;
if (_cell is RnnCellBase rnn_base_cell)
{
init_state = rnn_base_cell.GetInitialState(null, batch_size, dtype);
}
else
{
init_state = RnnUtils.generate_zero_filled_state(batch_size, _cell.StateSize, dtype);
}

return init_state;
}

// Check whether the state_size contains multiple states.
public static bool _is_multiple_state(object state_size)
public static bool is_multiple_state(GeneralizedTensorShape state_size)
{
var myIndexerProperty = state_size.GetType().GetProperty("Item");
return myIndexerProperty != null
&& myIndexerProperty.GetIndexParameters().Length == 1
&& !(state_size.GetType() == typeof(Shape));
return state_size.Shapes.Length > 1;
}
}
}

+ 13
- 0
src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class RnnBase: Layer
{
public RnnBase(LayerArgs args): base(args) { }
}
}

+ 24
- 0
src/TensorFlowNET.Keras/Layers/Rnn/RnnCellBase.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
{
public abstract class RnnCellBase: Layer, IRnnCell
{
public RnnCellBase(LayerArgs args) : base(args) { }
public abstract GeneralizedTensorShape StateSize { get; }
public abstract GeneralizedTensorShape OutputSize { get; }
public abstract bool IsTFRnnCell { get; }
public abstract bool SupportOptionalArgs { get; }
public virtual Tensors GetInitialState(Tensors inputs, long batch_size, TF_DataType dtype)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}
}
}

+ 20
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -10,18 +10,36 @@ namespace Tensorflow.Keras.Layers.Rnn
public class SimpleRNN : RNN
{
SimpleRNNArgs args;
public SimpleRNN(SimpleRNNArgs args) : base(args)
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args))
{
this.args = args;
}

private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args)
{
args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs()
{
Units = args.Units,
Activation = args.Activation,
UseBias = args.UseBias,
KernelInitializer = args.KernelInitializer,
RecurrentInitializer = args.RecurrentInitializer,
BiasInitializer = args.BiasInitializer,
Dropout = args.Dropout,
RecurrentDropout = args.RecurrentDropout,
DType = args.DType,
Trainable = args.Trainable,
});
return args;
}

public override void build(KerasShapesWrapper input_shape)
{
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_buildInputShape = input_shape;

kernel = add_weight("kernel", (single_shape[-1], args.Units),
_kernel = add_weight("kernel", (single_shape[-1], args.Units),
initializer: args.KernelInitializer
//regularizer = self.kernel_regularizer,
//constraint = self.kernel_constraint,


+ 83
- 16
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -4,47 +4,114 @@ using System.Text;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Keras.Layers.Rnn
{
public class SimpleRNNCell : Layer
/// <summary>
/// Cell class for SimpleRNN.
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
/// for details about the usage of RNN API.
/// This class processes one step within the whole time sequence input, whereas
/// `tf.keras.layer.SimpleRNN` processes the whole sequence.
/// </summary>
public class SimpleRNNCell : DropoutRNNCellMixin
{
SimpleRNNArgs args;
IVariableV1 kernel;
IVariableV1 recurrent_kernel;
IVariableV1 bias;
SimpleRNNCellArgs _args;
IVariableV1 _kernel;
IVariableV1 _recurrent_kernel;
IVariableV1 _bias;
GeneralizedTensorShape _state_size;
GeneralizedTensorShape _output_size;

public SimpleRNNCell(SimpleRNNArgs args) : base(args)
public override GeneralizedTensorShape StateSize => _state_size;
public override GeneralizedTensorShape OutputSize => _output_size;
public override bool IsTFRnnCell => true;
public override bool SupportOptionalArgs => false;

public SimpleRNNCell(SimpleRNNCellArgs args) : base(args)
{
this.args = args;
this._args = args;
if (args.Units <= 0)
{
throw new ValueError(
$"units must be a positive integer, got {args.Units}");
}
this._args.Dropout = Math.Min(1f, Math.Max(0f, this._args.Dropout));
this._args.RecurrentDropout = Math.Min(1f, Math.Max(0f, this._args.RecurrentDropout));
_state_size = new GeneralizedTensorShape(args.Units);
_output_size = new GeneralizedTensorShape(args.Units);
}

public override void build(KerasShapesWrapper input_shape)
{
// TODO(Rinne): add the cache.
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];

kernel = add_weight("kernel", (single_shape[-1], args.Units),
initializer: args.KernelInitializer
_kernel = add_weight("kernel", (single_shape[-1], _args.Units),
initializer: _args.KernelInitializer
);

recurrent_kernel = add_weight("recurrent_kernel", (args.Units, args.Units),
initializer: args.RecurrentInitializer
_recurrent_kernel = add_weight("recurrent_kernel", (_args.Units, _args.Units),
initializer: _args.RecurrentInitializer
);

if (args.UseBias)
if (_args.UseBias)
{
bias = add_weight("bias", (args.Units),
initializer: args.BiasInitializer
_bias = add_weight("bias", (_args.Units),
initializer: _args.BiasInitializer
);
}

built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
// TODO(Rinne): revise the trining param (with refactoring of the framework)
protected override Tensors Call(Tensors inputs, Tensors states = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return base.Call(inputs, state, training);
// TODO(Rinne): check if it will have multiple tensors when not nested.
Tensors prev_output = Nest.IsNested(states) ? new Tensors(states[0]) : states;
var dp_mask = get_dropout_maskcell_for_cell(inputs, training.Value);
var rec_dp_mask = get_recurrent_dropout_maskcell_for_cell(prev_output, training.Value);

Tensor h;
if (dp_mask != null)
{
h = math_ops.matmul(math_ops.multiply(inputs.Single, dp_mask.Single), _kernel.AsTensor());
}
else
{
h = math_ops.matmul(inputs, _kernel.AsTensor());
}

if (_bias != null)
{
h = tf.nn.bias_add(h, _bias);
}

if (rec_dp_mask != null)
{
prev_output = math_ops.multiply(prev_output, rec_dp_mask);
}

Tensor output = h + math_ops.matmul(prev_output, _recurrent_kernel.AsTensor());

if (_args.Activation != null)
{
output = _args.Activation.Apply(output);
}
if (Nest.IsNested(states))
{
return new Nest<Tensor>(new List<Nest<Tensor>> {
new Nest<Tensor>(new List<Nest<Tensor>> { new Nest<Tensor>(output) }), new Nest<Tensor>(output) })
.ToTensors();
}
else
{
return new Tensors(output, output);
}
}
}
}

+ 12
- 2
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
@@ -8,7 +9,7 @@ using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers.Rnn
{
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
public class StackedRNNCells : Layer, IRnnCell
{
public IList<RnnCell> Cells { get; set; }
public bool reverse_state_order;
@@ -51,7 +52,7 @@ namespace Tensorflow.Keras.Layers.Rnn
{
return lastCell.output_size;
}
else if (RNN._is_multiple_state(lastCell.state_size))
else if (RNN.is_multiple_state(lastCell.StateSize))
{
// return ((dynamic)Cells[-1].state_size)[0];
throw new NotImplementedException("");
@@ -162,5 +163,14 @@ namespace Tensorflow.Keras.Layers.Rnn
// deserialize_layer(cell_config, custom_objects = custom_objects))
// return cls(cells, **config)
}

public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
{
throw new NotImplementedException();
}
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 2
- 1
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -10,6 +10,7 @@ using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using Tensorflow.Functions;
using System.Threading;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers
{
@@ -34,7 +35,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (tf.Context.executing_eagerly())
return DeFunCall(inputs);


+ 1
- 1
src/TensorFlowNET.Keras/Metrics/metrics_utils.cs View File

@@ -304,7 +304,7 @@ public class metrics_utils
var NEG_INF = -1e10;
var (_, top_k_idx) = tf.math.top_k(x, k, sorted: false);
var top_k_mask = tf.reduce_sum(
tf.one_hot(top_k_idx, (int)x.shape[-1], axis: -1), axis: -2);
tf.one_hot(top_k_idx.Single, (int)x.shape[-1], axis: -1), axis: -2);
return x * top_k_mask + NEG_INF * (1 - top_k_mask);
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -129,7 +129,7 @@ namespace Tensorflow.Keras
var indices = z.map(m =>
{
var (i, positions) = m;
return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
return tf.range(positions.Single[i], positions.Single[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor);
}, num_parallel_calls: -1);
var dataset = sequences_from_indices(data, indices, start_index, end_index);



+ 1
- 1
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -8,7 +8,7 @@ using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Text.RegularExpressions;
using Tensorflow.Extensions;
using Tensorflow.Common.Extensions;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;


+ 93
- 0
src/TensorFlowNET.Keras/Utils/RnnUtils.cs View File

@@ -0,0 +1,93 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Keras.Utils
{
internal static class RnnUtils
{
internal static Tensors generate_zero_filled_state(long batch_size_tensor, GeneralizedTensorShape state_size, TF_DataType dtype)
{
Func<GeneralizedTensorShape, Tensor> create_zeros;
create_zeros = (GeneralizedTensorShape unnested_state_size) =>
{
var flat_dims = unnested_state_size.ToSingleShape().dims;
var init_state_size = new long[] { batch_size_tensor }.Concat(flat_dims).ToArray();
return array_ops.zeros(new Shape(init_state_size), dtype: dtype);
};

// TODO(Rinne): map structure with nested tensors.
if(state_size.Shapes.Length > 1)
{
return new Tensors(state_size.ToShapeArray().Select(s => create_zeros(new GeneralizedTensorShape(s))));
}
else
{
return create_zeros(state_size);
}

}

internal static Tensors generate_zero_filled_state_for_cell(IRnnCell cell, Tensors inputs, long batch_size, TF_DataType dtype)
{
if (inputs != null)
{
batch_size = inputs.shape[0];
dtype = inputs.dtype;
}
return generate_zero_filled_state(batch_size, cell.StateSize, dtype);
}

/// <summary>
/// Standardizes `__call__` to a single list of tensor inputs.
///
/// When running a model loaded from a file, the input tensors
/// `initial_state` and `constants` can be passed to `RNN.__call__()` as part
/// of `inputs` instead of by the dedicated keyword arguments.This method
/// makes sure the arguments are separated and that `initial_state` and
/// `constants` are lists of tensors(or None).
/// </summary>
/// <param name="inputs">Tensor or list/tuple of tensors. which may include constants
/// and initial states.In that case `num_constant` must be specified.</param>
/// <param name="initial_state">Tensor or list of tensors or None, initial states.</param>
/// <param name="constants">Tensor or list of tensors or None, constant tensors.</param>
/// <param name="num_constants">Expected number of constants (if constants are passed as
/// part of the `inputs` list.</param>
/// <returns></returns>
internal static (Tensors, Tensors, Tensors) standardize_args(Tensors inputs, Tensors initial_state, Tensors constants, int num_constants)
{
if(inputs.Length > 1)
{
// There are several situations here:
// In the graph mode, __call__ will be only called once. The initial_state
// and constants could be in inputs (from file loading).
// In the eager mode, __call__ will be called twice, once during
// rnn_layer(inputs=input_t, constants=c_t, ...), and second time will be
// model.fit/train_on_batch/predict with real np data. In the second case,
// the inputs will contain initial_state and constants as eager tensor.
//
// For either case, the real input is the first item in the list, which
// could be a nested structure itself. Then followed by initial_states, which
// could be a list of items, or list of list if the initial_state is complex
// structure, and finally followed by constants which is a flat list.
Debug.Assert(initial_state is null && constants is null);
if(num_constants > 0)
{
constants = inputs.TakeLast(num_constants).ToTensors();
inputs = inputs.SkipLast(num_constants).ToTensors();
}
if(inputs.Length > 1)
{
initial_state = inputs.Skip(1).ToTensors();
inputs = inputs.Take(1).ToTensors();
}
}

return (inputs, initial_state, constants);
}
}
}

+ 2
- 1
src/TensorflowNET.Hub/KerasLayer.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Train;
using Tensorflow.Training;
@@ -89,7 +90,7 @@ namespace Tensorflow.Hub
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optionalArgs = null)
{
_check_trainability();



+ 0
- 11
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -144,17 +144,6 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual(expected_output, actual_output);
}

[TestMethod, Ignore("WIP")]
public void SimpleRNN()
{
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
/*var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);*/
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
}

[TestMethod]
public void Resizing()
{


+ 28
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -0,0 +1,28 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.UnitTest.Layers
{
[TestClass]
public class Rnn
{
[TestMethod]
public void SimpleRNN()
{
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
/*var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);*/
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
Console.WriteLine(whole_sequence_output);
Console.WriteLine(final_state);
}
}
}

+ 1
- 1
tools/TensorFlowNET.Console/SimpleRnnTest.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow

// whole_sequence_output has shape `[32, 10, 4]`.
// final_state has shape `[32, 4]`.
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
var (whole_sequence_output, final_states) = simple_rnn.Apply(inputs);
}
}
}

Loading…
Cancel
Save