Browse Source

Blank SimpleRNN and test for it

tags/yolov3
MPnoy Haiping 4 years ago
parent
commit
d89609a4e3
9 changed files with 307 additions and 8 deletions
  1. +16
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs
  2. +30
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs
  3. +9
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  5. +19
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  6. +84
    -2
      src/TensorFlowNET.Keras/Layers/RNN.cs
  7. +14
    -0
      src/TensorFlowNET.Keras/Layers/SimpleRNN.cs
  8. +125
    -0
      src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs
  9. +9
    -4
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 16
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/RNNArgs.cs View File

@@ -1,6 +1,21 @@
namespace Tensorflow.Keras.ArgsDefinition
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class RNNArgs : LayerArgs public class RNNArgs : LayerArgs
{ {
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}

public IRnnArgCell Cell { get; set; } = null;
public bool ReturnSequences { get; set; } = false;
public bool ReturnState { get; set; } = false;
public bool GoBackwards { get; set; } = false;
public bool Stateful { get; set; } = false;
public bool Unroll { get; set; } = false;
public bool TimeMajor { get; set; } = false;
public Dictionary<string, object> Kwargs { get; set; } = null;
} }
} }

+ 30
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/SimpleRNNArgs.cs View File

@@ -0,0 +1,30 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNArgs : RNNArgs
{
public int Units { get; set; }
public Activation Activation { get; set; }
// units,
// activation='tanh',
// use_bias=True,
// kernel_initializer='glorot_uniform',
// recurrent_initializer='orthogonal',
// bias_initializer='zeros',
// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
}
}

+ 9
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/StackedRNNCellsArgs.cs View File

@@ -0,0 +1,9 @@
using System.Collections.Generic;

namespace Tensorflow.Keras.ArgsDefinition
{
public class StackedRNNCellsArgs : LayerArgs
{
public IList<RnnCell> Cells { get; set; }
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` /// matching structure of Tensors having shape `[batch_size].concatenate(s)`
/// for each `s` in `self.batch_size`. /// for each `s` in `self.batch_size`.
/// </summary> /// </summary>
public abstract class RnnCell : ILayer
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
{ {
/// <summary> /// <summary>
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight /// Attribute that indicates whether the cell is a TF RNN cell, due the slight


+ 19
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -1,4 +1,5 @@
using NumSharp; using NumSharp;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -327,6 +328,24 @@ namespace Tensorflow.Keras.Layers
Alpha = alpha Alpha = alpha
}); });


public Layer SimpleRNN(int units) => SimpleRNN(units, "tanh");

public Layer SimpleRNN(int units,
Activation activation = null)
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = activation
});

public Layer SimpleRNN(int units,
string activation = "tanh")
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
Activation = GetActivationByName(activation)
});

public Layer LSTM(int units, public Layer LSTM(int units,
Activation activation = null, Activation activation = null,
Activation recurrent_activation = null, Activation recurrent_activation = null,


+ 84
- 2
src/TensorFlowNET.Keras/Layers/RNN.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;


@@ -6,12 +7,93 @@ namespace Tensorflow.Keras.Layers
{ {
public class RNN : Layer public class RNN : Layer
{ {
public RNN(RNNArgs args)
: base(args)
private RNNArgs args;

public RNN(RNNArgs args) : base(PreConstruct(args))
{ {
this.args = args;
SupportsMasking = true;

// The input shape is unknown yet, it could have nested tensor inputs, and
// the input spec will be the list of specs for nested inputs, the structure
// of the input_spec will be the same as the input.


//self.input_spec = None
//self.state_spec = None
//self._states = None
//self.constants_spec = None
//self._num_constants = 0

//if stateful:
// if ds_context.has_strategy():
// raise ValueError('RNNs with stateful=True not yet supported with '
// 'tf.distribute.Strategy.')
} }


private static RNNArgs PreConstruct(RNNArgs args)
{
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary<string, object>();
}

// If true, the output for masked timestep will be zeros, whereas in the
// false case, output from previous timestep is returned for masked timestep.
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);

object input_shape;
var propIS = args.Kwargs.Get("input_shape", null);
var propID = args.Kwargs.Get("input_dim", null);
var propIL = args.Kwargs.Get("input_length", null);

if (propIS == null && (propID != null || propIL != null))
{
input_shape = (
propIL ?? new NoneValue(), // maybe null is needed here
propID ?? new NoneValue()); // and here
args.Kwargs["input_shape"] = input_shape;
}

return args;
}

public RNN New(LayerRnnCell cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = cell,
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});

public RNN New(IList<RnnCell> cell,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
bool stateful = false,
bool unroll = false,
bool time_major = false)
=> new RNN(new RNNArgs
{
Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = cell }),
ReturnSequences = return_sequences,
ReturnState = return_state,
GoBackwards = go_backwards,
Stateful = stateful,
Unroll = unroll,
TimeMajor = time_major
});


protected Tensor get_initial_state(Tensor inputs) protected Tensor get_initial_state(Tensor inputs)
{ {
return _generate_zero_filled_state_for_cell(null, null); return _generate_zero_filled_state_for_cell(null, null);


+ 14
- 0
src/TensorFlowNET.Keras/Layers/SimpleRNN.cs View File

@@ -0,0 +1,14 @@
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Layers
{
public class SimpleRNN : RNN
{

public SimpleRNN(RNNArgs args) : base(args)
{

}

}
}

+ 125
- 0
src/TensorFlowNET.Keras/Layers/StackedRNNCells.cs View File

@@ -0,0 +1,125 @@
using System;
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell
{
public IList<RnnCell> Cells { get; set; }

public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
{
Cells = args.Cells;
//Cells.reverse_state_order = kwargs.pop('reverse_state_order', False);
// self.reverse_state_order = kwargs.pop('reverse_state_order', False)
// if self.reverse_state_order:
// logging.warning('reverse_state_order=True in StackedRNNCells will soon '
// 'be deprecated. Please update the code to work with the '
// 'natural order of states if you rely on the RNN states, '
// 'eg RNN(return_state=True).')
// super(StackedRNNCells, self).__init__(**kwargs)
throw new NotImplementedException("");
}

public object state_size
{
get => throw new NotImplementedException();
}

//@property
//def state_size(self) :
// return tuple(c.state_size for c in
// (self.cells[::- 1] if self.reverse_state_order else self.cells))

// @property
// def output_size(self) :
// if getattr(self.cells[-1], 'output_size', None) is not None:
// return self.cells[-1].output_size
// elif _is_multiple_state(self.cells[-1].state_size) :
// return self.cells[-1].state_size[0]
// else:
// return self.cells[-1].state_size

// def get_initial_state(self, inputs= None, batch_size= None, dtype= None) :
// initial_states = []
// for cell in self.cells[::- 1] if self.reverse_state_order else self.cells:
// get_initial_state_fn = getattr(cell, 'get_initial_state', None)
// if get_initial_state_fn:
// initial_states.append(get_initial_state_fn(
// inputs=inputs, batch_size=batch_size, dtype=dtype))
// else:
// initial_states.append(_generate_zero_filled_state_for_cell(
// cell, inputs, batch_size, dtype))

// return tuple(initial_states)

// def call(self, inputs, states, constants= None, training= None, ** kwargs):
// # Recover per-cell states.
// state_size = (self.state_size[::- 1]
// if self.reverse_state_order else self.state_size)
// nested_states = nest.pack_sequence_as(state_size, nest.flatten(states))

// # Call the cells in order and store the returned states.
// new_nested_states = []
// for cell, states in zip(self.cells, nested_states) :
// states = states if nest.is_nested(states) else [states]
//# TF cell does not wrap the state into list when there is only one state.
// is_tf_rnn_cell = getattr(cell, '_is_tf_rnn_cell', None) is not None
// states = states[0] if len(states) == 1 and is_tf_rnn_cell else states
// if generic_utils.has_arg(cell.call, 'training'):
// kwargs['training'] = training
// else:
// kwargs.pop('training', None)
// # Use the __call__ function for callable objects, eg layers, so that it
// # will have the proper name scopes for the ops, etc.
// cell_call_fn = cell.__call__ if callable(cell) else cell.call
// if generic_utils.has_arg(cell.call, 'constants'):
// inputs, states = cell_call_fn(inputs, states,
// constants= constants, ** kwargs)
// else:
// inputs, states = cell_call_fn(inputs, states, ** kwargs)
// new_nested_states.append(states)

// return inputs, nest.pack_sequence_as(state_size,
// nest.flatten(new_nested_states))

// @tf_utils.shape_type_conversion
// def build(self, input_shape) :
// if isinstance(input_shape, list) :
// input_shape = input_shape[0]
// for cell in self.cells:
// if isinstance(cell, Layer) and not cell.built:
// with K.name_scope(cell.name):
// cell.build(input_shape)
// cell.built = True
// if getattr(cell, 'output_size', None) is not None:
// output_dim = cell.output_size
// elif _is_multiple_state(cell.state_size) :
// output_dim = cell.state_size[0]
// else:
// output_dim = cell.state_size
// input_shape = tuple([input_shape[0]] +
// tensor_shape.TensorShape(output_dim).as_list())
// self.built = True

// def get_config(self) :
// cells = []
// for cell in self.cells:
// cells.append(generic_utils.serialize_keras_object(cell))
// config = {'cells': cells
//}
//base_config = super(StackedRNNCells, self).get_config()
// return dict(list(base_config.items()) + list(config.items()))

// @classmethod
// def from_config(cls, config, custom_objects = None):
// from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
// cells = []
// for cell_config in config.pop('cells'):
// cells.append(
// deserialize_layer(cell_config, custom_objects = custom_objects))
// return cls(cells, **config)
}
}

+ 9
- 4
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -36,7 +36,7 @@ namespace TensorFlowNET.Keras.UnitTest
var model = keras.Model(inputs, outputs, name: "mnist_model"); var model = keras.Model(inputs, outputs, name: "mnist_model");
model.summary(); model.summary();
} }
/// <summary> /// <summary>
/// Custom layer test, used in Dueling DQN /// Custom layer test, used in Dueling DQN
/// </summary> /// </summary>
@@ -45,10 +45,10 @@ namespace TensorFlowNET.Keras.UnitTest
{ {
var layers = keras.layers; var layers = keras.layers;
var inputs = layers.Input(shape: 24); var inputs = layers.Input(shape: 24);
var x = layers.Dense(128, activation:"relu").Apply(inputs);
var x = layers.Dense(128, activation: "relu").Apply(inputs);
var value = layers.Dense(24).Apply(x); var value = layers.Dense(24).Apply(x);
var adv = layers.Dense(1).Apply(x); var adv = layers.Dense(1).Apply(x);
var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true); var mean = adv - tf.reduce_mean(adv, axis: 1, keepdims: true);
adv = layers.Subtract().Apply((adv, mean)); adv = layers.Subtract().Apply((adv, mean));
var outputs = layers.Add().Apply((value, adv)); var outputs = layers.Add().Apply((value, adv));
@@ -105,9 +105,14 @@ namespace TensorFlowNET.Keras.UnitTest
} }


[TestMethod] [TestMethod]
[Ignore]
public void SimpleRNN() public void SimpleRNN()
{ {

var inputs = np.random.rand(32, 10, 8).astype(np.float32);
var simple_rnn = keras.layers.SimpleRNN(4);
var output = simple_rnn.Apply(inputs);
Assert.AreEqual((32, 4), output.shape);
} }

} }
} }

Loading…
Cancel
Save