diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md
index 141bba35..4002aa21 100644
--- a/src/SciSharp.TensorFlow.Redist/README.md
+++ b/src/SciSharp.TensorFlow.Redist/README.md
@@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
#### Download pre-build package
-[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)
+[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip)
@@ -35,6 +35,6 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
-2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
+2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
diff --git a/src/TensorFlowNET.Console/Program.cs b/src/TensorFlowNET.Console/Program.cs
index 4b7f52de..638fe0a3 100644
--- a/src/TensorFlowNET.Console/Program.cs
+++ b/src/TensorFlowNET.Console/Program.cs
@@ -10,6 +10,9 @@ namespace Tensorflow
var diag = new Diagnostician();
// diag.Diagnose(@"D:\memory.txt");
+ var rnn = new SimpleRnnTest();
+ rnn.Run();
+
// this class is used explor new features.
var exploring = new Exploring();
// exploring.Run();
diff --git a/src/TensorFlowNET.Console/SimpleRnnTest.cs b/src/TensorFlowNET.Console/SimpleRnnTest.cs
new file mode 100644
index 00000000..b61cee9c
--- /dev/null
+++ b/src/TensorFlowNET.Console/SimpleRnnTest.cs
@@ -0,0 +1,31 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras;
+using Tensorflow.NumPy;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow
+{
+ public class SimpleRnnTest
+ {
+ public void Run()
+ {
+ tf.keras = new KerasInterface();
+ var inputs = np.random.random((32, 10, 8)).astype(np.float32);
+ var simple_rnn = tf.keras.layers.SimpleRNN(4);
+ var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
+ if (output.shape == (32, 4))
+ {
+
+ }
+ /*simple_rnn = tf.keras.layers.SimpleRNN(
+ 4, return_sequences = True, return_state = True)
+
+ # whole_sequence_output has shape `[32, 10, 4]`.
+ # final_state has shape `[32, 4]`.
+ whole_sequence_output, final_state = simple_rnn(inputs)*/
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
index 058722eb..e66c7033 100644
--- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj
+++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
@@ -6,7 +6,7 @@
Tensorflow
Tensorflow
AnyCPU;x64
- 9.0
+ 11.0
@@ -20,7 +20,7 @@
-
+
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs
deleted file mode 100644
index 0a2555a6..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs
+++ /dev/null
@@ -1,22 +0,0 @@
-namespace Tensorflow.Keras.ArgsDefinition
-{
- public class LSTMArgs : RNNArgs
- {
- public int Units { get; set; }
- public Activation Activation { get; set; }
- public Activation RecurrentActivation { get; set; }
- public IInitializer KernelInitializer { get; set; }
- public IInitializer RecurrentInitializer { get; set; }
- public IInitializer BiasInitializer { get; set; }
- public bool UnitForgetBias { get; set; }
- public float Dropout { get; set; }
- public float RecurrentDropout { get; set; }
- public int Implementation { get; set; }
- public bool ReturnSequences { get; set; }
- public bool ReturnState { get; set; }
- public bool GoBackwards { get; set; }
- public bool Stateful { get; set; }
- public bool TimeMajor { get; set; }
- public bool Unroll { get; set; }
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs
new file mode 100644
index 00000000..b08d21d8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs
@@ -0,0 +1,12 @@
+using Tensorflow.Keras.ArgsDefinition.Rnn;
+
+namespace Tensorflow.Keras.ArgsDefinition.Lstm
+{
+ public class LSTMArgs : RNNArgs
+ {
+ public bool UnitForgetBias { get; set; }
+ public float Dropout { get; set; }
+ public float RecurrentDropout { get; set; }
+ public int Implementation { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
similarity index 53%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
index 62f9a0c4..fb0868dc 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
@@ -1,4 +1,4 @@
-namespace Tensorflow.Keras.ArgsDefinition
+namespace Tensorflow.Keras.ArgsDefinition.Lstm
{
public class LSTMCellArgs : LayerArgs
{
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs
deleted file mode 100644
index 3ebcf617..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs
+++ /dev/null
@@ -1,21 +0,0 @@
-using System.Collections.Generic;
-
-namespace Tensorflow.Keras.ArgsDefinition
-{
- public class RNNArgs : LayerArgs
- {
- public interface IRnnArgCell : ILayer
- {
- object state_size { get; }
- }
-
- public IRnnArgCell Cell { get; set; } = null;
- public bool ReturnSequences { get; set; } = false;
- public bool ReturnState { get; set; } = false;
- public bool GoBackwards { get; set; } = false;
- public bool Stateful { get; set; } = false;
- public bool Unroll { get; set; } = false;
- public bool TimeMajor { get; set; } = false;
- public Dictionary Kwargs { get; set; } = null;
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
new file mode 100644
index 00000000..da527925
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
@@ -0,0 +1,45 @@
+using System.Collections.Generic;
+
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class RNNArgs : LayerArgs
+ {
+ public interface IRnnArgCell : ILayer
+ {
+ object state_size { get; }
+ }
+
+ public IRnnArgCell Cell { get; set; } = null;
+ public bool ReturnSequences { get; set; } = false;
+ public bool ReturnState { get; set; } = false;
+ public bool GoBackwards { get; set; } = false;
+ public bool Stateful { get; set; } = false;
+ public bool Unroll { get; set; } = false;
+ public bool TimeMajor { get; set; } = false;
+ public Dictionary Kwargs { get; set; } = null;
+
+ public int Units { get; set; }
+ public Activation Activation { get; set; }
+ public Activation RecurrentActivation { get; set; }
+ public bool UseBias { get; set; } = true;
+ public IInitializer KernelInitializer { get; set; }
+ public IInitializer RecurrentInitializer { get; set; }
+ public IInitializer BiasInitializer { get; set; }
+
+ // kernel_regularizer=None,
+ // recurrent_regularizer=None,
+ // bias_regularizer=None,
+ // activity_regularizer=None,
+ // kernel_constraint=None,
+ // recurrent_constraint=None,
+ // bias_constraint=None,
+ // dropout=0.,
+ // recurrent_dropout=0.,
+ // return_sequences=False,
+ // return_state=False,
+ // go_backwards=False,
+ // stateful=False,
+ // unroll=False,
+ // **kwargs):
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs
new file mode 100644
index 00000000..fcfd694d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs
@@ -0,0 +1,7 @@
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class SimpleRNNArgs : RNNArgs
+ {
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
similarity index 82%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
index 9b910e17..fdfadab8 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
@@ -1,6 +1,6 @@
using System.Collections.Generic;
-namespace Tensorflow.Keras.ArgsDefinition
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs
deleted file mode 100644
index 65815587..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs
+++ /dev/null
@@ -1,30 +0,0 @@
-namespace Tensorflow.Keras.ArgsDefinition
-{
- public class SimpleRNNArgs : RNNArgs
- {
- public int Units { get; set; }
- public Activation Activation { get; set; }
-
- // units,
- // activation='tanh',
- // use_bias=True,
- // kernel_initializer='glorot_uniform',
- // recurrent_initializer='orthogonal',
- // bias_initializer='zeros',
- // kernel_regularizer=None,
- // recurrent_regularizer=None,
- // bias_regularizer=None,
- // activity_regularizer=None,
- // kernel_constraint=None,
- // recurrent_constraint=None,
- // bias_constraint=None,
- // dropout=0.,
- // recurrent_dropout=0.,
- // return_sequences=False,
- // return_state=False,
- // go_backwards=False,
- // stateful=False,
- // unroll=False,
- // **kwargs):
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/IKerasApi.cs b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
new file mode 100644
index 00000000..660dcbde
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
@@ -0,0 +1,12 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.Layers;
+
+namespace Tensorflow.Keras
+{
+ public interface IKerasApi
+ {
+ public ILayersApi layers { get; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/IPreprocessing.cs b/src/TensorFlowNET.Core/Keras/IPreprocessing.cs
new file mode 100644
index 00000000..28eea0f5
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/IPreprocessing.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras
+{
+ public interface IPreprocessing
+ {
+ public ILayer Resizing(int height, int width, string interpolation = "bilinear");
+ public ILayer TextVectorization(Func standardize = null,
+ string split = "whitespace",
+ int max_tokens = -1,
+ string output_mode = "int",
+ int output_sequence_length = -1);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs
new file mode 100644
index 00000000..73a6787c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs
@@ -0,0 +1,20 @@
+using System;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.NumPy;
+using Tensorflow.Operations.Activation;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial interface ILayersApi
+ {
+ public ILayer ELU(float alpha = 0.1f);
+ public ILayer SELU();
+ public ILayer Softmax(Axis axis);
+ public ILayer Softplus();
+ public ILayer HardSigmoid();
+ public ILayer Softsign();
+ public ILayer Swish();
+ public ILayer Tanh();
+ public ILayer Exponential();
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs
new file mode 100644
index 00000000..22fb50d3
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs
@@ -0,0 +1,28 @@
+using System;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial interface ILayersApi
+ {
+ public ILayer Attention(bool use_scale = false,
+ string score_mode = "dot",
+ bool causal = false,
+ float dropout = 0f);
+ public ILayer MultiHeadAttention(int num_heads,
+ int key_dim,
+ int? value_dim = null,
+ float dropout = 0f,
+ bool use_bias = true,
+ Shape output_shape = null,
+ Shape attention_axes = null,
+ IInitializer kernel_initializer = null,
+ IInitializer bias_initializer = null,
+ IRegularizer kernel_regularizer = null,
+ IRegularizer bias_regularizer = null,
+ IRegularizer activity_regularizer = null,
+ Action kernel_constraint = null,
+ Action bias_constraint = null);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
new file mode 100644
index 00000000..602e7a88
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
@@ -0,0 +1,13 @@
+using System;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial interface ILayersApi
+ {
+ public ILayer Cropping1D(NDArray cropping);
+ public ILayer Cropping2D(NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last);
+ public ILayer Cropping3D(NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs
new file mode 100644
index 00000000..d0a7f09f
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs
@@ -0,0 +1,10 @@
+using System;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial interface ILayersApi
+ {
+ public ILayer Concatenate(int axis = -1);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs
new file mode 100644
index 00000000..d41e0688
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs
@@ -0,0 +1,18 @@
+using System;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial interface ILayersApi
+ {
+ public ILayer Reshape(Shape target_shape);
+ public ILayer Reshape(object[] target_shape);
+
+ public ILayer UpSampling2D(Shape size = null,
+ string data_format = null,
+ string interpolation = "nearest");
+
+ public ILayer ZeroPadding2D(NDArray padding);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
new file mode 100644
index 00000000..5945bb55
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -0,0 +1,169 @@
+using System;
+using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
+
+namespace Tensorflow.Keras.Layers
+{
+ public partial interface ILayersApi
+ {
+ public IPreprocessing preprocessing { get; }
+
+ public ILayer Add();
+
+ public ILayer AveragePooling2D(Shape pool_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null);
+
+ public ILayer BatchNormalization(int axis = -1,
+ float momentum = 0.99f,
+ float epsilon = 0.001f,
+ bool center = true,
+ bool scale = true,
+ IInitializer beta_initializer = null,
+ IInitializer gamma_initializer = null,
+ IInitializer moving_mean_initializer = null,
+ IInitializer moving_variance_initializer = null,
+ bool trainable = true,
+ string name = null,
+ bool renorm = false,
+ float renorm_momentum = 0.99f);
+
+ public ILayer Conv1D(int filters,
+ Shape kernel_size,
+ int strides = 1,
+ string padding = "valid",
+ string data_format = "channels_last",
+ int dilation_rate = 1,
+ int groups = 1,
+ string activation = null,
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string bias_initializer = "zeros");
+
+ public ILayer Conv2D(int filters,
+ Shape kernel_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null,
+ Shape dilation_rate = null,
+ int groups = 1,
+ Activation activation = null,
+ bool use_bias = true,
+ IInitializer kernel_initializer = null,
+ IInitializer bias_initializer = null,
+ IRegularizer kernel_regularizer = null,
+ IRegularizer bias_regularizer = null,
+ IRegularizer activity_regularizer = null);
+
+ public ILayer Conv2D(int filters,
+ Shape kernel_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null,
+ Shape dilation_rate = null,
+ int groups = 1,
+ string activation = null,
+ bool use_bias = true,
+ string kernel_initializer = "glorot_uniform",
+ string bias_initializer = "zeros");
+
+ public ILayer Dense(int units);
+ public ILayer Dense(int units,
+ string activation = null,
+ Shape input_shape = null);
+ public ILayer Dense(int units,
+ Activation activation = null,
+ IInitializer kernel_initializer = null,
+ bool use_bias = true,
+ IInitializer bias_initializer = null,
+ Shape input_shape = null);
+
+ public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null);
+
+ public ILayer Embedding(int input_dim,
+ int output_dim,
+ IInitializer embeddings_initializer = null,
+ bool mask_zero = false,
+ Shape input_shape = null,
+ int input_length = -1);
+
+ public ILayer EinsumDense(string equation,
+ Shape output_shape,
+ string bias_axes,
+ Activation activation = null,
+ IInitializer kernel_initializer = null,
+ IInitializer bias_initializer = null,
+ IRegularizer kernel_regularizer = null,
+ IRegularizer bias_regularizer = null,
+ IRegularizer activity_regularizer = null,
+ Action kernel_constraint = null,
+ Action bias_constraint = null);
+
+ public ILayer Flatten(string data_format = null);
+
+ public ILayer GlobalAveragePooling1D(string data_format = "channels_last");
+ public ILayer GlobalAveragePooling2D();
+ public ILayer GlobalAveragePooling2D(string data_format = "channels_last");
+ public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
+ public ILayer GlobalMaxPooling2D(string data_format = "channels_last");
+
+ public Tensors Input(Shape shape,
+ string name = null,
+ bool sparse = false,
+ bool ragged = false);
+ public ILayer InputLayer(Shape input_shape,
+ string name = null,
+ bool sparse = false,
+ bool ragged = false);
+
+ public ILayer LayerNormalization(Axis? axis,
+ float epsilon = 1e-3f,
+ bool center = true,
+ bool scale = true,
+ IInitializer beta_initializer = null,
+ IInitializer gamma_initializer = null);
+
+ public ILayer LeakyReLU(float alpha = 0.3f);
+
+ public ILayer LSTM(int units,
+ Activation activation = null,
+ Activation recurrent_activation = null,
+ bool use_bias = true,
+ IInitializer kernel_initializer = null,
+ IInitializer recurrent_initializer = null,
+ IInitializer bias_initializer = null,
+ bool unit_forget_bias = true,
+ float dropout = 0f,
+ float recurrent_dropout = 0f,
+ int implementation = 2,
+ bool return_sequences = false,
+ bool return_state = false,
+ bool go_backwards = false,
+ bool stateful = false,
+ bool time_major = false,
+ bool unroll = false);
+
+ public ILayer MaxPooling1D(int? pool_size = null,
+ int? strides = null,
+ string padding = "valid",
+ string data_format = null);
+ public ILayer MaxPooling2D(Shape pool_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null);
+
+ public ILayer Permute(int[] dims);
+
+ public ILayer Rescaling(float scale,
+ float offset = 0,
+ Shape input_shape = null);
+
+ public ILayer SimpleRNN(int units,
+ string activation = "tanh",
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros");
+
+ public ILayer Subtract();
+ }
+}
diff --git a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs
index 222b10bb..064c7362 100644
--- a/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs
+++ b/src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs
@@ -20,11 +20,11 @@ namespace Tensorflow.NumPy
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize);
}
- public NDArray rand(params int[] shape)
- => throw new NotImplementedException("");
+ public NDArray random(Shape size)
+ => uniform(low: 0, high: 1, size: size);
[AutoNumPy]
- public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32)
+ public NDArray randint(int low, int? high = null, Shape? size = null, TF_DataType dtype = TF_DataType.TF_INT32)
{
if(high == null)
{
@@ -41,11 +41,11 @@ namespace Tensorflow.NumPy
=> new NDArray(random_ops.random_normal(shape ?? Shape.Scalar));
[AutoNumPy]
- public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null)
+ public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape? size = null)
=> new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale));
[AutoNumPy]
- public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape size = null)
+ public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape? size = null)
=> new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high));
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
index 7c5b21b6..041268b7 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
@@ -18,6 +18,7 @@ using System;
using System.Collections.Generic;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;
using Tensorflow.Util;
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 4bd0a490..36449826 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -5,8 +5,8 @@
Tensorflow.Binding
Tensorflow
2.2.0
- 0.70.2
- 9.0
+ 0.100.0
+ 10.0
enable
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
@@ -20,9 +20,9 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.70.1.0
+ 0.100.0.0
- tf.net 0.70.x and above are based on tensorflow native 2.7.0
+ tf.net 0.100.x and above are based on tensorflow native 2.10.0
* Eager Mode is added finally.
* tf.keras is partially working.
@@ -35,14 +35,17 @@ https://tensorflownet.readthedocs.io
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.
- tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
- 0.70.1.0
+ tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
+ tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
+
+ 0.100.0.0
LICENSE
true
true
Open.snk
AnyCPU;x64
TensorFlow.NET
+ Debug;Release;GPU
@@ -51,6 +54,12 @@ https://tensorflownet.readthedocs.io
AnyCPU
+
+ true
+ TRACE;DEBUG;TRACK_TENSOR_LIFE_1
+ AnyCPU
+
+
true
TRACE;DEBUG;TRACK_TENSOR_LIFE1
@@ -58,6 +67,13 @@ https://tensorflownet.readthedocs.io
TensorFlow.NET.xml
+
+ true
+ TRACE;DEBUG;TRACK_TENSOR_LIFE1
+ x64
+ TensorFlow.NET.xml
+
+
true
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index 8a2c78a7..e02723b7 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -20,6 +20,7 @@ using System.Threading;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Gradients;
+using Tensorflow.Keras;
namespace Tensorflow
{
@@ -51,6 +52,8 @@ namespace Tensorflow
ThreadLocal _runner = new ThreadLocal(() => new EagerRunner());
public IEagerRunner Runner => _runner.Value;
+ public IKerasApi keras { get; set; }
+
public tensorflow()
{
Logger = new LoggerConfiguration()
diff --git a/src/TensorFlowNET.Keras/KerasApi.cs b/src/TensorFlowNET.Keras/KerasApi.cs
index d10ced0c..f79c2b5f 100644
--- a/src/TensorFlowNET.Keras/KerasApi.cs
+++ b/src/TensorFlowNET.Keras/KerasApi.cs
@@ -2,6 +2,9 @@
namespace Tensorflow
{
+ ///
+ /// Deprecated, will use tf.keras
+ ///
public static class KerasApi
{
public static KerasInterface keras { get; } = new KerasInterface();
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index 02362a55..5bf9f97f 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -10,18 +10,17 @@ using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
-using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using System.Threading;
namespace Tensorflow.Keras
{
- public class KerasInterface
+ public class KerasInterface : IKerasApi
{
public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers();
public Regularizers regularizers { get; } = new Regularizers();
- public LayersApi layers { get; } = new LayersApi();
+ public ILayersApi layers { get; } = new LayersApi();
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs
index 0978d0d3..24a56839 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs
@@ -7,16 +7,16 @@ using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Layers {
public partial class LayersApi {
- public ELU ELU ( float alpha = 0.1f )
+ public ILayer ELU ( float alpha = 0.1f )
=> new ELU(new ELUArgs { Alpha = alpha });
- public SELU SELU ()
+ public ILayer SELU ()
=> new SELU(new LayerArgs { });
- public Softmax Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis });
- public Softplus Softplus () => new Softplus(new LayerArgs { });
- public HardSigmoid HardSigmoid () => new HardSigmoid(new LayerArgs { });
- public Softsign Softsign () => new Softsign(new LayerArgs { });
- public Swish Swish () => new Swish(new LayerArgs { });
- public Tanh Tanh () => new Tanh(new LayerArgs { });
- public Exponential Exponential () => new Exponential(new LayerArgs { });
+ public ILayer Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis });
+ public ILayer Softplus () => new Softplus(new LayerArgs { });
+ public ILayer HardSigmoid () => new HardSigmoid(new LayerArgs { });
+ public ILayer Softsign () => new Softsign(new LayerArgs { });
+ public ILayer Swish () => new Swish(new LayerArgs { });
+ public ILayer Tanh () => new Tanh(new LayerArgs { });
+ public ILayer Exponential () => new Exponential(new LayerArgs { });
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs
index 5effd175..859e9c14 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers
{
public partial class LayersApi
{
- public Attention Attention(bool use_scale = false,
+ public ILayer Attention(bool use_scale = false,
string score_mode = "dot",
bool causal = false,
float dropout = 0f) =>
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers
causal = causal,
dropout = dropout
});
- public MultiHeadAttention MultiHeadAttention(int num_heads,
+ public ILayer MultiHeadAttention(int num_heads,
int key_dim,
int? value_dim = null,
float dropout = 0f,
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
index f4d2230c..339ddb85 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers {
/// Cropping layer for 1D input
///
/// cropping size
- public Cropping1D Cropping1D ( NDArray cropping )
+ public ILayer Cropping1D ( NDArray cropping )
=> new Cropping1D(new CroppingArgs {
cropping = cropping
});
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers {
///
/// Cropping layer for 2D input
///
- public Cropping2D Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last )
+ public ILayer Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last )
=> new Cropping2D(new Cropping2DArgs {
cropping = cropping,
data_format = data_format
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers {
///
/// Cropping layer for 3D input
///
- public Cropping3D Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last )
+ public ILayer Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last )
=> new Cropping3D(new Cropping3DArgs {
cropping = cropping,
data_format = data_format
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
index ecf8c0a6..d94bfb4d 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers
///
/// Axis along which to concatenate.
///
- public Concatenate Concatenate(int axis = -1)
+ public ILayer Concatenate(int axis = -1)
=> new Concatenate(new MergeArgs
{
Axis = axis
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
index 5cfec89e..d3db1d66 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
///
///
///
- public ZeroPadding2D ZeroPadding2D ( NDArray padding )
+ public ILayer ZeroPadding2D ( NDArray padding )
=> new ZeroPadding2D(new ZeroPadding2DArgs {
Padding = padding
});
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers {
///
///
///
- public UpSampling2D UpSampling2D ( Shape size = null,
+ public ILayer UpSampling2D ( Shape size = null,
string data_format = null,
string interpolation = "nearest" )
=> new UpSampling2D(new UpSampling2DArgs {
@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers {
///
/// Permutes the dimensions of the input according to a given pattern.
///
- public Permute Permute ( int[] dims )
+ public ILayer Permute ( int[] dims )
=> new Permute(new PermuteArgs {
dims = dims
});
@@ -44,12 +44,12 @@ namespace Tensorflow.Keras.Layers {
///
///
///
- public Reshape Reshape ( Shape target_shape )
- => new Reshape(new ReshapeArgs {
- TargetShape = target_shape
- });
+ public ILayer Reshape ( Shape target_shape )
+ => new Reshape(new ReshapeArgs {
+ TargetShape = target_shape
+ });
- public Reshape Reshape ( object[] target_shape )
+ public ILayer Reshape ( object[] target_shape )
=> new Reshape(new ReshapeArgs {
TargetShapeObjects = target_shape
});
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 48856735..8498f5ac 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -1,16 +1,18 @@
using System;
-using Tensorflow.NumPy;
-using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Lstm;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Layers.Lstm;
+using Tensorflow.Keras.Layers.Rnn;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Layers
{
- public partial class LayersApi
+ public partial class LayersApi : ILayersApi
{
- public Preprocessing preprocessing { get; } = new Preprocessing();
+ public IPreprocessing preprocessing { get; } = new Preprocessing();
///
/// Layer that normalizes its inputs.
@@ -38,7 +40,7 @@ namespace Tensorflow.Keras.Layers
/// Note that momentum is still applied to get the means and variances for inference.
///
/// Tensor of the same shape as input.
- public BatchNormalization BatchNormalization(int axis = -1,
+ public ILayer BatchNormalization(int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
bool center = true,
@@ -84,7 +86,7 @@ namespace Tensorflow.Keras.Layers
/// Initializer for the kernel weights matrix (see keras.initializers).
/// Initializer for the bias vector (see keras.initializers).
/// A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).
- public Conv1D Conv1D(int filters,
+ public ILayer Conv1D(int filters,
Shape kernel_size,
int strides = 1,
string padding = "valid",
@@ -131,7 +133,7 @@ namespace Tensorflow.Keras.Layers
/// Regularizer function applied to the bias vector (see keras.regularizers).
/// Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).
/// A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).
- public Conv2D Conv2D(int filters,
+ public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
@@ -184,7 +186,7 @@ namespace Tensorflow.Keras.Layers
/// The name of the regularizer function applied to the bias vector (see keras.regularizers).
/// The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).
/// A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).
- public Conv2D Conv2D(int filters,
+ public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
@@ -228,7 +230,7 @@ namespace Tensorflow.Keras.Layers
/// The name of the regularizer function applied to the bias vector (see keras.regularizers).
/// The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).
/// A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).
- public Conv2DTranspose Conv2DTranspose(int filters,
+ public ILayer Conv2DTranspose(int filters,
Shape kernel_size = null,
Shape strides = null,
string output_padding = "valid",
@@ -270,7 +272,7 @@ namespace Tensorflow.Keras.Layers
/// Initializer for the bias vector.
/// N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).
/// N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).
- public Dense Dense(int units,
+ public ILayer Dense(int units,
Activation activation = null,
IInitializer kernel_initializer = null,
bool use_bias = true,
@@ -294,7 +296,7 @@ namespace Tensorflow.Keras.Layers
///
/// Positive integer, dimensionality of the output space.
/// N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).
- public Dense Dense(int units)
+ public ILayer Dense(int units)
=> new Dense(new DenseArgs
{
Units = units,
@@ -312,7 +314,7 @@ namespace Tensorflow.Keras.Layers
/// Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).
/// N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).
/// N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).
- public Dense Dense(int units,
+ public ILayer Dense(int units,
string activation = null,
Shape input_shape = null)
=> new Dense(new DenseArgs
@@ -364,7 +366,7 @@ namespace Tensorflow.Keras.Layers
}
- public EinsumDense EinsumDense(string equation,
+ public ILayer EinsumDense(string equation,
Shape output_shape,
string bias_axes,
Activation activation = null,
@@ -402,7 +404,7 @@ namespace Tensorflow.Keras.Layers
///
/// An integer to use as random seed.
///
- public Dropout Dropout(float rate, Shape noise_shape = null, int? seed = null)
+ public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null)
=> new Dropout(new DropoutArgs
{
Rate = rate,
@@ -421,7 +423,7 @@ namespace Tensorflow.Keras.Layers
/// Initializer for the embeddings matrix (see keras.initializers).
///
///
- public Embedding Embedding(int input_dim,
+ public ILayer Embedding(int input_dim,
int output_dim,
IInitializer embeddings_initializer = null,
bool mask_zero = false,
@@ -446,7 +448,7 @@ namespace Tensorflow.Keras.Layers
/// If you never set it, then it will be "channels_last".
///
///
- public Flatten Flatten(string data_format = null)
+ public ILayer Flatten(string data_format = null)
=> new Flatten(new FlattenArgs
{
DataFormat = data_format
@@ -482,7 +484,7 @@ namespace Tensorflow.Keras.Layers
return input_layer.InboundNodes[0].Outputs;
}
- public InputLayer InputLayer(Shape input_shape,
+ public ILayer InputLayer(Shape input_shape,
string name = null,
bool sparse = false,
bool ragged = false)
@@ -502,7 +504,7 @@ namespace Tensorflow.Keras.Layers
///
///
///
- public AveragePooling2D AveragePooling2D(Shape pool_size = null,
+ public ILayer AveragePooling2D(Shape pool_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null)
@@ -527,7 +529,7 @@ namespace Tensorflow.Keras.Layers
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps).
///
///
- public MaxPooling1D MaxPooling1D(int? pool_size = null,
+ public ILayer MaxPooling1D(int? pool_size = null,
int? strides = null,
string padding = "valid",
string data_format = null)
@@ -564,7 +566,7 @@ namespace Tensorflow.Keras.Layers
/// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json.
/// If you never set it, then it will be "channels_last"
///
- public MaxPooling2D MaxPooling2D(Shape pool_size = null,
+ public ILayer MaxPooling2D(Shape pool_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null)
@@ -618,7 +620,7 @@ namespace Tensorflow.Keras.Layers
return layer.Apply(inputs);
}
- public Layer LayerNormalization(Axis? axis,
+ public ILayer LayerNormalization(Axis? axis,
float epsilon = 1e-3f,
bool center = true,
bool scale = true,
@@ -638,45 +640,30 @@ namespace Tensorflow.Keras.Layers
///
/// Negative slope coefficient.
///
- public Layer LeakyReLU(float alpha = 0.3f)
+ public ILayer LeakyReLU(float alpha = 0.3f)
=> new LeakyReLu(new LeakyReLuArgs
{
Alpha = alpha
});
- ///
- /// Fully-connected RNN where the output is to be fed back to input.
- ///
- /// Positive integer, dimensionality of the output space.
- ///
- public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh");
-
- ///
- /// Fully-connected RNN where the output is to be fed back to input.
- ///
- /// Positive integer, dimensionality of the output space.
- /// Activation function to use. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x).
- ///
- public Layer SimpleRNN(int units,
- Activation activation = null)
- => new SimpleRNN(new SimpleRNNArgs
- {
- Units = units,
- Activation = activation
- });
-
///
///
///
/// Positive integer, dimensionality of the output space.
/// The name of the activation function to use. Default: hyperbolic tangent (tanh)..
///
- public Layer SimpleRNN(int units,
- string activation = "tanh")
+ public ILayer SimpleRNN(int units,
+ string activation = "tanh",
+ string kernel_initializer = "glorot_uniform",
+ string recurrent_initializer = "orthogonal",
+ string bias_initializer = "zeros")
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
- Activation = GetActivationByName(activation)
+ Activation = GetActivationByName(activation),
+ KernelInitializer = GetInitializerByName(kernel_initializer),
+ RecurrentInitializer= GetInitializerByName(recurrent_initializer),
+ BiasInitializer= GetInitializerByName(bias_initializer)
});
///
@@ -706,7 +693,7 @@ namespace Tensorflow.Keras.Layers
/// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.
///
///
- public Layer LSTM(int units,
+ public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,
bool use_bias = true,
@@ -749,7 +736,7 @@ namespace Tensorflow.Keras.Layers
///
///
///
- public Rescaling Rescaling(float scale,
+ public ILayer Rescaling(float scale,
float offset = 0,
Shape input_shape = null)
=> new Rescaling(new RescalingArgs
@@ -763,21 +750,21 @@ namespace Tensorflow.Keras.Layers
///
///
///
- public Add Add()
+ public ILayer Add()
=> new Add(new MergeArgs { });
///
///
///
///
- public Subtract Subtract()
+ public ILayer Subtract()
=> new Subtract(new MergeArgs { });
///
/// Global max pooling operation for spatial data.
///
///
- public GlobalAveragePooling2D GlobalAveragePooling2D()
+ public ILayer GlobalAveragePooling2D()
=> new GlobalAveragePooling2D(new Pooling2DArgs { });
///
@@ -787,7 +774,7 @@ namespace Tensorflow.Keras.Layers
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps).
///
///
- public GlobalAveragePooling1D GlobalAveragePooling1D(string data_format = "channels_last")
+ public ILayer GlobalAveragePooling1D(string data_format = "channels_last")
=> new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format });
///
@@ -796,7 +783,7 @@ namespace Tensorflow.Keras.Layers
/// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs.
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).
///
- public GlobalAveragePooling2D GlobalAveragePooling2D(string data_format = "channels_last")
+ public ILayer GlobalAveragePooling2D(string data_format = "channels_last")
=> new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format });
///
@@ -807,7 +794,7 @@ namespace Tensorflow.Keras.Layers
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps).
///
///
- public GlobalMaxPooling1D GlobalMaxPooling1D(string data_format = "channels_last")
+ public ILayer GlobalMaxPooling1D(string data_format = "channels_last")
=> new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format });
///
@@ -816,7 +803,7 @@ namespace Tensorflow.Keras.Layers
/// A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs.
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).
///
- public GlobalMaxPooling2D GlobalMaxPooling2D(string data_format = "channels_last")
+ public ILayer GlobalMaxPooling2D(string data_format = "channels_last")
=> new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format });
@@ -848,6 +835,7 @@ namespace Tensorflow.Keras.Layers
"glorot_uniform" => tf.glorot_uniform_initializer,
"zeros" => tf.zeros_initializer,
"ones" => tf.ones_initializer,
+ "orthogonal" => tf.orthogonal_initializer,
_ => tf.glorot_uniform_initializer
};
}
diff --git a/src/TensorFlowNET.Keras/Layers/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs
similarity index 87%
rename from src/TensorFlowNET.Keras/Layers/LSTM.cs
rename to src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs
index 73a2df12..b7d97384 100644
--- a/src/TensorFlowNET.Keras/Layers/LSTM.cs
+++ b/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs
@@ -1,8 +1,9 @@
using System.Linq;
-using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Lstm;
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Layers.Rnn;
-namespace Tensorflow.Keras.Layers
+namespace Tensorflow.Keras.Layers.Lstm
{
///
/// Long Short-Term Memory layer - Hochreiter 1997.
diff --git a/src/TensorFlowNET.Keras/Layers/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs
similarity index 72%
rename from src/TensorFlowNET.Keras/Layers/LSTMCell.cs
rename to src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs
index dda279a7..3cd35a09 100644
--- a/src/TensorFlowNET.Keras/Layers/LSTMCell.cs
+++ b/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs
@@ -1,7 +1,7 @@
-using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Lstm;
using Tensorflow.Keras.Engine;
-namespace Tensorflow.Keras.Layers
+namespace Tensorflow.Keras.Layers.Lstm
{
public class LSTMCell : Layer
{
diff --git a/src/TensorFlowNET.Keras/Layers/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
similarity index 95%
rename from src/TensorFlowNET.Keras/Layers/RNN.cs
rename to src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
index 293c27fb..c2b86ae4 100644
--- a/src/TensorFlowNET.Keras/Layers/RNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
@@ -1,10 +1,12 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Layers.Lstm;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
-namespace Tensorflow.Keras.Layers
+namespace Tensorflow.Keras.Layers.Rnn
{
public class RNN : Layer
{
@@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Layers
private object _states = null;
private object constants_spec = null;
private int _num_constants = 0;
+ protected IVariableV1 kernel;
+ protected IVariableV1 bias;
public RNN(RNNArgs args) : base(PreConstruct(args))
{
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
new file mode 100644
index 00000000..58b700fe
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
@@ -0,0 +1,31 @@
+using System.Data;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
+using Tensorflow.Operations.Activation;
+using static HDF.PInvoke.H5Z;
+using static Tensorflow.ApiDef.Types;
+
+namespace Tensorflow.Keras.Layers.Rnn
+{
+ public class SimpleRNN : RNN
+ {
+ SimpleRNNArgs args;
+ SimpleRNNCell cell;
+ public SimpleRNN(SimpleRNNArgs args) : base(args)
+ {
+ this.args = args;
+ }
+
+ protected override void build(Tensors inputs)
+ {
+ var input_shape = inputs.shape;
+ var input_dim = input_shape[-1];
+
+ kernel = add_weight("kernel", (input_shape[-1], args.Units),
+ initializer: args.KernelInitializer
+ //regularizer = self.kernel_regularizer,
+ //constraint = self.kernel_constraint,
+ //caching_device = default_caching_device,
+ );
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
new file mode 100644
index 00000000..de50c361
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
@@ -0,0 +1,21 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
+using Tensorflow.Keras.Engine;
+
+namespace Tensorflow.Keras.Layers.Rnn
+{
+ public class SimpleRNNCell : Layer
+ {
+ public SimpleRNNCell(SimpleRNNArgs args) : base(args)
+ {
+
+ }
+
+ protected override void build(Tensors inputs)
+ {
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
similarity index 98%
rename from src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs
rename to src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
index 2da206ca..eead274a 100644
--- a/src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
@@ -2,9 +2,10 @@
using System.Collections.Generic;
using System.ComponentModel;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
-namespace Tensorflow.Keras.Layers
+namespace Tensorflow.Keras.Layers.Rnn
{
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
{
diff --git a/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs
deleted file mode 100644
index c1fc4afd..00000000
--- a/src/TensorFlowNET.Keras/Layers/SimpleRNN.cs
+++ /dev/null
@@ -1,14 +0,0 @@
-using Tensorflow.Keras.ArgsDefinition;
-
-namespace Tensorflow.Keras.Layers
-{
- public class SimpleRNN : RNN
- {
-
- public SimpleRNN(RNNArgs args) : base(args)
- {
-
- }
-
- }
-}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs
index 5e93f583..0be7f1e6 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras
///
///
///
- public Resizing Resizing(int height, int width, string interpolation = "bilinear")
+ public ILayer Resizing(int height, int width, string interpolation = "bilinear")
=> new Resizing(new ResizingArgs
{
Height = height,
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
index 994a36d6..94fc4a20 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
@@ -5,7 +5,7 @@ using Tensorflow.Keras.Preprocessings;
namespace Tensorflow.Keras
{
- public partial class Preprocessing
+ public partial class Preprocessing : IPreprocessing
{
public Sequence sequence => new Sequence();
public DatasetUtils dataset_utils => new DatasetUtils();
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras
private static TextApi _text = new TextApi();
- public TextVectorization TextVectorization(Func standardize = null,
+ public ILayer TextVectorization(Func standardize = null,
string split = "whitespace",
int max_tokens = -1,
string output_mode = "int",
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 3d448454..0c3eff9f 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -3,11 +3,11 @@
netstandard2.0
Tensorflow.Keras
- 9.0
+ 10.0
enable
Tensorflow.Keras
AnyCPU;x64
- 0.7.0
+ 0.10.0
Haiping Chen
Keras for .NET
Apache 2.0, Haiping Chen 2021
@@ -37,9 +37,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
Git
true
Open.snk
- 0.7.0.0
- 0.7.0.0
+ 0.10.0.0
+ 0.10.0.0
LICENSE
+ Debug;Release;GPU
@@ -47,6 +48,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
false
+
+ DEBUG;TRACE
+ false
+
+
false
@@ -55,6 +61,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
Tensorflow.Keras.xml
+
+ Tensorflow.Keras.xml
+
+
diff --git a/src/TensorFlowNET.Keras/tf.layers.cs b/src/TensorFlowNET.Keras/tf.layers.cs
index 3f5ed01c..da7c2347 100644
--- a/src/TensorFlowNET.Keras/tf.layers.cs
+++ b/src/TensorFlowNET.Keras/tf.layers.cs
@@ -134,7 +134,7 @@ namespace Tensorflow.Keras
///
///
///
- public Tensor max_pooling2d(Tensor inputs,
+ public Tensor MaxPooling2D(Tensor inputs,
int[] pool_size,
int[] strides,
string padding = "valid",
diff --git a/src/python/.vscode/launch.json b/src/python/.vscode/launch.json
new file mode 100644
index 00000000..2b2502c6
--- /dev/null
+++ b/src/python/.vscode/launch.json
@@ -0,0 +1,16 @@
+{
+ // Use IntelliSense to learn about possible attributes.
+ // Hover to view descriptions of existing attributes.
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
+ "version": "0.2.0",
+ "configurations": [
+ {
+ "name": "Python: Current File",
+ "type": "python",
+ "request": "launch",
+ "program": "${file}",
+ "console": "integratedTerminal",
+ "justMyCode": false
+ }
+ ]
+}
\ No newline at end of file
diff --git a/src/python/simple_rnn.py b/src/python/simple_rnn.py
new file mode 100644
index 00000000..97f9f3f3
--- /dev/null
+++ b/src/python/simple_rnn.py
@@ -0,0 +1,15 @@
+import numpy as np
+import tensorflow as tf
+
+# tf.experimental.numpy
+inputs = np.random.random([32, 10, 8]).astype(np.float32)
+simple_rnn = tf.keras.layers.SimpleRNN(4)
+
+output = simple_rnn(inputs) # The output has shape `[32, 4]`.
+
+simple_rnn = tf.keras.layers.SimpleRNN(
+ 4, return_sequences=True, return_state=True)
+
+# whole_sequence_output has shape `[32, 10, 4]`.
+# final_state has shape `[32, 4]`.
+whole_sequence_output, final_state = simple_rnn(inputs)
\ No newline at end of file
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
index 0c02b5db..02298ce8 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
@@ -83,7 +83,7 @@ namespace TensorFlowNET.Keras.UnitTest
{ 2.5f, 2.6f, 2.7f, 2.8f },
{ 3.5f, 3.6f, 3.7f, 3.8f }
} }, dtype: np.float32);
- var attention_layer = keras.layers.Attention();
+ var attention_layer = (Attention)keras.layers.Attention();
//attention_layer.build(((1, 2, 4), (1, 3, 4)));
var actual = attention_layer._calculate_scores(query: q, key: k);
// Expected tensor of shape [1, 2, 3].
@@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest
{ 2.5f, 2.6f, 2.7f, 2.8f },
{ 3.5f, 3.6f, 3.7f, 3.8f }
} }, dtype: np.float32);
- var attention_layer = keras.layers.Attention(score_mode: "concat");
+ var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat");
//attention_layer.concat_score_weight = 1;
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() {
Name = "concat_score_weight",
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
index 53a13394..f4fdf94a 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
@@ -148,10 +148,9 @@ namespace TensorFlowNET.Keras.UnitTest
}
[TestMethod]
- [Ignore]
public void SimpleRNN()
{
- var inputs = np.random.rand(32, 10, 8).astype(np.float32);
+ var inputs = np.random.random((32, 10, 8)).astype(np.float32);
var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);
diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
index 6d0b1ca3..fc693b1e 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
+++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
@@ -4,7 +4,7 @@
net6.0
false
-
+ 11.0
AnyCPU;x64
diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
index ffb583c9..36ff4a3d 100644
--- a/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj
@@ -11,7 +11,7 @@
Open.snk
- 9.0
+ 11.0
AnyCPU;x64