Browse Source

Merge pull request #1110 from Wanglongzhi2001/master

feat: Support training of RNN and LSTM.
tags/v0.110.0-LSTM-Model
Rinne GitHub 2 years ago
parent
commit
0454c7b068
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 9416 additions and 571 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +5
    -5
      src/TensorFlowNET.Core/APIs/tf.control_flow.cs
  3. +3
    -3
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  5. +0
    -0
      src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs
  6. +3
    -3
      src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs
  7. +38
    -0
      src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs
  8. +33
    -0
      src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs
  9. +0
    -0
      src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs
  10. +20
    -0
      src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs
  11. +69
    -0
      src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs
  12. +40
    -0
      src/TensorFlowNET.Core/Common/Types/INestStructure.cs
  13. +11
    -0
      src/TensorFlowNET.Core/Common/Types/INestable.cs
  14. +21
    -0
      src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs
  15. +0
    -0
      src/TensorFlowNET.Core/Common/Types/NamedTuple.cs
  16. +62
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.Static.cs
  17. +485
    -0
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  18. +103
    -0
      src/TensorFlowNET.Core/Common/Types/NestDictionary.cs
  19. +53
    -0
      src/TensorFlowNET.Core/Common/Types/NestList.cs
  20. +36
    -0
      src/TensorFlowNET.Core/Common/Types/NestNode.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs
  22. +2
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  23. +7
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  24. +7
    -3
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  25. +6
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
  26. +19
    -0
      src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs
  27. +13
    -0
      src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs
  28. +89
    -0
      src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs
  29. +2
    -2
      src/TensorFlowNET.Core/Framework/function_def_lib.cs
  30. +13
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  31. +2
    -3
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  32. +2
    -2
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  33. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs
  34. +0
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  35. +30
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  36. +8
    -24
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  37. +14
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs
  38. +27
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs
  39. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  40. +3
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  41. +46
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  42. +25
    -0
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  43. +12
    -0
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs
  44. +1
    -0
      src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs
  45. +1
    -0
      src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs
  46. +0
    -5
      src/TensorFlowNET.Core/NumPy/Axis.cs
  47. +9
    -9
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  48. +23
    -1
      src/TensorFlowNET.Core/Numpy/Shape.cs
  49. +22
    -0
      src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs
  50. +2
    -3
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  51. +2
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  52. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  53. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs
  54. +18
    -2
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  55. +57
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  56. +1
    -1
      src/TensorFlowNET.Core/Operations/Operation.Output.cs
  57. +13
    -3
      src/TensorFlowNET.Core/Operations/Operation.cs
  58. +113
    -24
      src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs
  59. +180
    -5
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  60. +78
    -22
      src/TensorFlowNET.Core/Operations/array_ops.cs
  61. +5
    -4
      src/TensorFlowNET.Core/Operations/control_flow_ops.cs
  62. +77
    -0
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  63. +489
    -10
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  64. +1042
    -81
      src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
  65. +827
    -109
      src/TensorFlowNET.Core/Operations/gen_io_ops.cs
  66. +1308
    -0
      src/TensorFlowNET.Core/Operations/gen_list_ops.cs
  67. +585
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  68. +409
    -0
      src/TensorFlowNET.Core/Operations/gen_nn_ops.cs
  69. +1469
    -104
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  70. +3
    -3
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  71. +111
    -0
      src/TensorFlowNET.Core/Operations/list_ops.cs
  72. +1
    -1
      src/TensorFlowNET.Core/Operations/logging_ops.cs
  73. +1
    -1
      src/TensorFlowNET.Core/Operations/sort_ops.cs
  74. +16
    -4
      src/TensorFlowNET.Core/Operations/tensor_array_ops.cs
  75. +401
    -0
      src/TensorFlowNET.Core/Operations/while_v2.cs
  76. +2
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  77. +6
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  78. +7
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  79. +24
    -0
      src/TensorFlowNET.Core/Tensors/TensorArray.cs
  80. +201
    -89
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  81. +1
    -2
      src/TensorFlowNET.Core/Training/Trackable.cs
  82. +1
    -0
      src/TensorFlowNET.Core/Util/nest.py.cs
  83. +20
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  84. +1
    -1
      src/TensorFlowNET.Core/ops.cs
  85. +525
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  86. +3
    -2
      src/TensorFlowNET.Keras/Engine/Functional.cs
  87. +6
    -3
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  88. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs
  89. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Build.cs
  90. +3
    -2
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  91. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  92. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  93. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.cs
  94. +2
    -1
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  95. +4
    -0
      src/TensorFlowNET.Keras/IsExternalInit.cs
  96. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  97. +2
    -2
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  98. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs
  99. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs
  100. +2
    -1
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs

+ 14
- 0
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Runtime.InteropServices;
using static Tensorflow.CppShapeInferenceResult.Types;

namespace Tensorflow
{
@@ -50,6 +51,19 @@ namespace Tensorflow
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
}

public unsafe static byte[] ByteStringPiece(IntPtr handle)
{
byte* str_data = (byte*)handle.ToPointer();
List<byte> bytes = new List<byte>();
byte current = 255;
while (current != ((byte)'\0'))
{
current = *(str_data++);
bytes.Add(current);
}
return bytes.Take(bytes.Count - 1).ToArray();
}

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);



+ 5
- 5
src/TensorFlowNET.Core/APIs/tf.control_flow.cs View File

@@ -46,10 +46,10 @@ namespace Tensorflow
Tensor loop_vars,
int parallel_iterations = 10)
{
Func<Tensor[], Tensor> cond1 = x
Func<Tensors, Tensor> cond1 = x
=> cond(x[0]);

Func<Tensor[], Tensor[]> body1 = x
Func<Tensors, Tensors> body1 = x
=> new[] { body(x[0]) };

var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ namespace Tensorflow
return results[0];
}

public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
=> control_flow_ops.while_loop(cond, body, loop_vars,


+ 3
- 3
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -71,15 +71,15 @@ namespace Tensorflow
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
axis: axis,
num_or_size_splits: num_split,
axis: ops.convert_to_tensor(axis),
name: name);

public Tensor ensure_shape(Tensor x, Shape shape, string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -524,7 +524,7 @@ namespace Tensorflow
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:


src/TensorFlowNET.Core/Extensions/DictionaryExtension.cs → src/TensorFlowNET.Core/Common/Extensions/DictionaryExtension.cs View File


src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs → src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs View File

@@ -3,16 +3,16 @@ using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Extensions
namespace Tensorflow.Common.Extensions
{
public static class JObjectExtensions
{
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
{
var res = obj[key];
if(res is null)
if (res is null)
{
return default(T);
return default;
}
else
{

+ 38
- 0
src/TensorFlowNET.Core/Common/Extensions/LinqExtensions.cs View File

@@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.Common.Extensions
{
public static class LinqExtensions
{
#if NETSTANDARD2_0
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Skip(sequence.Count() - count);
}

public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
{
return sequence.Take(sequence.Count() - count);
}
#endif
public static Tensors ToTensors(this Tensor[] tensors)
{
return new Tensors(tensors);
}

public static Tensors ToTensors(this IList<Tensor> tensors)
{
return new Tensors(tensors);
}

public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
{
first = values.Item1;
second = values.Item2;
third = values.Item3;
}
}
}

+ 33
- 0
src/TensorFlowNET.Core/Common/Extensions/NestExtensions.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Common.Extensions
{
public static class NestExtensions
{
public static Tensors ToTensors(this INestable<Tensor> tensors)
{
return new Tensors(tensors.AsNest());
}

public static Tensors? ToTensors(this Nest<Tensor> tensors)
{
return Tensors.FromNest(tensors);
}

/// <summary>
/// If the nested object is already a nested type, this function could reduce it.
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
{
return Nest<TOut>.ReduceFrom(input);
}
}
}

src/TensorFlowNET.Core/Extensions/OneofExtension.cs → src/TensorFlowNET.Core/Common/Extensions/OneofExtension.cs View File


+ 20
- 0
src/TensorFlowNET.Core/Common/Types/FakeTensorByTensorArray.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This is a temp solution, which should be removed after refactoring `Tensors`
/// </summary>
[Obsolete]
public class FakeTensorByTensorArray: Tensor
{
public TensorArray TensorArray { get; set; }

public FakeTensorByTensorArray(TensorArray array)
{
TensorArray = array;
}
}
}

+ 69
- 0
src/TensorFlowNET.Core/Common/Types/GeneralizedTensorShape.cs View File

@@ -0,0 +1,69 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;

namespace Tensorflow.Common.Types
{
public class GeneralizedTensorShape: Nest<Shape>
{
public GeneralizedTensorShape(Shape value, string? name = null)
{
NodeValue = value;
NestType = NestType.Node;
}

public GeneralizedTensorShape(IEnumerable<Shape> values, string? name = null)
{
ListValue = values.Select(s => new Nest<Shape>(s) as INestStructure<Shape>).ToList();
Name = name;
NestType = NestType.List;
}

public GeneralizedTensorShape(Dictionary<string, Shape> value, string? name = null)
{
DictValue = value.ToDictionary(x => x.Key, x => new Nest<Shape>(x.Value) as INestStructure<Shape>);
Name = name;
NestType = NestType.Dictionary;
}

public GeneralizedTensorShape(Nest<Shape> other)
{
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public Shape ToSingleShape()
{
var shapes = Flatten().ToList();
if (shapes.Count != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
return shapes[0];
}

public long ToNumber()
{
var shapes = Flatten().ToList();
if (shapes.Count != 1 || shapes[0].ndim != 1)
{
throw new ValueError("The generalized shape contains more than 1 dim.");
}
return shapes[0].dims[0];
}

public INestStructure<TensorShapeConfig> ToTensorShapeConfigs()
{
return MapStructure(s => new TensorShapeConfig() { Items = s.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() });
}

public static implicit operator GeneralizedTensorShape(Shape shape)
{
return new GeneralizedTensorShape(shape);
}
}
}

+ 40
- 0
src/TensorFlowNET.Core/Common/Types/INestStructure.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface indicates that a class may have a nested structure and provide
/// methods to manipulate with the structure.
/// </summary>
public interface INestStructure<T>: INestable<T>
{
NestType NestType { get; }

/// <summary>
/// The item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has ShallowNestedCount = 3.
/// </summary>
int ShallowNestedCount { get; }
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
int TotalNestedCount { get; }

/// <summary>
/// Flatten the Nestable object. Node that if the object contains only one value,
/// it will be flattened to an enumerable with one element.
/// </summary>
/// <returns></returns>
IEnumerable<T> Flatten();
/// <summary>
/// Construct a new object with the same nested structure.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <returns></returns>
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
}
}

+ 11
- 0
src/TensorFlowNET.Core/Common/Types/INestable.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public interface INestable<T>
{
Nest<T> AsNest();
}
}

+ 21
- 0
src/TensorFlowNET.Core/Common/Types/IOptionalArgs.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// This interface is used when some corresponding python methods have optional args.
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
/// as the parameter of the method.
/// </summary>
public interface IOptionalArgs
{
/// <summary>
/// The identifier of the class. It is not an argument but only something to
/// separate different OptionalArgs.
/// </summary>
string Identifier { get; }
}
}

src/TensorFlowNET.Core/Extensions/NamedTuple.cs → src/TensorFlowNET.Core/Common/Types/NamedTuple.cs View File


+ 62
- 0
src/TensorFlowNET.Core/Common/Types/Nest.Static.cs View File

@@ -0,0 +1,62 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public static class Nest
{
/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<TOut> PackSequenceAs<T, TOut>(INestable<T> template, TOut[] flatItems)
{
return template.AsNest().PackSequence(flatItems);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="template"></param>
/// <param name="flatItems"></param>
/// <returns></returns>
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems)
{
return template.AsNest().PackSequence(flatItems.ToArray());
}

/// <summary>
/// Flatten the nested object.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject)
{
return nestedObject.AsNest().Flatten();
}

/// <summary>
/// Map the structure with specified function.
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="func"></param>
/// <param name="nestedObject"></param>
/// <returns></returns>
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject)
{
return nestedObject.AsNest().MapStructure(func);
}

public static bool IsNested<T>(INestable<T> obj)
{
return obj.AsNest().IsNested();
}
}
}

+ 485
- 0
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -0,0 +1,485 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Common.Types
{
public enum NestType
{
Empty,
Node,
List,
Dictionary
}

/// <summary>
/// A nested structure which may inclulde value, list and dictionary.
/// Note that dictionary does not ensure the data order. When using it as IEnumerable,
/// its order is depth-first.
/// </summary>
/// <typeparam name="T"></typeparam>
public class Nest<T> : INestStructure<T>, IEnumerable<T>
{
private static readonly Nest<T> _empty = new Nest<T>()
{
NestType = NestType.Empty,
};
public static Nest<T> Empty => _empty;
public NestType NestType { get; protected set; }
public string? Name { get; set; }
public T? NodeValue { get; protected set; }
public List<INestStructure<T>>? ListValue { get; protected set; }
public Dictionary<string, INestStructure<T>>? DictValue { get; protected set; }

public int ShallowNestedCount
{
get
{
if (NestType == NestType.Empty)
{
return 0;
}
else if (NestType == NestType.Node)
{
return 1;
}
else if (NestType == NestType.List)
{
return ListValue!.Count;
}
else // dict
{
return DictValue!.Count;
}
}
}

public int TotalNestedCount
{
get
{
return Flatten().Count();
}
}

protected Nest() { }

public Nest(T value, string? name = null)
{
NodeValue = value;
Name = name;
NestType = NestType.Node;
}

public Nest(IEnumerable<INestStructure<T>> values, string? name = null)
{
ListValue = values.ToList();
Name = name;
NestType = NestType.List;
}

public Nest(Dictionary<string, INestStructure<T>> value, string? name = null)
{
DictValue = value;
Name = name;
NestType = NestType.Dictionary;
}

public Nest(Nest<T> other)
{
NestType = other.NestType;
NodeValue = other.NodeValue;
DictValue = other.DictValue;
ListValue = other.ListValue;
Name = other.Name;
}

public virtual IEnumerable<T> Flatten()
{
return FlattenInternal(this);
}
public virtual INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return MapStructureInternal(func);
}

/// <summary>
/// Pack the flat items to a nested sequence by the template.
/// </summary>
/// <param name="flatItems"></param>
/// <returns></returns>
public virtual Nest<TOut> PackSequence<TOut>(TOut[] flatItems)
{
if(flatItems.Length == 0)
{
return Nest<TOut>.Empty;
}
int index = 0;
return PackSequenceInternal(this, flatItems, ref index);
}

private static Nest<TOut> PackSequenceInternal<TOut>(Nest<T> template, TOut[] flatItems, ref int index)
{
if(template.NestType == NestType.Node)
{
if(index >= flatItems.Length)
{
throw new InvalidArgumentError("The template and flat items are not matched.");
}
return new Nest<TOut>(flatItems[index++]);
}
else if(template.NestType == NestType.List)
{
List<Nest<TOut>> nestedObjects = new List<Nest<TOut>>();
for (int i = 0; i < template.ListValue!.Count; i++)
{
nestedObjects.Add(PackSequenceInternal(template.ListValue![i].AsNest(), flatItems, ref index));
}
return new Nest<TOut>(nestedObjects);
}
else if(template.NestType == NestType.Node)
{
Dictionary<string, INestStructure<TOut>> dict = new Dictionary<string, INestStructure<TOut>>();
foreach(var (key, value) in template.DictValue!)
{
dict[key] = PackSequenceInternal(value.AsNest(), flatItems, ref index);
}
return new Nest<TOut>(dict);
}
// Consider Empty as invalid type.
throw new InvalidArgumentError("When using `PackSequenceAs`, the template cannot contain empty node.");
}

public virtual Nest<T> AsNest()
{
return this;
}

public virtual Nest<T> MergeWith(Nest<T>? other)
{
if(other is null || other == Nest<T>.Empty)
{
return this;
}
if(this == Nest<T>.Empty)
{
return other;
}
if(NestType == NestType.Node && other.NestType == NestType.Node)
{
return new Nest<T>(new Nest<T>[] { this, other });
}
else if(NestType == NestType.List && other.NestType == NestType.List)
{
return new Nest<T>(this.ListValue!.Concat(other.ListValue!));
}
else if(NestType == NestType.Dictionary && other.NestType == NestType.Dictionary)
{
return new Nest<T>(this.DictValue!.Concat(other.DictValue!).ToDictionary(x => x.Key, x => x.Value));
}
else
{
return new Nest<T>(new Nest<T>[] { this, other });
}
}

/// <summary>
/// To see if the nested object is really nested. Despite being called `Nest`, sometimes it's actually not
/// nested. For example, [1, 2, 3] is not nested, while [1, [2, 3]] is nested.
/// </summary>
/// <returns></returns>
public bool IsNested()
{
if(NestType is NestType.Empty or NestType.Node)
{
return false;
}
else if(NestType is NestType.List)
{
return ListValue!.Count > 0;
}
else
{
return DictValue!.Count > 0;
}
}

[Obsolete("The indexer of Tensors is not encouraged because it leads to unclear meanings.")]
public T this[int index]
{
get
{
bool success = FindInternal(this, index, out var result);
if (success)
{
return result;
}
else
{
throw new IndexOutOfRangeException();
}
}
set
{
bool success = SetInternal(this, index, value);
if (!success)
{
throw new IndexOutOfRangeException();
}
}
}

/// <summary>
/// If the existing nested structure if of type `Nest[INestStructure[T]]`, we can reduce it
/// to `Nest[T]`.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <param name="input"></param>
/// <returns></returns>
public static Nest<T> ReduceFrom<TOut>(INestStructure<TOut> input) where TOut: INestStructure<T>
{
var nested = input.AsNest();
return ReduceInternal(nested).AsNest();
}

private static INestStructure<T> ReduceInternal<TOut>(Nest<TOut> node) where TOut : INestStructure<T>
{
if(node.NestType == NestType.Empty)
{
return Nest<T>.Empty;
}
else if(node.NestType == NestType.Node)
{
return node.NodeValue!.AsNest();
}
else if(node.NestType == NestType.List)
{
return new Nest<T>(node.ListValue!.Select(x => ReduceInternal(x.AsNest())));
}
else // Dictionary type
{
return new Nest<T>(node.DictValue!.ToDictionary(x => x.Key, x => ReduceInternal(x.Value.AsNest())));
}
}

private static bool FindInternal(Nest<T> node, int index, out T? result)
{
if (node.NestType == NestType.Node)
{
if(index == 0)
{
result = node.NodeValue!;
return true;
}
result = default(T);
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if(index == 0)
{
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
result = default(T);
return false;
}
else if(node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return FindInternal(item.AsNest(), index, out result);
}
index--;
}
result = default(T);
return false;
}
else
{
result = default(T);
return false;
}
}

private static bool SetInternal(Nest<T> node, int index, T newValue)
{
if (node.NestType == NestType.Node)
{
if (index == 0)
{
node.NodeValue = newValue;
return true;
}
return false;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
if (index == 0)
{
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
return false;
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
if (index == 0)
{
return SetInternal(item.AsNest(), index, newValue);
}
index--;
}
return false;
}
else
{
return false;
}
}

private static IEnumerable<T> FlattenInternal(Nest<T> node)
{
if (node.NestType == NestType.Node)
{
yield return node.NodeValue!;
}
else if (node.NestType == NestType.List)
{
foreach (var item in node.ListValue!)
{
foreach(var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
}
}
else if (node.NestType == NestType.Dictionary)
{
foreach (var item in node.DictValue!.Values)
{
foreach (var val in FlattenInternal(item.AsNest()))
{
yield return val;
}
}
}
}

private Nest<TOut> MapStructureInternal<TOut>(Func<T, TOut> func)
{
if (NestType == NestType.Node)
{
return new Nest<TOut>(func(NodeValue!));
}
else if (NestType == NestType.List)
{
List<Nest<TOut>> outs = new List<Nest<TOut>>();
foreach (var item in ListValue!)
{
outs.Add(item.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else if (NestType == NestType.Dictionary)
{
Dictionary<string, INestStructure<TOut>> outs = new Dictionary<string, INestStructure<TOut>>();
foreach (var (key, value) in DictValue!)
{
outs.Add(key, value.AsNest().MapStructureInternal(func));
}
return new Nest<TOut>(outs);
}
else
{
return Nest<TOut>.Empty;
}
}

public IEnumerator<T> GetEnumerator()
{
return Flatten().GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public override string ToString()
{
StringBuilder sb = new StringBuilder();
sb.Append("(");
WriteString(this, sb);
sb.Append(")");
return sb.ToString();
}

private static void WriteString(Nest<T> node, StringBuilder sb)
{
if (!string.IsNullOrEmpty(node.Name))
{
sb.Append($"{node.Name}: ");
}
if (node.NestType == NestType.Node)
{
sb.Append(node.NodeValue!.ToString());
}
else if (node.NestType == NestType.List)
{
sb.Append("[");
for(int i = 0; i < node.ListValue!.Count; i++)
{
WriteString(node.ListValue![i].AsNest(), sb);
if(i != node.ListValue!.Count - 1)
{
sb.Append(", ");
}
}
sb.Append("]");
}
else if (node.NestType == NestType.Dictionary)
{
sb.Append("{");
int count = node.DictValue!.Count;
int i = 0;
foreach (var (key, value) in node.DictValue!)
{
sb.Append($"{key}: ");
WriteString(value.AsNest(), sb);
if (i != count - 1)
{
sb.Append(", ");
}
i++;
}
sb.Append("}");
}
else
{
sb.Append("<empty>");
}
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2 });
}

public static implicit operator Nest<T>((INestStructure<T>, INestStructure<T>, INestStructure<T>) inputs)
{
return new Nest<T>(new INestStructure<T>[] { inputs.Item1, inputs.Item2, inputs.Item3 });
}
}
}

+ 103
- 0
src/TensorFlowNET.Core/Common/Types/NestDictionary.cs View File

@@ -0,0 +1,103 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
public class NestDictionary<TKey, TValue> : INestStructure<TValue>, IDictionary<TKey, TValue> where TKey : notnull
{
public NestType NestType => NestType.Dictionary;
public IDictionary<TKey, TValue> Value { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;
public NestDictionary(IDictionary<TKey, TValue> dict)
{
Value = dict;
}
public IEnumerable<TValue> Flatten()
{
return Value.Select(x => x.Value);
}
public INestStructure<TOut> MapStructure<TOut>(Func<TValue, TOut> func)
{
return new NestList<TOut>(Value.Select(x => func(x.Value)));
}

public Nest<TValue> AsNest()
{
return new Nest<TValue>(Value.Values.Select(x => new Nest<TValue>(x)));
}

// Required IDictionary<TKey, TValue> members
public int Count => Value.Count;

public bool IsReadOnly => Value.IsReadOnly;

public ICollection<TKey> Keys => Value.Keys;

public ICollection<TValue> Values => Value.Values;

public void Add(TKey key, TValue value)
{
Value.Add(key, value);
}

public void Add(KeyValuePair<TKey, TValue> item)
{
Value.Add(item);
}

public void Clear()
{
Value.Clear();
}

public bool Contains(KeyValuePair<TKey, TValue> item)
{
return Value.Contains(item);
}

public bool ContainsKey(TKey key)
{
return Value.ContainsKey(key);
}

public void CopyTo(KeyValuePair<TKey, TValue>[] array, int arrayIndex)
{
Value.CopyTo(array, arrayIndex);
}

public IEnumerator<KeyValuePair<TKey, TValue>> GetEnumerator()
{
return Value.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}

public bool Remove(TKey key)
{
return Value.Remove(key);
}

public bool Remove(KeyValuePair<TKey, TValue> item)
{
return Value.Remove(item);
}

public bool TryGetValue(TKey key, out TValue value)
{
return Value.TryGetValue(key, out value);
}

// Optional IDictionary<TKey, TValue> members
public TValue this[TKey key]
{
get => Value[key];
set => Value[key] = value;
}
}
}

+ 53
- 0
src/TensorFlowNET.Core/Common/Types/NestList.cs View File

@@ -0,0 +1,53 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// The implementation of a list that support nest structure, in which the depth is 1.
/// </summary>
/// <typeparam name="T"></typeparam>
public sealed class NestList<T> : INestStructure<T>, IEnumerable<T>
{
public NestType NestType => NestType.List;
public List<T> Values { get; set; }
public int ShallowNestedCount => Values.Count;

public int TotalNestedCount => Values.Count;

public NestList(params T[] values)
{
Values = new List<T>(values);
}

public NestList(IEnumerable<T> values)
{
Values = new List<T>(values);
}
public IEnumerable<T> Flatten()
{
return Values;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestList<TOut>(Values.Select(x => func(x)));
}

public Nest<T> AsNest()
{
return new Nest<T>(Values.Select(x => new Nest<T>(x)));
}

// Enumerator implementation
public IEnumerator<T> GetEnumerator()
{
return Values.GetEnumerator();
}

IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Common/Types/NestNode.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Common.Types
{
/// <summary>
/// A nested structure with only one element.
/// </summary>
/// <typeparam name="T"></typeparam>
public class NestNode<T> : INestStructure<T>
{
public NestType NestType => NestType.Node;
public T Value { get; set; }
public int ShallowNestedCount => 1;

public int TotalNestedCount => 1;
public NestNode(T value)
{
Value = value;
}
public IEnumerable<T> Flatten()
{
yield return Value;
}
public INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func)
{
return new NestNode<TOut>(func(Value));
}

public Nest<T> AsNest()
{
return new Nest<T>(Value);
}
}
}

src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs → src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs View File

@@ -3,7 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
namespace Tensorflow.Common.Types
{
public class TensorShapeConfig
{

+ 2
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -161,8 +161,8 @@ namespace Tensorflow
break;
}

yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount)));
yield return (new Tensors(results.Take(FirstInputTensorCount).ToArray()), results.Length == FirstInputTensorCount ?
null : new Tensors(results.Skip(FirstInputTensorCount).ToArray()));
}
}



+ 7
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -352,13 +352,19 @@ namespace Tensorflow.Eager
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
break;
case TF_AttrType.TF_ATTR_SHAPE:
var dims = (value as long[]).ToArray();
long[] dims;
if (value is Shape shape) dims = shape.dims.ToArray();
else if (value is long[] longs) dims = longs.ToArray();
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray();
else dims = ((long[])value).ToArray();
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
else if(value is string str)
c_api.TFE_OpSetAttrFunctionName(op, key, str, str.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;


+ 7
- 3
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);


bool unconnected_gradients_zero = unconnected_gradients == "zero";
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager
{
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
}
Shape tensor_shape = new(dims);

if(status.Code != TF_Code.TF_OK)
{
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager
}
else
{
Shape tensor_shape = new(dims);
return new TapeTensor(id, dtype, tensor_shape);
}
}
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager
return dtype == dtypes.variant || dtype == dtypes.resource;
}

bool ListContainNone(long[] list)
bool ListContainNone(long[]? list)
{
if(list is null)
{
return true;
}
int len = list.Length;
if(len == 0)
{


+ 6
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Eager
var str = NDArrayRender.ToString(nd);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
public string ToString(int maxLength)
{
var nd = new NDArray(this);
var str = NDArrayRender.ToString(nd, maxLength);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
}
}

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Exceptions
{
public class NotOkStatusException : TensorflowException
{
public NotOkStatusException() : base()
{

}

public NotOkStatusException(string message) : base(message)
{

}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs View File

@@ -1,4 +1,5 @@
using System.Linq;
using Tensorflow.Eager;

namespace Tensorflow.Framework.Models
{
@@ -24,5 +25,17 @@ namespace Tensorflow.Framework.Models
shapes.Insert(0, dim);
return new TensorSpec(shapes.ToArray(), _dtype);
}

public static TensorSpec FromTensor(Tensor tensor, string? name = null)
{
if(tensor is EagerTensor)
{
return new TensorSpec(tensor.shape, tensor.dtype, name);
}
else
{
return new TensorSpec(tensor.shape, tensor.dtype, name ?? tensor.name);
}
}
}
}

+ 89
- 0
src/TensorFlowNET.Core/Framework/auto_control_deps_utils.cs View File

@@ -0,0 +1,89 @@
using Tensorflow.Graphs;

namespace Tensorflow.Framework
{
internal static class auto_control_deps_utils
{
public static readonly string READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs";
public static List<int> get_read_only_resource_input_indices_graph(FuncGraph func_graph)
{
List<int> result = new List<int>();
// A cache to store the read only resource inputs of an Op.
// Operation -> ObjectIdentitySet of resource handles.
Dictionary<Operation, HashSet<Tensor>> opReadOnlyResourceInputs =
new Dictionary<Operation, HashSet<Tensor>>();

for (int inputIndex = 0; inputIndex < func_graph.Inputs.Length; inputIndex++)
{
Tensor t = func_graph.Inputs[inputIndex];
if (t.dtype != dtypes.resource)
continue;

bool readOnly = true;
foreach (var op in t.consumers())
{
if (opReadOnlyResourceInputs.ContainsKey(op))
{
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
else
{
List<int> indices = _get_read_only_resource_input_indices_op(op);
opReadOnlyResourceInputs[op] = new HashSet<Tensor>(
indices.Select(i => op.inputs[i]));
if (!opReadOnlyResourceInputs[op].Contains(t))
{
readOnly = false;
break;
}
}
}

if (readOnly)
result.Add(inputIndex);
}

return result;
}

private static List<int> _get_read_only_resource_input_indices_op(Operation op)
{
// ignore the RESOURCE_READ_OPS

int[] read_only_input_indices;

try
{
read_only_input_indices = op.get_attr<int[]>(READ_ONLY_RESOURCE_INPUTS_ATTR);
}
catch (InvalidArgumentError)
{
return new List<int>();
}

int read_only_index = 0;
List<int> result = new();
for (int i = 0; i < op.inputs.Length; i++)
{
if (read_only_index >= read_only_input_indices.Length)
{
break;
}
if (op.inputs[i].dtype != dtypes.resource)
{
continue;
}
if (read_only_index < read_only_input_indices.Length && i == read_only_input_indices[read_only_index])
{
result.Add(i);
read_only_index++;
}
}
return result;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Framework/function_def_lib.cs View File

@@ -42,10 +42,10 @@ namespace Tensorflow.Framework
func_graph.as_default();
importer.import_graph_def(graph_def, name: "", validate_colocation_constraints: false);
var input_tensor_names = fdef.Signature.InputArg.Select(x => nested_to_flat_tensor_name[x.Name]);
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Inputs = new Tensors(input_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());

var output_tensor_names = fdef.Signature.OutputArg.Select(x => nested_to_flat_tensor_name[fdef.Ret[x.Name]]);
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)));
func_graph.Outputs = new Tensors(output_tensor_names.Select(x => func_graph.get_tensor_by_name(x)).ToArray());
// TODO(Rinne): func_graph.ControlOutputs
_set_handle_data(func_graph, fdef);



+ 13
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Gradients;
using Tensorflow.Graphs;
using Tensorflow.Train;
using Tensorflow.Util;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;

namespace Tensorflow.Functions
@@ -40,6 +41,18 @@ namespace Tensorflow.Functions
public Tensor[] FlatStructuredOutputs => func_graph.FlatStructuredOutputs;
public IEnumerable<IVariableV1> Variables => func_graph.Variables;
public IEnumerable<IVariableV1> TrainableVariables => func_graph.TrainableVariables;
internal NameAttrList AsNameAttrList
{
get
{
NameAttrList ret = new() { Name = this.Name };
foreach (var (name, value) in _attrs)
{
ret.Attr[name] = value;
}
return ret;
}
}

public ConcreteFunction(string name)
{


+ 2
- 3
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients
? input_values[0].rank + dim_int
: dim_int % input_values[0].rank;
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
var sizes_tensor = constant_op.constant(sizes);
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList();
}
else if (constant_op.is_constant(concat_dim))
{
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList();
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList();
}
else
{


+ 2
- 2
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -81,7 +81,7 @@ public class FuncGraph : Graph, IDisposable
public IEnumerable<IVariableV1> TrainableVariables => Variables.Where(v => v.Trainable);
public Dictionary<string, AttrValue> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
internal Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures
@@ -399,7 +399,7 @@ public class FuncGraph : Graph, IDisposable
var flat_func_args = nest.flatten(func_args as object);
var flat_func_kwargs = nest.flatten(func_kwargs as object);
func_graph.Inputs = new Tensors(flat_func_args.concat(flat_func_kwargs)
.Where(x => x is Tensor).Select(x => (Tensor)x));
.Where(x => x is Tensor).Select(x => (Tensor)x).ToArray());

//var func_args_before = nest.pack_sequence_as(func_args, flat_func_args, true);
//var func_kwargs_before = nest.pack_sequence_as(func_kwargs, flat_func_kwargs, true);


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -129,7 +129,7 @@ namespace Tensorflow
}
}

protected Graph outer_graph;
internal Graph outer_graph;
public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions;
public SafeGraphHandle c_graph => _handle;


+ 0
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -4,8 +4,6 @@
{
// TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
public int Implementation { get; set; }
}
}

+ 30
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -1,7 +1,35 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
using Newtonsoft.Json;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
// TODO: complete the implementation
public class LSTMCellArgs : LayerArgs
public class LSTMCellArgs : AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }
[JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")]
public int Implementation { get; set; } = 2;

}
}

+ 8
- 24
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -1,17 +1,12 @@
using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
// TODO(Rinne): add regularizers.
public class RNNArgs : AutoSerializeLayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnArgCell Cell { get; set; } = null;
[JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")]
@@ -24,8 +19,10 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public bool Unroll { get; set; } = false;
[JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false;

public int? InputDim { get; set; }
public int? InputLength { get; set; }
// TODO: Add `num_constants` and `zero_output_for_mask`.
public Dictionary<string, object> Kwargs { get; set; } = null;

public int Units { get; set; }
public Activation Activation { get; set; }
@@ -34,21 +31,8 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }

// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
public float Dropout { get; set; } = .0f;
public bool ZeroOutputForMask { get; set; } = false;
public float RecurrentDropout { get; set; } = .0f;
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class RnnOptionalArgs: IOptionalArgs
{
public string Identifier => "Rnn";
public Tensor Mask { get; set; } = null;
public Tensors Constants { get; set; } = null;
}
}

+ 27
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs View File

@@ -0,0 +1,27 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
{
[JsonProperty("units")]
public int Units { get; set; }
// TODO(Rinne): lack of initialized value of Activation. Merging keras
// into tf.net could resolve it.
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; }
[JsonProperty("recurrent_initializer")]
public IInitializer RecurrentInitializer { get; set; }
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; }

}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -1,10 +1,10 @@
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
public bool ReverseStateOrder = false;
}
}

+ 3
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,4 +1,5 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Training;
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras
List<ILayer> Layers { get; }
List<INode> InboundNodes { get; }
List<INode> OutboundNodes { get; }
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false);
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null);
List<IVariableV1> TrainableVariables { get; }
List<IVariableV1> TrainableWeights { get; }
List<IVariableV1> NonTrainableWeights { get; }


+ 46
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -1,5 +1,6 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

@@ -159,6 +160,18 @@ namespace Tensorflow.Keras.Layers
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
public ILayer LeakyReLU(float alpha = 0.3f);

public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2);

public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,
@@ -192,6 +205,19 @@ namespace Tensorflow.Keras.Layers
float offset = 0,
Shape input_shape = null);

public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros",
float dropout = 0f,
float recurrent_dropout = 0f);

public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells);

public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
@@ -200,6 +226,26 @@ namespace Tensorflow.Keras.Layers
bool return_sequences = false,
bool return_state = false);

public ILayer RNN(
IRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer RNN(
IEnumerable<IRnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false
);

public ILayer Subtract();
}
}

+ 25
- 0
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Rnn
{
public interface IRnnCell: ILayer
{
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
INestStructure<long>? StateSize { get; }
/// <summary>
/// If the derived class tends to not implement it, please return null.
/// </summary>
INestStructure<long>? OutputSize { get; }
/// <summary>
/// Whether the optional RNN args are supported when appying the layer.
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
/// </summary>
bool SupportOptionalArgs { get; }
Tensors GetInitialState(Tensors inputs, Tensor batch_size, TF_DataType dtype);
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Layers.Rnn
{
public interface IStackedRnnCells : IRnnCell
{
int Count { get; }
IRnnCell this[int idx] { get; }
}
}

+ 1
- 0
src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs View File

@@ -3,6 +3,7 @@ using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Saving.Json
{


+ 1
- 0
src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs View File

@@ -6,6 +6,7 @@ using System.Text;
using System.Diagnostics;
using OneOf.Types;
using Tensorflow.Keras.Saving.Json;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Saving
{


+ 0
- 5
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -74,8 +74,3 @@ namespace Tensorflow
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
}
}

namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 9
- 9
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy
{
public class NDArrayRender
{
public static string ToString(NDArray array)
public static string ToString(NDArray array, int maxLength = 10)
{
Shape shape = array.shape;
if (shape.IsScalar)
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy

var s = new StringBuilder();
s.Append("array(");
Build(s, array);
Build(s, array, maxLength);
s.Append(")");
return s.ToString();
}

static void Build(StringBuilder s, NDArray array)
static void Build(StringBuilder s, NDArray array, int maxLength)
{
var shape = array.shape;

@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy
var len = shape[0];
s.Append("[");

if (len <= 10)
if (len <= maxLength)
{
for (int i = 0; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy
}
else
{
for (int i = 0; i < 5; i++)
for (int i = 0; i < maxLength / 2; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy
s.Append(" ... ");
s.AppendLine();

for (int i = (int)len - 5; i < len; i++)
for (int i = (int)len - maxLength / 2; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");


+ 23
- 1
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -19,13 +19,14 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving.Common;
using Tensorflow.NumPy;

namespace Tensorflow
{
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape
public class Shape : INestStructure<long>
{
public int ndim => _dims == null ? -1 : _dims.Length;
long[] _dims;
@@ -41,6 +42,27 @@ namespace Tensorflow
}
}

public NestType NestType => NestType.List;

public int ShallowNestedCount => ndim;
/// <summary>
/// The total item count of depth 1 of the nested structure.
/// For example, [1, 2, [3, 4, 5]] has TotalNestedCount = 5.
/// </summary>
public int TotalNestedCount => ndim;

public IEnumerable<long> Flatten() => dims.Select(x => x);

public INestStructure<TOut> MapStructure<TOut>(Func<long, TOut> func)
{
return new NestList<TOut>(dims.Select(x => func(x)));
}

public Nest<long> AsNest()
{
return new NestList<long>(Flatten()).AsNest();
}

#region https://docs.microsoft.com/en-us/dotnet/csharp/language-reference/proposals/csharp-8.0/ranges
public int Length => ndim;
public long[] Slice(int start, int length)


+ 22
- 0
src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Operations.Initializers
{
/// <summary>
/// An initializer specially used for debugging (to load weights from disk).
/// </summary>
class NpyLoadInitializer : IInitializer
{
string _path;
public NpyLoadInitializer(string path) { _path = path; }
public string ClassName => "";
public IDictionary<string, object> Config => new Dictionary<string, object>();
public Tensor Apply(InitializerArgs args)
{
return np.load(_path);
}
}
}

+ 2
- 3
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -53,13 +53,12 @@ public class Orthogonal : IInitializer
// Compute the qr factorization
var (q, r) = tf.linalg.qr(a, full_matrices: false);
// Make Q uniform
var d = tf.linalg.tensor_diag_part(r);
var d = tf.linalg.tensor_diag_part(r.Single);
q *= tf.sign(d);

if (num_rows < num_cols)
{
// q = tf.linalg.matrix_transpose(q);
throw new NotImplementedException("");
q = array_ops.matrix_transpose(q);
}

return _gain * tf.reshape(q, shape);


+ 2
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -11,6 +11,7 @@ namespace Tensorflow
/// Basic LSTM recurrent network cell.
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
/// </summary>
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class BasicLstmCell : LayerRnnCell
{
int _num_units;
@@ -88,7 +89,7 @@ namespace Tensorflow
gate_inputs = nn_ops.bias_add(gate_inputs, _bias);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);

var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -20,6 +20,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class BasicRnnCell : LayerRnnCell
{
int _num_units;


+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs View File

@@ -19,6 +19,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public class LayerRnnCell : RnnCell
{
protected InputSpec inputSpec;


+ 18
- 2
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -16,10 +16,12 @@

using System;
using System.Collections.Generic;
using Tensorflow.Common.Types;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;
@@ -50,7 +52,8 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`.
/// </summary>
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
public abstract class RnnCell : ILayer, IRnnCell
{
/// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight
@@ -142,7 +145,7 @@ namespace Tensorflow
throw new NotImplementedException("_zero_state_tensors");
}

public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null)
{
throw new NotImplementedException();
}
@@ -173,5 +176,18 @@ namespace Tensorflow
{
throw new NotImplementedException();
}

public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
{
throw new NotImplementedException();
}
public Tensors GetInitialState(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
throw new NotImplementedException();
}
public INestStructure<long> StateSize => throw new NotImplementedException();
public INestStructure<long> OutputSize => throw new NotImplementedException();
public bool IsTFRnnCell => throw new NotImplementedException();
public bool SupportOptionalArgs => throw new NotImplementedException();
}
}

+ 57
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -15,9 +15,11 @@
******************************************************************************/

using Google.Protobuf;
using Google.Protobuf.Collections;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Functions;
using static Tensorflow.Binding;
using static Tensorflow.OpDef.Types;

@@ -387,9 +389,13 @@ namespace Tensorflow
case "list(type)":
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def)));
break;
case "list(float)":
if (value != null)
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray());
break;
case "list(int)":
if (value != null)
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x)));
break;
case "bool":
attr_value.B = (bool)value;
@@ -420,6 +426,15 @@ namespace Tensorflow
case "list(shape)":
attr_value.List.Shape.AddRange((value as Shape[]).Select(x => _MakeShape(x, attr_def)));
break;
case "func":
attr_value.Func = _MakeFunc(value, attr_def.Name);
break;
case "list(func)":
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name));
break;
case "list(string)":
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x)));
break;
default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
}
@@ -427,6 +442,47 @@ namespace Tensorflow
return attr_value;
}

private NameAttrList _MakeFunc(object func, string arg_name)
{
if(func is NameAttrList attrList)
{
return attrList;
}
NameAttrList fn_attr;
if(func is string funcStr)
{
fn_attr = new NameAttrList() { Name = funcStr };
}
else if(func is ConcreteFunction concrete)
{
concrete.AddTograph(ops.get_default_graph());
fn_attr = concrete.AsNameAttrList;
}
else if(func is EagerDefinedFunction eager)
{
eager.AddToGraph(ops.get_default_graph());
fn_attr = new NameAttrList() { Name = eager.Name };
}
else
{
throw new TypeError($"Don't know how to convert {func} to a func for argument {arg_name}");
}
return fn_attr;
}

private List<NameAttrList> _MakeFuncList(object funcList, string arg_name)
{
List<NameAttrList> res = new List<NameAttrList>();
if(funcList is IEnumerable enumerable)
{
foreach(var func in enumerable)
{
res.Add(_MakeFunc(func, arg_name));
}
}
return res;
}

private bool _IsListParameter(ArgDef arg)
{
if (!String.IsNullOrEmpty(arg.NumberAttr))


+ 1
- 1
src/TensorFlowNET.Core/Operations/Operation.Output.cs View File

@@ -34,7 +34,7 @@ namespace Tensorflow
return num;
}

protected Tensor[] _outputs;
internal Tensor[] _outputs;
public virtual Tensor[] outputs => _outputs;
public Tensor output => _outputs.FirstOrDefault();



+ 13
- 3
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -46,9 +46,9 @@ namespace Tensorflow
/// </summary>
public partial class Operation : ITensorOrOperation
{
private readonly IntPtr _handle; // _c_op in python
protected IntPtr _handle; // _c_op in python

private readonly Graph _graph;
protected Graph _graph;

internal Func<Operation, object[], Tensor[]> _gradient_function;

@@ -69,6 +69,7 @@ namespace Tensorflow
//private OperationDescription _op_desc;

public NodeDef node_def => GetNodeDef();
protected Operation() { }

public Operation(IntPtr handle, Graph g = null)
{
@@ -185,7 +186,16 @@ namespace Tensorflow
}

public virtual T get_attr<T>(string name)
=> (T)get_attr(name);
{
if (typeof(T).IsValueType)
{
return (T)Convert.ChangeType(get_attr(name), typeof(T));
}
else
{
return (T)get_attr(name);
}
}

internal unsafe TF_DataType _get_attr_type(string name)
{


+ 113
- 24
src/TensorFlowNET.Core/Operations/_EagerTensorArray.cs View File

@@ -17,6 +17,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -37,10 +39,6 @@ namespace Tensorflow.Operations

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public bool _dynamic_size;
public Shape _element_shape;

public List<Tensor> _colocate_with;

Tensor _handle;
public override Tensor handle => _handle;
@@ -48,12 +46,14 @@ namespace Tensorflow.Operations
public override Tensor flow => _flow;
bool _clear_after_read;
List<Tensor> _tensor_array;
List<int> _previous_read_indices;

public _EagerTensorArray(TF_DataType dtype, Tensor size, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
_size = size;
_flow = constant_op.constant(0);
_infer_shape = infer_shape;
_element_shape = element_shape ?? Shape.Null;
@@ -61,16 +61,20 @@ namespace Tensorflow.Operations
_dtype = dtype.as_base_dtype();
_dynamic_size = dynamic_size;
_clear_after_read = clear_after_read;
_tensor_array = new List<Tensor>();
_tensor_array = Enumerable.Repeat<Tensor>(null, size.numpy()).ToList();
_previous_read_indices = new();
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _handle, value }), delegate
var tensors = array_ops.unstack(value, name: name);
if(tensors.Length > _tensor_array.Count && !_dynamic_size)
{
var num_elements = array_ops.shape(value)[0];
return scatter(indices: math_ops.range(0, num_elements), value: value, name: name);
});
throw new ValueError($"Cannot unstack {tensors.Length} tensors into a TensorArray of static size {_tensor_array.Count}");
}
_tensor_array = tensors.ToList();
// TODO(Rinne): revise the implementation. Here we should return `parent()`.
return this;
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
@@ -103,7 +107,19 @@ namespace Tensorflow.Operations

return ta;
});*/
throw new NotImplementedException("");
//if (indices is EagerTensor)
//{
// indices = indices as EagerTensor;
// indices = indices.numpy();
//}

//foreach (var (index, val) in zip(indices.ToArray<int>(), array_ops.unstack(value)))
//{
// this.write(index, val);
//}
//return base;
//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)
@@ -116,9 +132,19 @@ namespace Tensorflow.Operations
_colocate_with.Add(value);
}

private Tensor _maybe_zero(int ix)
{
var val = _tensor_array[ix];
if(val is null)
{
val = _tensor_array[ix] = array_ops.zeros(_element_shape, _dtype);
}
return val;
}

public override Tensor read<T>(T index, string name = null)
{
int index_int = -1;
int index_int;
if (index is int int_index)
index_int = int_index;
else if (index is Tensor tensor_index)
@@ -126,27 +152,75 @@ namespace Tensorflow.Operations
else
throw new ValueError("");

if(index_int >= _tensor_array.Count)
{
throw new OutOfRangeError($"Tried to read from index {index_int} but array size is: {_tensor_array.Count} ");
}

var res = _tensor_array[index_int];
if(res is null)
{
if (_previous_read_indices.Contains(index_int))
{
throw new InvalidArgumentError($"Could not read index {index_int} twice because it was cleared after " +
$"a previous read (perhaps try setting clear_after_read = false?)");
}
else
{
res = _maybe_zero(index_int);
}
}

if (_clear_after_read)
{
_tensor_array[index_int] = null;
_previous_read_indices.Add(index_int);
}

return _tensor_array[index_int];
return res;
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
if (_infer_shape)
_element_shape = _element_shape.merge_with(value.shape);
_tensor_array.add(value);
return this;
int index_int;
if(index is EagerTensor eager)
{
return write<Tensor>(eager.numpy(), value, name);
}
throw new InvalidArgumentError("The index is supposed to be an EagerTensor");
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor, name: name);
int size = _tensor_array.Count;
if(index >= size)
{
if (!_dynamic_size)
{
throw new OutOfRangeError($"Tried to write to index {index} but array is not resizeable and size " +
$"is: {size} ");
}
_tensor_array.AddRange(Enumerable.Repeat<Tensor>(null, index - size + 1));
}

Tensor tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
if(_dtype != tensor.dtype)
{
throw new InvalidArgumentError($"TensorArray dtype is {_dtype.as_python_name()} but Op is " +
$"trying to write dtype {tensor.dtype.as_python_name()} ");
}

if (!_element_shape.is_compatible_with(tensor.shape))
{
throw new ValueError($"Incompatible shape for value ({tensor.shape}), expected ({_element_shape})");
}

if (_infer_shape)
{
_element_shape = _element_shape.merge_with(tensor.shape);
}
_tensor_array[index] = tensor;
return this;
}

private Tensor size(string name = null)
@@ -156,11 +230,26 @@ namespace Tensorflow.Operations

public override Tensor stack(string name = null)
{
ops.colocate_with(_handle);
return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
if(_tensor_array.Count > 0)
{
for(int i = 0; i < _tensor_array.Count; i++)
{
_maybe_zero(i);
}
}
if(_tensor_array.Count == 0 && _element_shape.IsFullyDefined)
{
return ops.convert_to_tensor(new Shape(new long[] { 0 }.Concat(_element_shape.dims).ToArray()), name: name, dtype: _dtype);
}
else
{
return gather(math_ops.range(0, size()), name: name);
});
return ops.convert_to_tensor(_tensor_array, name: name, dtype: _dtype);
}
//ops.colocate_with(_handle);
//return tf_with(ops.name_scope(name, "TensorArrayStack", new { _handle }), delegate
//{
// return gather(math_ops.range(0, size()), name: name);
//});
}

public override Tensor gather(Tensor indices, string name = null)


+ 180
- 5
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -16,7 +16,10 @@

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
@@ -32,18 +35,18 @@ namespace Tensorflow.Operations
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public bool colocate_with_first_write_call => _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public bool infer_shape => _infer_shape;
public bool _dynamic_size;
public override bool infer_shape => _infer_shape;
public List<Shape> _element_shape;

public List<Tensor> _colocate_with;

internal Tensor _handle;
public Tensor handle => _handle;
public override Tensor handle => _handle;
internal Tensor _flow;
public override Tensor flow => _flow;

public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
@@ -54,6 +57,7 @@ namespace Tensorflow.Operations
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_dtype = dtype;
_size = size;

_colocate_with_first_write_call = colocate_with_first_write_call;
if (colocate_with_first_write_call)
@@ -146,7 +150,9 @@ namespace Tensorflow.Operations

return ta;
});*/
throw new NotImplementedException("");

//throw new NotImplementedException("");
return this;
}

public void _merge_element_shape(Shape shape)
@@ -232,4 +238,173 @@ namespace Tensorflow.Operations
return value;
}
}

public class _GraphTensorArrayV2 : TensorArray
{
internal TF_DataType _dtype;
public override TF_DataType dtype => _dtype;

/// <summary>
/// Used to keep track of what tensors the TensorArray should be
/// colocated with. We choose to colocate the TensorArray with the
/// first tensor written to it.
/// </summary>
bool _colocate_with_first_write_call;
public override bool colocate_with_first_write_call => _colocate_with_first_write_call;

bool _infer_shape;
public override bool infer_shape => _infer_shape;
public Shape _element_shape;

public List<Tensor> _colocate_with;

internal Tensor _handle;
public override Tensor handle => _handle;
internal Tensor _flow;
public override Tensor flow => _flow;

public _GraphTensorArrayV2(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
Debug.Assert(handle is null);
dynamic_size = dynamic_size ?? false;
_dynamic_size = dynamic_size.Value;
_size = size;

if(flow is not null && flow.dtype != dtypes.variant)
{
throw new TypeError($"Expected `flow` to be a variant tensor, but received `{flow.dtype}` instead");
}
if(flow is null && size is null)
{
throw new ValueError("Argument `size` must be provided if argument `flow` is not provided.");
}
if(flow is not null && size is not null)
{
throw new ValueError("Cannot provide both `flow` and `size` arguments at the same time.");
}
if(flow is not null && element_shape is not null)
{
throw new ValueError("Cannot provide both `flow` and `element_shape` arguments at the same time.");
}

_dtype = dtype;

_element_shape = element_shape;
_infer_shape = infer_shape;
tf_with(ops.name_scope(name, "TensorArrayV2", new object[] { size, flow }), scope =>
{
if (flow is null)
{
_flow = list_ops.tensor_list_reserve(element_shape, size, dtype, scope.scope_name);
}
else
{
_flow = flow;
}
});

_colocate_with_first_write_call = false;
_colocate_with = null;
}

public override TensorArray unstack(Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayUnstack", new { _flow, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_from_tensor(value, value.shape.dims.Skip(1).ToArray());
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public TensorArray scatter(Tensor indices, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayScatter", new { _flow, value, indices }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_scatter(value, indices, _element_shape, _flow);
return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public override Tensor read<T>(T index, string name = null)
{
if(index is Tensor tensor)
{
return read(tensor, name);
}
else
{
throw new TypeError("Please use non-generic method instead.");
}
}

public Tensor read(Tensor index, string name = null)
{
return tf_with(tf.name_scope(name, "TensorArrayV2Read", new object[] { _flow, index }), scope =>
{
return list_ops.tensor_list_get_item(_flow, index, _dtype, _element_shape, name);
});
}

public override TensorArray write(Tensor index, Tensor value, string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayV2Write", new { _flow, index, value }), delegate
{
value = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
Debug.Assert(value.dtype == _dtype);
var flow_out = list_ops.tensor_list_set_item(_flow, index, value, _dynamic_size, name);

return tensor_array_ops.build_ta_with_new_flow(this, flow_out);
});
}

public override TensorArray write<T>(int index, T value, string name = null)
{
var value_tensor = ops.convert_to_tensor(value, preferred_dtype: _dtype, name: "value");
var index_tensor = ops.convert_to_tensor(index, name: "index");
return write(index_tensor, value_tensor);
}

private Tensor size(string name = null)
{
if(!_dynamic_size && _size is not null)
{
return ops.convert_to_tensor(_size, dtypes.int32);
}
else
{
return gen_list_ops.tensor_list_length(_flow, name);
}
}

public override Tensor stack(string name = null)
{
return tf_with(ops.name_scope(name, "TensorArrayV2Stack", _flow), delegate
{
int ta_size;
if(!_dynamic_size && (_size is not null))
{
var size_tensor = tensor_util.constant_value(_size);
ta_size = size_tensor is null ? -1 : (int)size_tensor;
}
else
{
ta_size = -1;
}
var value = list_ops.tensor_list_stack(_flow, _dtype, ta_size, _element_shape);
return value;
});
}

public override Tensor gather(Tensor indices, string name = null)
{
return list_ops.tensor_list_gather(_flow, indices, _dtype, _element_shape, name);
}
}
}

+ 78
- 22
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -119,6 +119,27 @@ namespace Tensorflow
}
}

public static Tensor zeros(Tensors shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
Tensor shapeTensor;
if(shape.Length > 1)
{
shapeTensor = ops.convert_to_tensor(shape, dtypes.int32);
if(shapeTensor.ndim > 1)
{
shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1));
}
}
else
{
shapeTensor = shape[0];
}
var output = fill(shapeTensor, array_ops.constant(0, dtype), name);
Debug.Assert(output.dtype.as_base_dtype() == dtype);
return output;
}

public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
{
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate
@@ -307,6 +328,9 @@ namespace Tensorflow
public static Tensor fill<T>(Shape dims, T value, string name = null)
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);

public static Tensor fill<T>(Tensor dims, T value, string name = null)
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name);

/// <summary>
/// Returns the rank of a tensor.
/// </summary>
@@ -947,38 +971,70 @@ namespace Tensorflow
});
}

public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1,
string name = "split")
/// <summary>
/// Transposes last two dimensions of tensor `a`.
/// For example:
/// <code> python
/// x = tf.constant([[1, 2, 3], [4, 5, 6]])
/// tf.matrix_transpose(x) # [[1, 4],
/// # [2, 5],
/// # [3, 6]]
/// </code>
/// Matrix with two batch dimensions.
/// x.shape is [1, 2, 3, 4]
/// tf.linalg.matrix_transpose(x) is shape [1, 2, 4, 3]
/// </summary>
/// <param name="a"></param>
/// <param name="name"></param>
/// <param name="conjugate"></param>
/// <returns></returns>
/// <exception cref="ValueError"></exception>
public static Tensor matrix_transpose(Tensor a, string name = "matrix_transpose", bool conjugate = false)
{
if (num == -1)
num = (int)size_splits.shape[0];

return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name);
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_shape = a.shape;
var ndims = a.shape.ndim;
Axis perm;
if(ndims != 0)
{
if (ndims < 2)
{
throw new ValueError("Argument `a` should be a (batch) matrix with rank " +
$">= 2. Received `a` = {a} with shape: {a_shape}");
}
perm = new Axis(Enumerable.Range(0, ndims - 2).Concat(new int[] { ndims - 1, ndims - 2 }).ToArray());
}
else
{
var a_rank = a.rank;
perm = new Axis(Enumerable.Range(0, a_rank - 2).Concat(new int[] { a_rank - 1, a_rank - 2 }).ToArray());
}
return transpose(a, perm:perm, conjugate:conjugate);
});
}

public static Tensor[] split<T>(Tensor value, int num_split, T axis,
public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null,
string name = "split")
{
var size_splits = ops.convert_to_tensor(num_split);
return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name);
}

if (tf.Context.executing_eagerly())
public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1,
string name = "split")
{
if(num_or_size_splits.Length == 0)
{
return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.Context);
throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split.");
}
var size_splits = ops.convert_to_tensor(num_or_size_splits);

var _op = tf.OpDefLib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}

private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null)
{
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { value });
var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32);
var _inputs_flat = new List<Tensor> { axis_tensor };
_inputs_flat.AddRange(input);
var _attrs = new object[] { "num_split", num_split, "T", _attr_T };
if(num == -1)
{
num = (int)size_splits.shape[0];
}

return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name);
}

public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)


+ 5
- 4
src/TensorFlowNET.Core/Operations/control_flow_ops.cs View File

@@ -675,16 +675,17 @@ namespace Tensorflow
}
}

public static Tensor[] while_loop(Func<Tensor[], Tensor> cond,
Func<Tensor[], Tensor[]> body,
Tensor[] loop_vars,
public static Tensors while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int parallel_iterations = 10,
string name = null)
{
var executing_eagerly = tf.Context.executing_eagerly();
if (!executing_eagerly)
{
throw new NotImplementedException("");
return while_v2.while_loop(cond, body, loop_vars, parallel_iterations: parallel_iterations,
name: name);
}

return tf_with(ops.name_scope("name", "while"), delegate


+ 77
- 0
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -16,12 +16,20 @@

using System;
using System.Linq;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class control_flow_util
{
public static readonly bool ENABLE_CONTROL_FLOW_V2 = !string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0" ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_CONTROL_FLOW_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_COND_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_WHILE_V2") != "0") ||
(!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2")) && Environment.GetEnvironmentVariable("TF_ENABLE_TENSOR_ARRAY_V2") != "0");
/// <summary>
/// Return true if `op` is an Exit.
/// </summary>
@@ -196,5 +204,74 @@ namespace Tensorflow
}
return null;
}

public static bool EnableControlFlowV2(Graph graph)
{
return ENABLE_CONTROL_FLOW_V2 || graph.building_function && (graph is not FuncGraph func || func.captures.Length == 0);
}

public static string create_new_tf_function(FuncGraph func_graph)
{
var func = new EagerDefinedFunction(func_graph.Name, func_graph, func_graph.Inputs, func_graph.Outputs, new Dictionary<string, AttrValue>());
func.AddToGraph(func_graph);
return func_graph.Name;
}

public static (Operation, Tensor[]) get_op_and_outputs(Tensor[] inputs)
{
if(inputs.Length == 0)
{
return (null, new Tensor[0]);
}
else
{
return (inputs[0], inputs);
}
}

public static Tensor[] run_as_function_for_tape_gradients(Func<Tensor[], Tensor[]> make_op, Tensor[] inputs)
{
if(gradients_util.PossibleTapeGradientTypes(inputs) == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER
&& !(ops.get_default_graph().building_function))
{
throw new NotImplementedException();
}
else
{
return make_op(inputs);
}
}

public static string unique_fn_name(string scope, string name)
{
return $"{scope}{name}_{ops.uid()}".Replace("/", "_");
}

public static bool output_all_intermediates()
{
if (in_defun())
{
return false;
}
if(tf.Context.FunctionCallOptions.ExecutorType == "SINGLE_THREADED_EXECUTOR")
{
return false;
}
// TODO(Rinne): check this after refactoring keras building.
return false;
}

public static bool in_defun()
{
if (tf.Context.executing_eagerly())
{
return false;
}

var graph = ops.get_default_graph();
// TODO(Rinne): CondBranchFuncGraph, WhileBodyFuncGraph, WhileCondFuncGraph
return graph is FuncGraph;
}
}
}

+ 489
- 10
src/TensorFlowNET.Core/Operations/gen_array_ops.cs
File diff suppressed because it is too large
View File


+ 1042
- 81
src/TensorFlowNET.Core/Operations/gen_functional_ops.cs
File diff suppressed because it is too large
View File


+ 827
- 109
src/TensorFlowNET.Core/Operations/gen_io_ops.cs
File diff suppressed because it is too large
View File


+ 1308
- 0
src/TensorFlowNET.Core/Operations/gen_list_ops.cs
File diff suppressed because it is too large
View File


+ 585
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs
File diff suppressed because it is too large
View File


+ 409
- 0
src/TensorFlowNET.Core/Operations/gen_nn_ops.cs
File diff suppressed because it is too large
View File


+ 1469
- 104
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
File diff suppressed because it is too large
View File


+ 3
- 3
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -1778,10 +1778,10 @@ new_height, new_width");
{
// a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3]
var a_xy_minmax = array_ops.split(
value: boxes_a, num_split: 4, axis: 2);
value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));
// b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3]
var b_xy_minmax = array_ops.split(
value: boxes_b, num_split: 4, axis: 2);
value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));

var i_xmin = math_ops.maximum(
a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 }));
@@ -1943,7 +1943,7 @@ new_height, new_width");
using (ops.name_scope("canonicalize_coordinates"))
{
// y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3]
var yx = array_ops.split(value: boxes, num_split: 4, axis: 2);
var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));
var y_1_is_min = math_ops.reduce_all(
gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0]));
var y_minmax = control_flow_ops.cond(


+ 111
- 0
src/TensorFlowNET.Core/Operations/list_ops.cs View File

@@ -0,0 +1,111 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;

namespace Tensorflow.Operations
{
internal class list_ops
{
private static void _set_handle_data(Tensor list_handle, Shape element_shape, TF_DataType element_dtype)
{
if(list_handle is EagerTensor eagerTensor)
{
var handle_data = new CppShapeInferenceResult.Types.HandleData();
handle_data.IsSet = true;
handle_data.ShapeAndType.Add(new CppShapeInferenceResult.Types.HandleShapeAndType()
{
Shape = element_shape.as_proto(),
Dtype = element_dtype.as_datatype_enum(),
Type = new FullTypeDef() { TypeId = FullTypeId.TftArray }
});
list_handle.HandleData = handle_data;
}
}

private static Tensor _build_element_shape(Shape? shape)
{
if(shape is null || shape.IsNull)
{
return ops.convert_to_tensor(-1);
}
else
{
return ops.convert_to_tensor(shape);
}
}

public static Tensor tensor_list_reserve(Shape? shape, Tensor num_elements, TF_DataType element_dtype, string name = null)
{
var result = gen_list_ops.tensor_list_reserve(_build_element_shape(shape), num_elements, element_dtype, name);
_set_handle_data(result, shape, element_dtype);
return result;
}

public static Tensor tensor_list_from_tensor(Tensor tensor, Shape element_shape, string? name = null)
{
var result = gen_list_ops.tensor_list_from_tensor(tensor, _build_element_shape(element_shape), name);
_set_handle_data(result, tensor.shape, tensor.dtype);
return result;
}

public static Tensor tensor_list_get_item(Tensor input_handle, Tensor index, TF_DataType element_dtype,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_get_item(input_handle, index, _build_element_shape(element_shape),
element_dtype, name);
}

public static Tensor tensor_list_set_item(Tensor input_handle, Tensor index, Tensor item,
bool resize_if_index_out_of_bounds = false, string? name = null)
{
if (resize_if_index_out_of_bounds)
{
var input_list_size = gen_list_ops.tensor_list_length(input_handle);
input_handle = control_flow_ops.cond(index >= input_list_size,
() => gen_list_ops.tensor_list_resize(input_handle, index + 1),
() => input_handle);
}
var output_handle = gen_list_ops.tensor_list_set_item(input_handle, index, item, name);
handle_data_util.copy_handle_data(input_handle, output_handle);
return output_handle;
}

public static Tensor tensor_list_stack(Tensor input_handle, TF_DataType element_dtype, int num_elements = -1,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_stack(input_handle, _build_element_shape(element_shape), element_dtype, num_elements, name);
}

public static Tensor tensor_list_gather(Tensor input_handle, Tensor indices, TF_DataType element_dtype,
Shape? element_shape = null, string? name = null)
{
return gen_list_ops.tensor_list_gather(input_handle, indices, _build_element_shape(element_shape), element_dtype, name);
}

public static Tensor tensor_list_scatter(Tensor tensor, Tensor indices, Shape? element_shape = null, Tensor? input_handle = null,
string? name = null)
{
if(input_handle is not null)
{
var output_handle = gen_list_ops.tensor_list_scatter_into_existing_list(input_handle, tensor, indices, name);
handle_data_util.copy_handle_data(input_handle, output_handle);
return output_handle;
}
else
{
var output_handle = gen_list_ops.tensor_list_scatter_v2(tensor, indices, _build_element_shape(element_shape),
constant_op.constant(-1), name);
_set_handle_data(output_handle, element_shape, tensor.dtype);
return output_handle;
}
}

public static Tensor empty_tensor_list(Shape? element_shape, TF_DataType element_dtype, int max_num_elements = -1,
string? name = null)
{
return gen_list_ops.empty_tensor_list(_build_element_shape(element_shape), element_dtype: element_dtype,
max_num_elements: ops.convert_to_tensor(max_num_elements, dtype: dtypes.int32), name: name);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/logging_ops.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
name: name);

return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string)
.SetAttributes(new { output_stream, end }));
.SetAttributes(new { output_stream, end })).SingleOrNull;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/sort_ops.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow
{
sorted = true
}));
return indices;
return indices.Single;
}

public static Tensor sort(Tensor values, Axis axis, string direction = "ASCENDING", string? name = null)


+ 16
- 4
src/TensorFlowNET.Core/Operations/tensor_array_ops.cs View File

@@ -13,11 +13,23 @@ namespace Tensorflow
/// <returns></returns>
public static TensorArray build_ta_with_new_flow(TensorArray old_ta, Tensor flow)
{
var new_ta = tf.TensorArray(
dtype: old_ta.dtype,
infer_shape: old_ta.infer_shape,
if (!tf.Context.executing_eagerly() && old_ta is not _GraphTensorArrayV2 && control_flow_util.EnableControlFlowV2(ops.get_default_graph()))
{
throw new NotImplementedException("Attempting to build a graph-mode TF2-style "
+ "TensorArray from either an eager-mode "
+ "TensorArray or a TF1-style TensorArray. "
+ "This is not currently supported. You may be "
+ "attempting to capture a TensorArray "
+ "inside a tf.function or tf.data map function. "
+ "Instead, construct a new TensorArray inside "
+ "the function.");
}
var new_ta = TensorArray.Create(old_ta.dtype, handle: old_ta.handle, flow: flow, infer_shape: old_ta.infer_shape,
colocate_with_first_write_call: old_ta.colocate_with_first_write_call);

new_ta._dynamic_size = old_ta._dynamic_size;
new_ta._size = old_ta._size;
new_ta._colocate_with = old_ta._colocate_with;
new_ta._element_shape = old_ta._element_shape;
return new_ta;
}



+ 401
- 0
src/TensorFlowNET.Core/Operations/while_v2.cs View File

@@ -0,0 +1,401 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Eager;
using Tensorflow.Framework;
using Tensorflow.Framework.Models;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
class _OperationWithOutputs : Operation
{
public _OperationWithOutputs(IntPtr handle, Graph g = null)
{
_handle = handle;
_graph = g;
_outputs = null;
g._add_op(this);
}
}
internal class while_v2
{
public static Tensor[] while_loop(Func<Tensors, Tensor> cond,
Func<Tensors, Tensors> body,
Tensors loop_vars,
int maximum_iterations = -1,
int parallel_iterations = 10,
string name = null,
bool back_prop = true,
bool return_same_structure = true)
{
var orig_loop_vars = loop_vars;
var flat_orig_loop_vars = orig_loop_vars.Flatten().ToArray();
int len_orig_loop_vars = orig_loop_vars.Length;

loop_vars = _tensor_array_to_flow(loop_vars);
loop_vars = Nest.MapStructure(x => _convert_to_tensor_or_indexed_slices(x, TF_DataType.DtInvalid, null), loop_vars).ToTensors();

var loop_vars_signature = Nest.MapStructure(x => new TensorSpec(x.shape, x.dtype), _tensor_array_to_flow(loop_vars));

var flat_shape_invariants = Nest.Flatten(loop_vars_signature).Select(x => x.shape).ToArray();

if(string.IsNullOrEmpty(name))
{
name = "while";
}

return tf_with<ITensorFlowObject, Tensor[]>(ops.name_scope(name), nameScopeWhile =>
{
string scope = (nameScopeWhile as ops.NameScope).scope_name;
string cond_name = control_flow_util.unique_fn_name(scope, "cond");
string body_name = control_flow_util.unique_fn_name(scope, "body");

var maximum_iterations_loop_var = _build_maximum_iterations_loop_var(maximum_iterations);
var loop_counter = constant_op.constant(0, maximum_iterations == -1 ? TF_DataType.DtInvalid : maximum_iterations_loop_var.dtype,
name: "loop_counter");
loop_vars = new Tensor[] { loop_counter, maximum_iterations_loop_var }.Concat(loop_vars).ToArray();

var func_graph_signature = new TensorSpec[] {TensorSpec.FromTensor(loop_counter),TensorSpec.FromTensor(maximum_iterations_loop_var)}
.Concat(loop_vars_signature.Flatten()).ToArray();

// TODO(Rinne): possible wrong implemenation here.
var add_control_dependencies = false;

object[] wrapped_cond(object[] inputs)
{
Tensor loop_counter = (Tensor)inputs[0];
Tensor maximum_iterations_arg = (Tensor)inputs[1];
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray();
var pred = cond(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args));
if(pred.shape.IsNull || pred.shape.ndim > 0)
{
pred = array_ops.squeeze(pred);
}
if(maximum_iterations == -1)
{
return new object[] { pred };
}
else
{
return new object[] { math_ops.logical_and(loop_counter < maximum_iterations_arg, pred) };
}
}

var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null,
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies);

bool stateful_parallelism = false;

object[] wrapped_body(object[] inputs)
{
Tensor loop_counter = (Tensor)inputs[0];
Tensor maximum_iterations_arg = (Tensor)inputs[1];
Tensor[] args = inputs.Skip(2).Select(x => (Tensor)x).ToArray();

_copy_handle_data(loop_vars.Flatten().Skip(2), args);

foreach(var t in cond_graph.external_captures)
{
var graph = (FuncGraph)(ops.get_default_graph());
graph.capture(t);
}

var outputs = body(_pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, args));
outputs = _tensor_array_to_flow(outputs);

return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray();
}

var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature,
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism);

// TODO(Rinne): possible wrong implementation here.
NestList<Tensors> loop_vars_list = new(new Tensors[] { loop_vars, body_graph.external_captures.ToTensors() });
body_graph.Outputs.AddRange(body_graph.internal_captures);
cond_graph.as_default();
int num_cond_captures = cond_graph.external_captures.Length;
Debug.Assert(cond_graph.external_captures.SequenceEqual(body_graph.external_captures.Take(num_cond_captures).ToArray()));
_duplicate_body_captures_in_cond(cond_graph, body_graph.external_captures.Skip(num_cond_captures).ToArray());
cond_graph.Exit();

int first_loop_var_index = 2;

int num_flattened_oututs = orig_loop_vars.Length;
int num_original_outputs = body_graph.Outputs.Length;
if (back_prop && control_flow_util.output_all_intermediates())
{
var intermediate_tensors = _get_intermediates(body_graph);

foreach(var intermediate_tensor in intermediate_tensors)
{
var tensor_list = list_ops.empty_tensor_list(intermediate_tensor.shape, intermediate_tensor.dtype, maximum_iterations);
loop_vars_list.Values.Add(tensor_list);

cond_graph.as_default();
cond_graph.capture(tensor_list);
cond_graph.Exit();

body_graph.as_default();
var appended_tensor_list = gen_ops.tensor_list_push_back(tensor_list, intermediate_tensor);
body_graph.Outputs.Add(appended_tensor_list);
body_graph.Exit();
}
}

List<Tensor> flattened_loop_vars = new();
foreach(var item in loop_vars_list.Values)
{
flattened_loop_vars.AddRange(item.Flatten());
}
// skip the check

// TODO(Rinne): deal with control dependencies
var output_shapes = body_graph.Outputs.Select(t => t.shape).ToArray();
var span = new Span<Shape>(output_shapes).Slice(first_loop_var_index, num_flattened_oututs);
for(int i = 0; i < span.Length; i++)
{
span[i] = flat_shape_invariants[i];
}

Tensor[] outputs = _build_while_op(flattened_loop_vars.ToArray(), cond_graph, body_graph, output_shapes, parallel_iterations,
(nameScopeWhile as ops.NameScope).scope_name, num_original_outputs, stateful_parallelism);

if (!ops.get_default_graph().building_function)
{
outputs = outputs.Select(t => array_ops.identity(t)).ToArray();
}

var output_loop_vars = outputs.Skip(first_loop_var_index).Take(num_flattened_oututs).ToArray();

if (!back_prop)
{
output_loop_vars = output_loop_vars.Select(t => array_ops.stop_gradient(t)).ToArray();
}
outputs = _pack_sequence_as(loop_vars_signature, flat_orig_loop_vars, output_loop_vars);

return outputs;
});
}

private static Tensors _tensor_array_to_flow(Tensors loop_vars)
{
if(loop_vars.NestType == NestType.Node)
{
if(loop_vars.NodeValue is FakeTensorByTensorArray fake)
{
return new Tensors(fake.TensorArray.flow);
}
else
{
return new Tensors(loop_vars.NodeValue!);
}
}
else if(loop_vars.NestType == NestType.List)
{
List<INestStructure<Tensor>> list = new();
foreach(var item in loop_vars.ListValue!)
{
if(item.NestType == NestType.Node)
{
var nested = item.AsNest();
if (nested.NodeValue is FakeTensorByTensorArray fake)
{
list.Add(new Nest<Tensor>(fake.TensorArray.flow));
}
else
{
list.Add(new Nest<Tensor>(nested.NodeValue!));
}
}
else
{
list.Add(new Nest<Tensor>(item.AsNest()));
}
}
return Tensors.FromNest(new Nest<Tensor>(list));
}
else
{
throw new NotImplementedException();
}
}

private static Tensor[] _build_while_op(Tensor[] loop_vars, FuncGraph cond_graph, FuncGraph body_graph,
Shape[] output_shapes, int parallel_iterations, string name, int num_original_outputs, bool stateful_parallelism)
{
var cond_stateful_ops = cond_graph.get_operations().Select(x => x.op);
var body_stateful_ops = body_graph.get_operations().Select(x => x.op);

bool is_stateful = cond_stateful_ops.Count() > 0 || body_stateful_ops.Count() > 0;

Tensor[] _make_op(Tensor[] inputs)
{
Tensor[] outputs;
if (is_stateful)
{
outputs = gen_functional_ops._while(
inputs,
control_flow_util.create_new_tf_function(cond_graph),
control_flow_util.create_new_tf_function(body_graph),
output_shapes,
parallel_iterations,
name
);
}
else
{
outputs = gen_functional_ops.stateless_while(
inputs,
control_flow_util.create_new_tf_function(cond_graph),
control_flow_util.create_new_tf_function(body_graph),
output_shapes,
parallel_iterations,
name
);
}
var (while_op, tensors) = control_flow_util.get_op_and_outputs(outputs);
_copy_handle_data(body_graph.Outputs, tensors);
_set_read_only_resource_inputs_attr(while_op, new FuncGraph[]{cond_graph, body_graph});
while_op._set_attr("_num_original_outputs", new AttrValue() { I = num_original_outputs });
while_op._set_attr("_stateful_parallelism", new AttrValue() { B = stateful_parallelism });

cond_graph.outer_graph = ops.get_default_graph();
body_graph.outer_graph = ops.get_default_graph();
// TODO(Rinne): set the two graphs to while_op
return tensors;
}

return control_flow_util.run_as_function_for_tape_gradients(_make_op, loop_vars);
}

/// <summary>
/// Sets the list of resource inputs which are read-only. This is used by AutomaticControlDependencies.
/// </summary>
/// <param name="op"></param>
/// <param name="branch_graphs"></param>
private static void _set_read_only_resource_inputs_attr(Operation op, FuncGraph[] branch_graphs)
{
List<int> read_only_indices = Enumerable.Range(0, op.inputs.Length).ToList();
foreach(var branch_graph in branch_graphs)
{
if (read_only_indices.Count == 0)
{
break;
}
var branch_read_only_indices = auto_control_deps_utils.get_read_only_resource_input_indices_graph(branch_graph);
read_only_indices = read_only_indices.Intersect(branch_read_only_indices).ToList();
}
AttrValue.Types.ListValue listValue = new();
listValue.I.AddRange(read_only_indices.OrderBy(x => x).Select(x => (long)x));
op._set_attr(auto_control_deps_utils.READ_ONLY_RESOURCE_INPUTS_ATTR, new AttrValue()
{
List = listValue
});
}

private static Tensors _pack_sequence_as<T>(INestStructure<T> loop_vars_signature, Tensor[] flat_orig_loop_vars, Tensor[] loop_vars)
{
var flattened_loop_vars = zip(loop_vars, flat_orig_loop_vars).Select<(Tensor, Tensor), Tensor>(item =>
{
var (flow, y) = item;
if (y is FakeTensorByTensorArray ta)
{
return new FakeTensorByTensorArray(tensor_array_ops.build_ta_with_new_flow(ta.TensorArray, flow));
}
else
{
return flow;
}
}).ToArray();
return Nest.PackSequenceAs(loop_vars_signature, flattened_loop_vars).ToTensors();
}

private static Tensor[] _get_intermediates(FuncGraph func_graph)
{
List<Tensor> intermediates = new();
var reversed_captures = func_graph.captures.ToDictionary(x => x.Item2, x => x.Item1);

foreach(var op in func_graph.get_operations())
{
Debug.Assert(op is Operation);
var oper = (Operation)op;
if(oper.type == "Identity" || oper.type == "MutexLock")
{
continue;
}
foreach(var o in op.outputs)
{
if(o != func_graph.Inputs[0] && o.dtype != dtypes.resource && !reversed_captures.ContainsKey(o))
{
intermediates.Add(o);
}
}
}
return intermediates.ToArray();
}

private static void _duplicate_body_captures_in_cond(FuncGraph cond_graph, Tensor[] body_graph_captures)
{
var types = body_graph_captures.Select(t => t.dtype).ToList();
var c_graph = cond_graph.c_graph;
var placeholders = types.Select(x => CreatePlaceholder(c_graph, _build_cond_placeholders_name_prefix(cond_graph), x)).ToList();

var placeholder_ops = placeholders.Select(ph => new _OperationWithOutputs(ph.oper, cond_graph)).ToList();

List<Tensor> tensors = new();
foreach(var (op, ph, dtype) in zip(placeholder_ops, placeholders, types))
{
var tensor = Tensor._create_with_tf_output(op, 0, dtype, ph);
op._outputs = new Tensor[] { tensor };
tensors.Add(tensor);
}

var tuples = zip(body_graph_captures, tensors).ToList();
var keys = body_graph_captures.Select(t => t.Id).ToList();
cond_graph._captures.Update(zip(keys, tuples).ToDictionary(x => x.Item1, x => x.Item2));
cond_graph.Inputs.AddRange(tensors);
}

private static TF_Output CreatePlaceholder(SafeGraphHandle graph, string name, TF_DataType dtype)
{
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
var op = c_api.TF_FinishOperation(desc, tf.Status);
tf.Status.Check(true);
var output = new TF_Output();
output.oper = op;
output.index = 0;
return output;
}

private static string _build_cond_placeholders_name_prefix(FuncGraph cond_graph)
{
return cond_graph.unique_name(cond_graph.Name + "___redundant_placeholder");
}

private static Tensor _convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype,
string name)
{
return ops.convert_to_tensor(value, dtype, name, false);
}

private static Tensor _build_maximum_iterations_loop_var(int maximum_iterations = -1)
{
return ops.convert_to_tensor(maximum_iterations, dtypes.int32, "maximum_iterations");
}

private static void _copy_handle_data(IEnumerable<Tensor> src_tensors, IEnumerable<Tensor> dst_tensors)
{
foreach(var (src_t, dst_t) in zip(src_tensors, dst_tensors))
{
handle_data_util.copy_handle_data(src_t, dst_t);
}
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Tensorflow.Exceptions;
using Tensorflow.Util;
using static Tensorflow.c_api;

@@ -88,7 +89,7 @@ namespace Tensorflow
case TF_Code.TF_INVALID_ARGUMENT:
throw new InvalidArgumentError(message);
default:
throw new TensorflowException(message);
throw new NotOkStatusException(message);
}
}
}


+ 6
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -111,7 +111,12 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="IsExternalInit" Version="1.0.3" PrivateAssets="all" />
<PackageReference Include="System.Memory" Version="4.5.4" PrivateAssets="all" />
</ItemGroup>
</Project>

+ 7
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -105,6 +105,13 @@ namespace Tensorflow
_id = ops.uid();
}

internal static Tensor _create_with_tf_output(Operation op, int value_index, TF_DataType dtype, TF_Output tf_output)
{
Tensor ret = new Tensor(op, value_index, dtype);
ret._tf_output = tf_output;
return ret;
}

protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, null);


+ 24
- 0
src/TensorFlowNET.Core/Tensors/TensorArray.cs View File

@@ -14,7 +14,9 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Common.Types;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -44,5 +46,27 @@ namespace Tensorflow

public abstract Tensor stack(string name = null);
public abstract Tensor gather(Tensor indices, string name = null);

internal bool _dynamic_size;
internal Tensor _size;
internal List<Tensor> _colocate_with;
internal Shape _element_shape;

public static TensorArray Create(TF_DataType dtype, Tensor size = null, bool dynamic_size = false,
bool clear_after_read = true, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
bool infer_shape = true, Shape? element_shape = null,
bool colocate_with_first_write_call = true, string name = null)
{
if (tf.Context.executing_eagerly() && (flow is null || flow.dtype != dtypes.variant))
{
return new _EagerTensorArray(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow,
infer_shape, element_shape, colocate_with_first_write_call, name);
}
else
{
return new _GraphTensorArrayV2(dtype, size, dynamic_size, clear_after_read, tensor_array_name, handle, flow,
infer_shape, element_shape, colocate_with_first_write_call, name);
}
}
}
}

+ 201
- 89
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -3,6 +3,9 @@ using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Operations;
using Tensorflow.Common.Extensions;

namespace Tensorflow
{
@@ -13,157 +16,278 @@ namespace Tensorflow
/// and Tensor[] from Tensors implicitily.
/// It works for tuple and scalar as well.
/// </summary>
public class Tensors : IEnumerable<Tensor>, IDisposable
public sealed class Tensors : Nest<Tensor>, IDisposable
{
List<Tensor> items = new List<Tensor>();

public TF_DataType dtype => items.First().dtype;
public Shape shape => items.First().shape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public TF_DataType dtype => this.First().dtype;
public Shape shape => this.First().shape;
public int rank => this.First().rank;
public Graph graph => this.First().graph;
public bool IsList { get; set; }
public int Length => items.Count();
public int Length => this.Count();
/// <summary>
/// Return a Tensor if `Tensors` has only one tensor, otherwise throw an exception.
/// </summary>
public Tensor Single
{
get
{
if (Length != 1)
{
throw new ValueError("Tensors with more than one tensor cannot be " +
"implicitly converted to Tensor.");
}
return this.First();
}
}

public Tensor this[int index]
/// <summary>
/// Return a Tensor if `Tensors` has only one tensor, and return null when `Tensors` is empty,
/// otherwise throw an exception.
/// </summary>
public Tensor? SingleOrNull
{
get => items[index];
set => items[index] = value;
get
{
if (Length > 1)
{
throw new ValueError($"Tensors with {Length} tensor cannot be " +
"implicitly converted to Tensor.");
}
return this.FirstOrDefault();
}
}

public Tensor this[params string[] slices]
=> items.First()[slices];
public Tensors(params Tensor[] tensors)
=> this.First()[slices];

internal Tensors(Nest<Tensor> nested) : base(nested)
{
items.AddRange(tensors);
}

public Tensors(IEnumerable<Tensor> tensors)
public Tensors(params Tensor[] tensors): base(DealWithConstructorArrayInput(tensors))
{
items.AddRange(tensors);
}

public Tensors(NDArray nd)
public Tensors(IList<Tensor> tensors) : base(tensors.Select(x => new Nest<Tensor>(x)))
{
items.Add(ops.convert_to_tensor(nd));
}

public IEnumerator<Tensor> GetEnumerator()
public Tensors(NDArray nd): base(ops.convert_to_tensor(nd))
{
foreach (var tensor in items)
yield return tensor;
}

/// <summary>
/// Get the element in shallow level. For example, for ts = [1, [2, 3], 4],
/// common indexer has ts[1] = 2. Shallow indexer has ts[1] = [2, 3]
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public Tensors GetShallow(int index)
{
if(NestType == NestType.Node)
{
if(index > 0)
{
throw new IndexOutOfRangeException();
}
return this;
}
else if(NestType == NestType.List)
{
return ListValue![index].AsNest().ToTensors();
}
else
{
throw new NotImplementedException();
}
}

private static Nest<Tensor> DealWithConstructorArrayInput(Tensor[] tensors)
{
if (tensors.Length == 0)
{
return Nest<Tensor>.Empty;
}
else if(tensors.Length == 1)
{
return new Nest<Tensor>(tensors[0]);
}
else
{
return new Nest<Tensor>(tensors.Select(x => new Nest<Tensor>(x)));
}
}

public bool IsSingle()
{
return Length == 1;
}

public new Tensors MergeWith(Nest<Tensor>? other)
{
return FromNest(base.MergeWith(other));
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " +
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")]
public void Add(Tensor tensor)
=> items.Add(tensor);
{
if(NestType == NestType.Dictionary)
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(NodeValue), new Nest<Tensor>(tensor) };
NodeValue = null;
}
else if(NestType == NestType.List)
{
ListValue!.Add(new Nest<Tensor>(tensor));
}
else //Empty
{
NestType = NestType.Node;
NodeValue = tensor;
}
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to add " +
"some tensors to `Tensors`, creating a new instance with your newly added tensors is a better choice.")]
public void AddRange(IEnumerable<Tensor> tensors)
=> items.AddRange(tensors);
{
if (NestType == NestType.Dictionary)
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
else if (NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(NodeValue) };
ListValue.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
NodeValue = null;
}
else if(NestType == NestType.List)
{
ListValue!.AddRange(tensors.Select(x => new Nest<Tensor>(x)));
}
else // empty
{
NestType = NestType.List;
ListValue = tensors.Select(x => new Nest<Tensor>(x) as INestStructure<Tensor>).ToList();
}
}

[Obsolete("This method is not encouraged to be used. It may be removed in the future. If you do want to insert " +
"a tensor to `Tensors`, creating a new instance with your newly added tensor is a better choice.")]
public void Insert(int index, Tensor tensor)
=> items.Insert(index, tensor);

IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();
{
if (NestType == NestType.List)
{
ListValue.Insert(index, new Nest<Tensor>(tensor));
}
else if(NestType == NestType.Node)
{
NestType = NestType.List;
ListValue = new() { new Nest<Tensor>(NodeValue) };
ListValue.Insert(index, new Nest<Tensor>(tensor));
NodeValue = null;
}
else
{
throw new ValueError("Cannot add a tensor to dictionary type of nested tensors.");
}
}

public string[] StringData()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData();
return Single.StringData();
}

public string StringData(int index)
{
EnsureSingleTensor(this, "nnumpy");
return this[0].StringData(index);
return Single.StringData(index);
}

public NDArray numpy()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].numpy();
return Single.numpy();
}

[Obsolete]
public T[] ToArray<T>() where T: unmanaged
{
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
return this[0].ToArray<T>();
return Single.ToArray<T>();
}

#region Explicit Conversions
public static explicit operator bool(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to bool");
return (bool)tensor[0];
return (bool)tensor.Single;
}

public static explicit operator sbyte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
return (sbyte)tensor[0];
return (sbyte)tensor.Single;
}

public static explicit operator byte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
return (byte)tensor.Single;
}

public static explicit operator ushort(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ushort");
return (ushort)tensor[0];
return (ushort)tensor.Single;
}

public static explicit operator short(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to short");
return (short)tensor[0];
return (short)tensor.Single;
}

public static explicit operator int(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to int");
return (int)tensor[0];
return (int)tensor.Single;
}

public static explicit operator uint(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to uint");
return (uint)tensor[0];
return (uint)tensor.Single;
}

public static explicit operator long(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to long");
return (long)tensor[0];
return (long)tensor.Single;
}

public static explicit operator ulong(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ulong");
return (ulong)tensor[0];
return (ulong)tensor.Single;
}

public static explicit operator float(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
return (byte)tensor.Single;
}

public static explicit operator double(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to double");
return (double)tensor[0];
return (double)tensor.Single;
}

public static explicit operator string(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to string");
return (string)tensor[0];
return (string)tensor.Single;
}

public static explicit operator object[](Tensors tensors)
=> tensors.items.ToArray();
=> tensors.Flatten().ToArray();
#endregion

#region Implicit Conversions
@@ -183,56 +307,44 @@ namespace Tensorflow
public static implicit operator Tensors(List<Tensor> tensors)
=> new Tensors(tensors.ToArray());

public static implicit operator Tensor(Tensors tensors)
=> tensors.FirstOrDefault();
public static implicit operator Tensor(Tensors? tensors)
=> tensors?.SingleOrNull;

public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();

=> tensors.Flatten().ToArray();
#endregion

public void Deconstruct(out Tensor a, out Tensor b)
public static Tensors? FromNest(Nest<Tensor> nested)
{
a = items[0];
b = items[1];
if(nested == Nest<Tensor>.Empty)
{
return null;
}
return new Tensors(nested);
}

private static void EnsureSingleTensor(Tensors tensors, string methodnName)
public void Deconstruct(out Tensor a, out Tensors? b)
{
if(tensors.Length == 0)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
}
else if(tensors.Length > 1)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
}
a = this.First();
b = Length == 1? null : new Tensors(this.Skip(1).ToArray());
}

public override string ToString()
{
if(items.Count == 1)
if(Length == 1)
{
return items[0].ToString();
return this.First().ToString();
}
else
{
StringBuilder sb = new StringBuilder();
sb.Append($"Totally {items.Count} tensors, which are {string.Join(", ", items.Select(x => x.name))}\n[\n");
for(int i = 0; i < items.Count; i++)
{
var tensor = items[i];
sb.Append($"Tensor {i}({tensor.name}): {tensor.ToString()}\n");
}
sb.Append("]\n");
return sb.ToString();
return $"Totally {Length} tensors: {base.ToString()}";
}
}

public void Dispose()
{
foreach (var item in items)
item.Dispose();
foreach (var tensor in this)
tensor.Dispose();
}
}
}

+ 1
- 2
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -179,8 +179,7 @@ namespace Tensorflow.Train
// handles slot variables.
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1);
return res as IVariableV1;
}


+ 1
- 0
src/TensorFlowNET.Core/Util/nest.py.cs View File

@@ -36,6 +36,7 @@ namespace Tensorflow.Util
// (np.array([3, 4]), tf.constant([3, 4])))`
//

[Obsolete]
public static class nest
{



+ 20
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -170,11 +170,28 @@ namespace Tensorflow
public Tensor value()
=> GraphElement ?? _read_variable_op();

protected Tensor _read_variable_op()
protected Tensor _read_variable_op(bool no_copy = false)
{
variable_accessed(this);
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);

Tensor read_and_set_handle(bool no_copy)
{
if (no_copy)
{
gen_resource_variable_ops.disable_copy_on_read(handle);
}
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);
return result;
}

// TODO(Rinne): deal with caching device.
var result = read_and_set_handle(no_copy);
if (!tf.Context.executing_eagerly())
{
tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle },
backward_function: (x, _) => x);
}

// have to set shape when converting to substituent placeholder
if (result.shape.ndim == -1)


+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -576,7 +576,7 @@ namespace Tensorflow
public static HandleData get_resource_handle_data(Tensor graph_op)
{
var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
return HandleData.Parser.ParseFrom(tf.compat.as_bytes(c_api.StringPiece(handle_data)));
return HandleData.Parser.ParseFrom(c_api.ByteStringPiece(handle_data));
}

public static void dismantle_graph(Graph graph)


+ 525
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -20,8 +20,12 @@ using System.Linq;
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using Tensorflow.Common.Extensions;
using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;
using Tensorflow.Util;
using Tensorflow.Common.Types;
using System.Diagnostics;

namespace Tensorflow.Keras
{
@@ -450,5 +454,526 @@ namespace Tensorflow.Keras

return x;
}

public (Tensors, Tensors, Tensors) rnn(
Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states
Tensors inputs, // inputs is a tuple of tensors (one per input sequence)
Tensors initial_states,
bool go_backwards = false,
Tensor? mask = null,
Tensors? constants = null,
bool unroll = false,
Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not
bool time_major = false,
bool zero_output_for_mask = false,
bool return_all_outputs = true)
{

Tensor swap_batch_timestep(Tensor input_t)
{
var axes = Enumerable.Range(0, input_t.rank).ToArray();
axes[0] = 1;
axes[1] = 0;
return tf.transpose(input_t, axes);
}

if (!time_major)
{
inputs = Nest.MapStructure(swap_batch_timestep, inputs).ToTensors();
}

var flatted_inptus = Nest.Flatten(inputs).ToList();
var first_flatted_input = flatted_inptus[0];
var time_steps = first_flatted_input.shape[0];
var batch = first_flatted_input.shape[1];
var time_steps_t = tf.shape(first_flatted_input)[0];

foreach (var input_ in flatted_inptus)
{
input_.shape.with_rank_at_least(3);
}

if (mask != null)
{
if (mask.dtype != TF_DataType.TF_BOOL)
{
mask = tf.cast(mask, TF_DataType.TF_BOOL);
}

if (mask.rank == 2)
{
mask = tf.expand_dims(mask, -1);
}

if (!time_major)
{
mask = swap_batch_timestep(mask);
}

}
// tf.where needs its condition tensor to be the same shape as its two
// result tensors, but in our case the condition (mask) tensor is
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
// So we need to broadcast the mask to match the shape of inputs.
// That's what the tile call does, it just repeats the mask along its
// second dimension n times.

Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
{
if (!mask_t.IsSingle())
{
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}");
}

if (!input_t.IsSingle())
{
throw new ValueError($"input_t is expected to be tensor, but got {input_t}");
}

var rank_diff = input_t.rank - mask_t.rank;
for (int i = 0; i < rank_diff; i++)
{
mask_t = tf.expand_dims(mask_t, -1);
}
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().Skip(fixed_dim).ToArray());
return tf.tile(mask_t, multiples);
}

Tensors outputs = new Tensors();
Tensors output_time_zero = new Tensors();
Tensors last_output = new Tensors();
Tensors new_states = new Tensors();
if (unroll)
{
if (time_steps == 0)
{
throw new ValueError("Unrolling requires a fixed number of timesteps.");
}

// Process the input tensors. The input tensor need to be split on the
// time_step dim, and reverse if go_backwards is True. In the case of
// nested input, the input is flattened and then transformed
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)


// TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple
//var states = Tuple.Create(initial_states);
var states = initial_states;

var successive_states = new Tensors();
var successive_outputs = new Tensors();

// Process the input tensors. The input tensor need to be split on the
// time_step dim, and reverse if go_backwards is True. In the case of
// nested input, the input is flattened and then transformed
// individually. The result of this will be a tuple of lists, each of
// the item in tuple is list of the tensor with shape (batch, feature)

Tensors _process_single_input_t(Tensor input_t)
{
var unstaked_input_t = array_ops.unstack(input_t); // unstack for time_step dim
if (go_backwards)
{
unstaked_input_t = unstaked_input_t.Reverse().ToArray();
}
return unstaked_input_t;
}

// TODO(Wanglongzhi2001)
Tensors processed_input;
if (!inputs.IsSingle())
{
processed_input = inputs.MapStructure(_process_single_input_t).ReduceTo<Tensors, Tensor>().ToTensors();
}
else
{
processed_input = _process_single_input_t(inputs);
}

object _get_input_tensor(int time)
{
List<Tensor> inp = new List<Tensor>();
foreach (var t_ in processed_input)
{
inp.Add(t_[time]);
}
return Nest.PackSequenceAs(inputs, inp);
}

if (mask != null)
{
var mask_list = tf.unstack(mask);
if (go_backwards)
{
mask_list.Reverse().ToArray();
}

for (int i = 0; i < time_steps; i++)
{
// TODO(Wanglongzhi2001),deal with _get_input_tensor
var inp = _get_input_tensor(i);
var mask_t = mask_list[i];
// TODO
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));

var tiled_mask_t = _expand_mask(mask_t, output);

Tensors prev_output;
if (successive_outputs == null)
{
prev_output = tf.zeros_like(output);
}
else
{
prev_output = successive_outputs.Last();
}

// output could be a tensor
output = tf.where(tiled_mask_t, output, prev_output);

var flat_states = Nest.Flatten(states).ToList();
var flat_new_states = Nest.Flatten(newStates).ToList();

var tiledMaskT = flat_states
.Select(s => _expand_mask(mask_t, s))
.ToArray();
var tuple = Tuple.Create(tiledMaskT);

List<Tensor> flat_final_states = new List<Tensor>();
foreach (var (m, s, ps) in zip(tiled_mask_t.ToList(), flat_new_states, flat_states))
{
flat_final_states.Add(tf.where(m, s, ps));
}

states = Nest.PackSequenceAs(states, flat_final_states).ToTensors();
if (return_all_outputs)
{
successive_outputs = successive_outputs.MergeWith(output);
successive_outputs = successive_states.MergeWith(states);
}
else
{
successive_outputs = new Tensors(output);
successive_states = new Tensors(states);
}

}
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);

if (zero_output_for_mask)
{
last_output = tf.where(_expand_mask(mask_list.Last(), last_output), last_output, tf.zeros_like(last_output));
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs));
}
else // mask is null
{
for (int i = 0; i < time_steps; i++)
{
var inp = _get_input_tensor(i);
var (output, newStates) = step_function((Tensors)inp, states.MergeWith(constants));
states = newStates;

if (return_all_outputs)
{
successive_outputs.Add(output);
successive_states.Add(newStates);
}
else
{
successive_outputs = new Tensors { output };
successive_states = new Tensors { newStates };
}
}
last_output = successive_outputs.Last();
new_states = successive_states.Last();
outputs = tf.stack(successive_outputs);
}
}
}
else // unroll == false
{
var states = initial_states;
// Create input tensor array, if the inputs is nested tensors, then it
// will be flattened first, and tensor array will be created one per
// flattened tensor.


var input_ta = new List<TensorArray>();
for (int i = 0; i < flatted_inptus.Count; i++)
{
input_ta.Add(TensorArray.Create(dtype: flatted_inptus[i].dtype, size: time_steps_t));
}

foreach(var (ta, input_) in zip(input_ta, flatted_inptus))
{
if (!go_backwards)
{
ta.unstack(input_);
}
else
{
ta.unstack(reverse(input_, 0));
}
}


// Get the time(0) input and compute the output for that, the output will
// be used to determine the dtype of output tensor array. Don't read from
// input_ta due to TensorArray clear_after_read default to True.
var input_time_zero = Nest.PackSequenceAs(inputs, flatted_inptus.Select(x => x[0]).ToArray()).ToTensors();

// output_time_zero is used to determine the cell output shape and its
// dtype. the value is discarded.
(output_time_zero, _) = step_function(input_time_zero,
constants is null ? initial_states : initial_states.MergeWith(constants));

Tensor output_ta_size = return_all_outputs ? time_steps_t : constant_op.constant(1);
var output_ta = new List<TensorArray>();
foreach(var output in output_time_zero.Flatten())
{
output_ta.Add(TensorArray.Create(dtype: output.dtype, size: output_ta_size, element_shape: output.shape));
}

var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time");

Func<Tensor, Tensor>? masking_fn;
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null;
if (mask != null)
{
if (go_backwards)
{
mask = tf.reverse(mask, axis: new[] { 0 });
}
var mask_ta = TensorArray.Create(dtype: TF_DataType.TF_BOOL, size: time_steps_t);
mask_ta = mask_ta.unstack(mask);

masking_fn = (time) =>
{
return mask_ta.read(time);
};

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var tiled_mask_t = new Tensors();
foreach (var o in flat_out)
{
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank));
}

Tensors res = new Tensors();
foreach (var (m, o, fm) in zip(tiled_mask_t.ToList(), flat_out.ToList(), flat_mask.ToList()))
{
res.Add(tf.where(m, o, fm));
}
return res;
};
}
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
else if (input_length is Tensor)
{
if (go_backwards)
{
var max_len = tf.reduce_max(input_length, axis: 0);
var rev_input_length = tf.subtract(max_len - 1, input_length);

masking_fn = (time) =>
{
return tf.less(rev_input_length, time);
};
}
else
{
masking_fn = (time) =>
{
return tf.greater(input_length, time);
};
}

compute_masked_output = (mask_t, flat_out, flat_mask) =>
{
var res = new List<Tensor>();
foreach (var (o, zo) in zip(flat_out, flat_mask))
{
res.Add(tf.where(mask_t, o, zo));
}
return res;
};
}
else
{
masking_fn = null;
}

Func<Tensors, Tensor> cond = (time) => (time[0] < time_steps_t);
int parallel_iterations = 32;
Tensors final_outputs;
if (masking_fn != null)
{
// Mask for the T output will be base on the output of T - 1. In the
// case T = 0, a zero filled tensor will be used.
var flat_zero_output = new Tensors();
foreach (var o in Nest.Flatten(output_time_zero))
{
flat_zero_output.Add(tf.zeros_like(o));
}

var prev_output = flat_zero_output;
var output_ta_t = output_ta;
Tensors _step(Tensors tensors)
{
/*
RNN step function.
Args:
time: Current timestep value.
output_ta_t: TensorArray.
prev_output: tuple of outputs from time - 1.
*states: List of states.
Returns:
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)`
*/

Tensor time = tensors[0];
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray;
Tensors prev_output = tensors.GetShallow(2);
Tensors states = new Tensors(tensors.Skip(2 + prev_output.Length).ToArray());

var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var mask_t = masking_fn(time);
var (output, new_states) = step_function(current_input, states.MergeWith(constants));
// mask output
var flat_output = Nest.Flatten(output).ToList();

var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.Flatten().ToList();

// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output);

// mask states
var flat_state = states.Flatten().ToList();
var flat_new_state = new_states.Flatten().ToList();

foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.shape = state.shape;
}
}

var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state);
new_states = Nest.PackSequenceAs(new_states, flat_final_state.ToArray()).ToTensors();

var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
Debug.Assert(flat_output.Count() == 1);
output_ta_t = output_ta_t.write(ta_index_to_write, flat_new_output.First());

return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(flat_new_output).Concat(new_states)
.ToArray().ToTensors();

}
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }
.Concat(flat_zero_output.Flatten()).Concat(states).ToArray().ToTensors();
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations);
new_states = final_outputs.Skip(3).ToList();
}
else
{
var output_ta_t = output_ta;
new_states = states;
Tensors _step(Tensors tensors)
{
Tensor time = tensors[0];
TensorArray output_ta_t = (tensors[1] as FakeTensorByTensorArray).TensorArray;
Tensors states = new Tensors(tensors.Skip(2).ToArray());
var flat_current_input = input_ta.Select(x => x.read(time)).ToList();
// maybe set shape
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type
var current_input = Nest.PackSequenceAs(inputs, flat_current_input).ToTensors();
var (output, new_states) = step_function(current_input, states.MergeWith(constants));
var flat_state = new_states.Flatten().ToList();
var flat_new_state = new_states.Flatten().ToList();
foreach (var (state, new_state) in zip(flat_state, flat_new_state))
{
if (new_state is Tensor)
{
new_state.shape = state.shape;
}
}
var flat_output = Nest.Flatten(output);
var ta_index_to_write = return_all_outputs ? time : tf.constant(0);
Debug.Assert(flat_output.Count() == 1);
output_ta_t = output_ta_t.write(ta_index_to_write, flat_output.First());

new_states = Nest.PackSequenceAs(initial_states, flat_new_state).ToTensors();
return new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta_t) }.Concat(new_states).ToArray().ToTensors();
}
Debug.Assert(output_ta.Count == 1);
var loop_vars = new Tensor[] { time + 1, new FakeTensorByTensorArray(output_ta[0]) }.Concat(states).ToArray().ToTensors();
final_outputs = control_flow_ops.while_loop(cond: cond, body: _step, loop_vars: loop_vars, parallel_iterations: parallel_iterations);
new_states = final_outputs.Skip(2).ToList();
}

output_ta = new List<TensorArray> { (final_outputs[1] as FakeTensorByTensorArray).TensorArray };
outputs = outputs.MergeWith(output_ta.Select(o => o.stack()).ToArray().ToTensors());
last_output = last_output.MergeWith(outputs.Select(o => o[-1]).ToArray().ToTensors());
outputs = Nest.PackSequenceAs(output_time_zero, (Tensor[])outputs).ToTensors();
last_output = Nest.PackSequenceAs(output_time_zero, (Tensor[])last_output).ToTensors();
}

Func<Tensor, Tensor> set_shape;
set_shape = (output_) =>
{
if (output_ is Tensor)
{
var shape = output_.shape.as_int_list();
if (return_all_outputs)
{
shape[0] = (int)time_steps;
}
else
{
shape[0] = 1;
}
shape[1] = (int)batch;
output_.shape = shape;
}
return output_;
};

outputs = Nest.MapStructure(set_shape, outputs).ToTensors();
if (!time_major)
{
outputs = Nest.MapStructure(swap_batch_timestep, outputs).ToTensors();
}
return (last_output, outputs, new_states);

}

public Tensor reverse(Tensor input, int axis)
{
return reverse(input, new int[] { axis });
}

public Tensor reverse(Tensor input, int[] axes)
{
return tf.reverse(input, axes);
}

public Tensor maybe_convert_to_ragged(bool is_ragged_output, Tensor output, int nested_row_lengths, bool go_backwards = false)
{
if (!is_ragged_output)
{
return output;
}

throw new NotImplementedException("Not implemented currently, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues");
}
}
}

+ 3
- 2
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
@@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Engine
}
else
{
_buildInputShape = new Saving.TensorShapeConfig();
_buildInputShape = new TensorShapeConfig();
}

if (outputs.Any(x => x.KerasHistory == null))
@@ -325,7 +326,7 @@ namespace Tensorflow.Keras.Engine
nodes_in_decreasing_depth.append(node);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
var tensor_dict = new Dictionary<long, Queue<Tensor>>();
// map input values


+ 6
- 3
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -1,4 +1,5 @@
using System.Threading;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
@@ -8,11 +9,11 @@ namespace Tensorflow.Keras.Engine
/// <summary>
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="inputs"></param>
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
public virtual Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null)
{
if (callContext.Value == null)
callContext.Value = new CallContext();
@@ -30,13 +31,15 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

var outputs = Call(inputs, state: state, training: training);
var outputs = Call(inputs, state: states, training: training);

// memory leak
// _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);

// TODO(Rinne): set save spec if null

scope.__exit__();

return outputs;


+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -32,7 +32,7 @@ using Tensorflow.Util;
using static Tensorflow.Binding;
using Tensorflow.Framework;
using Tensorflow.Sessions;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Engine
{
@@ -332,7 +332,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="training"></param>
/// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected virtual Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if(ReplacedCall is not null)
{


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Build.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Engine
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
graph.as_default();
var shapes = input_shape.ToShapeArray();
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)));
var x = new Tensors(shapes.Select(x => base_layer_utils.generate_placeholders_from_shape(x)).ToArray());
try
{
Call(x, training: false);


+ 3
- 2
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -72,7 +72,7 @@ namespace Tensorflow.Keras.Engine
{
var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(x),
X = new Tensors(x.ToArray()),
Y = y,
Model = this,
StepsPerExecution = _steps_per_execution
@@ -168,7 +168,8 @@ namespace Tensorflow.Keras.Engine
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
{
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Engine

var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(train_x),
X = new Tensors(train_x.ToArray()),
Y = train_y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -1,8 +1,8 @@
using System.Diagnostics;
using Tensorflow.Common.Types;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.Keras.Utils;
using Tensorflow.Train;


+ 2
- 1
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -21,6 +21,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Engine
{
@@ -143,7 +144,7 @@ namespace Tensorflow.Keras.Engine
}
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
if (!_has_explicit_input_shape)
{


+ 4
- 0
src/TensorFlowNET.Keras/IsExternalInit.cs View File

@@ -0,0 +1,4 @@
namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -29,7 +30,7 @@ namespace Tensorflow.Keras.Layers {
base.build(input_shape);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
output = tf.where(output > 0f, output,


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -4,7 +4,7 @@ using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using static Tensorflow.Binding;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers {
public class Exponential : Layer
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers {
{
base.build(input_shape);
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
Tensor output = inputs;
return tf.exp(output);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/HardSigmoid.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers {
@@ -10,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
public HardSigmoid ( LayerArgs args ) : base(args) {
// hard sigmoid has no arguments
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null ) {
Tensor x = inputs;
return tf.clip_by_value(
tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f);


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
@@ -19,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
{
return tf.nn.leaky_relu(inputs, alpha: alpha);
}


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
@@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Layers {
}
base.build(input_shape);
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
protected override Tensors Call ( Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) {
Tensor output = inputs;
return tf.where(output > 0f,
tf.multiply(scale, output),


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save