@@ -27,7 +27,8 @@ namespace Tensorflow | |||
public IInitializer zeros_initializer => new Zeros(); | |||
public IInitializer ones_initializer => new Ones(); | |||
public IInitializer glorot_uniform_initializer => new GlorotUniform(); | |||
public IInitializer uniform_initializer => new RandomUniform(); | |||
public IInitializer random_uniform_initializer => new RandomUniform(); | |||
public IInitializer orthogonal_initializer => new Orthogonal(); | |||
public variable_scope variable_scope(string name, | |||
string default_name = null, | |||
@@ -20,7 +20,9 @@ namespace Tensorflow.Keras | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features }); | |||
return _op.output; | |||
}; | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras | |||
{ | |||
public partial class Activations | |||
{ | |||
public Activation Sigmoid = (features, name) => | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Sigmoid", name, | |||
null, | |||
features); | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
}; | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras | |||
{ | |||
public partial class Activations | |||
{ | |||
public Activation Tanh = (features, name) => | |||
{ | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Tanh", name, | |||
null, | |||
features); | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
}; | |||
} | |||
} |
@@ -0,0 +1,15 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class EmbeddingArgs : LayerArgs | |||
{ | |||
public int InputDim { get; set; } | |||
public int OutputDim { get; set; } | |||
public bool MaskZero { get; set; } | |||
public int InputLength { get; set; } = -1; | |||
public IInitializer EmbeddingsInitializer { get; set; } | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class LSTMArgs : RNNArgs | |||
{ | |||
public int Units { get; set; } | |||
public Activation Activation { get; set; } | |||
public Activation RecurrentActivation { get; set; } | |||
public IInitializer KernelInitializer { get; set; } | |||
public IInitializer RecurrentInitializer { get; set; } | |||
public IInitializer BiasInitializer { get; set; } | |||
public bool UnitForgetBias { get; set; } | |||
public float Dropout { get; set; } | |||
public float RecurrentDropout { get; set; } | |||
public int Implementation { get; set; } | |||
public bool ReturnSequences { get; set; } | |||
public bool ReturnState { get; set; } | |||
public bool GoBackwards { get; set; } | |||
public bool Stateful { get; set; } | |||
public bool TimeMajor { get; set; } | |||
public bool Unroll { get; set; } | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class LSTMCellArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class RNNArgs : LayerArgs | |||
{ | |||
} | |||
} |
@@ -26,17 +26,22 @@ namespace Tensorflow.Keras.Engine | |||
public int? ndim; | |||
public int? min_ndim; | |||
Dictionary<int, int> axes; | |||
TensorShape shape; | |||
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | |||
int? ndim = null, | |||
int? min_ndim = null, | |||
Dictionary<int, int> axes = null) | |||
Dictionary<int, int> axes = null, | |||
TensorShape shape = null) | |||
{ | |||
this.ndim = ndim; | |||
if (axes == null) | |||
axes = new Dictionary<int, int>(); | |||
this.axes = axes; | |||
this.min_ndim = min_ndim; | |||
this.shape = shape; | |||
if (ndim == null && shape != null) | |||
this.ndim = shape.ndim; | |||
} | |||
} | |||
} |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Operations.Activation; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
@@ -100,5 +101,46 @@ namespace Tensorflow.Keras.Engine | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
protected Layer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer recurrent_initializer = null, | |||
IInitializer bias_initializer = null, | |||
bool unit_forget_bias = true, | |||
float dropout = 0f, | |||
float recurrent_dropout = 0f, | |||
int implementation = 2, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool time_major = false, | |||
bool unroll = false) | |||
{ | |||
var layer = new LSTM(new LSTMArgs | |||
{ | |||
Units = units, | |||
Activation = activation ?? tf.keras.activations.Tanh, | |||
RecurrentActivation = recurrent_activation ?? tf.keras.activations.Sigmoid, | |||
KernelInitializer = kernel_initializer ?? tf.glorot_uniform_initializer, | |||
RecurrentInitializer = recurrent_initializer ?? tf.orthogonal_initializer, | |||
BiasInitializer = bias_initializer ?? tf.zeros_initializer, | |||
Dropout = dropout, | |||
RecurrentDropout = recurrent_dropout, | |||
Implementation = implementation, | |||
ReturnSequences = return_sequences, | |||
ReturnState = return_state, | |||
GoBackwards = go_backwards, | |||
Stateful = stateful, | |||
TimeMajor = time_major, | |||
Unroll = unroll | |||
}); | |||
_layers.Add(layer); | |||
return layer; | |||
} | |||
} | |||
} |
@@ -144,7 +144,9 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.graph_mode(); | |||
tf_with(ops.name_scope(nameScope), scope => | |||
{ | |||
if (!built) | |||
@@ -157,6 +159,8 @@ namespace Tensorflow.Keras.Engine | |||
_set_mask_metadata(inputs, outputs, null); | |||
}); | |||
tf.Context.eager_mode(); | |||
return outputs; | |||
} | |||
@@ -17,7 +17,7 @@ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Layers; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -56,6 +56,8 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
// Set metadata on outputs. | |||
var node_index = layer.InboundNodes.Count - 1; | |||
args.Outputs.KerasHistory.Add(layer); | |||
} | |||
} | |||
} |
@@ -14,17 +14,22 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Layers; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class Sequential : Model, ITensorFlowObject | |||
/// <summary> | |||
/// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. | |||
/// `Sequential` provides training and inference features on this model. | |||
/// </summary> | |||
public class Sequential | |||
{ | |||
#pragma warning disable CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false | |||
bool _is_graph_network; | |||
#pragma warning restore CS0649 // Field 'Sequential._is_graph_network' is never assigned to, and will always have its default value false | |||
Tensor inputs; | |||
Tensor outputs; | |||
bool computeOutputAndMaskJointly; | |||
@@ -32,26 +37,24 @@ namespace Tensorflow.Keras.Engine | |||
TensorShape inferredInputShape; | |||
bool hasExplicitInputShape; | |||
TF_DataType inputDType; | |||
Layer[] layers; | |||
List<Layer> layers; | |||
public TensorShape output_shape => outputs.TensorShape; | |||
bool built = false; | |||
public Sequential(Layer[] layers = null, string name = null) | |||
: base(new ModelArgs { Name = name}) | |||
{ | |||
this.layers = layers ?? new Layer[0]; | |||
SupportsMasking = true; | |||
this.layers = layers == null ? new List<Layer>() : layers.ToList(); | |||
// SupportsMasking = true; | |||
computeOutputAndMaskJointly = true; | |||
autoTrackSubLayers = false; | |||
hasExplicitInputShape = false; | |||
_is_graph_network = false; | |||
} | |||
public void __enter__() | |||
public void add(Tensor tensor) | |||
{ | |||
} | |||
public void add(Tensor layer) | |||
{ | |||
var layer = tensor.KerasHistory[0]; | |||
add(layer); | |||
} | |||
/// <summary> | |||
@@ -62,9 +65,9 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
built = false; | |||
var set_inputs = false; | |||
if(layers.Length == 0) | |||
if (layers.Count == 0) | |||
{ | |||
if(layer is InputLayer) | |||
if (layer is InputLayer) | |||
{ | |||
set_inputs = true; | |||
} | |||
@@ -93,31 +96,33 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
} | |||
else if (outputs != null) | |||
{ | |||
outputs = layer.Apply(outputs); | |||
} | |||
if (set_inputs || _is_graph_network) | |||
{ | |||
_init_graph_network(inputs, outputs); | |||
} | |||
} | |||
public void __exit__() | |||
{ | |||
} | |||
public void Dispose() | |||
{ | |||
else | |||
{ | |||
} | |||
} | |||
public void __init__() | |||
void _init_graph_network(Tensor inputs, Tensor outputs) | |||
{ | |||
_is_graph_network = true; | |||
this.inputs = inputs; | |||
this.outputs = outputs; | |||
built = true; | |||
_map_graph_network(inputs, outputs); | |||
} | |||
public void __del__() | |||
void _map_graph_network(Tensor inputs, Tensor outputs) | |||
{ | |||
layers.add(outputs.KerasHistory[0]); | |||
} | |||
} | |||
} |
@@ -63,18 +63,9 @@ namespace Tensorflow | |||
var layer = new InputLayer(args); | |||
return layer.InboundNodes[0].Outputs[0]; | |||
return layer.InboundNodes[0].Outputs; | |||
} | |||
public static Embedding Embedding(int input_dim, | |||
int output_dim, | |||
IInitializer embeddings_initializer = null, | |||
bool mask_zero = false) | |||
=> new Embedding(input_dim, | |||
output_dim, | |||
embeddings_initializer, | |||
mask_zero); | |||
public class LayersApi | |||
{ | |||
public Layer Dense(int units, | |||
@@ -86,6 +77,30 @@ namespace Tensorflow | |||
Activation = activation ?? tf.keras.activations.Linear, | |||
InputShape = input_shape | |||
}); | |||
/// <summary> | |||
/// Turns positive integers (indexes) into dense vectors of fixed size. | |||
/// </summary> | |||
/// <param name="input_dim"></param> | |||
/// <param name="output_dim"></param> | |||
/// <param name="embeddings_initializer"></param> | |||
/// <param name="mask_zero"></param> | |||
/// <returns></returns> | |||
public Embedding Embedding(int input_dim, | |||
int output_dim, | |||
IInitializer embeddings_initializer = null, | |||
bool mask_zero = false, | |||
TensorShape input_shape = null, | |||
int input_length = -1) | |||
=> new Embedding(new EmbeddingArgs | |||
{ | |||
InputDim = input_dim, | |||
OutputDim = output_dim, | |||
MaskZero = mask_zero, | |||
InputShape = input_shape ?? input_length, | |||
InputLength = input_length, | |||
EmbeddingsInitializer = embeddings_initializer | |||
}); | |||
} | |||
} | |||
} |
@@ -20,40 +20,37 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
/// <summary> | |||
/// Turns positive integers (indexes) into dense vectors of fixed size. | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | |||
/// </summary> | |||
public class Embedding : Layer | |||
{ | |||
private int input_dim; | |||
private int output_dim; | |||
private bool mask_zero; | |||
public IVariableV1 embeddings; | |||
public IInitializer embeddings_initializer; | |||
int input_length; | |||
EmbeddingArgs args; | |||
int input_dim => args.InputDim; | |||
int output_dim => args.OutputDim; | |||
bool mask_zero => args.MaskZero; | |||
IVariableV1 embeddings; | |||
IInitializer embeddings_initializer; | |||
public Embedding(int input_dim, int output_dim, | |||
IInitializer embeddings_initializer = null, | |||
bool mask_zero = false, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int[] input_shape = null, | |||
int input_length = -1) : | |||
base(new LayerArgs | |||
{ | |||
DType = dtype, | |||
InputShape = input_shape ?? new[] { input_length } | |||
}) | |||
public Embedding(EmbeddingArgs args) | |||
: base(args) | |||
{ | |||
this.input_dim = input_dim; | |||
this.output_dim = output_dim; | |||
this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer; | |||
this.mask_zero = mask_zero; | |||
this.args = args; | |||
if(args.InputShape == null) | |||
args.InputShape = args.InputLength; | |||
embeddings_initializer = embeddings_initializer ?? tf.random_uniform_initializer; | |||
SupportsMasking = mask_zero; | |||
this.input_length = input_length; | |||
} | |||
protected override void build(TensorShape input_shape) | |||
{ | |||
embeddings = add_weight(shape: new int[] { input_dim, output_dim }, | |||
tf.Context.eager_mode(); | |||
embeddings = add_weight(shape: (input_dim, output_dim), | |||
initializer: embeddings_initializer, | |||
name: "embeddings"); | |||
tf.Context.graph_mode(); | |||
built = true; | |||
} | |||
@@ -63,7 +60,7 @@ namespace Tensorflow.Keras.Layers | |||
if (dtype != tf.int32 && dtype != tf.int64) | |||
inputs = math_ops.cast(inputs, tf.int32); | |||
var outputs = embedding_ops.embedding_lookup(embeddings, inputs[0]); | |||
var outputs = embedding_ops.embedding_lookup(embeddings, inputs); | |||
return outputs; | |||
} | |||
} | |||
@@ -53,6 +53,11 @@ namespace Tensorflow.Keras.Layers | |||
args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); | |||
} | |||
if(args.DType == TF_DataType.DtInvalid) | |||
{ | |||
args.DType = args.InputTensor == null ? tf.float32 : args.InputTensor.dtype; | |||
} | |||
if (args.InputTensor == null) | |||
{ | |||
if(args.InputShape != null) | |||
@@ -72,7 +77,8 @@ namespace Tensorflow.Keras.Layers | |||
shape: BatchInputShape, | |||
dtype: DType, | |||
name: Name, | |||
sparse: args.Sparse); | |||
sparse: args.Sparse, | |||
ragged: args.Ragged); | |||
tf.Context.eager_mode(); | |||
isPlaceholder = true; | |||
@@ -0,0 +1,37 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
/// <summary> | |||
/// Long Short-Term Memory layer - Hochreiter 1997. | |||
/// | |||
/// See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn) | |||
/// for details about the usage of RNN API. | |||
/// </summary> | |||
public class LSTM : RNN | |||
{ | |||
LSTMArgs args; | |||
InputSpec[] state_spec; | |||
int units => args.Units; | |||
public LSTM(LSTMArgs args) : | |||
base(args) | |||
{ | |||
this.args = args; | |||
state_spec = new[] { units, units } | |||
.Select(dim => new InputSpec(shape: (-1, dim))) | |||
.ToArray(); | |||
} | |||
protected override Tensor call(Tensor inputs, bool is_training = false, Tensor state = null) | |||
{ | |||
return base.call(inputs, is_training, state); | |||
} | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Runtime.CompilerServices; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class LSTMCell : Layer | |||
{ | |||
LSTMCellArgs args; | |||
public LSTMCell(LSTMCellArgs args) | |||
: base(args) | |||
{ | |||
this.args = args; | |||
} | |||
} | |||
} |
@@ -0,0 +1,27 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class RNN : Layer | |||
{ | |||
public RNN(RNNArgs args) | |||
: base(args) | |||
{ | |||
} | |||
protected Tensor get_initial_state(Tensor inputs) | |||
{ | |||
return _generate_zero_filled_state_for_cell(null, null); | |||
} | |||
Tensor _generate_zero_filled_state_for_cell(LSTMCell cell, Tensor batch_size) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.Initializers | |||
{ | |||
public class Orthogonal : IInitializer | |||
{ | |||
public Tensor Apply(InitializerArgs args) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -29,7 +29,7 @@ namespace Tensorflow.Operations.Initializers | |||
#pragma warning restore CS0649 // Field 'RandomUniform.maxval' is never assigned to, and will always have its default value 0 | |||
private TF_DataType dtype; | |||
public RandomUniform(TF_DataType dtype = TF_DataType.DtInvalid) | |||
public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.dtype = dtype; | |||
} | |||
@@ -37,7 +37,7 @@ namespace Tensorflow.Operations.Initializers | |||
public Tensor Apply(InitializerArgs args) | |||
{ | |||
if (args.DType == TF_DataType.DtInvalid) | |||
args.DType = this.dtype; | |||
args.DType = dtype; | |||
return random_ops.random_uniform(args.Shape, | |||
minval: minval, | |||
@@ -193,7 +193,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
public static Tensor identity(Tensor input, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Identity", name, | |||
@@ -140,7 +140,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ReadVariableOp", name, | |||
@@ -5,7 +5,7 @@ | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
<Version>0.20.0-preview3</Version> | |||
<Version>0.20.0-preview4</Version> | |||
<LangVersion>8.0</LangVersion> | |||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
<Company>SciSharp STACK</Company> | |||
@@ -25,6 +25,7 @@ using System.Text; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow | |||
{ | |||
@@ -97,6 +98,8 @@ namespace Tensorflow | |||
/// </summary> | |||
public SafeTensorHandleHandle EagerTensorHandle { get; set; } | |||
public bool IsEagerTensor => this is EagerTensor; | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
/// </summary> | |||
@@ -138,6 +141,11 @@ namespace Tensorflow | |||
public TensorShape TensorShape => rank < 0 ? new TensorShape() : tensor_util.to_shape(shape); | |||
/// <summary> | |||
/// Keras History: (Layer, (node_index, tensor_index)) | |||
/// </summary> | |||
public List<Layer> KerasHistory = new List<Layer>(); | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
/// </summary> | |||
@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest.Keras | |||
/// <summary> | |||
/// https://www.tensorflow.org/versions/r2.3/api_docs/python/tf/keras/layers | |||
/// </summary> | |||
[TestClass, Ignore] | |||
[TestClass] | |||
public class LayersTest : GraphModeTestBase | |||
{ | |||
[TestMethod] | |||
@@ -23,11 +23,15 @@ namespace TensorFlowNET.UnitTest.Keras | |||
model.add(tf.keras.Input(shape: 16)); | |||
} | |||
[TestMethod] | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding | |||
/// </summary> | |||
[TestMethod, Ignore] | |||
public void Embedding() | |||
{ | |||
var model = new Sequential(); | |||
model.add(new Embedding(1000, 64, input_length: 10)); | |||
var layer = tf.keras.layers.Embedding(1000, 64, input_length: 10); | |||
model.add(layer); | |||
// the model will take as input an integer matrix of size (batch, | |||
// input_length). | |||
// the largest integer (i.e. word index) in the input should be no larger | |||
@@ -35,15 +39,32 @@ namespace TensorFlowNET.UnitTest.Keras | |||
// now model.output_shape == (None, 10, 64), where None is the batch | |||
// dimension. | |||
var input_array = np.random.randint(1000, size: (32, 10)); | |||
model.compile("rmsprop", "mse"); | |||
// model.compile("rmsprop", "mse"); | |||
// output_array = model.predict(input_array) | |||
} | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense | |||
/// </summary> | |||
[TestMethod] | |||
public void Dense() | |||
{ | |||
// Create a `Sequential` model and add a Dense layer as the first layer. | |||
var model = tf.keras.Sequential(); | |||
var dense_layer = tf.keras.layers.Dense(5, input_shape: 3); | |||
model.add(dense_layer); | |||
model.add(tf.keras.Input(shape: 16)); | |||
model.add(tf.keras.layers.Dense(32, activation: tf.keras.activations.Relu)); | |||
// Now the model will take as input arrays of shape (None, 16) | |||
// and output arrays of shape (None, 32). | |||
// Note that after the first layer, you don't need to specify | |||
// the size of the input anymore: | |||
model.add(tf.keras.layers.Dense(32)); | |||
Assert.AreEqual((-1, 32), model.output_shape); | |||
} | |||
[TestMethod] | |||
public void SimpleRNN() | |||
{ | |||
} | |||
} | |||
} |