@@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 | |||
#### Download pre-build package | |||
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) | |||
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip) | |||
@@ -35,6 +35,6 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 | |||
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | |||
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | |||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
@@ -10,6 +10,9 @@ namespace Tensorflow | |||
var diag = new Diagnostician(); | |||
// diag.Diagnose(@"D:\memory.txt"); | |||
var rnn = new SimpleRnnTest(); | |||
rnn.Run(); | |||
// this class is used explor new features. | |||
var exploring = new Exploring(); | |||
// exploring.Run(); | |||
@@ -0,0 +1,31 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow | |||
{ | |||
public class SimpleRnnTest | |||
{ | |||
public void Run() | |||
{ | |||
tf.keras = new KerasInterface(); | |||
var inputs = np.random.random((32, 10, 8)).astype(np.float32); | |||
var simple_rnn = tf.keras.layers.SimpleRNN(4); | |||
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. | |||
if (output.shape == (32, 4)) | |||
{ | |||
} | |||
/*simple_rnn = tf.keras.layers.SimpleRNN( | |||
4, return_sequences = True, return_state = True) | |||
# whole_sequence_output has shape `[32, 10, 4]`. | |||
# final_state has shape `[32, 4]`. | |||
whole_sequence_output, final_state = simple_rnn(inputs)*/ | |||
} | |||
} | |||
} |
@@ -6,7 +6,7 @@ | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<AssemblyName>Tensorflow</AssemblyName> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
<LangVersion>9.0</LangVersion> | |||
<LangVersion>11.0</LangVersion> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
@@ -20,7 +20,7 @@ | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.7.0" /> | |||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
@@ -1,22 +0,0 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class LSTMArgs : RNNArgs | |||
{ | |||
public int Units { get; set; } | |||
public Activation Activation { get; set; } | |||
public Activation RecurrentActivation { get; set; } | |||
public IInitializer KernelInitializer { get; set; } | |||
public IInitializer RecurrentInitializer { get; set; } | |||
public IInitializer BiasInitializer { get; set; } | |||
public bool UnitForgetBias { get; set; } | |||
public float Dropout { get; set; } | |||
public float RecurrentDropout { get; set; } | |||
public int Implementation { get; set; } | |||
public bool ReturnSequences { get; set; } | |||
public bool ReturnState { get; set; } | |||
public bool GoBackwards { get; set; } | |||
public bool Stateful { get; set; } | |||
public bool TimeMajor { get; set; } | |||
public bool Unroll { get; set; } | |||
} | |||
} |
@@ -0,0 +1,12 @@ | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||
{ | |||
public class LSTMArgs : RNNArgs | |||
{ | |||
public bool UnitForgetBias { get; set; } | |||
public float Dropout { get; set; } | |||
public float RecurrentDropout { get; set; } | |||
public int Implementation { get; set; } | |||
} | |||
} |
@@ -1,4 +1,4 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||
{ | |||
public class LSTMCellArgs : LayerArgs | |||
{ |
@@ -1,21 +0,0 @@ | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class RNNArgs : LayerArgs | |||
{ | |||
public interface IRnnArgCell : ILayer | |||
{ | |||
object state_size { get; } | |||
} | |||
public IRnnArgCell Cell { get; set; } = null; | |||
public bool ReturnSequences { get; set; } = false; | |||
public bool ReturnState { get; set; } = false; | |||
public bool GoBackwards { get; set; } = false; | |||
public bool Stateful { get; set; } = false; | |||
public bool Unroll { get; set; } = false; | |||
public bool TimeMajor { get; set; } = false; | |||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||
} | |||
} |
@@ -0,0 +1,45 @@ | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
{ | |||
public class RNNArgs : LayerArgs | |||
{ | |||
public interface IRnnArgCell : ILayer | |||
{ | |||
object state_size { get; } | |||
} | |||
public IRnnArgCell Cell { get; set; } = null; | |||
public bool ReturnSequences { get; set; } = false; | |||
public bool ReturnState { get; set; } = false; | |||
public bool GoBackwards { get; set; } = false; | |||
public bool Stateful { get; set; } = false; | |||
public bool Unroll { get; set; } = false; | |||
public bool TimeMajor { get; set; } = false; | |||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||
public int Units { get; set; } | |||
public Activation Activation { get; set; } | |||
public Activation RecurrentActivation { get; set; } | |||
public bool UseBias { get; set; } = true; | |||
public IInitializer KernelInitializer { get; set; } | |||
public IInitializer RecurrentInitializer { get; set; } | |||
public IInitializer BiasInitializer { get; set; } | |||
// kernel_regularizer=None, | |||
// recurrent_regularizer=None, | |||
// bias_regularizer=None, | |||
// activity_regularizer=None, | |||
// kernel_constraint=None, | |||
// recurrent_constraint=None, | |||
// bias_constraint=None, | |||
// dropout=0., | |||
// recurrent_dropout=0., | |||
// return_sequences=False, | |||
// return_state=False, | |||
// go_backwards=False, | |||
// stateful=False, | |||
// unroll=False, | |||
// **kwargs): | |||
} | |||
} |
@@ -0,0 +1,7 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
{ | |||
public class SimpleRNNArgs : RNNArgs | |||
{ | |||
} | |||
} |
@@ -1,6 +1,6 @@ | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||
{ | |||
public class StackedRNNCellsArgs : LayerArgs | |||
{ |
@@ -1,30 +0,0 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SimpleRNNArgs : RNNArgs | |||
{ | |||
public int Units { get; set; } | |||
public Activation Activation { get; set; } | |||
// units, | |||
// activation='tanh', | |||
// use_bias=True, | |||
// kernel_initializer='glorot_uniform', | |||
// recurrent_initializer='orthogonal', | |||
// bias_initializer='zeros', | |||
// kernel_regularizer=None, | |||
// recurrent_regularizer=None, | |||
// bias_regularizer=None, | |||
// activity_regularizer=None, | |||
// kernel_constraint=None, | |||
// recurrent_constraint=None, | |||
// bias_constraint=None, | |||
// dropout=0., | |||
// recurrent_dropout=0., | |||
// return_sequences=False, | |||
// return_state=False, | |||
// go_backwards=False, | |||
// stateful=False, | |||
// unroll=False, | |||
// **kwargs): | |||
} | |||
} |
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Layers; | |||
namespace Tensorflow.Keras | |||
{ | |||
public interface IKerasApi | |||
{ | |||
public ILayersApi layers { get; } | |||
} | |||
} |
@@ -0,0 +1,16 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras | |||
{ | |||
public interface IPreprocessing | |||
{ | |||
public ILayer Resizing(int height, int width, string interpolation = "bilinear"); | |||
public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
string split = "whitespace", | |||
int max_tokens = -1, | |||
string output_mode = "int", | |||
int output_sequence_length = -1); | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using System; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Operations.Activation; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial interface ILayersApi | |||
{ | |||
public ILayer ELU(float alpha = 0.1f); | |||
public ILayer SELU(); | |||
public ILayer Softmax(Axis axis); | |||
public ILayer Softplus(); | |||
public ILayer HardSigmoid(); | |||
public ILayer Softsign(); | |||
public ILayer Swish(); | |||
public ILayer Tanh(); | |||
public ILayer Exponential(); | |||
} | |||
} |
@@ -0,0 +1,28 @@ | |||
using System; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial interface ILayersApi | |||
{ | |||
public ILayer Attention(bool use_scale = false, | |||
string score_mode = "dot", | |||
bool causal = false, | |||
float dropout = 0f); | |||
public ILayer MultiHeadAttention(int num_heads, | |||
int key_dim, | |||
int? value_dim = null, | |||
float dropout = 0f, | |||
bool use_bias = true, | |||
Shape output_shape = null, | |||
Shape attention_axes = null, | |||
IInitializer kernel_initializer = null, | |||
IInitializer bias_initializer = null, | |||
IRegularizer kernel_regularizer = null, | |||
IRegularizer bias_regularizer = null, | |||
IRegularizer activity_regularizer = null, | |||
Action kernel_constraint = null, | |||
Action bias_constraint = null); | |||
} | |||
} |
@@ -0,0 +1,13 @@ | |||
using System; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial interface ILayersApi | |||
{ | |||
public ILayer Cropping1D(NDArray cropping); | |||
public ILayer Cropping2D(NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last); | |||
public ILayer Cropping3D(NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last); | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial interface ILayersApi | |||
{ | |||
public ILayer Concatenate(int axis = -1); | |||
} | |||
} |
@@ -0,0 +1,18 @@ | |||
using System; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial interface ILayersApi | |||
{ | |||
public ILayer Reshape(Shape target_shape); | |||
public ILayer Reshape(object[] target_shape); | |||
public ILayer UpSampling2D(Shape size = null, | |||
string data_format = null, | |||
string interpolation = "nearest"); | |||
public ILayer ZeroPadding2D(NDArray padding); | |||
} | |||
} |
@@ -0,0 +1,169 @@ | |||
using System; | |||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial interface ILayersApi | |||
{ | |||
public IPreprocessing preprocessing { get; } | |||
public ILayer Add(); | |||
public ILayer AveragePooling2D(Shape pool_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
string data_format = null); | |||
public ILayer BatchNormalization(int axis = -1, | |||
float momentum = 0.99f, | |||
float epsilon = 0.001f, | |||
bool center = true, | |||
bool scale = true, | |||
IInitializer beta_initializer = null, | |||
IInitializer gamma_initializer = null, | |||
IInitializer moving_mean_initializer = null, | |||
IInitializer moving_variance_initializer = null, | |||
bool trainable = true, | |||
string name = null, | |||
bool renorm = false, | |||
float renorm_momentum = 0.99f); | |||
public ILayer Conv1D(int filters, | |||
Shape kernel_size, | |||
int strides = 1, | |||
string padding = "valid", | |||
string data_format = "channels_last", | |||
int dilation_rate = 1, | |||
int groups = 1, | |||
string activation = null, | |||
bool use_bias = true, | |||
string kernel_initializer = "glorot_uniform", | |||
string bias_initializer = "zeros"); | |||
public ILayer Conv2D(int filters, | |||
Shape kernel_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
string data_format = null, | |||
Shape dilation_rate = null, | |||
int groups = 1, | |||
Activation activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer bias_initializer = null, | |||
IRegularizer kernel_regularizer = null, | |||
IRegularizer bias_regularizer = null, | |||
IRegularizer activity_regularizer = null); | |||
public ILayer Conv2D(int filters, | |||
Shape kernel_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
string data_format = null, | |||
Shape dilation_rate = null, | |||
int groups = 1, | |||
string activation = null, | |||
bool use_bias = true, | |||
string kernel_initializer = "glorot_uniform", | |||
string bias_initializer = "zeros"); | |||
public ILayer Dense(int units); | |||
public ILayer Dense(int units, | |||
string activation = null, | |||
Shape input_shape = null); | |||
public ILayer Dense(int units, | |||
Activation activation = null, | |||
IInitializer kernel_initializer = null, | |||
bool use_bias = true, | |||
IInitializer bias_initializer = null, | |||
Shape input_shape = null); | |||
public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null); | |||
public ILayer Embedding(int input_dim, | |||
int output_dim, | |||
IInitializer embeddings_initializer = null, | |||
bool mask_zero = false, | |||
Shape input_shape = null, | |||
int input_length = -1); | |||
public ILayer EinsumDense(string equation, | |||
Shape output_shape, | |||
string bias_axes, | |||
Activation activation = null, | |||
IInitializer kernel_initializer = null, | |||
IInitializer bias_initializer = null, | |||
IRegularizer kernel_regularizer = null, | |||
IRegularizer bias_regularizer = null, | |||
IRegularizer activity_regularizer = null, | |||
Action kernel_constraint = null, | |||
Action bias_constraint = null); | |||
public ILayer Flatten(string data_format = null); | |||
public ILayer GlobalAveragePooling1D(string data_format = "channels_last"); | |||
public ILayer GlobalAveragePooling2D(); | |||
public ILayer GlobalAveragePooling2D(string data_format = "channels_last"); | |||
public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | |||
public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | |||
public Tensors Input(Shape shape, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false); | |||
public ILayer InputLayer(Shape input_shape, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false); | |||
public ILayer LayerNormalization(Axis? axis, | |||
float epsilon = 1e-3f, | |||
bool center = true, | |||
bool scale = true, | |||
IInitializer beta_initializer = null, | |||
IInitializer gamma_initializer = null); | |||
public ILayer LeakyReLU(float alpha = 0.3f); | |||
public ILayer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer recurrent_initializer = null, | |||
IInitializer bias_initializer = null, | |||
bool unit_forget_bias = true, | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
int implementation = 2, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool time_major = false, | |||
bool unroll = false); | |||
public ILayer MaxPooling1D(int? pool_size = null, | |||
int? strides = null, | |||
string padding = "valid", | |||
string data_format = null); | |||
public ILayer MaxPooling2D(Shape pool_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
string data_format = null); | |||
public ILayer Permute(int[] dims); | |||
public ILayer Rescaling(float scale, | |||
float offset = 0, | |||
Shape input_shape = null); | |||
public ILayer SimpleRNN(int units, | |||
string activation = "tanh", | |||
string kernel_initializer = "glorot_uniform", | |||
string recurrent_initializer = "orthogonal", | |||
string bias_initializer = "zeros"); | |||
public ILayer Subtract(); | |||
} | |||
} |
@@ -20,11 +20,11 @@ namespace Tensorflow.NumPy | |||
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize); | |||
} | |||
public NDArray rand(params int[] shape) | |||
=> throw new NotImplementedException(""); | |||
public NDArray random(Shape size) | |||
=> uniform(low: 0, high: 1, size: size); | |||
[AutoNumPy] | |||
public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32) | |||
public NDArray randint(int low, int? high = null, Shape? size = null, TF_DataType dtype = TF_DataType.TF_INT32) | |||
{ | |||
if(high == null) | |||
{ | |||
@@ -41,11 +41,11 @@ namespace Tensorflow.NumPy | |||
=> new NDArray(random_ops.random_normal(shape ?? Shape.Scalar)); | |||
[AutoNumPy] | |||
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null) | |||
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape? size = null) | |||
=> new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale)); | |||
[AutoNumPy] | |||
public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape size = null) | |||
public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape? size = null) | |||
=> new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high)); | |||
} | |||
} |
@@ -18,6 +18,7 @@ using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Util; | |||
@@ -5,8 +5,8 @@ | |||
<AssemblyName>Tensorflow.Binding</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
<Version>0.70.2</Version> | |||
<LangVersion>9.0</LangVersion> | |||
<Version>0.100.0</Version> | |||
<LangVersion>10.0</LangVersion> | |||
<Nullable>enable</Nullable> | |||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
<Company>SciSharp STACK</Company> | |||
@@ -20,9 +20,9 @@ | |||
<Description>Google's TensorFlow full binding in .NET Standard. | |||
Building, training and infering deep learning models. | |||
https://tensorflownet.readthedocs.io</Description> | |||
<AssemblyVersion>0.70.1.0</AssemblyVersion> | |||
<AssemblyVersion>0.100.0.0</AssemblyVersion> | |||
<PackageReleaseNotes> | |||
tf.net 0.70.x and above are based on tensorflow native 2.7.0 | |||
tf.net 0.100.x and above are based on tensorflow native 2.10.0 | |||
* Eager Mode is added finally. | |||
* tf.keras is partially working. | |||
@@ -35,14 +35,17 @@ https://tensorflownet.readthedocs.io</Description> | |||
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. | |||
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. | |||
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.</PackageReleaseNotes> | |||
<FileVersion>0.70.1.0</FileVersion> | |||
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. | |||
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. | |||
</PackageReleaseNotes> | |||
<FileVersion>0.100.0.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
<PackageId>TensorFlow.NET</PackageId> | |||
<Configurations>Debug;Release;GPU</Configurations> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
@@ -51,6 +54,12 @@ https://tensorflownet.readthedocs.io</Description> | |||
<PlatformTarget>AnyCPU</PlatformTarget> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE_1</DefineConstants> | |||
<PlatformTarget>AnyCPU</PlatformTarget> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | |||
@@ -58,6 +67,13 @@ https://tensorflownet.readthedocs.io</Description> | |||
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | |||
<PlatformTarget>x64</PlatformTarget> | |||
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
@@ -20,6 +20,7 @@ using System.Threading; | |||
using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Keras; | |||
namespace Tensorflow | |||
{ | |||
@@ -51,6 +52,8 @@ namespace Tensorflow | |||
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | |||
public IEagerRunner Runner => _runner.Value; | |||
public IKerasApi keras { get; set; } | |||
public tensorflow() | |||
{ | |||
Logger = new LoggerConfiguration() | |||
@@ -2,6 +2,9 @@ | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Deprecated, will use tf.keras | |||
/// </summary> | |||
public static class KerasApi | |||
{ | |||
public static KerasInterface keras { get; } = new KerasInterface(); | |||
@@ -10,18 +10,17 @@ using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Models; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.Keras.Utils; | |||
using System.Threading; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class KerasInterface | |||
public class KerasInterface : IKerasApi | |||
{ | |||
public KerasDataset datasets { get; } = new KerasDataset(); | |||
public Initializers initializers { get; } = new Initializers(); | |||
public Regularizers regularizers { get; } = new Regularizers(); | |||
public LayersApi layers { get; } = new LayersApi(); | |||
public ILayersApi layers { get; } = new LayersApi(); | |||
public LossesApi losses { get; } = new LossesApi(); | |||
public Activations activations { get; } = new Activations(); | |||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
@@ -7,16 +7,16 @@ using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Layers { | |||
public partial class LayersApi { | |||
public ELU ELU ( float alpha = 0.1f ) | |||
public ILayer ELU ( float alpha = 0.1f ) | |||
=> new ELU(new ELUArgs { Alpha = alpha }); | |||
public SELU SELU () | |||
public ILayer SELU () | |||
=> new SELU(new LayerArgs { }); | |||
public Softmax Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); | |||
public Softplus Softplus () => new Softplus(new LayerArgs { }); | |||
public HardSigmoid HardSigmoid () => new HardSigmoid(new LayerArgs { }); | |||
public Softsign Softsign () => new Softsign(new LayerArgs { }); | |||
public Swish Swish () => new Swish(new LayerArgs { }); | |||
public Tanh Tanh () => new Tanh(new LayerArgs { }); | |||
public Exponential Exponential () => new Exponential(new LayerArgs { }); | |||
public ILayer Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); | |||
public ILayer Softplus () => new Softplus(new LayerArgs { }); | |||
public ILayer HardSigmoid () => new HardSigmoid(new LayerArgs { }); | |||
public ILayer Softsign () => new Softsign(new LayerArgs { }); | |||
public ILayer Swish () => new Swish(new LayerArgs { }); | |||
public ILayer Tanh () => new Tanh(new LayerArgs { }); | |||
public ILayer Exponential () => new Exponential(new LayerArgs { }); | |||
} | |||
} |
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial class LayersApi | |||
{ | |||
public Attention Attention(bool use_scale = false, | |||
public ILayer Attention(bool use_scale = false, | |||
string score_mode = "dot", | |||
bool causal = false, | |||
float dropout = 0f) => | |||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers | |||
causal = causal, | |||
dropout = dropout | |||
}); | |||
public MultiHeadAttention MultiHeadAttention(int num_heads, | |||
public ILayer MultiHeadAttention(int num_heads, | |||
int key_dim, | |||
int? value_dim = null, | |||
float dropout = 0f, | |||
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers { | |||
/// Cropping layer for 1D input | |||
/// </summary> | |||
/// <param name="cropping">cropping size</param> | |||
public Cropping1D Cropping1D ( NDArray cropping ) | |||
public ILayer Cropping1D ( NDArray cropping ) | |||
=> new Cropping1D(new CroppingArgs { | |||
cropping = cropping | |||
}); | |||
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers { | |||
/// <summary> | |||
/// Cropping layer for 2D input <br/> | |||
/// </summary> | |||
public Cropping2D Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) | |||
public ILayer Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) | |||
=> new Cropping2D(new Cropping2DArgs { | |||
cropping = cropping, | |||
data_format = data_format | |||
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers { | |||
/// <summary> | |||
/// Cropping layer for 3D input <br/> | |||
/// </summary> | |||
public Cropping3D Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) | |||
public ILayer Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) | |||
=> new Cropping3D(new Cropping3DArgs { | |||
cropping = cropping, | |||
data_format = data_format | |||
@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||
/// </summary> | |||
/// <param name="axis">Axis along which to concatenate.</param> | |||
/// <returns></returns> | |||
public Concatenate Concatenate(int axis = -1) | |||
public ILayer Concatenate(int axis = -1) | |||
=> new Concatenate(new MergeArgs | |||
{ | |||
Axis = axis | |||
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||
/// </summary> | |||
/// <param name="padding"></param> | |||
/// <returns></returns> | |||
public ZeroPadding2D ZeroPadding2D ( NDArray padding ) | |||
public ILayer ZeroPadding2D ( NDArray padding ) | |||
=> new ZeroPadding2D(new ZeroPadding2DArgs { | |||
Padding = padding | |||
}); | |||
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers { | |||
/// <param name="data_format"></param> | |||
/// <param name="interpolation"></param> | |||
/// <returns></returns> | |||
public UpSampling2D UpSampling2D ( Shape size = null, | |||
public ILayer UpSampling2D ( Shape size = null, | |||
string data_format = null, | |||
string interpolation = "nearest" ) | |||
=> new UpSampling2D(new UpSampling2DArgs { | |||
@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers { | |||
/// <summary> | |||
/// Permutes the dimensions of the input according to a given pattern. | |||
/// </summary> | |||
public Permute Permute ( int[] dims ) | |||
public ILayer Permute ( int[] dims ) | |||
=> new Permute(new PermuteArgs { | |||
dims = dims | |||
}); | |||
@@ -44,12 +44,12 @@ namespace Tensorflow.Keras.Layers { | |||
/// </summary> | |||
/// <param name="target_shape"></param> | |||
/// <returns></returns> | |||
public Reshape Reshape ( Shape target_shape ) | |||
=> new Reshape(new ReshapeArgs { | |||
TargetShape = target_shape | |||
}); | |||
public ILayer Reshape ( Shape target_shape ) | |||
=> new Reshape(new ReshapeArgs { | |||
TargetShape = target_shape | |||
}); | |||
public Reshape Reshape ( object[] target_shape ) | |||
public ILayer Reshape ( object[] target_shape ) | |||
=> new Reshape(new ReshapeArgs { | |||
TargetShapeObjects = target_shape | |||
}); | |||
@@ -1,16 +1,18 @@ | |||
using System; | |||
using Tensorflow.NumPy; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Lstm; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers.Lstm; | |||
using Tensorflow.Keras.Layers.Rnn; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public partial class LayersApi | |||
public partial class LayersApi : ILayersApi | |||
{ | |||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||
public IPreprocessing preprocessing { get; } = new Preprocessing(); | |||
/// <summary> | |||
/// Layer that normalizes its inputs. | |||
@@ -38,7 +40,7 @@ namespace Tensorflow.Keras.Layers | |||
/// Note that momentum is still applied to get the means and variances for inference. | |||
/// </param> | |||
/// <returns>Tensor of the same shape as input.</returns> | |||
public BatchNormalization BatchNormalization(int axis = -1, | |||
public ILayer BatchNormalization(int axis = -1, | |||
float momentum = 0.99f, | |||
float epsilon = 0.001f, | |||
bool center = true, | |||
@@ -84,7 +86,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param> | |||
/// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param> | |||
/// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns> | |||
public Conv1D Conv1D(int filters, | |||
public ILayer Conv1D(int filters, | |||
Shape kernel_size, | |||
int strides = 1, | |||
string padding = "valid", | |||
@@ -131,7 +133,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="bias_regularizer">Regularizer function applied to the bias vector (see keras.regularizers).</param> | |||
/// <param name="activity_regularizer">Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | |||
public Conv2D Conv2D(int filters, | |||
public ILayer Conv2D(int filters, | |||
Shape kernel_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
@@ -184,7 +186,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | |||
/// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | |||
public Conv2D Conv2D(int filters, | |||
public ILayer Conv2D(int filters, | |||
Shape kernel_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
@@ -228,7 +230,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | |||
/// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | |||
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | |||
public Conv2DTranspose Conv2DTranspose(int filters, | |||
public ILayer Conv2DTranspose(int filters, | |||
Shape kernel_size = null, | |||
Shape strides = null, | |||
string output_padding = "valid", | |||
@@ -270,7 +272,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="bias_initializer">Initializer for the bias vector.</param> | |||
/// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | |||
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | |||
public Dense Dense(int units, | |||
public ILayer Dense(int units, | |||
Activation activation = null, | |||
IInitializer kernel_initializer = null, | |||
bool use_bias = true, | |||
@@ -294,7 +296,7 @@ namespace Tensorflow.Keras.Layers | |||
/// </summary> | |||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | |||
public Dense Dense(int units) | |||
public ILayer Dense(int units) | |||
=> new Dense(new DenseArgs | |||
{ | |||
Units = units, | |||
@@ -312,7 +314,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).</param> | |||
/// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | |||
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | |||
public Dense Dense(int units, | |||
public ILayer Dense(int units, | |||
string activation = null, | |||
Shape input_shape = null) | |||
=> new Dense(new DenseArgs | |||
@@ -364,7 +366,7 @@ namespace Tensorflow.Keras.Layers | |||
} | |||
public EinsumDense EinsumDense(string equation, | |||
public ILayer EinsumDense(string equation, | |||
Shape output_shape, | |||
string bias_axes, | |||
Activation activation = null, | |||
@@ -402,7 +404,7 @@ namespace Tensorflow.Keras.Layers | |||
/// </param> | |||
/// <param name="seed">An integer to use as random seed.</param> | |||
/// <returns></returns> | |||
public Dropout Dropout(float rate, Shape noise_shape = null, int? seed = null) | |||
public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null) | |||
=> new Dropout(new DropoutArgs | |||
{ | |||
Rate = rate, | |||
@@ -421,7 +423,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param> | |||
/// <param name="mask_zero"></param> | |||
/// <returns></returns> | |||
public Embedding Embedding(int input_dim, | |||
public ILayer Embedding(int input_dim, | |||
int output_dim, | |||
IInitializer embeddings_initializer = null, | |||
bool mask_zero = false, | |||
@@ -446,7 +448,7 @@ namespace Tensorflow.Keras.Layers | |||
/// If you never set it, then it will be "channels_last". | |||
/// </param> | |||
/// <returns></returns> | |||
public Flatten Flatten(string data_format = null) | |||
public ILayer Flatten(string data_format = null) | |||
=> new Flatten(new FlattenArgs | |||
{ | |||
DataFormat = data_format | |||
@@ -482,7 +484,7 @@ namespace Tensorflow.Keras.Layers | |||
return input_layer.InboundNodes[0].Outputs; | |||
} | |||
public InputLayer InputLayer(Shape input_shape, | |||
public ILayer InputLayer(Shape input_shape, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false) | |||
@@ -502,7 +504,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="padding"></param> | |||
/// <param name="data_format"></param> | |||
/// <returns></returns> | |||
public AveragePooling2D AveragePooling2D(Shape pool_size = null, | |||
public ILayer AveragePooling2D(Shape pool_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
string data_format = null) | |||
@@ -527,7 +529,7 @@ namespace Tensorflow.Keras.Layers | |||
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | |||
/// </param> | |||
/// <returns></returns> | |||
public MaxPooling1D MaxPooling1D(int? pool_size = null, | |||
public ILayer MaxPooling1D(int? pool_size = null, | |||
int? strides = null, | |||
string padding = "valid", | |||
string data_format = null) | |||
@@ -564,7 +566,7 @@ namespace Tensorflow.Keras.Layers | |||
/// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. | |||
/// If you never set it, then it will be "channels_last"</param> | |||
/// <returns></returns> | |||
public MaxPooling2D MaxPooling2D(Shape pool_size = null, | |||
public ILayer MaxPooling2D(Shape pool_size = null, | |||
Shape strides = null, | |||
string padding = "valid", | |||
string data_format = null) | |||
@@ -618,7 +620,7 @@ namespace Tensorflow.Keras.Layers | |||
return layer.Apply(inputs); | |||
} | |||
public Layer LayerNormalization(Axis? axis, | |||
public ILayer LayerNormalization(Axis? axis, | |||
float epsilon = 1e-3f, | |||
bool center = true, | |||
bool scale = true, | |||
@@ -638,45 +640,30 @@ namespace Tensorflow.Keras.Layers | |||
/// </summary> | |||
/// <param name="alpha">Negative slope coefficient.</param> | |||
/// <returns></returns> | |||
public Layer LeakyReLU(float alpha = 0.3f) | |||
public ILayer LeakyReLU(float alpha = 0.3f) | |||
=> new LeakyReLu(new LeakyReLuArgs | |||
{ | |||
Alpha = alpha | |||
}); | |||
/// <summary> | |||
/// Fully-connected RNN where the output is to be fed back to input. | |||
/// </summary> | |||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
/// <returns></returns> | |||
public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); | |||
/// <summary> | |||
/// Fully-connected RNN where the output is to be fed back to input. | |||
/// </summary> | |||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
/// <param name="activation">Activation function to use. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x).</param> | |||
/// <returns></returns> | |||
public Layer SimpleRNN(int units, | |||
Activation activation = null) | |||
=> new SimpleRNN(new SimpleRNNArgs | |||
{ | |||
Units = units, | |||
Activation = activation | |||
}); | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | |||
/// <param name="activation">The name of the activation function to use. Default: hyperbolic tangent (tanh)..</param> | |||
/// <returns></returns> | |||
public Layer SimpleRNN(int units, | |||
string activation = "tanh") | |||
public ILayer SimpleRNN(int units, | |||
string activation = "tanh", | |||
string kernel_initializer = "glorot_uniform", | |||
string recurrent_initializer = "orthogonal", | |||
string bias_initializer = "zeros") | |||
=> new SimpleRNN(new SimpleRNNArgs | |||
{ | |||
Units = units, | |||
Activation = GetActivationByName(activation) | |||
Activation = GetActivationByName(activation), | |||
KernelInitializer = GetInitializerByName(kernel_initializer), | |||
RecurrentInitializer= GetInitializerByName(recurrent_initializer), | |||
BiasInitializer= GetInitializerByName(bias_initializer) | |||
}); | |||
/// <summary> | |||
@@ -706,7 +693,7 @@ namespace Tensorflow.Keras.Layers | |||
/// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. | |||
/// </param> | |||
/// <returns></returns> | |||
public Layer LSTM(int units, | |||
public ILayer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
bool use_bias = true, | |||
@@ -749,7 +736,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="offset"></param> | |||
/// <param name="input_shape"></param> | |||
/// <returns></returns> | |||
public Rescaling Rescaling(float scale, | |||
public ILayer Rescaling(float scale, | |||
float offset = 0, | |||
Shape input_shape = null) | |||
=> new Rescaling(new RescalingArgs | |||
@@ -763,21 +750,21 @@ namespace Tensorflow.Keras.Layers | |||
/// | |||
/// </summary> | |||
/// <returns></returns> | |||
public Add Add() | |||
public ILayer Add() | |||
=> new Add(new MergeArgs { }); | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <returns></returns> | |||
public Subtract Subtract() | |||
public ILayer Subtract() | |||
=> new Subtract(new MergeArgs { }); | |||
/// <summary> | |||
/// Global max pooling operation for spatial data. | |||
/// </summary> | |||
/// <returns></returns> | |||
public GlobalAveragePooling2D GlobalAveragePooling2D() | |||
public ILayer GlobalAveragePooling2D() | |||
=> new GlobalAveragePooling2D(new Pooling2DArgs { }); | |||
/// <summary> | |||
@@ -787,7 +774,7 @@ namespace Tensorflow.Keras.Layers | |||
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | |||
/// </param> | |||
/// <returns></returns> | |||
public GlobalAveragePooling1D GlobalAveragePooling1D(string data_format = "channels_last") | |||
public ILayer GlobalAveragePooling1D(string data_format = "channels_last") | |||
=> new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format }); | |||
/// <summary> | |||
@@ -796,7 +783,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | |||
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | |||
/// <returns></returns> | |||
public GlobalAveragePooling2D GlobalAveragePooling2D(string data_format = "channels_last") | |||
public ILayer GlobalAveragePooling2D(string data_format = "channels_last") | |||
=> new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format }); | |||
/// <summary> | |||
@@ -807,7 +794,7 @@ namespace Tensorflow.Keras.Layers | |||
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | |||
/// </param> | |||
/// <returns></returns> | |||
public GlobalMaxPooling1D GlobalMaxPooling1D(string data_format = "channels_last") | |||
public ILayer GlobalMaxPooling1D(string data_format = "channels_last") | |||
=> new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format }); | |||
/// <summary> | |||
@@ -816,7 +803,7 @@ namespace Tensorflow.Keras.Layers | |||
/// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | |||
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | |||
/// <returns></returns> | |||
public GlobalMaxPooling2D GlobalMaxPooling2D(string data_format = "channels_last") | |||
public ILayer GlobalMaxPooling2D(string data_format = "channels_last") | |||
=> new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format }); | |||
@@ -848,6 +835,7 @@ namespace Tensorflow.Keras.Layers | |||
"glorot_uniform" => tf.glorot_uniform_initializer, | |||
"zeros" => tf.zeros_initializer, | |||
"ones" => tf.ones_initializer, | |||
"orthogonal" => tf.orthogonal_initializer, | |||
_ => tf.glorot_uniform_initializer | |||
}; | |||
} | |||
@@ -1,8 +1,9 @@ | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Lstm; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers.Rnn; | |||
namespace Tensorflow.Keras.Layers | |||
namespace Tensorflow.Keras.Layers.Lstm | |||
{ | |||
/// <summary> | |||
/// Long Short-Term Memory layer - Hochreiter 1997. |
@@ -1,7 +1,7 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Lstm; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers | |||
namespace Tensorflow.Keras.Layers.Lstm | |||
{ | |||
public class LSTMCell : Layer | |||
{ |
@@ -1,10 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Layers.Lstm; | |||
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; | |||
namespace Tensorflow.Keras.Layers | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public class RNN : Layer | |||
{ | |||
@@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Layers | |||
private object _states = null; | |||
private object constants_spec = null; | |||
private int _num_constants = 0; | |||
protected IVariableV1 kernel; | |||
protected IVariableV1 bias; | |||
public RNN(RNNArgs args) : base(PreConstruct(args)) | |||
{ |
@@ -0,0 +1,31 @@ | |||
using System.Data; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Operations.Activation; | |||
using static HDF.PInvoke.H5Z; | |||
using static Tensorflow.ApiDef.Types; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public class SimpleRNN : RNN | |||
{ | |||
SimpleRNNArgs args; | |||
SimpleRNNCell cell; | |||
public SimpleRNN(SimpleRNNArgs args) : base(args) | |||
{ | |||
this.args = args; | |||
} | |||
protected override void build(Tensors inputs) | |||
{ | |||
var input_shape = inputs.shape; | |||
var input_dim = input_shape[-1]; | |||
kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||
initializer: args.KernelInitializer | |||
//regularizer = self.kernel_regularizer, | |||
//constraint = self.kernel_constraint, | |||
//caching_device = default_caching_device, | |||
); | |||
} | |||
} | |||
} |
@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public class SimpleRNNCell : Layer | |||
{ | |||
public SimpleRNNCell(SimpleRNNArgs args) : base(args) | |||
{ | |||
} | |||
protected override void build(Tensors inputs) | |||
{ | |||
} | |||
} | |||
} |
@@ -2,9 +2,10 @@ | |||
using System.Collections.Generic; | |||
using System.ComponentModel; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers | |||
namespace Tensorflow.Keras.Layers.Rnn | |||
{ | |||
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | |||
{ |
@@ -1,14 +0,0 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class SimpleRNN : RNN | |||
{ | |||
public SimpleRNN(RNNArgs args) : base(args) | |||
{ | |||
} | |||
} | |||
} |
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras | |||
/// <param name="width"></param> | |||
/// <param name="interpolation"></param> | |||
/// <returns></returns> | |||
public Resizing Resizing(int height, int width, string interpolation = "bilinear") | |||
public ILayer Resizing(int height, int width, string interpolation = "bilinear") | |||
=> new Resizing(new ResizingArgs | |||
{ | |||
Height = height, | |||
@@ -5,7 +5,7 @@ using Tensorflow.Keras.Preprocessings; | |||
namespace Tensorflow.Keras | |||
{ | |||
public partial class Preprocessing | |||
public partial class Preprocessing : IPreprocessing | |||
{ | |||
public Sequence sequence => new Sequence(); | |||
public DatasetUtils dataset_utils => new DatasetUtils(); | |||
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras | |||
private static TextApi _text = new TextApi(); | |||
public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null, | |||
string split = "whitespace", | |||
int max_tokens = -1, | |||
string output_mode = "int", | |||
@@ -3,11 +3,11 @@ | |||
<PropertyGroup> | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
<AssemblyName>Tensorflow.Keras</AssemblyName> | |||
<LangVersion>9.0</LangVersion> | |||
<LangVersion>10.0</LangVersion> | |||
<Nullable>enable</Nullable> | |||
<RootNamespace>Tensorflow.Keras</RootNamespace> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
<Version>0.7.0</Version> | |||
<Version>0.10.0</Version> | |||
<Authors>Haiping Chen</Authors> | |||
<Product>Keras for .NET</Product> | |||
<Copyright>Apache 2.0, Haiping Chen 2021</Copyright> | |||
@@ -37,9 +37,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<RepositoryType>Git</RepositoryType> | |||
<SignAssembly>true</SignAssembly> | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<AssemblyVersion>0.7.0.0</AssemblyVersion> | |||
<FileVersion>0.7.0.0</FileVersion> | |||
<AssemblyVersion>0.10.0.0</AssemblyVersion> | |||
<FileVersion>0.10.0.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<Configurations>Debug;Release;GPU</Configurations> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
@@ -47,6 +48,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'"> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
<AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
<AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
@@ -55,6 +61,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'"> | |||
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | |||
<DefineConstants /> | |||
</PropertyGroup> | |||
@@ -134,7 +134,7 @@ namespace Tensorflow.Keras | |||
/// <param name="data_format"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor max_pooling2d(Tensor inputs, | |||
public Tensor MaxPooling2D(Tensor inputs, | |||
int[] pool_size, | |||
int[] strides, | |||
string padding = "valid", | |||
@@ -0,0 +1,16 @@ | |||
{ | |||
// Use IntelliSense to learn about possible attributes. | |||
// Hover to view descriptions of existing attributes. | |||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | |||
"version": "0.2.0", | |||
"configurations": [ | |||
{ | |||
"name": "Python: Current File", | |||
"type": "python", | |||
"request": "launch", | |||
"program": "${file}", | |||
"console": "integratedTerminal", | |||
"justMyCode": false | |||
} | |||
] | |||
} |
@@ -0,0 +1,15 @@ | |||
import numpy as np | |||
import tensorflow as tf | |||
# tf.experimental.numpy | |||
inputs = np.random.random([32, 10, 8]).astype(np.float32) | |||
simple_rnn = tf.keras.layers.SimpleRNN(4) | |||
output = simple_rnn(inputs) # The output has shape `[32, 4]`. | |||
simple_rnn = tf.keras.layers.SimpleRNN( | |||
4, return_sequences=True, return_state=True) | |||
# whole_sequence_output has shape `[32, 10, 4]`. | |||
# final_state has shape `[32, 4]`. | |||
whole_sequence_output, final_state = simple_rnn(inputs) |
@@ -83,7 +83,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||
{ 2.5f, 2.6f, 2.7f, 2.8f }, | |||
{ 3.5f, 3.6f, 3.7f, 3.8f } | |||
} }, dtype: np.float32); | |||
var attention_layer = keras.layers.Attention(); | |||
var attention_layer = (Attention)keras.layers.Attention(); | |||
//attention_layer.build(((1, 2, 4), (1, 3, 4))); | |||
var actual = attention_layer._calculate_scores(query: q, key: k); | |||
// Expected tensor of shape [1, 2, 3]. | |||
@@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||
{ 2.5f, 2.6f, 2.7f, 2.8f }, | |||
{ 3.5f, 3.6f, 3.7f, 3.8f } | |||
} }, dtype: np.float32); | |||
var attention_layer = keras.layers.Attention(score_mode: "concat"); | |||
var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat"); | |||
//attention_layer.concat_score_weight = 1; | |||
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { | |||
Name = "concat_score_weight", | |||
@@ -148,10 +148,9 @@ namespace TensorFlowNET.Keras.UnitTest | |||
} | |||
[TestMethod] | |||
[Ignore] | |||
public void SimpleRNN() | |||
{ | |||
var inputs = np.random.rand(32, 10, 8).astype(np.float32); | |||
var inputs = np.random.random((32, 10, 8)).astype(np.float32); | |||
var simple_rnn = keras.layers.SimpleRNN(4); | |||
var output = simple_rnn.Apply(inputs); | |||
Assert.AreEqual((32, 4), output.shape); | |||
@@ -4,7 +4,7 @@ | |||
<TargetFramework>net6.0</TargetFramework> | |||
<IsPackable>false</IsPackable> | |||
<LangVersion>11.0</LangVersion> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
@@ -11,7 +11,7 @@ | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<LangVersion>9.0</LangVersion> | |||
<LangVersion>11.0</LangVersion> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||