Browse Source

Define Keras interface in core project (WIP).

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
33707231ee
49 changed files with 610 additions and 205 deletions
  1. +2
    -2
      src/SciSharp.TensorFlow.Redist/README.md
  2. +3
    -0
      src/TensorFlowNET.Console/Program.cs
  3. +31
    -0
      src/TensorFlowNET.Console/SimpleRnnTest.cs
  4. +2
    -2
      src/TensorFlowNET.Console/Tensorflow.Console.csproj
  5. +0
    -22
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs
  6. +12
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
  8. +0
    -21
      src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs
  9. +45
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  10. +7
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  12. +0
    -30
      src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs
  13. +12
    -0
      src/TensorFlowNET.Core/Keras/IKerasApi.cs
  14. +16
    -0
      src/TensorFlowNET.Core/Keras/IPreprocessing.cs
  15. +20
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs
  16. +28
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs
  17. +13
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
  18. +10
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs
  19. +18
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs
  20. +169
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  21. +5
    -5
      src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs
  22. +1
    -0
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  23. +22
    -6
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  24. +3
    -0
      src/TensorFlowNET.Core/tensorflow.cs
  25. +3
    -0
      src/TensorFlowNET.Keras/KerasApi.cs
  26. +2
    -3
      src/TensorFlowNET.Keras/KerasInterface.cs
  27. +9
    -9
      src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs
  28. +2
    -2
      src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs
  29. +3
    -3
      src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
  30. +1
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs
  31. +8
    -8
      src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
  32. +43
    -55
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  33. +3
    -2
      src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs
  34. +2
    -2
      src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs
  35. +5
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  36. +31
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  37. +21
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  38. +2
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  39. +0
    -14
      src/TensorFlowNET.Keras/Layers/SimpleRNN.cs
  40. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs
  41. +2
    -2
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs
  42. +14
    -4
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  43. +1
    -1
      src/TensorFlowNET.Keras/tf.layers.cs
  44. +16
    -0
      src/python/.vscode/launch.json
  45. +15
    -0
      src/python/simple_rnn.py
  46. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
  47. +1
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs
  48. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
  49. +1
    -1
      test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj

+ 2
- 2
src/SciSharp.TensorFlow.Redist/README.md View File

@@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5

#### Download pre-build package

[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip)



@@ -35,6 +35,6 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.

1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`



+ 3
- 0
src/TensorFlowNET.Console/Program.cs View File

@@ -10,6 +10,9 @@ namespace Tensorflow
var diag = new Diagnostician();
// diag.Diagnose(@"D:\memory.txt");

var rnn = new SimpleRnnTest();
rnn.Run();

// this class is used explor new features.
var exploring = new Exploring();
// exploring.Run();


+ 31
- 0
src/TensorFlowNET.Console/SimpleRnnTest.cs View File

@@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow
{
public class SimpleRnnTest
{
public void Run()
{
tf.keras = new KerasInterface();
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
var simple_rnn = tf.keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
if (output.shape == (32, 4))
{

}
/*simple_rnn = tf.keras.layers.SimpleRNN(
4, return_sequences = True, return_state = True)

# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = simple_rnn(inputs)*/
}
}
}

+ 2
- 2
src/TensorFlowNET.Console/Tensorflow.Console.csproj View File

@@ -6,7 +6,7 @@
<RootNamespace>Tensorflow</RootNamespace>
<AssemblyName>Tensorflow</AssemblyName>
<Platforms>AnyCPU;x64</Platforms>
<LangVersion>9.0</LangVersion>
<LangVersion>11.0</LangVersion>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -20,7 +20,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.7.0" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" />
</ItemGroup>

<ItemGroup>


+ 0
- 22
src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMArgs.cs View File

@@ -1,22 +0,0 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class LSTMArgs : RNNArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
public Activation RecurrentActivation { get; set; }
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
public int Implementation { get; set; }
public bool ReturnSequences { get; set; }
public bool ReturnState { get; set; }
public bool GoBackwards { get; set; }
public bool Stateful { get; set; }
public bool TimeMajor { get; set; }
public bool Unroll { get; set; }
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs View File

@@ -0,0 +1,12 @@
using Tensorflow.Keras.ArgsDefinition.Rnn;

namespace Tensorflow.Keras.ArgsDefinition.Lstm
{
public class LSTMArgs : RNNArgs
{
public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
public int Implementation { get; set; }
}
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/LSTMCellArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs View File

@@ -1,4 +1,4 @@
namespace Tensorflow.Keras.ArgsDefinition
namespace Tensorflow.Keras.ArgsDefinition.Lstm
{
public class LSTMCellArgs : LayerArgs
{

+ 0
- 21
src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs View File

@@ -1,21 +0,0 @@
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition
{
public class RNNArgs : LayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}

public IRnnArgCell Cell { get; set; } = null;
public bool ReturnSequences { get; set; } = false;
public bool ReturnState { get; set; } = false;
public bool GoBackwards { get; set; } = false;
public bool Stateful { get; set; } = false;
public bool Unroll { get; set; } = false;
public bool TimeMajor { get; set; } = false;
public Dictionary<string, object> Kwargs { get; set; } = null;
}
}

+ 45
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -0,0 +1,45 @@
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class RNNArgs : LayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}

public IRnnArgCell Cell { get; set; } = null;
public bool ReturnSequences { get; set; } = false;
public bool ReturnState { get; set; } = false;
public bool GoBackwards { get; set; } = false;
public bool Stateful { get; set; } = false;
public bool Unroll { get; set; } = false;
public bool TimeMajor { get; set; } = false;
public Dictionary<string, object> Kwargs { get; set; } = null;

public int Units { get; set; }
public Activation Activation { get; set; }
public Activation RecurrentActivation { get; set; }
public bool UseBias { get; set; } = true;
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }

// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
}
}

+ 7
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs View File

@@ -0,0 +1,7 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class SimpleRNNArgs : RNNArgs
{

}
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -1,6 +1,6 @@
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition
namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class StackedRNNCellsArgs : LayerArgs
{

+ 0
- 30
src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs View File

@@ -1,30 +0,0 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNArgs : RNNArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
// units,
// activation='tanh',
// use_bias=True,
// kernel_initializer='glorot_uniform',
// recurrent_initializer='orthogonal',
// bias_initializer='zeros',
// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
}
}

+ 12
- 0
src/TensorFlowNET.Core/Keras/IKerasApi.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras
{
public interface IKerasApi
{
public ILayersApi layers { get; }
}
}

+ 16
- 0
src/TensorFlowNET.Core/Keras/IPreprocessing.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public interface IPreprocessing
{
public ILayer Resizing(int height, int width, string interpolation = "bilinear");
public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null,
string split = "whitespace",
int max_tokens = -1,
string output_mode = "int",
int output_sequence_length = -1);
}
}

+ 20
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Activation.cs View File

@@ -0,0 +1,20 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.NumPy;
using Tensorflow.Operations.Activation;

namespace Tensorflow.Keras.Layers
{
public partial interface ILayersApi
{
public ILayer ELU(float alpha = 0.1f);
public ILayer SELU();
public ILayer Softmax(Axis axis);
public ILayer Softplus();
public ILayer HardSigmoid();
public ILayer Softsign();
public ILayer Swish();
public ILayer Tanh();
public ILayer Exponential();
}
}

+ 28
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Attention.cs View File

@@ -0,0 +1,28 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Layers
{
public partial interface ILayersApi
{
public ILayer Attention(bool use_scale = false,
string score_mode = "dot",
bool causal = false,
float dropout = 0f);
public ILayer MultiHeadAttention(int num_heads,
int key_dim,
int? value_dim = null,
float dropout = 0f,
bool use_bias = true,
Shape output_shape = null,
Shape attention_axes = null,
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null,
IRegularizer kernel_regularizer = null,
IRegularizer bias_regularizer = null,
IRegularizer activity_regularizer = null,
Action kernel_constraint = null,
Action bias_constraint = null);
}
}

+ 13
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs View File

@@ -0,0 +1,13 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Layers
{
public partial interface ILayersApi
{
public ILayer Cropping1D(NDArray cropping);
public ILayer Cropping2D(NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last);
public ILayer Cropping3D(NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last);
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Merging.cs View File

@@ -0,0 +1,10 @@
using System;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Layers
{
public partial interface ILayersApi
{
public ILayer Concatenate(int axis = -1);
}
}

+ 18
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Reshaping.cs View File

@@ -0,0 +1,18 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.Layers
{
public partial interface ILayersApi
{
public ILayer Reshape(Shape target_shape);
public ILayer Reshape(object[] target_shape);

public ILayer UpSampling2D(Shape size = null,
string data_format = null,
string interpolation = "nearest");

public ILayer ZeroPadding2D(NDArray padding);
}
}

+ 169
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -0,0 +1,169 @@
using System;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

namespace Tensorflow.Keras.Layers
{
public partial interface ILayersApi
{
public IPreprocessing preprocessing { get; }

public ILayer Add();

public ILayer AveragePooling2D(Shape pool_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null);

public ILayer BatchNormalization(int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
bool center = true,
bool scale = true,
IInitializer beta_initializer = null,
IInitializer gamma_initializer = null,
IInitializer moving_mean_initializer = null,
IInitializer moving_variance_initializer = null,
bool trainable = true,
string name = null,
bool renorm = false,
float renorm_momentum = 0.99f);

public ILayer Conv1D(int filters,
Shape kernel_size,
int strides = 1,
string padding = "valid",
string data_format = "channels_last",
int dilation_rate = 1,
int groups = 1,
string activation = null,
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros");

public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null,
Shape dilation_rate = null,
int groups = 1,
Activation activation = null,
bool use_bias = true,
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null,
IRegularizer kernel_regularizer = null,
IRegularizer bias_regularizer = null,
IRegularizer activity_regularizer = null);

public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null,
Shape dilation_rate = null,
int groups = 1,
string activation = null,
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros");

public ILayer Dense(int units);
public ILayer Dense(int units,
string activation = null,
Shape input_shape = null);
public ILayer Dense(int units,
Activation activation = null,
IInitializer kernel_initializer = null,
bool use_bias = true,
IInitializer bias_initializer = null,
Shape input_shape = null);

public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null);

public ILayer Embedding(int input_dim,
int output_dim,
IInitializer embeddings_initializer = null,
bool mask_zero = false,
Shape input_shape = null,
int input_length = -1);

public ILayer EinsumDense(string equation,
Shape output_shape,
string bias_axes,
Activation activation = null,
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null,
IRegularizer kernel_regularizer = null,
IRegularizer bias_regularizer = null,
IRegularizer activity_regularizer = null,
Action kernel_constraint = null,
Action bias_constraint = null);

public ILayer Flatten(string data_format = null);

public ILayer GlobalAveragePooling1D(string data_format = "channels_last");
public ILayer GlobalAveragePooling2D();
public ILayer GlobalAveragePooling2D(string data_format = "channels_last");
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");

public Tensors Input(Shape shape,
string name = null,
bool sparse = false,
bool ragged = false);
public ILayer InputLayer(Shape input_shape,
string name = null,
bool sparse = false,
bool ragged = false);

public ILayer LayerNormalization(Axis? axis,
float epsilon = 1e-3f,
bool center = true,
bool scale = true,
IInitializer beta_initializer = null,
IInitializer gamma_initializer = null);

public ILayer LeakyReLU(float alpha = 0.3f);

public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,
bool use_bias = true,
IInitializer kernel_initializer = null,
IInitializer recurrent_initializer = null,
IInitializer bias_initializer = null,
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 2,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool time_major = false,
bool unroll = false);

public ILayer MaxPooling1D(int? pool_size = null,
int? strides = null,
string padding = "valid",
string data_format = null);
public ILayer MaxPooling2D(Shape pool_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null);

public ILayer Permute(int[] dims);

public ILayer Rescaling(float scale,
float offset = 0,
Shape input_shape = null);

public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros");

public ILayer Subtract();
}
}

+ 5
- 5
src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs View File

@@ -20,11 +20,11 @@ namespace Tensorflow.NumPy
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize);
}

public NDArray rand(params int[] shape)
=> throw new NotImplementedException("");
public NDArray random(Shape size)
=> uniform(low: 0, high: 1, size: size);

[AutoNumPy]
public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32)
public NDArray randint(int low, int? high = null, Shape? size = null, TF_DataType dtype = TF_DataType.TF_INT32)
{
if(high == null)
{
@@ -41,11 +41,11 @@ namespace Tensorflow.NumPy
=> new NDArray(random_ops.random_normal(shape ?? Shape.Scalar));

[AutoNumPy]
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null)
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape? size = null)
=> new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale));

[AutoNumPy]
public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape size = null)
public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape? size = null)
=> new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high));
}
}

+ 1
- 0
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -18,6 +18,7 @@ using System;
using System.Collections.Generic;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;
using Tensorflow.Util;


+ 22
- 6
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,8 +5,8 @@
<AssemblyName>Tensorflow.Binding</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.70.2</Version>
<LangVersion>9.0</LangVersion>
<Version>0.100.0</Version>
<LangVersion>10.0</LangVersion>
<Nullable>enable</Nullable>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
@@ -20,9 +20,9 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.70.1.0</AssemblyVersion>
<AssemblyVersion>0.100.0.0</AssemblyVersion>
<PackageReleaseNotes>
tf.net 0.70.x and above are based on tensorflow native 2.7.0
tf.net 0.100.x and above are based on tensorflow native 2.10.0

* Eager Mode is added finally.
* tf.keras is partially working.
@@ -35,14 +35,17 @@ https://tensorflownet.readthedocs.io</Description>

tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.</PackageReleaseNotes>
<FileVersion>0.70.1.0</FileVersion>
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
</PackageReleaseNotes>
<FileVersion>0.100.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<Platforms>AnyCPU;x64</Platforms>
<PackageId>TensorFlow.NET</PackageId>
<Configurations>Debug;Release;GPU</Configurations>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -51,6 +54,12 @@ https://tensorflownet.readthedocs.io</Description>
<PlatformTarget>AnyCPU</PlatformTarget>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE_1</DefineConstants>
<PlatformTarget>AnyCPU</PlatformTarget>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants>
@@ -58,6 +67,13 @@ https://tensorflownet.readthedocs.io</Description>
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>


+ 3
- 0
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -20,6 +20,7 @@ using System.Threading;
using Tensorflow.Contexts;
using Tensorflow.Eager;
using Tensorflow.Gradients;
using Tensorflow.Keras;

namespace Tensorflow
{
@@ -51,6 +52,8 @@ namespace Tensorflow
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner());
public IEagerRunner Runner => _runner.Value;

public IKerasApi keras { get; set; }

public tensorflow()
{
Logger = new LoggerConfiguration()


+ 3
- 0
src/TensorFlowNET.Keras/KerasApi.cs View File

@@ -2,6 +2,9 @@

namespace Tensorflow
{
/// <summary>
/// Deprecated, will use tf.keras
/// </summary>
public static class KerasApi
{
public static KerasInterface keras { get; } = new KerasInterface();


+ 2
- 3
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -10,18 +10,17 @@ using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using System.Threading;

namespace Tensorflow.Keras
{
public class KerasInterface
public class KerasInterface : IKerasApi
{
public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers();
public Regularizers regularizers { get; } = new Regularizers();
public LayersApi layers { get; } = new LayersApi();
public ILayersApi layers { get; } = new LayersApi();
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();


+ 9
- 9
src/TensorFlowNET.Keras/Layers/LayersApi.Activation.cs View File

@@ -7,16 +7,16 @@ using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers {
public partial class LayersApi {
public ELU ELU ( float alpha = 0.1f )
public ILayer ELU ( float alpha = 0.1f )
=> new ELU(new ELUArgs { Alpha = alpha });
public SELU SELU ()
public ILayer SELU ()
=> new SELU(new LayerArgs { });
public Softmax Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis });
public Softplus Softplus () => new Softplus(new LayerArgs { });
public HardSigmoid HardSigmoid () => new HardSigmoid(new LayerArgs { });
public Softsign Softsign () => new Softsign(new LayerArgs { });
public Swish Swish () => new Swish(new LayerArgs { });
public Tanh Tanh () => new Tanh(new LayerArgs { });
public Exponential Exponential () => new Exponential(new LayerArgs { });
public ILayer Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis });
public ILayer Softplus () => new Softplus(new LayerArgs { });
public ILayer HardSigmoid () => new HardSigmoid(new LayerArgs { });
public ILayer Softsign () => new Softsign(new LayerArgs { });
public ILayer Swish () => new Swish(new LayerArgs { });
public ILayer Tanh () => new Tanh(new LayerArgs { });
public ILayer Exponential () => new Exponential(new LayerArgs { });
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Layers/LayersApi.Attention.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers
{
public partial class LayersApi
{
public Attention Attention(bool use_scale = false,
public ILayer Attention(bool use_scale = false,
string score_mode = "dot",
bool causal = false,
float dropout = 0f) =>
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers
causal = causal,
dropout = dropout
});
public MultiHeadAttention MultiHeadAttention(int num_heads,
public ILayer MultiHeadAttention(int num_heads,
int key_dim,
int? value_dim = null,
float dropout = 0f,


+ 3
- 3
src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers {
/// Cropping layer for 1D input
/// </summary>
/// <param name="cropping">cropping size</param>
public Cropping1D Cropping1D ( NDArray cropping )
public ILayer Cropping1D ( NDArray cropping )
=> new Cropping1D(new CroppingArgs {
cropping = cropping
});
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers {
/// <summary>
/// Cropping layer for 2D input <br/>
/// </summary>
public Cropping2D Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last )
public ILayer Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last )
=> new Cropping2D(new Cropping2DArgs {
cropping = cropping,
data_format = data_format
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers {
/// <summary>
/// Cropping layer for 3D input <br/>
/// </summary>
public Cropping3D Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last )
public ILayer Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last )
=> new Cropping3D(new Cropping3DArgs {
cropping = cropping,
data_format = data_format


+ 1
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.Merging.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers
/// </summary>
/// <param name="axis">Axis along which to concatenate.</param>
/// <returns></returns>
public Concatenate Concatenate(int axis = -1)
public ILayer Concatenate(int axis = -1)
=> new Concatenate(new MergeArgs
{
Axis = axis


+ 8
- 8
src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
/// </summary>
/// <param name="padding"></param>
/// <returns></returns>
public ZeroPadding2D ZeroPadding2D ( NDArray padding )
public ILayer ZeroPadding2D ( NDArray padding )
=> new ZeroPadding2D(new ZeroPadding2DArgs {
Padding = padding
});
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers {
/// <param name="data_format"></param>
/// <param name="interpolation"></param>
/// <returns></returns>
public UpSampling2D UpSampling2D ( Shape size = null,
public ILayer UpSampling2D ( Shape size = null,
string data_format = null,
string interpolation = "nearest" )
=> new UpSampling2D(new UpSampling2DArgs {
@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers {
/// <summary>
/// Permutes the dimensions of the input according to a given pattern.
/// </summary>
public Permute Permute ( int[] dims )
public ILayer Permute ( int[] dims )
=> new Permute(new PermuteArgs {
dims = dims
});
@@ -44,12 +44,12 @@ namespace Tensorflow.Keras.Layers {
/// </summary>
/// <param name="target_shape"></param>
/// <returns></returns>
public Reshape Reshape ( Shape target_shape )
=> new Reshape(new ReshapeArgs {
TargetShape = target_shape
});
public ILayer Reshape ( Shape target_shape )
=> new Reshape(new ReshapeArgs {
TargetShape = target_shape
});

public Reshape Reshape ( object[] target_shape )
public ILayer Reshape ( object[] target_shape )
=> new Reshape(new ReshapeArgs {
TargetShapeObjects = target_shape
});


+ 43
- 55
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -1,16 +1,18 @@
using System;
using Tensorflow.NumPy;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Lstm;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Lstm;
using Tensorflow.Keras.Layers.Rnn;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{
public partial class LayersApi
public partial class LayersApi : ILayersApi
{
public Preprocessing preprocessing { get; } = new Preprocessing();
public IPreprocessing preprocessing { get; } = new Preprocessing();

/// <summary>
/// Layer that normalizes its inputs.
@@ -38,7 +40,7 @@ namespace Tensorflow.Keras.Layers
/// Note that momentum is still applied to get the means and variances for inference.
/// </param>
/// <returns>Tensor of the same shape as input.</returns>
public BatchNormalization BatchNormalization(int axis = -1,
public ILayer BatchNormalization(int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
bool center = true,
@@ -84,7 +86,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param>
/// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param>
/// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns>
public Conv1D Conv1D(int filters,
public ILayer Conv1D(int filters,
Shape kernel_size,
int strides = 1,
string padding = "valid",
@@ -131,7 +133,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="bias_regularizer">Regularizer function applied to the bias vector (see keras.regularizers).</param>
/// <param name="activity_regularizer">Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param>
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns>
public Conv2D Conv2D(int filters,
public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
@@ -184,7 +186,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param>
/// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param>
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns>
public Conv2D Conv2D(int filters,
public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
@@ -228,7 +230,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param>
/// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param>
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns>
public Conv2DTranspose Conv2DTranspose(int filters,
public ILayer Conv2DTranspose(int filters,
Shape kernel_size = null,
Shape strides = null,
string output_padding = "valid",
@@ -270,7 +272,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="bias_initializer">Initializer for the bias vector.</param>
/// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param>
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns>
public Dense Dense(int units,
public ILayer Dense(int units,
Activation activation = null,
IInitializer kernel_initializer = null,
bool use_bias = true,
@@ -294,7 +296,7 @@ namespace Tensorflow.Keras.Layers
/// </summary>
/// <param name="units">Positive integer, dimensionality of the output space.</param>
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns>
public Dense Dense(int units)
public ILayer Dense(int units)
=> new Dense(new DenseArgs
{
Units = units,
@@ -312,7 +314,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).</param>
/// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param>
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns>
public Dense Dense(int units,
public ILayer Dense(int units,
string activation = null,
Shape input_shape = null)
=> new Dense(new DenseArgs
@@ -364,7 +366,7 @@ namespace Tensorflow.Keras.Layers
}


public EinsumDense EinsumDense(string equation,
public ILayer EinsumDense(string equation,
Shape output_shape,
string bias_axes,
Activation activation = null,
@@ -402,7 +404,7 @@ namespace Tensorflow.Keras.Layers
/// </param>
/// <param name="seed">An integer to use as random seed.</param>
/// <returns></returns>
public Dropout Dropout(float rate, Shape noise_shape = null, int? seed = null)
public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null)
=> new Dropout(new DropoutArgs
{
Rate = rate,
@@ -421,7 +423,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param>
/// <param name="mask_zero"></param>
/// <returns></returns>
public Embedding Embedding(int input_dim,
public ILayer Embedding(int input_dim,
int output_dim,
IInitializer embeddings_initializer = null,
bool mask_zero = false,
@@ -446,7 +448,7 @@ namespace Tensorflow.Keras.Layers
/// If you never set it, then it will be "channels_last".
/// </param>
/// <returns></returns>
public Flatten Flatten(string data_format = null)
public ILayer Flatten(string data_format = null)
=> new Flatten(new FlattenArgs
{
DataFormat = data_format
@@ -482,7 +484,7 @@ namespace Tensorflow.Keras.Layers
return input_layer.InboundNodes[0].Outputs;
}

public InputLayer InputLayer(Shape input_shape,
public ILayer InputLayer(Shape input_shape,
string name = null,
bool sparse = false,
bool ragged = false)
@@ -502,7 +504,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="padding"></param>
/// <param name="data_format"></param>
/// <returns></returns>
public AveragePooling2D AveragePooling2D(Shape pool_size = null,
public ILayer AveragePooling2D(Shape pool_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null)
@@ -527,7 +529,7 @@ namespace Tensorflow.Keras.Layers
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps).
/// </param>
/// <returns></returns>
public MaxPooling1D MaxPooling1D(int? pool_size = null,
public ILayer MaxPooling1D(int? pool_size = null,
int? strides = null,
string padding = "valid",
string data_format = null)
@@ -564,7 +566,7 @@ namespace Tensorflow.Keras.Layers
/// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json.
/// If you never set it, then it will be "channels_last"</param>
/// <returns></returns>
public MaxPooling2D MaxPooling2D(Shape pool_size = null,
public ILayer MaxPooling2D(Shape pool_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null)
@@ -618,7 +620,7 @@ namespace Tensorflow.Keras.Layers
return layer.Apply(inputs);
}

public Layer LayerNormalization(Axis? axis,
public ILayer LayerNormalization(Axis? axis,
float epsilon = 1e-3f,
bool center = true,
bool scale = true,
@@ -638,45 +640,30 @@ namespace Tensorflow.Keras.Layers
/// </summary>
/// <param name="alpha">Negative slope coefficient.</param>
/// <returns></returns>
public Layer LeakyReLU(float alpha = 0.3f)
public ILayer LeakyReLU(float alpha = 0.3f)
=> new LeakyReLu(new LeakyReLuArgs
{
Alpha = alpha
});

/// <summary>
/// Fully-connected RNN where the output is to be fed back to input.
/// </summary>
/// <param name="units">Positive integer, dimensionality of the output space.</param>
/// <returns></returns>
public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh");

/// <summary>
/// Fully-connected RNN where the output is to be fed back to input.
/// </summary>
/// <param name="units">Positive integer, dimensionality of the output space.</param>
/// <param name="activation">Activation function to use. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x).</param>
/// <returns></returns>
public Layer SimpleRNN(int units,
Activation activation = null)
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = activation
});

/// <summary>
///
/// </summary>
/// <param name="units">Positive integer, dimensionality of the output space.</param>
/// <param name="activation">The name of the activation function to use. Default: hyperbolic tangent (tanh)..</param>
/// <returns></returns>
public Layer SimpleRNN(int units,
string activation = "tanh")
public ILayer SimpleRNN(int units,
string activation = "tanh",
string kernel_initializer = "glorot_uniform",
string recurrent_initializer = "orthogonal",
string bias_initializer = "zeros")
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = GetActivationByName(activation)
Activation = GetActivationByName(activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer= GetInitializerByName(recurrent_initializer),
BiasInitializer= GetInitializerByName(bias_initializer)
});

/// <summary>
@@ -706,7 +693,7 @@ namespace Tensorflow.Keras.Layers
/// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences.
/// </param>
/// <returns></returns>
public Layer LSTM(int units,
public ILayer LSTM(int units,
Activation activation = null,
Activation recurrent_activation = null,
bool use_bias = true,
@@ -749,7 +736,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="offset"></param>
/// <param name="input_shape"></param>
/// <returns></returns>
public Rescaling Rescaling(float scale,
public ILayer Rescaling(float scale,
float offset = 0,
Shape input_shape = null)
=> new Rescaling(new RescalingArgs
@@ -763,21 +750,21 @@ namespace Tensorflow.Keras.Layers
///
/// </summary>
/// <returns></returns>
public Add Add()
public ILayer Add()
=> new Add(new MergeArgs { });

/// <summary>
///
/// </summary>
/// <returns></returns>
public Subtract Subtract()
public ILayer Subtract()
=> new Subtract(new MergeArgs { });

/// <summary>
/// Global max pooling operation for spatial data.
/// </summary>
/// <returns></returns>
public GlobalAveragePooling2D GlobalAveragePooling2D()
public ILayer GlobalAveragePooling2D()
=> new GlobalAveragePooling2D(new Pooling2DArgs { });

/// <summary>
@@ -787,7 +774,7 @@ namespace Tensorflow.Keras.Layers
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps).
/// </param>
/// <returns></returns>
public GlobalAveragePooling1D GlobalAveragePooling1D(string data_format = "channels_last")
public ILayer GlobalAveragePooling1D(string data_format = "channels_last")
=> new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format });

/// <summary>
@@ -796,7 +783,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs.
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param>
/// <returns></returns>
public GlobalAveragePooling2D GlobalAveragePooling2D(string data_format = "channels_last")
public ILayer GlobalAveragePooling2D(string data_format = "channels_last")
=> new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format });

/// <summary>
@@ -807,7 +794,7 @@ namespace Tensorflow.Keras.Layers
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps).
/// </param>
/// <returns></returns>
public GlobalMaxPooling1D GlobalMaxPooling1D(string data_format = "channels_last")
public ILayer GlobalMaxPooling1D(string data_format = "channels_last")
=> new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format });

/// <summary>
@@ -816,7 +803,7 @@ namespace Tensorflow.Keras.Layers
/// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs.
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param>
/// <returns></returns>
public GlobalMaxPooling2D GlobalMaxPooling2D(string data_format = "channels_last")
public ILayer GlobalMaxPooling2D(string data_format = "channels_last")
=> new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format });


@@ -848,6 +835,7 @@ namespace Tensorflow.Keras.Layers
"glorot_uniform" => tf.glorot_uniform_initializer,
"zeros" => tf.zeros_initializer,
"ones" => tf.ones_initializer,
"orthogonal" => tf.orthogonal_initializer,
_ => tf.glorot_uniform_initializer
};
}


src/TensorFlowNET.Keras/Layers/LSTM.cs → src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs View File

@@ -1,8 +1,9 @@
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Lstm;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;

namespace Tensorflow.Keras.Layers
namespace Tensorflow.Keras.Layers.Lstm
{
/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.

src/TensorFlowNET.Keras/Layers/LSTMCell.cs → src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs View File

@@ -1,7 +1,7 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Lstm;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
namespace Tensorflow.Keras.Layers.Lstm
{
public class LSTMCell : Layer
{

src/TensorFlowNET.Keras/Layers/RNN.cs → src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -1,10 +1,12 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Lstm;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers
namespace Tensorflow.Keras.Layers.Rnn
{
public class RNN : Layer
{
@@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Layers
private object _states = null;
private object constants_spec = null;
private int _num_constants = 0;
protected IVariableV1 kernel;
protected IVariableV1 bias;

public RNN(RNNArgs args) : base(PreConstruct(args))
{

+ 31
- 0
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -0,0 +1,31 @@
using System.Data;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Operations.Activation;
using static HDF.PInvoke.H5Z;
using static Tensorflow.ApiDef.Types;

namespace Tensorflow.Keras.Layers.Rnn
{
public class SimpleRNN : RNN
{
SimpleRNNArgs args;
SimpleRNNCell cell;
public SimpleRNN(SimpleRNNArgs args) : base(args)
{
this.args = args;
}

protected override void build(Tensors inputs)
{
var input_shape = inputs.shape;
var input_dim = input_shape[-1];

kernel = add_weight("kernel", (input_shape[-1], args.Units),
initializer: args.KernelInitializer
//regularizer = self.kernel_regularizer,
//constraint = self.kernel_constraint,
//caching_device = default_caching_device,
);
}
}
}

+ 21
- 0
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers.Rnn
{
public class SimpleRNNCell : Layer
{
public SimpleRNNCell(SimpleRNNArgs args) : base(args)
{

}

protected override void build(Tensors inputs)
{
}
}
}

src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs → src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -2,9 +2,10 @@
using System.Collections.Generic;
using System.ComponentModel;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
namespace Tensorflow.Keras.Layers.Rnn
{
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
{

+ 0
- 14
src/TensorFlowNET.Keras/Layers/SimpleRNN.cs View File

@@ -1,14 +0,0 @@
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Layers
{
public class SimpleRNN : RNN
{

public SimpleRNN(RNNArgs args) : base(args)
{

}

}
}

+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow.Keras
/// <param name="width"></param>
/// <param name="interpolation"></param>
/// <returns></returns>
public Resizing Resizing(int height, int width, string interpolation = "bilinear")
public ILayer Resizing(int height, int width, string interpolation = "bilinear")
=> new Resizing(new ResizingArgs
{
Height = height,


+ 2
- 2
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs View File

@@ -5,7 +5,7 @@ using Tensorflow.Keras.Preprocessings;

namespace Tensorflow.Keras
{
public partial class Preprocessing
public partial class Preprocessing : IPreprocessing
{
public Sequence sequence => new Sequence();
public DatasetUtils dataset_utils => new DatasetUtils();
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras

private static TextApi _text = new TextApi();

public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null,
public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null,
string split = "whitespace",
int max_tokens = -1,
string output_mode = "int",


+ 14
- 4
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -3,11 +3,11 @@
<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<AssemblyName>Tensorflow.Keras</AssemblyName>
<LangVersion>9.0</LangVersion>
<LangVersion>10.0</LangVersion>
<Nullable>enable</Nullable>
<RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms>
<Version>0.7.0</Version>
<Version>0.10.0</Version>
<Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2021</Copyright>
@@ -37,9 +37,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.7.0.0</AssemblyVersion>
<FileVersion>0.7.0.0</FileVersion>
<AssemblyVersion>0.10.0.0</AssemblyVersion>
<FileVersion>0.10.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<Configurations>Debug;Release;GPU</Configurations>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -47,6 +48,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup>
@@ -55,6 +61,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'">
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
<DefineConstants />
</PropertyGroup>


+ 1
- 1
src/TensorFlowNET.Keras/tf.layers.cs View File

@@ -134,7 +134,7 @@ namespace Tensorflow.Keras
/// <param name="data_format"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor max_pooling2d(Tensor inputs,
public Tensor MaxPooling2D(Tensor inputs,
int[] pool_size,
int[] strides,
string padding = "valid",


+ 16
- 0
src/python/.vscode/launch.json View File

@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
}
]
}

+ 15
- 0
src/python/simple_rnn.py View File

@@ -0,0 +1,15 @@
import numpy as np
import tensorflow as tf

# tf.experimental.numpy
inputs = np.random.random([32, 10, 8]).astype(np.float32)
simple_rnn = tf.keras.layers.SimpleRNN(4)

output = simple_rnn(inputs) # The output has shape `[32, 4]`.

simple_rnn = tf.keras.layers.SimpleRNN(
4, return_sequences=True, return_state=True)

# whole_sequence_output has shape `[32, 10, 4]`.
# final_state has shape `[32, 4]`.
whole_sequence_output, final_state = simple_rnn(inputs)

+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs View File

@@ -83,7 +83,7 @@ namespace TensorFlowNET.Keras.UnitTest
{ 2.5f, 2.6f, 2.7f, 2.8f },
{ 3.5f, 3.6f, 3.7f, 3.8f }
} }, dtype: np.float32);
var attention_layer = keras.layers.Attention();
var attention_layer = (Attention)keras.layers.Attention();
//attention_layer.build(((1, 2, 4), (1, 3, 4)));
var actual = attention_layer._calculate_scores(query: q, key: k);
// Expected tensor of shape [1, 2, 3].
@@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest
{ 2.5f, 2.6f, 2.7f, 2.8f },
{ 3.5f, 3.6f, 3.7f, 3.8f }
} }, dtype: np.float32);
var attention_layer = keras.layers.Attention(score_mode: "concat");
var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat");
//attention_layer.concat_score_weight = 1;
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() {
Name = "concat_score_weight",


+ 1
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -148,10 +148,9 @@ namespace TensorFlowNET.Keras.UnitTest
}

[TestMethod]
[Ignore]
public void SimpleRNN()
{
var inputs = np.random.rand(32, 10, 8).astype(np.float32);
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);


+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj View File

@@ -4,7 +4,7 @@
<TargetFramework>net6.0</TargetFramework>

<IsPackable>false</IsPackable>
<LangVersion>11.0</LangVersion>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>



+ 1
- 1
test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj View File

@@ -11,7 +11,7 @@

<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>

<LangVersion>9.0</LangVersion>
<LangVersion>11.0</LangVersion>

<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>


Loading…
Cancel
Save