@@ -1,6 +1,21 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class RNNArgs : LayerArgs | |||
{ | |||
public interface IRnnArgCell : ILayer | |||
{ | |||
object state_size { get; } | |||
} | |||
public IRnnArgCell Cell { get; set; } = null; | |||
public bool ReturnSequences { get; set; } = false; | |||
public bool ReturnState { get; set; } = false; | |||
public bool GoBackwards { get; set; } = false; | |||
public bool Stateful { get; set; } = false; | |||
public bool Unroll { get; set; } = false; | |||
public bool TimeMajor { get; set; } = false; | |||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||
} | |||
} |
@@ -0,0 +1,30 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class SimpleRNNArgs : RNNArgs | |||
{ | |||
public int Units { get; set; } | |||
public Activation Activation { get; set; } | |||
// units, | |||
// activation='tanh', | |||
// use_bias=True, | |||
// kernel_initializer='glorot_uniform', | |||
// recurrent_initializer='orthogonal', | |||
// bias_initializer='zeros', | |||
// kernel_regularizer=None, | |||
// recurrent_regularizer=None, | |||
// bias_regularizer=None, | |||
// activity_regularizer=None, | |||
// kernel_constraint=None, | |||
// recurrent_constraint=None, | |||
// bias_constraint=None, | |||
// dropout=0., | |||
// recurrent_dropout=0., | |||
// return_sequences=False, | |||
// return_state=False, | |||
// go_backwards=False, | |||
// stateful=False, | |||
// unroll=False, | |||
// **kwargs): | |||
} | |||
} |
@@ -0,0 +1,9 @@ | |||
using System.Collections.Generic; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class StackedRNNCellsArgs : LayerArgs | |||
{ | |||
public IList<RnnCell> Cells { get; set; } | |||
} | |||
} |
@@ -46,7 +46,7 @@ namespace Tensorflow | |||
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` | |||
/// for each `s` in `self.batch_size`. | |||
/// </summary> | |||
public abstract class RnnCell : ILayer | |||
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||
{ | |||
/// <summary> | |||
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight | |||
@@ -1,4 +1,5 @@ | |||
using NumSharp; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
@@ -327,6 +328,24 @@ namespace Tensorflow.Keras.Layers | |||
Alpha = alpha | |||
}); | |||
public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh"); | |||
public Layer SimpleRNN(int units, | |||
Activation activation = null) | |||
=> new SimpleRNN(new SimpleRNNArgs | |||
{ | |||
Units = units, | |||
Activation = activation | |||
}); | |||
public Layer SimpleRNN(int units, | |||
string activation = "tanh") | |||
=> new SimpleRNN(new SimpleRNNArgs | |||
{ | |||
Units = units, | |||
Activation = GetActivationByName(activation) | |||
}); | |||
public Layer LSTM(int units, | |||
Activation activation = null, | |||
Activation recurrent_activation = null, | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
@@ -6,12 +7,93 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
public class RNN : Layer | |||
{ | |||
public RNN(RNNArgs args) | |||
: base(args) | |||
private RNNArgs args; | |||
public RNN(RNNArgs args) : base(PreConstruct(args)) | |||
{ | |||
this.args = args; | |||
SupportsMasking = true; | |||
// The input shape is unknown yet, it could have nested tensor inputs, and | |||
// the input spec will be the list of specs for nested inputs, the structure | |||
// of the input_spec will be the same as the input. | |||
//self.input_spec = None | |||
//self.state_spec = None | |||
//self._states = None | |||
//self.constants_spec = None | |||
//self._num_constants = 0 | |||
//if stateful: | |||
// if ds_context.has_strategy(): | |||
// raise ValueError('RNNs with stateful=True not yet supported with ' | |||
// 'tf.distribute.Strategy.') | |||
} | |||
private static RNNArgs PreConstruct(RNNArgs args) | |||
{ | |||
if (args.Kwargs == null) | |||
{ | |||
args.Kwargs = new Dictionary<string, object>(); | |||
} | |||
// If true, the output for masked timestep will be zeros, whereas in the | |||
// false case, output from previous timestep is returned for masked timestep. | |||
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); | |||
object input_shape; | |||
var propIS = args.Kwargs.Get("input_shape", null); | |||
var propID = args.Kwargs.Get("input_dim", null); | |||
var propIL = args.Kwargs.Get("input_length", null); | |||
if (propIS == null && (propID != null || propIL != null)) | |||
{ | |||
input_shape = ( | |||
propIL ?? new NoneValue(), // maybe null is needed here | |||
propID ?? new NoneValue()); // and here | |||
args.Kwargs["input_shape"] = input_shape; | |||
} | |||
return args; | |||
} | |||
public RNN New(LayerRnnCell cell, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool unroll = false, | |||
bool time_major = false) | |||
=> new RNN(new RNNArgs | |||
{ | |||
Cell = cell, | |||
ReturnSequences = return_sequences, | |||
ReturnState = return_state, | |||
GoBackwards = go_backwards, | |||
Stateful = stateful, | |||
Unroll = unroll, | |||
TimeMajor = time_major | |||
}); | |||
public RNN New(IList<RnnCell> cell, | |||
bool return_sequences = false, | |||
bool return_state = false, | |||
bool go_backwards = false, | |||
bool stateful = false, | |||
bool unroll = false, | |||
bool time_major = false) | |||
=> new RNN(new RNNArgs | |||
{ | |||
Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }), | |||
ReturnSequences = return_sequences, | |||
ReturnState = return_state, | |||
GoBackwards = go_backwards, | |||
Stateful = stateful, | |||
Unroll = unroll, | |||
TimeMajor = time_major | |||
}); | |||
protected Tensor get_initial_state(Tensor inputs) | |||
{ | |||
return _generate_zero_filled_state_for_cell(null, null); | |||
@@ -0,0 +1,14 @@ | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class SimpleRNN : RNN | |||
{ | |||
public SimpleRNN(RNNArgs args) : base(args) | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,125 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | |||
{ | |||
public IList<RnnCell> Cells { get; set; } | |||
public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | |||
{ | |||
Cells = args.Cells; | |||
//Cells.reverse_state_order = kwargs.pop('reverse_state_order', False); | |||
// self.reverse_state_order = kwargs.pop('reverse_state_order', False) | |||
// if self.reverse_state_order: | |||
// logging.warning('reverse_state_order=True in StackedRNNCells will soon ' | |||
// 'be deprecated. Please update the code to work with the ' | |||
// 'natural order of states if you rely on the RNN states, ' | |||
// 'eg RNN(return_state=True).') | |||
// super(StackedRNNCells, self).__init__(**kwargs) | |||
throw new NotImplementedException(""); | |||
} | |||
public object state_size | |||
{ | |||
get => throw new NotImplementedException(); | |||
} | |||
//@property | |||
//def state_size(self) : | |||
// return tuple(c.state_size for c in | |||
// (self.cells[::- 1] if self.reverse_state_order else self.cells)) | |||
// @property | |||
// def output_size(self) : | |||
// if getattr(self.cells[-1], 'output_size', None) is not None: | |||
// return self.cells[-1].output_size | |||
// elif _is_multiple_state(self.cells[-1].state_size) : | |||
// return self.cells[-1].state_size[0] | |||
// else: | |||
// return self.cells[-1].state_size | |||
// def get_initial_state(self, inputs= None, batch_size= None, dtype= None) : | |||
// initial_states = [] | |||
// for cell in self.cells[::- 1] if self.reverse_state_order else self.cells: | |||
// get_initial_state_fn = getattr(cell, 'get_initial_state', None) | |||
// if get_initial_state_fn: | |||
// initial_states.append(get_initial_state_fn( | |||
// inputs=inputs, batch_size=batch_size, dtype=dtype)) | |||
// else: | |||
// initial_states.append(_generate_zero_filled_state_for_cell( | |||
// cell, inputs, batch_size, dtype)) | |||
// return tuple(initial_states) | |||
// def call(self, inputs, states, constants= None, training= None, ** kwargs): | |||
// # Recover per-cell states. | |||
// state_size = (self.state_size[::- 1] | |||
// if self.reverse_state_order else self.state_size) | |||
// nested_states = nest.pack_sequence_as(state_size, nest.flatten(states)) | |||
// # Call the cells in order and store the returned states. | |||
// new_nested_states = [] | |||
// for cell, states in zip(self.cells, nested_states) : | |||
// states = states if nest.is_nested(states) else [states] | |||
//# TF cell does not wrap the state into list when there is only one state. | |||
// is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None | |||
// states = states[0] if len(states) == 1 and is_tf_rnn_cell else states | |||
// if generic_utils.has_arg(cell.call, 'training'): | |||
// kwargs['training'] = training | |||
// else: | |||
// kwargs.pop('training', None) | |||
// # Use the __call__ function for callable objects, eg layers, so that it | |||
// # will have the proper name scopes for the ops, etc. | |||
// cell_call_fn = cell.__call__ if callable(cell) else cell.call | |||
// if generic_utils.has_arg(cell.call, 'constants'): | |||
// inputs, states = cell_call_fn(inputs, states, | |||
// constants= constants, ** kwargs) | |||
// else: | |||
// inputs, states = cell_call_fn(inputs, states, ** kwargs) | |||
// new_nested_states.append(states) | |||
// return inputs, nest.pack_sequence_as(state_size, | |||
// nest.flatten(new_nested_states)) | |||
// @tf_utils.shape_type_conversion | |||
// def build(self, input_shape) : | |||
// if isinstance(input_shape, list) : | |||
// input_shape = input_shape[0] | |||
// for cell in self.cells: | |||
// if isinstance(cell, Layer) and not cell.built: | |||
// with K.name_scope(cell.name): | |||
// cell.build(input_shape) | |||
// cell.built = True | |||
// if getattr(cell, 'output_size', None) is not None: | |||
// output_dim = cell.output_size | |||
// elif _is_multiple_state(cell.state_size) : | |||
// output_dim = cell.state_size[0] | |||
// else: | |||
// output_dim = cell.state_size | |||
// input_shape = tuple([input_shape[0]] + | |||
// tensor_shape.TensorShape(output_dim).as_list()) | |||
// self.built = True | |||
// def get_config(self) : | |||
// cells = [] | |||
// for cell in self.cells: | |||
// cells.append(generic_utils.serialize_keras_object(cell)) | |||
// config = {'cells': cells | |||
//} | |||
//base_config = super(StackedRNNCells, self).get_config() | |||
// return dict(list(base_config.items()) + list(config.items())) | |||
// @classmethod | |||
// def from_config(cls, config, custom_objects = None): | |||
// from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top | |||
// cells = [] | |||
// for cell_config in config.pop('cells'): | |||
// cells.append( | |||
// deserialize_layer(cell_config, custom_objects = custom_objects)) | |||
// return cls(cells, **config) | |||
} | |||
} |
@@ -36,7 +36,7 @@ namespace TensorFlowNET.Keras.UnitTest | |||
var model = keras.Model(inputs, outputs, name: "mnist_model"); | |||
model.summary(); | |||
} | |||
/// <summary> | |||
/// Custom layer test, used in Dueling DQN | |||
/// </summary> | |||
@@ -45,10 +45,10 @@ namespace TensorFlowNET.Keras.UnitTest | |||
{ | |||
var layers = keras.layers; | |||
var inputs = layers.Input(shape: 24); | |||
var x = layers.Dense(128, activation:"relu").Apply(inputs); | |||
var x = layers.Dense(128, activation: "relu").Apply(inputs); | |||
var value = layers.Dense(24).Apply(x); | |||
var adv = layers.Dense(1).Apply(x); | |||
var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); | |||
adv = layers.Subtract().Apply((adv, mean)); | |||
var outputs = layers.Add().Apply((value, adv)); | |||
@@ -105,9 +105,14 @@ namespace TensorFlowNET.Keras.UnitTest | |||
} | |||
[TestMethod] | |||
[Ignore] | |||
public void SimpleRNN() | |||
{ | |||
var inputs = np.random.rand(32, 10, 8).astype(np.float32); | |||
var simple_rnn = keras.layers.SimpleRNN(4); | |||
var output = simple_rnn.Apply(inputs); | |||
Assert.AreEqual((32, 4), output.shape); | |||
} | |||
} | |||
} |