@@ -26,7 +26,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 | |||||
#### Download pre-build package | #### Download pre-build package | ||||
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) | |||||
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.10.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.10.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.10.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.10.0.zip), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.10.0.zip) | |||||
@@ -35,6 +35,6 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 | |||||
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | ||||
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | 1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | ||||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.10.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||||
@@ -10,6 +10,9 @@ namespace Tensorflow | |||||
var diag = new Diagnostician(); | var diag = new Diagnostician(); | ||||
// diag.Diagnose(@"D:\memory.txt"); | // diag.Diagnose(@"D:\memory.txt"); | ||||
var rnn = new SimpleRnnTest(); | |||||
rnn.Run(); | |||||
// this class is used explor new features. | // this class is used explor new features. | ||||
var exploring = new Exploring(); | var exploring = new Exploring(); | ||||
// exploring.Run(); | // exploring.Run(); | ||||
@@ -0,0 +1,31 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | |||||
using static Tensorflow.KerasApi; | |||||
namespace Tensorflow | |||||
{ | |||||
public class SimpleRnnTest | |||||
{ | |||||
public void Run() | |||||
{ | |||||
tf.keras = new KerasInterface(); | |||||
var inputs = np.random.random((32, 10, 8)).astype(np.float32); | |||||
var simple_rnn = tf.keras.layers.SimpleRNN(4); | |||||
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`. | |||||
if (output.shape == (32, 4)) | |||||
{ | |||||
} | |||||
/*simple_rnn = tf.keras.layers.SimpleRNN( | |||||
4, return_sequences = True, return_state = True) | |||||
# whole_sequence_output has shape `[32, 10, 4]`. | |||||
# final_state has shape `[32, 4]`. | |||||
whole_sequence_output, final_state = simple_rnn(inputs)*/ | |||||
} | |||||
} | |||||
} |
@@ -6,7 +6,7 @@ | |||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<AssemblyName>Tensorflow</AssemblyName> | <AssemblyName>Tensorflow</AssemblyName> | ||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
<LangVersion>9.0</LangVersion> | |||||
<LangVersion>11.0</LangVersion> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -20,7 +20,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist-Windows-GPU" Version="2.7.0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.10.0" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -1,22 +0,0 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class LSTMArgs : RNNArgs | |||||
{ | |||||
public int Units { get; set; } | |||||
public Activation Activation { get; set; } | |||||
public Activation RecurrentActivation { get; set; } | |||||
public IInitializer KernelInitializer { get; set; } | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
public IInitializer BiasInitializer { get; set; } | |||||
public bool UnitForgetBias { get; set; } | |||||
public float Dropout { get; set; } | |||||
public float RecurrentDropout { get; set; } | |||||
public int Implementation { get; set; } | |||||
public bool ReturnSequences { get; set; } | |||||
public bool ReturnState { get; set; } | |||||
public bool GoBackwards { get; set; } | |||||
public bool Stateful { get; set; } | |||||
public bool TimeMajor { get; set; } | |||||
public bool Unroll { get; set; } | |||||
} | |||||
} |
@@ -0,0 +1,12 @@ | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||||
{ | |||||
public class LSTMArgs : RNNArgs | |||||
{ | |||||
public bool UnitForgetBias { get; set; } | |||||
public float Dropout { get; set; } | |||||
public float RecurrentDropout { get; set; } | |||||
public int Implementation { get; set; } | |||||
} | |||||
} |
@@ -1,4 +1,4 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
namespace Tensorflow.Keras.ArgsDefinition.Lstm | |||||
{ | { | ||||
public class LSTMCellArgs : LayerArgs | public class LSTMCellArgs : LayerArgs | ||||
{ | { |
@@ -1,21 +0,0 @@ | |||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class RNNArgs : LayerArgs | |||||
{ | |||||
public interface IRnnArgCell : ILayer | |||||
{ | |||||
object state_size { get; } | |||||
} | |||||
public IRnnArgCell Cell { get; set; } = null; | |||||
public bool ReturnSequences { get; set; } = false; | |||||
public bool ReturnState { get; set; } = false; | |||||
public bool GoBackwards { get; set; } = false; | |||||
public bool Stateful { get; set; } = false; | |||||
public bool Unroll { get; set; } = false; | |||||
public bool TimeMajor { get; set; } = false; | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
} | |||||
} |
@@ -0,0 +1,45 @@ | |||||
using System.Collections.Generic; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | |||||
public class RNNArgs : LayerArgs | |||||
{ | |||||
public interface IRnnArgCell : ILayer | |||||
{ | |||||
object state_size { get; } | |||||
} | |||||
public IRnnArgCell Cell { get; set; } = null; | |||||
public bool ReturnSequences { get; set; } = false; | |||||
public bool ReturnState { get; set; } = false; | |||||
public bool GoBackwards { get; set; } = false; | |||||
public bool Stateful { get; set; } = false; | |||||
public bool Unroll { get; set; } = false; | |||||
public bool TimeMajor { get; set; } = false; | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public int Units { get; set; } | |||||
public Activation Activation { get; set; } | |||||
public Activation RecurrentActivation { get; set; } | |||||
public bool UseBias { get; set; } = true; | |||||
public IInitializer KernelInitializer { get; set; } | |||||
public IInitializer RecurrentInitializer { get; set; } | |||||
public IInitializer BiasInitializer { get; set; } | |||||
// kernel_regularizer=None, | |||||
// recurrent_regularizer=None, | |||||
// bias_regularizer=None, | |||||
// activity_regularizer=None, | |||||
// kernel_constraint=None, | |||||
// recurrent_constraint=None, | |||||
// bias_constraint=None, | |||||
// dropout=0., | |||||
// recurrent_dropout=0., | |||||
// return_sequences=False, | |||||
// return_state=False, | |||||
// go_backwards=False, | |||||
// stateful=False, | |||||
// unroll=False, | |||||
// **kwargs): | |||||
} | |||||
} |
@@ -0,0 +1,7 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | |||||
public class SimpleRNNArgs : RNNArgs | |||||
{ | |||||
} | |||||
} |
@@ -1,6 +1,6 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | { | ||||
public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
{ | { |
@@ -1,30 +0,0 @@ | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class SimpleRNNArgs : RNNArgs | |||||
{ | |||||
public int Units { get; set; } | |||||
public Activation Activation { get; set; } | |||||
// units, | |||||
// activation='tanh', | |||||
// use_bias=True, | |||||
// kernel_initializer='glorot_uniform', | |||||
// recurrent_initializer='orthogonal', | |||||
// bias_initializer='zeros', | |||||
// kernel_regularizer=None, | |||||
// recurrent_regularizer=None, | |||||
// bias_regularizer=None, | |||||
// activity_regularizer=None, | |||||
// kernel_constraint=None, | |||||
// recurrent_constraint=None, | |||||
// bias_constraint=None, | |||||
// dropout=0., | |||||
// recurrent_dropout=0., | |||||
// return_sequences=False, | |||||
// return_state=False, | |||||
// go_backwards=False, | |||||
// stateful=False, | |||||
// unroll=False, | |||||
// **kwargs): | |||||
} | |||||
} |
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.Layers; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public interface IKerasApi | |||||
{ | |||||
public ILayersApi layers { get; } | |||||
} | |||||
} |
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public interface IPreprocessing | |||||
{ | |||||
public ILayer Resizing(int height, int width, string interpolation = "bilinear"); | |||||
public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null, | |||||
string split = "whitespace", | |||||
int max_tokens = -1, | |||||
string output_mode = "int", | |||||
int output_sequence_length = -1); | |||||
} | |||||
} |
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.NumPy; | |||||
using Tensorflow.Operations.Activation; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public partial interface ILayersApi | |||||
{ | |||||
public ILayer ELU(float alpha = 0.1f); | |||||
public ILayer SELU(); | |||||
public ILayer Softmax(Axis axis); | |||||
public ILayer Softplus(); | |||||
public ILayer HardSigmoid(); | |||||
public ILayer Softsign(); | |||||
public ILayer Swish(); | |||||
public ILayer Tanh(); | |||||
public ILayer Exponential(); | |||||
} | |||||
} |
@@ -0,0 +1,28 @@ | |||||
using System; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public partial interface ILayersApi | |||||
{ | |||||
public ILayer Attention(bool use_scale = false, | |||||
string score_mode = "dot", | |||||
bool causal = false, | |||||
float dropout = 0f); | |||||
public ILayer MultiHeadAttention(int num_heads, | |||||
int key_dim, | |||||
int? value_dim = null, | |||||
float dropout = 0f, | |||||
bool use_bias = true, | |||||
Shape output_shape = null, | |||||
Shape attention_axes = null, | |||||
IInitializer kernel_initializer = null, | |||||
IInitializer bias_initializer = null, | |||||
IRegularizer kernel_regularizer = null, | |||||
IRegularizer bias_regularizer = null, | |||||
IRegularizer activity_regularizer = null, | |||||
Action kernel_constraint = null, | |||||
Action bias_constraint = null); | |||||
} | |||||
} |
@@ -0,0 +1,13 @@ | |||||
using System; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public partial interface ILayersApi | |||||
{ | |||||
public ILayer Cropping1D(NDArray cropping); | |||||
public ILayer Cropping2D(NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last); | |||||
public ILayer Cropping3D(NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last); | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public partial interface ILayersApi | |||||
{ | |||||
public ILayer Concatenate(int axis = -1); | |||||
} | |||||
} |
@@ -0,0 +1,18 @@ | |||||
using System; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public partial interface ILayersApi | |||||
{ | |||||
public ILayer Reshape(Shape target_shape); | |||||
public ILayer Reshape(object[] target_shape); | |||||
public ILayer UpSampling2D(Shape size = null, | |||||
string data_format = null, | |||||
string interpolation = "nearest"); | |||||
public ILayer ZeroPadding2D(NDArray padding); | |||||
} | |||||
} |
@@ -0,0 +1,169 @@ | |||||
using System; | |||||
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public partial interface ILayersApi | |||||
{ | |||||
public IPreprocessing preprocessing { get; } | |||||
public ILayer Add(); | |||||
public ILayer AveragePooling2D(Shape pool_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid", | |||||
string data_format = null); | |||||
public ILayer BatchNormalization(int axis = -1, | |||||
float momentum = 0.99f, | |||||
float epsilon = 0.001f, | |||||
bool center = true, | |||||
bool scale = true, | |||||
IInitializer beta_initializer = null, | |||||
IInitializer gamma_initializer = null, | |||||
IInitializer moving_mean_initializer = null, | |||||
IInitializer moving_variance_initializer = null, | |||||
bool trainable = true, | |||||
string name = null, | |||||
bool renorm = false, | |||||
float renorm_momentum = 0.99f); | |||||
public ILayer Conv1D(int filters, | |||||
Shape kernel_size, | |||||
int strides = 1, | |||||
string padding = "valid", | |||||
string data_format = "channels_last", | |||||
int dilation_rate = 1, | |||||
int groups = 1, | |||||
string activation = null, | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string bias_initializer = "zeros"); | |||||
public ILayer Conv2D(int filters, | |||||
Shape kernel_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid", | |||||
string data_format = null, | |||||
Shape dilation_rate = null, | |||||
int groups = 1, | |||||
Activation activation = null, | |||||
bool use_bias = true, | |||||
IInitializer kernel_initializer = null, | |||||
IInitializer bias_initializer = null, | |||||
IRegularizer kernel_regularizer = null, | |||||
IRegularizer bias_regularizer = null, | |||||
IRegularizer activity_regularizer = null); | |||||
public ILayer Conv2D(int filters, | |||||
Shape kernel_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid", | |||||
string data_format = null, | |||||
Shape dilation_rate = null, | |||||
int groups = 1, | |||||
string activation = null, | |||||
bool use_bias = true, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string bias_initializer = "zeros"); | |||||
public ILayer Dense(int units); | |||||
public ILayer Dense(int units, | |||||
string activation = null, | |||||
Shape input_shape = null); | |||||
public ILayer Dense(int units, | |||||
Activation activation = null, | |||||
IInitializer kernel_initializer = null, | |||||
bool use_bias = true, | |||||
IInitializer bias_initializer = null, | |||||
Shape input_shape = null); | |||||
public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null); | |||||
public ILayer Embedding(int input_dim, | |||||
int output_dim, | |||||
IInitializer embeddings_initializer = null, | |||||
bool mask_zero = false, | |||||
Shape input_shape = null, | |||||
int input_length = -1); | |||||
public ILayer EinsumDense(string equation, | |||||
Shape output_shape, | |||||
string bias_axes, | |||||
Activation activation = null, | |||||
IInitializer kernel_initializer = null, | |||||
IInitializer bias_initializer = null, | |||||
IRegularizer kernel_regularizer = null, | |||||
IRegularizer bias_regularizer = null, | |||||
IRegularizer activity_regularizer = null, | |||||
Action kernel_constraint = null, | |||||
Action bias_constraint = null); | |||||
public ILayer Flatten(string data_format = null); | |||||
public ILayer GlobalAveragePooling1D(string data_format = "channels_last"); | |||||
public ILayer GlobalAveragePooling2D(); | |||||
public ILayer GlobalAveragePooling2D(string data_format = "channels_last"); | |||||
public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); | |||||
public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); | |||||
public Tensors Input(Shape shape, | |||||
string name = null, | |||||
bool sparse = false, | |||||
bool ragged = false); | |||||
public ILayer InputLayer(Shape input_shape, | |||||
string name = null, | |||||
bool sparse = false, | |||||
bool ragged = false); | |||||
public ILayer LayerNormalization(Axis? axis, | |||||
float epsilon = 1e-3f, | |||||
bool center = true, | |||||
bool scale = true, | |||||
IInitializer beta_initializer = null, | |||||
IInitializer gamma_initializer = null); | |||||
public ILayer LeakyReLU(float alpha = 0.3f); | |||||
public ILayer LSTM(int units, | |||||
Activation activation = null, | |||||
Activation recurrent_activation = null, | |||||
bool use_bias = true, | |||||
IInitializer kernel_initializer = null, | |||||
IInitializer recurrent_initializer = null, | |||||
IInitializer bias_initializer = null, | |||||
bool unit_forget_bias = true, | |||||
float dropout = 0f, | |||||
float recurrent_dropout = 0f, | |||||
int implementation = 2, | |||||
bool return_sequences = false, | |||||
bool return_state = false, | |||||
bool go_backwards = false, | |||||
bool stateful = false, | |||||
bool time_major = false, | |||||
bool unroll = false); | |||||
public ILayer MaxPooling1D(int? pool_size = null, | |||||
int? strides = null, | |||||
string padding = "valid", | |||||
string data_format = null); | |||||
public ILayer MaxPooling2D(Shape pool_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid", | |||||
string data_format = null); | |||||
public ILayer Permute(int[] dims); | |||||
public ILayer Rescaling(float scale, | |||||
float offset = 0, | |||||
Shape input_shape = null); | |||||
public ILayer SimpleRNN(int units, | |||||
string activation = "tanh", | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros"); | |||||
public ILayer Subtract(); | |||||
} | |||||
} |
@@ -20,11 +20,11 @@ namespace Tensorflow.NumPy | |||||
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize); | Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize); | ||||
} | } | ||||
public NDArray rand(params int[] shape) | |||||
=> throw new NotImplementedException(""); | |||||
public NDArray random(Shape size) | |||||
=> uniform(low: 0, high: 1, size: size); | |||||
[AutoNumPy] | [AutoNumPy] | ||||
public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32) | |||||
public NDArray randint(int low, int? high = null, Shape? size = null, TF_DataType dtype = TF_DataType.TF_INT32) | |||||
{ | { | ||||
if(high == null) | if(high == null) | ||||
{ | { | ||||
@@ -41,11 +41,11 @@ namespace Tensorflow.NumPy | |||||
=> new NDArray(random_ops.random_normal(shape ?? Shape.Scalar)); | => new NDArray(random_ops.random_normal(shape ?? Shape.Scalar)); | ||||
[AutoNumPy] | [AutoNumPy] | ||||
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null) | |||||
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape? size = null) | |||||
=> new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale)); | => new NDArray(random_ops.random_normal(size ?? Shape.Scalar, mean: loc, stddev: scale)); | ||||
[AutoNumPy] | [AutoNumPy] | ||||
public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape size = null) | |||||
public NDArray uniform(float low = 0.0f, float high = 1.0f, Shape? size = null) | |||||
=> new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high)); | => new NDArray(random_ops.random_uniform(size ?? Shape.Scalar, low, high)); | ||||
} | } | ||||
} | } |
@@ -18,6 +18,7 @@ using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
@@ -5,8 +5,8 @@ | |||||
<AssemblyName>Tensorflow.Binding</AssemblyName> | <AssemblyName>Tensorflow.Binding</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
<Version>0.70.2</Version> | |||||
<LangVersion>9.0</LangVersion> | |||||
<Version>0.100.0</Version> | |||||
<LangVersion>10.0</LangVersion> | |||||
<Nullable>enable</Nullable> | <Nullable>enable</Nullable> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
@@ -20,9 +20,9 @@ | |||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.70.1.0</AssemblyVersion> | |||||
<AssemblyVersion>0.100.0.0</AssemblyVersion> | |||||
<PackageReleaseNotes> | <PackageReleaseNotes> | ||||
tf.net 0.70.x and above are based on tensorflow native 2.7.0 | |||||
tf.net 0.100.x and above are based on tensorflow native 2.10.0 | |||||
* Eager Mode is added finally. | * Eager Mode is added finally. | ||||
* tf.keras is partially working. | * tf.keras is partially working. | ||||
@@ -35,14 +35,17 @@ https://tensorflownet.readthedocs.io</Description> | |||||
tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. | tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library. | ||||
tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. | tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. | ||||
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.</PackageReleaseNotes> | |||||
<FileVersion>0.70.1.0</FileVersion> | |||||
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library. | |||||
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library. | |||||
</PackageReleaseNotes> | |||||
<FileVersion>0.100.0.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
<PackageId>TensorFlow.NET</PackageId> | <PackageId>TensorFlow.NET</PackageId> | ||||
<Configurations>Debug;Release;GPU</Configurations> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -51,6 +54,12 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<PlatformTarget>AnyCPU</PlatformTarget> | <PlatformTarget>AnyCPU</PlatformTarget> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'"> | |||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE_1</DefineConstants> | |||||
<PlatformTarget>AnyCPU</PlatformTarget> | |||||
</PropertyGroup> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | <DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | ||||
@@ -58,6 +67,13 @@ https://tensorflownet.readthedocs.io</Description> | |||||
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | <DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'"> | |||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | |||||
<PlatformTarget>x64</PlatformTarget> | |||||
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | |||||
</PropertyGroup> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -20,6 +20,7 @@ using System.Threading; | |||||
using Tensorflow.Contexts; | using Tensorflow.Contexts; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Gradients; | using Tensorflow.Gradients; | ||||
using Tensorflow.Keras; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -51,6 +52,8 @@ namespace Tensorflow | |||||
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner()); | ||||
public IEagerRunner Runner => _runner.Value; | public IEagerRunner Runner => _runner.Value; | ||||
public IKerasApi keras { get; set; } | |||||
public tensorflow() | public tensorflow() | ||||
{ | { | ||||
Logger = new LoggerConfiguration() | Logger = new LoggerConfiguration() | ||||
@@ -2,6 +2,9 @@ | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
/// <summary> | |||||
/// Deprecated, will use tf.keras | |||||
/// </summary> | |||||
public static class KerasApi | public static class KerasApi | ||||
{ | { | ||||
public static KerasInterface keras { get; } = new KerasInterface(); | public static KerasInterface keras { get; } = new KerasInterface(); | ||||
@@ -10,18 +10,17 @@ using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.Keras.Models; | using Tensorflow.Keras.Models; | ||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using System.Threading; | using System.Threading; | ||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
public class KerasInterface | |||||
public class KerasInterface : IKerasApi | |||||
{ | { | ||||
public KerasDataset datasets { get; } = new KerasDataset(); | public KerasDataset datasets { get; } = new KerasDataset(); | ||||
public Initializers initializers { get; } = new Initializers(); | public Initializers initializers { get; } = new Initializers(); | ||||
public Regularizers regularizers { get; } = new Regularizers(); | public Regularizers regularizers { get; } = new Regularizers(); | ||||
public LayersApi layers { get; } = new LayersApi(); | |||||
public ILayersApi layers { get; } = new LayersApi(); | |||||
public LossesApi losses { get; } = new LossesApi(); | public LossesApi losses { get; } = new LossesApi(); | ||||
public Activations activations { get; } = new Activations(); | public Activations activations { get; } = new Activations(); | ||||
public Preprocessing preprocessing { get; } = new Preprocessing(); | public Preprocessing preprocessing { get; } = new Preprocessing(); | ||||
@@ -7,16 +7,16 @@ using static Tensorflow.KerasApi; | |||||
namespace Tensorflow.Keras.Layers { | namespace Tensorflow.Keras.Layers { | ||||
public partial class LayersApi { | public partial class LayersApi { | ||||
public ELU ELU ( float alpha = 0.1f ) | |||||
public ILayer ELU ( float alpha = 0.1f ) | |||||
=> new ELU(new ELUArgs { Alpha = alpha }); | => new ELU(new ELUArgs { Alpha = alpha }); | ||||
public SELU SELU () | |||||
public ILayer SELU () | |||||
=> new SELU(new LayerArgs { }); | => new SELU(new LayerArgs { }); | ||||
public Softmax Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); | |||||
public Softplus Softplus () => new Softplus(new LayerArgs { }); | |||||
public HardSigmoid HardSigmoid () => new HardSigmoid(new LayerArgs { }); | |||||
public Softsign Softsign () => new Softsign(new LayerArgs { }); | |||||
public Swish Swish () => new Swish(new LayerArgs { }); | |||||
public Tanh Tanh () => new Tanh(new LayerArgs { }); | |||||
public Exponential Exponential () => new Exponential(new LayerArgs { }); | |||||
public ILayer Softmax ( Axis axis ) => new Softmax(new SoftmaxArgs { axis = axis }); | |||||
public ILayer Softplus () => new Softplus(new LayerArgs { }); | |||||
public ILayer HardSigmoid () => new HardSigmoid(new LayerArgs { }); | |||||
public ILayer Softsign () => new Softsign(new LayerArgs { }); | |||||
public ILayer Swish () => new Swish(new LayerArgs { }); | |||||
public ILayer Tanh () => new Tanh(new LayerArgs { }); | |||||
public ILayer Exponential () => new Exponential(new LayerArgs { }); | |||||
} | } | ||||
} | } |
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public partial class LayersApi | public partial class LayersApi | ||||
{ | { | ||||
public Attention Attention(bool use_scale = false, | |||||
public ILayer Attention(bool use_scale = false, | |||||
string score_mode = "dot", | string score_mode = "dot", | ||||
bool causal = false, | bool causal = false, | ||||
float dropout = 0f) => | float dropout = 0f) => | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers | |||||
causal = causal, | causal = causal, | ||||
dropout = dropout | dropout = dropout | ||||
}); | }); | ||||
public MultiHeadAttention MultiHeadAttention(int num_heads, | |||||
public ILayer MultiHeadAttention(int num_heads, | |||||
int key_dim, | int key_dim, | ||||
int? value_dim = null, | int? value_dim = null, | ||||
float dropout = 0f, | float dropout = 0f, | ||||
@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Layers { | |||||
/// Cropping layer for 1D input | /// Cropping layer for 1D input | ||||
/// </summary> | /// </summary> | ||||
/// <param name="cropping">cropping size</param> | /// <param name="cropping">cropping size</param> | ||||
public Cropping1D Cropping1D ( NDArray cropping ) | |||||
public ILayer Cropping1D ( NDArray cropping ) | |||||
=> new Cropping1D(new CroppingArgs { | => new Cropping1D(new CroppingArgs { | ||||
cropping = cropping | cropping = cropping | ||||
}); | }); | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers { | |||||
/// <summary> | /// <summary> | ||||
/// Cropping layer for 2D input <br/> | /// Cropping layer for 2D input <br/> | ||||
/// </summary> | /// </summary> | ||||
public Cropping2D Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) | |||||
public ILayer Cropping2D ( NDArray cropping, Cropping2DArgs.DataFormat data_format = Cropping2DArgs.DataFormat.channels_last ) | |||||
=> new Cropping2D(new Cropping2DArgs { | => new Cropping2D(new Cropping2DArgs { | ||||
cropping = cropping, | cropping = cropping, | ||||
data_format = data_format | data_format = data_format | ||||
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers { | |||||
/// <summary> | /// <summary> | ||||
/// Cropping layer for 3D input <br/> | /// Cropping layer for 3D input <br/> | ||||
/// </summary> | /// </summary> | ||||
public Cropping3D Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) | |||||
public ILayer Cropping3D ( NDArray cropping, Cropping3DArgs.DataFormat data_format = Cropping3DArgs.DataFormat.channels_last ) | |||||
=> new Cropping3D(new Cropping3DArgs { | => new Cropping3D(new Cropping3DArgs { | ||||
cropping = cropping, | cropping = cropping, | ||||
data_format = data_format | data_format = data_format | ||||
@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// </summary> | /// </summary> | ||||
/// <param name="axis">Axis along which to concatenate.</param> | /// <param name="axis">Axis along which to concatenate.</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Concatenate Concatenate(int axis = -1) | |||||
public ILayer Concatenate(int axis = -1) | |||||
=> new Concatenate(new MergeArgs | => new Concatenate(new MergeArgs | ||||
{ | { | ||||
Axis = axis | Axis = axis | ||||
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers { | |||||
/// </summary> | /// </summary> | ||||
/// <param name="padding"></param> | /// <param name="padding"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public ZeroPadding2D ZeroPadding2D ( NDArray padding ) | |||||
public ILayer ZeroPadding2D ( NDArray padding ) | |||||
=> new ZeroPadding2D(new ZeroPadding2DArgs { | => new ZeroPadding2D(new ZeroPadding2DArgs { | ||||
Padding = padding | Padding = padding | ||||
}); | }); | ||||
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers { | |||||
/// <param name="data_format"></param> | /// <param name="data_format"></param> | ||||
/// <param name="interpolation"></param> | /// <param name="interpolation"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public UpSampling2D UpSampling2D ( Shape size = null, | |||||
public ILayer UpSampling2D ( Shape size = null, | |||||
string data_format = null, | string data_format = null, | ||||
string interpolation = "nearest" ) | string interpolation = "nearest" ) | ||||
=> new UpSampling2D(new UpSampling2DArgs { | => new UpSampling2D(new UpSampling2DArgs { | ||||
@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers { | |||||
/// <summary> | /// <summary> | ||||
/// Permutes the dimensions of the input according to a given pattern. | /// Permutes the dimensions of the input according to a given pattern. | ||||
/// </summary> | /// </summary> | ||||
public Permute Permute ( int[] dims ) | |||||
public ILayer Permute ( int[] dims ) | |||||
=> new Permute(new PermuteArgs { | => new Permute(new PermuteArgs { | ||||
dims = dims | dims = dims | ||||
}); | }); | ||||
@@ -44,12 +44,12 @@ namespace Tensorflow.Keras.Layers { | |||||
/// </summary> | /// </summary> | ||||
/// <param name="target_shape"></param> | /// <param name="target_shape"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Reshape Reshape ( Shape target_shape ) | |||||
=> new Reshape(new ReshapeArgs { | |||||
TargetShape = target_shape | |||||
}); | |||||
public ILayer Reshape ( Shape target_shape ) | |||||
=> new Reshape(new ReshapeArgs { | |||||
TargetShape = target_shape | |||||
}); | |||||
public Reshape Reshape ( object[] target_shape ) | |||||
public ILayer Reshape ( object[] target_shape ) | |||||
=> new Reshape(new ReshapeArgs { | => new Reshape(new ReshapeArgs { | ||||
TargetShapeObjects = target_shape | TargetShapeObjects = target_shape | ||||
}); | }); | ||||
@@ -1,16 +1,18 @@ | |||||
using System; | using System; | ||||
using Tensorflow.NumPy; | |||||
using System.Collections.Generic; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Lstm; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Lstm; | |||||
using Tensorflow.Keras.Layers.Rnn; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
public partial class LayersApi | |||||
public partial class LayersApi : ILayersApi | |||||
{ | { | ||||
public Preprocessing preprocessing { get; } = new Preprocessing(); | |||||
public IPreprocessing preprocessing { get; } = new Preprocessing(); | |||||
/// <summary> | /// <summary> | ||||
/// Layer that normalizes its inputs. | /// Layer that normalizes its inputs. | ||||
@@ -38,7 +40,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// Note that momentum is still applied to get the means and variances for inference. | /// Note that momentum is still applied to get the means and variances for inference. | ||||
/// </param> | /// </param> | ||||
/// <returns>Tensor of the same shape as input.</returns> | /// <returns>Tensor of the same shape as input.</returns> | ||||
public BatchNormalization BatchNormalization(int axis = -1, | |||||
public ILayer BatchNormalization(int axis = -1, | |||||
float momentum = 0.99f, | float momentum = 0.99f, | ||||
float epsilon = 0.001f, | float epsilon = 0.001f, | ||||
bool center = true, | bool center = true, | ||||
@@ -84,7 +86,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param> | /// <param name="kernel_initializer">Initializer for the kernel weights matrix (see keras.initializers).</param> | ||||
/// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param> | /// <param name="bias_initializer">Initializer for the bias vector (see keras.initializers).</param> | ||||
/// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns> | /// <returns>A tensor of rank 3 representing activation(conv1d(inputs, kernel) + bias).</returns> | ||||
public Conv1D Conv1D(int filters, | |||||
public ILayer Conv1D(int filters, | |||||
Shape kernel_size, | Shape kernel_size, | ||||
int strides = 1, | int strides = 1, | ||||
string padding = "valid", | string padding = "valid", | ||||
@@ -131,7 +133,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="bias_regularizer">Regularizer function applied to the bias vector (see keras.regularizers).</param> | /// <param name="bias_regularizer">Regularizer function applied to the bias vector (see keras.regularizers).</param> | ||||
/// <param name="activity_regularizer">Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | /// <param name="activity_regularizer">Regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | ||||
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | /// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | ||||
public Conv2D Conv2D(int filters, | |||||
public ILayer Conv2D(int filters, | |||||
Shape kernel_size = null, | Shape kernel_size = null, | ||||
Shape strides = null, | Shape strides = null, | ||||
string padding = "valid", | string padding = "valid", | ||||
@@ -184,7 +186,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | /// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | ||||
/// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | /// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | ||||
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | /// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | ||||
public Conv2D Conv2D(int filters, | |||||
public ILayer Conv2D(int filters, | |||||
Shape kernel_size = null, | Shape kernel_size = null, | ||||
Shape strides = null, | Shape strides = null, | ||||
string padding = "valid", | string padding = "valid", | ||||
@@ -228,7 +230,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | /// <param name="bias_regularizer">The name of the regularizer function applied to the bias vector (see keras.regularizers).</param> | ||||
/// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | /// <param name="activity_regularizer">The name of the regularizer function applied to the output of the layer (its "activation") (see keras.regularizers).</param> | ||||
/// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | /// <returns>A tensor of rank 4+ representing activation(conv2d(inputs, kernel) + bias).</returns> | ||||
public Conv2DTranspose Conv2DTranspose(int filters, | |||||
public ILayer Conv2DTranspose(int filters, | |||||
Shape kernel_size = null, | Shape kernel_size = null, | ||||
Shape strides = null, | Shape strides = null, | ||||
string output_padding = "valid", | string output_padding = "valid", | ||||
@@ -270,7 +272,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="bias_initializer">Initializer for the bias vector.</param> | /// <param name="bias_initializer">Initializer for the bias vector.</param> | ||||
/// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | /// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | ||||
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | /// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | ||||
public Dense Dense(int units, | |||||
public ILayer Dense(int units, | |||||
Activation activation = null, | Activation activation = null, | ||||
IInitializer kernel_initializer = null, | IInitializer kernel_initializer = null, | ||||
bool use_bias = true, | bool use_bias = true, | ||||
@@ -294,7 +296,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// </summary> | /// </summary> | ||||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | /// <param name="units">Positive integer, dimensionality of the output space.</param> | ||||
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | /// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | ||||
public Dense Dense(int units) | |||||
public ILayer Dense(int units) | |||||
=> new Dense(new DenseArgs | => new Dense(new DenseArgs | ||||
{ | { | ||||
Units = units, | Units = units, | ||||
@@ -312,7 +314,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).</param> | /// <param name="activation">Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: a(x) = x).</param> | ||||
/// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | /// <param name="input_shape">N-D tensor with shape: (batch_size, ..., input_dim). The most common situation would be a 2D input with shape (batch_size, input_dim).</param> | ||||
/// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | /// <returns>N-D tensor with shape: (batch_size, ..., units). For instance, for a 2D input with shape (batch_size, input_dim), the output would have shape (batch_size, units).</returns> | ||||
public Dense Dense(int units, | |||||
public ILayer Dense(int units, | |||||
string activation = null, | string activation = null, | ||||
Shape input_shape = null) | Shape input_shape = null) | ||||
=> new Dense(new DenseArgs | => new Dense(new DenseArgs | ||||
@@ -364,7 +366,7 @@ namespace Tensorflow.Keras.Layers | |||||
} | } | ||||
public EinsumDense EinsumDense(string equation, | |||||
public ILayer EinsumDense(string equation, | |||||
Shape output_shape, | Shape output_shape, | ||||
string bias_axes, | string bias_axes, | ||||
Activation activation = null, | Activation activation = null, | ||||
@@ -402,7 +404,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// </param> | /// </param> | ||||
/// <param name="seed">An integer to use as random seed.</param> | /// <param name="seed">An integer to use as random seed.</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Dropout Dropout(float rate, Shape noise_shape = null, int? seed = null) | |||||
public ILayer Dropout(float rate, Shape noise_shape = null, int? seed = null) | |||||
=> new Dropout(new DropoutArgs | => new Dropout(new DropoutArgs | ||||
{ | { | ||||
Rate = rate, | Rate = rate, | ||||
@@ -421,7 +423,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param> | /// <param name="embeddings_initializer">Initializer for the embeddings matrix (see keras.initializers).</param> | ||||
/// <param name="mask_zero"></param> | /// <param name="mask_zero"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Embedding Embedding(int input_dim, | |||||
public ILayer Embedding(int input_dim, | |||||
int output_dim, | int output_dim, | ||||
IInitializer embeddings_initializer = null, | IInitializer embeddings_initializer = null, | ||||
bool mask_zero = false, | bool mask_zero = false, | ||||
@@ -446,7 +448,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// If you never set it, then it will be "channels_last". | /// If you never set it, then it will be "channels_last". | ||||
/// </param> | /// </param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Flatten Flatten(string data_format = null) | |||||
public ILayer Flatten(string data_format = null) | |||||
=> new Flatten(new FlattenArgs | => new Flatten(new FlattenArgs | ||||
{ | { | ||||
DataFormat = data_format | DataFormat = data_format | ||||
@@ -482,7 +484,7 @@ namespace Tensorflow.Keras.Layers | |||||
return input_layer.InboundNodes[0].Outputs; | return input_layer.InboundNodes[0].Outputs; | ||||
} | } | ||||
public InputLayer InputLayer(Shape input_shape, | |||||
public ILayer InputLayer(Shape input_shape, | |||||
string name = null, | string name = null, | ||||
bool sparse = false, | bool sparse = false, | ||||
bool ragged = false) | bool ragged = false) | ||||
@@ -502,7 +504,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="padding"></param> | /// <param name="padding"></param> | ||||
/// <param name="data_format"></param> | /// <param name="data_format"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public AveragePooling2D AveragePooling2D(Shape pool_size = null, | |||||
public ILayer AveragePooling2D(Shape pool_size = null, | |||||
Shape strides = null, | Shape strides = null, | ||||
string padding = "valid", | string padding = "valid", | ||||
string data_format = null) | string data_format = null) | ||||
@@ -527,7 +529,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | ||||
/// </param> | /// </param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public MaxPooling1D MaxPooling1D(int? pool_size = null, | |||||
public ILayer MaxPooling1D(int? pool_size = null, | |||||
int? strides = null, | int? strides = null, | ||||
string padding = "valid", | string padding = "valid", | ||||
string data_format = null) | string data_format = null) | ||||
@@ -564,7 +566,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. | /// It defaults to the image_data_format value found in your Keras config file at ~/.keras/keras.json. | ||||
/// If you never set it, then it will be "channels_last"</param> | /// If you never set it, then it will be "channels_last"</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public MaxPooling2D MaxPooling2D(Shape pool_size = null, | |||||
public ILayer MaxPooling2D(Shape pool_size = null, | |||||
Shape strides = null, | Shape strides = null, | ||||
string padding = "valid", | string padding = "valid", | ||||
string data_format = null) | string data_format = null) | ||||
@@ -618,7 +620,7 @@ namespace Tensorflow.Keras.Layers | |||||
return layer.Apply(inputs); | return layer.Apply(inputs); | ||||
} | } | ||||
public Layer LayerNormalization(Axis? axis, | |||||
public ILayer LayerNormalization(Axis? axis, | |||||
float epsilon = 1e-3f, | float epsilon = 1e-3f, | ||||
bool center = true, | bool center = true, | ||||
bool scale = true, | bool scale = true, | ||||
@@ -638,45 +640,30 @@ namespace Tensorflow.Keras.Layers | |||||
/// </summary> | /// </summary> | ||||
/// <param name="alpha">Negative slope coefficient.</param> | /// <param name="alpha">Negative slope coefficient.</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Layer LeakyReLU(float alpha = 0.3f) | |||||
public ILayer LeakyReLU(float alpha = 0.3f) | |||||
=> new LeakyReLu(new LeakyReLuArgs | => new LeakyReLu(new LeakyReLuArgs | ||||
{ | { | ||||
Alpha = alpha | Alpha = alpha | ||||
}); | }); | ||||
/// <summary> | |||||
/// Fully-connected RNN where the output is to be fed back to input. | |||||
/// </summary> | |||||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | |||||
/// <returns></returns> | |||||
public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); | |||||
/// <summary> | |||||
/// Fully-connected RNN where the output is to be fed back to input. | |||||
/// </summary> | |||||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | |||||
/// <param name="activation">Activation function to use. If you pass null, no activation is applied (ie. "linear" activation: a(x) = x).</param> | |||||
/// <returns></returns> | |||||
public Layer SimpleRNN(int units, | |||||
Activation activation = null) | |||||
=> new SimpleRNN(new SimpleRNNArgs | |||||
{ | |||||
Units = units, | |||||
Activation = activation | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
/// <param name="units">Positive integer, dimensionality of the output space.</param> | /// <param name="units">Positive integer, dimensionality of the output space.</param> | ||||
/// <param name="activation">The name of the activation function to use. Default: hyperbolic tangent (tanh)..</param> | /// <param name="activation">The name of the activation function to use. Default: hyperbolic tangent (tanh)..</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Layer SimpleRNN(int units, | |||||
string activation = "tanh") | |||||
public ILayer SimpleRNN(int units, | |||||
string activation = "tanh", | |||||
string kernel_initializer = "glorot_uniform", | |||||
string recurrent_initializer = "orthogonal", | |||||
string bias_initializer = "zeros") | |||||
=> new SimpleRNN(new SimpleRNNArgs | => new SimpleRNN(new SimpleRNNArgs | ||||
{ | { | ||||
Units = units, | Units = units, | ||||
Activation = GetActivationByName(activation) | |||||
Activation = GetActivationByName(activation), | |||||
KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
RecurrentInitializer= GetInitializerByName(recurrent_initializer), | |||||
BiasInitializer= GetInitializerByName(bias_initializer) | |||||
}); | }); | ||||
/// <summary> | /// <summary> | ||||
@@ -706,7 +693,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. | /// although it tends to be more memory-intensive. Unrolling is only suitable for short sequences. | ||||
/// </param> | /// </param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Layer LSTM(int units, | |||||
public ILayer LSTM(int units, | |||||
Activation activation = null, | Activation activation = null, | ||||
Activation recurrent_activation = null, | Activation recurrent_activation = null, | ||||
bool use_bias = true, | bool use_bias = true, | ||||
@@ -749,7 +736,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="offset"></param> | /// <param name="offset"></param> | ||||
/// <param name="input_shape"></param> | /// <param name="input_shape"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Rescaling Rescaling(float scale, | |||||
public ILayer Rescaling(float scale, | |||||
float offset = 0, | float offset = 0, | ||||
Shape input_shape = null) | Shape input_shape = null) | ||||
=> new Rescaling(new RescalingArgs | => new Rescaling(new RescalingArgs | ||||
@@ -763,21 +750,21 @@ namespace Tensorflow.Keras.Layers | |||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Add Add() | |||||
public ILayer Add() | |||||
=> new Add(new MergeArgs { }); | => new Add(new MergeArgs { }); | ||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Subtract Subtract() | |||||
public ILayer Subtract() | |||||
=> new Subtract(new MergeArgs { }); | => new Subtract(new MergeArgs { }); | ||||
/// <summary> | /// <summary> | ||||
/// Global max pooling operation for spatial data. | /// Global max pooling operation for spatial data. | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public GlobalAveragePooling2D GlobalAveragePooling2D() | |||||
public ILayer GlobalAveragePooling2D() | |||||
=> new GlobalAveragePooling2D(new Pooling2DArgs { }); | => new GlobalAveragePooling2D(new Pooling2DArgs { }); | ||||
/// <summary> | /// <summary> | ||||
@@ -787,7 +774,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | ||||
/// </param> | /// </param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public GlobalAveragePooling1D GlobalAveragePooling1D(string data_format = "channels_last") | |||||
public ILayer GlobalAveragePooling1D(string data_format = "channels_last") | |||||
=> new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format }); | => new GlobalAveragePooling1D(new Pooling1DArgs { DataFormat = data_format }); | ||||
/// <summary> | /// <summary> | ||||
@@ -796,7 +783,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | /// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | ||||
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public GlobalAveragePooling2D GlobalAveragePooling2D(string data_format = "channels_last") | |||||
public ILayer GlobalAveragePooling2D(string data_format = "channels_last") | |||||
=> new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format }); | => new GlobalAveragePooling2D(new Pooling2DArgs { DataFormat = data_format }); | ||||
/// <summary> | /// <summary> | ||||
@@ -807,7 +794,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | /// channels_last corresponds to inputs with shape (batch, steps, features) while channels_first corresponds to inputs with shape (batch, features, steps). | ||||
/// </param> | /// </param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public GlobalMaxPooling1D GlobalMaxPooling1D(string data_format = "channels_last") | |||||
public ILayer GlobalMaxPooling1D(string data_format = "channels_last") | |||||
=> new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format }); | => new GlobalMaxPooling1D(new Pooling1DArgs { DataFormat = data_format }); | ||||
/// <summary> | /// <summary> | ||||
@@ -816,7 +803,7 @@ namespace Tensorflow.Keras.Layers | |||||
/// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | /// <param name="data_format">A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. | ||||
/// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | /// channels_last corresponds to inputs with shape (batch, height, width, channels) while channels_first corresponds to inputs with shape (batch, channels, height, width).</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public GlobalMaxPooling2D GlobalMaxPooling2D(string data_format = "channels_last") | |||||
public ILayer GlobalMaxPooling2D(string data_format = "channels_last") | |||||
=> new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format }); | => new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format }); | ||||
@@ -848,6 +835,7 @@ namespace Tensorflow.Keras.Layers | |||||
"glorot_uniform" => tf.glorot_uniform_initializer, | "glorot_uniform" => tf.glorot_uniform_initializer, | ||||
"zeros" => tf.zeros_initializer, | "zeros" => tf.zeros_initializer, | ||||
"ones" => tf.ones_initializer, | "ones" => tf.ones_initializer, | ||||
"orthogonal" => tf.orthogonal_initializer, | |||||
_ => tf.glorot_uniform_initializer | _ => tf.glorot_uniform_initializer | ||||
}; | }; | ||||
} | } | ||||
@@ -1,8 +1,9 @@ | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.ArgsDefinition.Lstm; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Rnn; | |||||
namespace Tensorflow.Keras.Layers | |||||
namespace Tensorflow.Keras.Layers.Lstm | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Long Short-Term Memory layer - Hochreiter 1997. | /// Long Short-Term Memory layer - Hochreiter 1997. |
@@ -1,7 +1,7 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.ArgsDefinition.Lstm; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
namespace Tensorflow.Keras.Layers | |||||
namespace Tensorflow.Keras.Layers.Lstm | |||||
{ | { | ||||
public class LSTMCell : Layer | public class LSTMCell : Layer | ||||
{ | { |
@@ -1,10 +1,12 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers.Lstm; | |||||
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; | // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | ||||
namespace Tensorflow.Keras.Layers | |||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
public class RNN : Layer | public class RNN : Layer | ||||
{ | { | ||||
@@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Layers | |||||
private object _states = null; | private object _states = null; | ||||
private object constants_spec = null; | private object constants_spec = null; | ||||
private int _num_constants = 0; | private int _num_constants = 0; | ||||
protected IVariableV1 kernel; | |||||
protected IVariableV1 bias; | |||||
public RNN(RNNArgs args) : base(PreConstruct(args)) | public RNN(RNNArgs args) : base(PreConstruct(args)) | ||||
{ | { |
@@ -0,0 +1,31 @@ | |||||
using System.Data; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Operations.Activation; | |||||
using static HDF.PInvoke.H5Z; | |||||
using static Tensorflow.ApiDef.Types; | |||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | |||||
public class SimpleRNN : RNN | |||||
{ | |||||
SimpleRNNArgs args; | |||||
SimpleRNNCell cell; | |||||
public SimpleRNN(SimpleRNNArgs args) : base(args) | |||||
{ | |||||
this.args = args; | |||||
} | |||||
protected override void build(Tensors inputs) | |||||
{ | |||||
var input_shape = inputs.shape; | |||||
var input_dim = input_shape[-1]; | |||||
kernel = add_weight("kernel", (input_shape[-1], args.Units), | |||||
initializer: args.KernelInitializer | |||||
//regularizer = self.kernel_regularizer, | |||||
//constraint = self.kernel_constraint, | |||||
//caching_device = default_caching_device, | |||||
); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | |||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | |||||
public class SimpleRNNCell : Layer | |||||
{ | |||||
public SimpleRNNCell(SimpleRNNArgs args) : base(args) | |||||
{ | |||||
} | |||||
protected override void build(Tensors inputs) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -2,9 +2,10 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.ComponentModel; | using System.ComponentModel; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
namespace Tensorflow.Keras.Layers | |||||
namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | ||||
{ | { |
@@ -1,14 +0,0 @@ | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public class SimpleRNN : RNN | |||||
{ | |||||
public SimpleRNN(RNNArgs args) : base(args) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras | |||||
/// <param name="width"></param> | /// <param name="width"></param> | ||||
/// <param name="interpolation"></param> | /// <param name="interpolation"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Resizing Resizing(int height, int width, string interpolation = "bilinear") | |||||
public ILayer Resizing(int height, int width, string interpolation = "bilinear") | |||||
=> new Resizing(new ResizingArgs | => new Resizing(new ResizingArgs | ||||
{ | { | ||||
Height = height, | Height = height, | ||||
@@ -5,7 +5,7 @@ using Tensorflow.Keras.Preprocessings; | |||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
public partial class Preprocessing | |||||
public partial class Preprocessing : IPreprocessing | |||||
{ | { | ||||
public Sequence sequence => new Sequence(); | public Sequence sequence => new Sequence(); | ||||
public DatasetUtils dataset_utils => new DatasetUtils(); | public DatasetUtils dataset_utils => new DatasetUtils(); | ||||
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras | |||||
private static TextApi _text = new TextApi(); | private static TextApi _text = new TextApi(); | ||||
public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null, | |||||
public ILayer TextVectorization(Func<Tensor, Tensor> standardize = null, | |||||
string split = "whitespace", | string split = "whitespace", | ||||
int max_tokens = -1, | int max_tokens = -1, | ||||
string output_mode = "int", | string output_mode = "int", | ||||
@@ -3,11 +3,11 @@ | |||||
<PropertyGroup> | <PropertyGroup> | ||||
<TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
<AssemblyName>Tensorflow.Keras</AssemblyName> | <AssemblyName>Tensorflow.Keras</AssemblyName> | ||||
<LangVersion>9.0</LangVersion> | |||||
<LangVersion>10.0</LangVersion> | |||||
<Nullable>enable</Nullable> | <Nullable>enable</Nullable> | ||||
<RootNamespace>Tensorflow.Keras</RootNamespace> | <RootNamespace>Tensorflow.Keras</RootNamespace> | ||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
<Version>0.7.0</Version> | |||||
<Version>0.10.0</Version> | |||||
<Authors>Haiping Chen</Authors> | <Authors>Haiping Chen</Authors> | ||||
<Product>Keras for .NET</Product> | <Product>Keras for .NET</Product> | ||||
<Copyright>Apache 2.0, Haiping Chen 2021</Copyright> | <Copyright>Apache 2.0, Haiping Chen 2021</Copyright> | ||||
@@ -37,9 +37,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<RepositoryType>Git</RepositoryType> | <RepositoryType>Git</RepositoryType> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
<AssemblyVersion>0.7.0.0</AssemblyVersion> | |||||
<FileVersion>0.7.0.0</FileVersion> | |||||
<AssemblyVersion>0.10.0.0</AssemblyVersion> | |||||
<FileVersion>0.10.0.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<Configurations>Debug;Release;GPU</Configurations> | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | ||||
@@ -47,6 +48,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<AllowUnsafeBlocks>false</AllowUnsafeBlocks> | <AllowUnsafeBlocks>false</AllowUnsafeBlocks> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|AnyCPU'"> | |||||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||||
<AllowUnsafeBlocks>false</AllowUnsafeBlocks> | |||||
</PropertyGroup> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | ||||
<AllowUnsafeBlocks>false</AllowUnsafeBlocks> | <AllowUnsafeBlocks>false</AllowUnsafeBlocks> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -55,6 +61,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | <DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='GPU|x64'"> | |||||
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile> | |||||
</PropertyGroup> | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> | ||||
<DefineConstants /> | <DefineConstants /> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -134,7 +134,7 @@ namespace Tensorflow.Keras | |||||
/// <param name="data_format"></param> | /// <param name="data_format"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor max_pooling2d(Tensor inputs, | |||||
public Tensor MaxPooling2D(Tensor inputs, | |||||
int[] pool_size, | int[] pool_size, | ||||
int[] strides, | int[] strides, | ||||
string padding = "valid", | string padding = "valid", | ||||
@@ -0,0 +1,16 @@ | |||||
{ | |||||
// Use IntelliSense to learn about possible attributes. | |||||
// Hover to view descriptions of existing attributes. | |||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | |||||
"version": "0.2.0", | |||||
"configurations": [ | |||||
{ | |||||
"name": "Python: Current File", | |||||
"type": "python", | |||||
"request": "launch", | |||||
"program": "${file}", | |||||
"console": "integratedTerminal", | |||||
"justMyCode": false | |||||
} | |||||
] | |||||
} |
@@ -0,0 +1,15 @@ | |||||
import numpy as np | |||||
import tensorflow as tf | |||||
# tf.experimental.numpy | |||||
inputs = np.random.random([32, 10, 8]).astype(np.float32) | |||||
simple_rnn = tf.keras.layers.SimpleRNN(4) | |||||
output = simple_rnn(inputs) # The output has shape `[32, 4]`. | |||||
simple_rnn = tf.keras.layers.SimpleRNN( | |||||
4, return_sequences=True, return_state=True) | |||||
# whole_sequence_output has shape `[32, 10, 4]`. | |||||
# final_state has shape `[32, 4]`. | |||||
whole_sequence_output, final_state = simple_rnn(inputs) |
@@ -83,7 +83,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
{ 2.5f, 2.6f, 2.7f, 2.8f }, | { 2.5f, 2.6f, 2.7f, 2.8f }, | ||||
{ 3.5f, 3.6f, 3.7f, 3.8f } | { 3.5f, 3.6f, 3.7f, 3.8f } | ||||
} }, dtype: np.float32); | } }, dtype: np.float32); | ||||
var attention_layer = keras.layers.Attention(); | |||||
var attention_layer = (Attention)keras.layers.Attention(); | |||||
//attention_layer.build(((1, 2, 4), (1, 3, 4))); | //attention_layer.build(((1, 2, 4), (1, 3, 4))); | ||||
var actual = attention_layer._calculate_scores(query: q, key: k); | var actual = attention_layer._calculate_scores(query: q, key: k); | ||||
// Expected tensor of shape [1, 2, 3]. | // Expected tensor of shape [1, 2, 3]. | ||||
@@ -116,7 +116,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
{ 2.5f, 2.6f, 2.7f, 2.8f }, | { 2.5f, 2.6f, 2.7f, 2.8f }, | ||||
{ 3.5f, 3.6f, 3.7f, 3.8f } | { 3.5f, 3.6f, 3.7f, 3.8f } | ||||
} }, dtype: np.float32); | } }, dtype: np.float32); | ||||
var attention_layer = keras.layers.Attention(score_mode: "concat"); | |||||
var attention_layer = (Attention)keras.layers.Attention(score_mode: "concat"); | |||||
//attention_layer.concat_score_weight = 1; | //attention_layer.concat_score_weight = 1; | ||||
attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { | attention_layer.concat_score_weight = base_layer_utils.make_variable(new VariableArgs() { | ||||
Name = "concat_score_weight", | Name = "concat_score_weight", | ||||
@@ -148,10 +148,9 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
[Ignore] | |||||
public void SimpleRNN() | public void SimpleRNN() | ||||
{ | { | ||||
var inputs = np.random.rand(32, 10, 8).astype(np.float32); | |||||
var inputs = np.random.random((32, 10, 8)).astype(np.float32); | |||||
var simple_rnn = keras.layers.SimpleRNN(4); | var simple_rnn = keras.layers.SimpleRNN(4); | ||||
var output = simple_rnn.Apply(inputs); | var output = simple_rnn.Apply(inputs); | ||||
Assert.AreEqual((32, 4), output.shape); | Assert.AreEqual((32, 4), output.shape); | ||||
@@ -4,7 +4,7 @@ | |||||
<TargetFramework>net6.0</TargetFramework> | <TargetFramework>net6.0</TargetFramework> | ||||
<IsPackable>false</IsPackable> | <IsPackable>false</IsPackable> | ||||
<LangVersion>11.0</LangVersion> | |||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -11,7 +11,7 @@ | |||||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | ||||
<LangVersion>9.0</LangVersion> | |||||
<LangVersion>11.0</LangVersion> | |||||
<Platforms>AnyCPU;x64</Platforms> | <Platforms>AnyCPU;x64</Platforms> | ||||
</PropertyGroup> | </PropertyGroup> | ||||