@@ -27,20 +27,6 @@ namespace Tensorflow.Keras.Engine | |||
Dictionary<int, int> tensor_usage_count; | |||
public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | |||
public override List<IVariableV1> trainable_variables | |||
{ | |||
get | |||
{ | |||
var variables = new List<IVariableV1>(); | |||
foreach(var layer in _layers) | |||
{ | |||
if (layer.Trainable) | |||
variables.AddRange(layer.trainable_variables); | |||
} | |||
return variables; | |||
} | |||
} | |||
public Functional(Tensors inputs, Tensors outputs, string name = null) | |||
: base(new ModelArgs | |||
{ | |||
@@ -12,136 +12,10 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
protected List<Layer> _layers = new List<Layer>(); | |||
public List<Layer> Layers => _layers; | |||
protected Layer Dense(int units, | |||
Activation activation = null, | |||
TensorShape input_shape = null) | |||
{ | |||
var layer = new Dense(new DenseArgs | |||
{ | |||
Units = units, | |||
Activation = activation ?? tf.keras.activations.Linear, | |||
InputShape = input_shape | |||
}); | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
protected Layer Conv2D(int filters, | |||
int kernel_size, | |||
TensorShape strides = null, | |||
string padding = "valid", | |||
string data_format = null, | |||
TensorShape dilation_rate = null, | |||
int groups = 1, | |||
Activation activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer bias_initializer = null, | |||
bool trainable = true, | |||
string name = null) | |||
{ | |||
var layer = new Conv2D(new Conv2DArgs | |||
{ | |||
Filters = filters, | |||
KernelSize = kernel_size, | |||
Strides = strides ?? (1, 1), | |||
Padding = padding, | |||
DataFormat = data_format, | |||
DilationRate = dilation_rate ?? (1, 1), | |||
Groups = groups, | |||
Activation = activation, | |||
UseBias = use_bias, | |||
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||
BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||
Trainable = trainable, | |||
Name = name | |||
}); | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
protected Layer MaxPooling2D(TensorShape pool_size, | |||
TensorShape strides, | |||
string padding = "valid", | |||
string data_format = null, | |||
string name = null) | |||
{ | |||
var layer = new MaxPooling2D(new MaxPooling2DArgs | |||
{ | |||
PoolSize = pool_size, | |||
Strides = strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
Name = name | |||
}); | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
protected Layer Dropout(float rate, TensorShape noise_shape = null, int? seed = null) | |||
{ | |||
var layer = new Dropout(new DropoutArgs | |||
{ | |||
Rate = rate, | |||
NoiseShape = noise_shape, | |||
Seed = seed | |||
}); | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
protected Layer Flatten() | |||
protected void StackLayers(params Layer[] layers) | |||
{ | |||
var layer = new Flatten(new FlattenArgs()); | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
protected Layer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer recurrent_initializer = null, | |||
IInitializer bias_initializer = null, | |||
bool unit_forget_bias = true, | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
int implementation = 2, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool time_major = false, | |||
bool unroll = false) | |||
{ | |||
var layer = new LSTM(new LSTMArgs | |||
{ | |||
Units = units, | |||
Activation = activation ?? tf.keras.activations.Tanh, | |||
RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, | |||
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||
RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, | |||
BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||
Dropout = dropout, | |||
RecurrentDropout = recurrent_dropout, | |||
Implementation = implementation, | |||
ReturnSequences = return_sequences, | |||
ReturnState = return_state, | |||
GoBackwards = go_backwards, | |||
Stateful = stateful, | |||
TimeMajor = time_major, | |||
Unroll = unroll | |||
}); | |||
_layers.Add(layer); | |||
return layer; | |||
_layers.AddRange(layers); | |||
} | |||
} | |||
} |
@@ -0,0 +1,53 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
public void compile(string optimizerName, ILossFunc lossName) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
{ | |||
this.optimizer = optimizer; | |||
var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
// Initialize cache attrs. | |||
_reset_compile_cache(); | |||
_is_compiled = true; | |||
this.loss = loss; | |||
} | |||
public void compile(string optimizerName, string lossName) | |||
{ | |||
switch (optimizerName) | |||
{ | |||
case "rmsprop": | |||
optimizer = new RMSprop(new RMSpropArgs | |||
{ | |||
}); | |||
break; | |||
} | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
_reset_compile_cache(); | |||
_is_compiled = true; | |||
} | |||
} | |||
} |
@@ -0,0 +1,56 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
/// <summary> | |||
/// Trains the model for a fixed number of epochs (iterations on a dataset). | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="y"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="epochs"></param> | |||
/// <param name="verbose"></param> | |||
/// <param name="validation_split"></param> | |||
/// <param name="shuffle"></param> | |||
public void fit(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
{ | |||
int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); | |||
var train_x = x[new Slice(0, train_count)]; | |||
var train_y = y[new Slice(0, train_count)]; | |||
var val_x = x[new Slice(train_count)]; | |||
var val_y = y[new Slice(train_count)]; | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = train_x, | |||
Y = train_y, | |||
BatchSize = batch_size, | |||
InitialEpoch = initial_epoch, | |||
Epochs = epochs, | |||
Shuffle = shuffle, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
} | |||
} | |||
} |
@@ -0,0 +1,51 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
/// <summary> | |||
/// Generates output predictions for the input samples. | |||
/// </summary> | |||
/// <param name="x">Input samples</param> | |||
/// <param name="batch_size">Number of samples per batch</param> | |||
/// <param name="verbose">Verbosity mode</param> | |||
/// <param name="steps"> | |||
/// Total number of steps (batches of samples) | |||
/// before declaring the prediction round finished. | |||
/// </param> | |||
/// <param name="max_queue_size"></param> | |||
/// <param name="workers"></param> | |||
/// <param name="use_multiprocessing"></param> | |||
/// <returns></returns> | |||
public Tensor predict(Tensor x, | |||
int batch_size = 32, | |||
int verbose = 0, | |||
int steps = -1, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = x, | |||
BatchSize = batch_size, | |||
StepsPerEpoch = steps, | |||
InitialEpoch = 0, | |||
Epochs = 1, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -5,6 +5,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
using NumSharp; | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -39,84 +40,6 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
{ | |||
this.optimizer = optimizer; | |||
var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
// Initialize cache attrs. | |||
_reset_compile_cache(); | |||
_is_compiled = true; | |||
this.loss = loss; | |||
} | |||
public void compile(string optimizerName, string lossName) | |||
{ | |||
switch (optimizerName) | |||
{ | |||
case "rmsprop": | |||
optimizer = new RMSprop(new RMSpropArgs | |||
{ | |||
}); | |||
break; | |||
} | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
_reset_compile_cache(); | |||
_is_compiled = true; | |||
} | |||
/// <summary> | |||
/// Trains the model for a fixed number of epochs (iterations on a dataset). | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="y"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="epochs"></param> | |||
/// <param name="verbose"></param> | |||
/// <param name="validation_split"></param> | |||
/// <param name="shuffle"></param> | |||
public void fit(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
{ | |||
int train_count = Convert.ToInt32(x.shape[0] * (1 - validation_split)); | |||
var train_x = x[new Slice(0, train_count)]; | |||
var train_y = y[new Slice(0, train_count)]; | |||
var val_x = x[new Slice(train_count)]; | |||
var val_y = y[new Slice(train_count)]; | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = train_x, | |||
Y = train_y, | |||
BatchSize = batch_size, | |||
InitialEpoch = initial_epoch, | |||
Epochs = epochs, | |||
Shuffle = shuffle, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
} | |||
void _configure_steps_per_execution(int steps_per_execution) | |||
{ | |||
_steps_per_execution = tf.Variable(steps_per_execution, | |||
@@ -145,48 +68,18 @@ namespace Tensorflow.Keras.Engine | |||
aggregation: VariableAggregation.OnlyFirstReplica); | |||
} | |||
public void compile(string optimizerName, ILossFunc lossName) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
/// <summary> | |||
/// Generates output predictions for the input samples. | |||
/// </summary> | |||
/// <param name="x">Input samples</param> | |||
/// <param name="batch_size">Number of samples per batch</param> | |||
/// <param name="verbose">Verbosity mode</param> | |||
/// <param name="steps"> | |||
/// Total number of steps (batches of samples) | |||
/// before declaring the prediction round finished. | |||
/// </param> | |||
/// <param name="max_queue_size"></param> | |||
/// <param name="workers"></param> | |||
/// <param name="use_multiprocessing"></param> | |||
/// <returns></returns> | |||
public Tensor predict(Tensor x, | |||
int batch_size = 32, | |||
int verbose = 0, | |||
int steps = -1, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
public override List<IVariableV1> trainable_variables | |||
{ | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
get | |||
{ | |||
X = x, | |||
BatchSize = batch_size, | |||
StepsPerEpoch = steps, | |||
InitialEpoch = 0, | |||
Epochs = 1, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
throw new NotImplementedException(""); | |||
var variables = new List<IVariableV1>(); | |||
foreach (var layer in _layers) | |||
{ | |||
if (layer.Trainable) | |||
variables.AddRange(layer.trainable_variables); | |||
} | |||
return variables; | |||
} | |||
} | |||
} | |||
} |
@@ -10,6 +10,24 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
public class LayersApi | |||
{ | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="filters"></param> | |||
/// <param name="kernel_size"></param> | |||
/// <param name="strides"></param> | |||
/// <param name="padding"></param> | |||
/// <param name="data_format"></param> | |||
/// <param name="dilation_rate"></param> | |||
/// <param name="groups"></param> | |||
/// <param name="activation">tf.keras.activations</param> | |||
/// <param name="use_bias"></param> | |||
/// <param name="kernel_initializer"></param> | |||
/// <param name="bias_initializer"></param> | |||
/// <param name="kernel_regularizer"></param> | |||
/// <param name="bias_regularizer"></param> | |||
/// <param name="activity_regularizer"></param> | |||
/// <returns></returns> | |||
public Conv2D Conv2D(int filters, | |||
TensorShape kernel_size = null, | |||
TensorShape strides = null, | |||
@@ -17,7 +35,7 @@ namespace Tensorflow.Keras.Layers | |||
string data_format = null, | |||
TensorShape dilation_rate = null, | |||
int groups = 1, | |||
string activation = null, | |||
Activation activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer bias_initializer = null, | |||
@@ -40,20 +58,27 @@ namespace Tensorflow.Keras.Layers | |||
BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | |||
BiasRegularizer = bias_regularizer, | |||
ActivityRegularizer = activity_regularizer, | |||
Activation = GetActivationByName(activation) | |||
Activation = activation ?? tf.keras.activations.Linear | |||
}); | |||
public Dense Dense(int units, | |||
string activation = "linear", | |||
Activation activation = null, | |||
TensorShape input_shape = null) | |||
=> new Dense(new DenseArgs | |||
{ | |||
Units = units, | |||
Activation = GetActivationByName(activation), | |||
Activation = activation ?? tf.keras.activations.Linear, | |||
InputShape = input_shape | |||
}); | |||
public Dropout Dropout(float rate, TensorShape noise_shape = null, int? seed = null) | |||
=> new Dropout(new DropoutArgs | |||
{ | |||
Rate = rate, | |||
NoiseShape = noise_shape, | |||
Seed = seed | |||
}); | |||
/// <summary> | |||
/// Turns positive integers (indexes) into dense vectors of fixed size. | |||
/// This layer can only be used as the first layer in a model. | |||
@@ -121,6 +146,42 @@ namespace Tensorflow.Keras.Layers | |||
Padding = padding | |||
}); | |||
public Layer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer recurrent_initializer = null, | |||
IInitializer bias_initializer = null, | |||
bool unit_forget_bias = true, | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
int implementation = 2, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool time_major = false, | |||
bool unroll = false) | |||
=> new LSTM(new LSTMArgs | |||
{ | |||
Units = units, | |||
Activation = activation ?? tf.keras.activations.Tanh, | |||
RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, | |||
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||
RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, | |||
BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||
Dropout = dropout, | |||
RecurrentDropout = recurrent_dropout, | |||
Implementation = implementation, | |||
ReturnSequences = return_sequences, | |||
ReturnState = return_state, | |||
GoBackwards = go_backwards, | |||
Stateful = stateful, | |||
TimeMajor = time_major, | |||
Unroll = unroll | |||
}); | |||
public Rescaling Rescaling(float scale, | |||
float offset = 0, | |||
TensorShape input_shape = null) | |||