feat: add Bidirectional layertags/v0.110.4-Transformer-Model
@@ -0,0 +1,20 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class BidirectionalArgs : AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("layer")] | |||||
public ILayer Layer { get; set; } | |||||
[JsonProperty("merge_mode")] | |||||
public string? MergeMode { get; set; } | |||||
[JsonProperty("backward_layer")] | |||||
public ILayer BackwardLayer { get; set; } | |||||
public NDArray Weights { get; set; } | |||||
} | |||||
} |
@@ -5,5 +5,10 @@ | |||||
// TODO: maybe change the `RNNArgs` and implement this class. | // TODO: maybe change the `RNNArgs` and implement this class. | ||||
public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
public int Implementation { get; set; } | public int Implementation { get; set; } | ||||
public LSTMArgs Clone() | |||||
{ | |||||
return (LSTMArgs)MemberwiseClone(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -40,5 +40,10 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
public bool ZeroOutputForMask { get; set; } = false; | public bool ZeroOutputForMask { get; set; } = false; | ||||
[JsonProperty("recurrent_dropout")] | [JsonProperty("recurrent_dropout")] | ||||
public float RecurrentDropout { get; set; } = .0f; | public float RecurrentDropout { get; set; } = .0f; | ||||
public RNNArgs Clone() | |||||
{ | |||||
return (RNNArgs)MemberwiseClone(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,24 @@ | |||||
using Newtonsoft.Json; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.CompilerServices; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class WrapperArgs : AutoSerializeLayerArgs | |||||
{ | |||||
[JsonProperty("layer")] | |||||
public ILayer Layer { get; set; } | |||||
public WrapperArgs(ILayer layer) | |||||
{ | |||||
Layer = layer; | |||||
} | |||||
public static implicit operator WrapperArgs(BidirectionalArgs args) | |||||
=> new WrapperArgs(args.Layer); | |||||
} | |||||
} |
@@ -258,7 +258,19 @@ namespace Tensorflow.Keras.Layers | |||||
float dropout = 0f, | float dropout = 0f, | ||||
float recurrent_dropout = 0f, | float recurrent_dropout = 0f, | ||||
bool reset_after = true); | bool reset_after = true); | ||||
/// <summary> | |||||
/// Bidirectional wrapper for RNNs. | |||||
/// </summary> | |||||
/// <param name="layer">`keras.layers.RNN` instance, such as `keras.layers.LSTM` or `keras.layers.GRU`</param> | |||||
/// automatically.</param> | |||||
/// <returns></returns> | |||||
public ILayer Bidirectional( | |||||
ILayer layer, | |||||
string merge_mode = "concat", | |||||
NDArray weights = null, | |||||
ILayer backward_layer = null); | |||||
public ILayer Subtract(); | public ILayer Subtract(); | ||||
} | } | ||||
} | } |
@@ -908,6 +908,20 @@ namespace Tensorflow.Keras.Layers | |||||
ResetAfter = reset_after | ResetAfter = reset_after | ||||
}); | }); | ||||
public ILayer Bidirectional( | |||||
ILayer layer, | |||||
string merge_mode = "concat", | |||||
NDArray weights = null, | |||||
ILayer backward_layer = null) | |||||
=> new Bidirectional(new BidirectionalArgs | |||||
{ | |||||
Layer = layer, | |||||
MergeMode = merge_mode, | |||||
Weights = weights, | |||||
BackwardLayer = backward_layer | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
/// </summary> | /// </summary> | ||||
@@ -0,0 +1,33 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Diagnostics; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
/// <summary> | |||||
/// Abstract wrapper base class. Wrappers take another layer and augment it in various ways. | |||||
/// Do not use this class as a layer, it is only an abstract base class. | |||||
/// Two usable wrappers are the `TimeDistributed` and `Bidirectional` wrappers. | |||||
/// </summary> | |||||
public abstract class Wrapper: Layer | |||||
{ | |||||
public ILayer _layer; | |||||
public Wrapper(WrapperArgs args):base(args) | |||||
{ | |||||
_layer = args.Layer; | |||||
} | |||||
public virtual void Build(KerasShapesWrapper input_shape) | |||||
{ | |||||
if (!_layer.Built) | |||||
{ | |||||
_layer.build(input_shape); | |||||
} | |||||
built = true; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,276 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
/// <summary> | |||||
/// Bidirectional wrapper for RNNs. | |||||
/// </summary> | |||||
public class Bidirectional: Wrapper | |||||
{ | |||||
BidirectionalArgs _args; | |||||
RNN _forward_layer; | |||||
RNN _backward_layer; | |||||
RNN _layer; | |||||
bool _support_masking = true; | |||||
int _num_constants = 0; | |||||
bool _return_state; | |||||
bool _stateful; | |||||
bool _return_sequences; | |||||
InputSpec _input_spec; | |||||
RNNArgs _layer_args_copy; | |||||
public Bidirectional(BidirectionalArgs args):base(args) | |||||
{ | |||||
_args = args; | |||||
if (_args.Layer is not ILayer) | |||||
throw new ValueError( | |||||
"Please initialize `Bidirectional` layer with a " + | |||||
$"`tf.keras.layers.Layer` instance. Received: {_args.Layer}"); | |||||
if (_args.BackwardLayer is not null && _args.BackwardLayer is not ILayer) | |||||
throw new ValueError( | |||||
"`backward_layer` need to be a `tf.keras.layers.Layer` " + | |||||
$"instance. Received: {_args.BackwardLayer}"); | |||||
if (!new List<string> { "sum", "mul", "ave", "concat", null }.Contains(_args.MergeMode)) | |||||
{ | |||||
throw new ValueError( | |||||
$"Invalid merge mode. Received: {_args.MergeMode}. " + | |||||
"Merge mode should be one of " + | |||||
"{\"sum\", \"mul\", \"ave\", \"concat\", null}" | |||||
); | |||||
} | |||||
if (_args.Layer is RNN) | |||||
{ | |||||
_layer = _args.Layer as RNN; | |||||
} | |||||
else | |||||
{ | |||||
throw new ValueError( | |||||
"Bidirectional only support RNN instance such as LSTM or GRU"); | |||||
} | |||||
_return_state = _layer.Args.ReturnState; | |||||
_return_sequences = _layer.Args.ReturnSequences; | |||||
_stateful = _layer.Args.Stateful; | |||||
_layer_args_copy = _layer.Args.Clone(); | |||||
// We don't want to track `layer` since we're already tracking the two | |||||
// copies of it we actually run. | |||||
// TODO(Wanglongzhi2001), since the feature of setattr_tracking has not been implemented. | |||||
// _setattr_tracking = false; | |||||
// super().__init__(layer, **kwargs) | |||||
// _setattr_tracking = true; | |||||
// Recreate the forward layer from the original layer config, so that it | |||||
// will not carry over any state from the layer. | |||||
var actualType = _layer.GetType(); | |||||
if (actualType == typeof(LSTM)) | |||||
{ | |||||
var arg = _layer_args_copy as LSTMArgs; | |||||
_forward_layer = new LSTM(arg); | |||||
} | |||||
// TODO(Wanglongzhi2001), add GRU if case. | |||||
else | |||||
{ | |||||
_forward_layer = new RNN(_layer.Cell, _layer_args_copy); | |||||
} | |||||
//_forward_layer = _recreate_layer_from_config(_layer); | |||||
if (_args.BackwardLayer is null) | |||||
{ | |||||
_backward_layer = _recreate_layer_from_config(_layer, go_backwards:true); | |||||
} | |||||
else | |||||
{ | |||||
_backward_layer = _args.BackwardLayer as RNN; | |||||
} | |||||
_forward_layer.Name = "forward_" + _forward_layer.Name; | |||||
_backward_layer.Name = "backward_" + _backward_layer.Name; | |||||
_verify_layer_config(); | |||||
void force_zero_output_for_mask(RNN layer) | |||||
{ | |||||
layer.Args.ZeroOutputForMask = layer.Args.ReturnSequences; | |||||
} | |||||
force_zero_output_for_mask(_forward_layer); | |||||
force_zero_output_for_mask(_backward_layer); | |||||
if (_args.Weights is not null) | |||||
{ | |||||
var nw = len(_args.Weights); | |||||
_forward_layer.set_weights(_args.Weights[$":,{nw / 2}"]); | |||||
_backward_layer.set_weights(_args.Weights[$"{nw / 2},:"]); | |||||
} | |||||
_input_spec = _layer.InputSpec; | |||||
} | |||||
private void _verify_layer_config() | |||||
{ | |||||
if (_forward_layer.Args.GoBackwards == _backward_layer.Args.GoBackwards) | |||||
{ | |||||
throw new ValueError( | |||||
"Forward layer and backward layer should have different " + | |||||
"`go_backwards` value." + | |||||
"forward_layer.go_backwards = " + | |||||
$"{_forward_layer.Args.GoBackwards}," + | |||||
"backward_layer.go_backwards = " + | |||||
$"{_backward_layer.Args.GoBackwards}"); | |||||
} | |||||
if (_forward_layer.Args.Stateful != _backward_layer.Args.Stateful) | |||||
{ | |||||
throw new ValueError( | |||||
"Forward layer and backward layer are expected to have "+ | |||||
$"the same value for attribute stateful, got "+ | |||||
$"{_forward_layer.Args.Stateful} for forward layer and "+ | |||||
$"{_backward_layer.Args.Stateful} for backward layer"); | |||||
} | |||||
if (_forward_layer.Args.ReturnState != _backward_layer.Args.ReturnState) | |||||
{ | |||||
throw new ValueError( | |||||
"Forward layer and backward layer are expected to have " + | |||||
$"the same value for attribute return_state, got " + | |||||
$"{_forward_layer.Args.ReturnState} for forward layer and " + | |||||
$"{_backward_layer.Args.ReturnState} for backward layer"); | |||||
} | |||||
if (_forward_layer.Args.ReturnSequences != _backward_layer.Args.ReturnSequences) | |||||
{ | |||||
throw new ValueError( | |||||
"Forward layer and backward layer are expected to have " + | |||||
$"the same value for attribute return_sequences, got " + | |||||
$"{_forward_layer.Args.ReturnSequences} for forward layer and " + | |||||
$"{_backward_layer.Args.ReturnSequences} for backward layer"); | |||||
} | |||||
} | |||||
private RNN _recreate_layer_from_config(RNN layer, bool go_backwards = false) | |||||
{ | |||||
var config = layer.get_config() as RNNArgs; | |||||
var cell = layer.Cell; | |||||
if (go_backwards) | |||||
{ | |||||
config.GoBackwards = !config.GoBackwards; | |||||
} | |||||
var actualType = layer.GetType(); | |||||
if (actualType == typeof(LSTM)) | |||||
{ | |||||
var arg = config as LSTMArgs; | |||||
return new LSTM(arg); | |||||
} | |||||
else | |||||
{ | |||||
return new RNN(cell, config); | |||||
} | |||||
} | |||||
public override void build(KerasShapesWrapper input_shape) | |||||
{ | |||||
_buildInputShape = input_shape; | |||||
tf_with(ops.name_scope(_forward_layer.Name), scope=> | |||||
{ | |||||
_forward_layer.build(input_shape); | |||||
}); | |||||
tf_with(ops.name_scope(_backward_layer.Name), scope => | |||||
{ | |||||
_backward_layer.build(input_shape); | |||||
}); | |||||
built = true; | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
{ | |||||
// `Bidirectional.call` implements the same API as the wrapped `RNN`. | |||||
Tensors forward_inputs; | |||||
Tensors backward_inputs; | |||||
Tensors forward_state; | |||||
Tensors backward_state; | |||||
// if isinstance(inputs, list) and len(inputs) > 1: | |||||
if (inputs.Length > 1) | |||||
{ | |||||
// initial_states are keras tensors, which means they are passed | |||||
// in together with inputs as list. The initial_states need to be | |||||
// split into forward and backward section, and be feed to layers | |||||
// accordingly. | |||||
forward_inputs = new Tensors { inputs[0] }; | |||||
backward_inputs = new Tensors { inputs[0] }; | |||||
var pivot = (len(inputs) - _num_constants) / 2 + 1; | |||||
// add forward initial state | |||||
forward_inputs.Concat(new Tensors { inputs[$"1:{pivot}"] }); | |||||
if (_num_constants != 0) | |||||
// add backward initial state | |||||
backward_inputs.Concat(new Tensors { inputs[$"{pivot}:"] }); | |||||
else | |||||
{ | |||||
// add backward initial state | |||||
backward_inputs.Concat(new Tensors { inputs[$"{pivot}:{-_num_constants}"] }); | |||||
// add constants for forward and backward layers | |||||
forward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); | |||||
backward_inputs.Concat(new Tensors { inputs[$"{-_num_constants}:"] }); | |||||
} | |||||
forward_state = null; | |||||
backward_state = null; | |||||
} | |||||
else if (state is not null) | |||||
{ | |||||
// initial_states are not keras tensors, eg eager tensor from np | |||||
// array. They are only passed in from kwarg initial_state, and | |||||
// should be passed to forward/backward layer via kwarg | |||||
// initial_state as well. | |||||
forward_inputs = inputs; | |||||
backward_inputs = inputs; | |||||
var half = len(state) / 2; | |||||
forward_state = state[$":{half}"]; | |||||
backward_state = state[$"{half}:"]; | |||||
} | |||||
else | |||||
{ | |||||
forward_inputs = inputs; | |||||
backward_inputs = inputs; | |||||
forward_state = null; | |||||
backward_state = null; | |||||
} | |||||
var y = _forward_layer.Apply(forward_inputs, forward_state); | |||||
var y_rev = _backward_layer.Apply(backward_inputs, backward_state); | |||||
Tensors states = new(); | |||||
if (_return_state) | |||||
{ | |||||
states = y["1:"] + y_rev["1:"]; | |||||
y = y[0]; | |||||
y_rev = y_rev[0]; | |||||
} | |||||
if (_return_sequences) | |||||
{ | |||||
int time_dim = _forward_layer.Args.TimeMajor ? 0 : 1; | |||||
y_rev = keras.backend.reverse(y_rev, time_dim); | |||||
} | |||||
Tensors output; | |||||
if (_args.MergeMode == "concat") | |||||
output = keras.backend.concatenate(new Tensors { y.Single(), y_rev.Single() }); | |||||
else if (_args.MergeMode == "sum") | |||||
output = y.Single() + y_rev.Single(); | |||||
else if (_args.MergeMode == "ave") | |||||
output = (y.Single() + y_rev.Single()) / 2; | |||||
else if (_args.MergeMode == "mul") | |||||
output = y.Single() * y_rev.Single(); | |||||
else if (_args.MergeMode is null) | |||||
output = new Tensors { y.Single(), y_rev.Single() }; | |||||
else | |||||
throw new ValueError( | |||||
"Unrecognized value for `merge_mode`. " + | |||||
$"Received: {_args.MergeMode}" + | |||||
"Expected values are [\"concat\", \"sum\", \"ave\", \"mul\"]"); | |||||
if (_return_state) | |||||
{ | |||||
if (_args.MergeMode is not null) | |||||
return new Tensors { output.Single(), states.Single()}; | |||||
} | |||||
return output; | |||||
} | |||||
} | |||||
} |
@@ -3,6 +3,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Common.Extensions; | using Tensorflow.Common.Extensions; | ||||
using Tensorflow.Keras.Saving; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
@@ -14,15 +15,15 @@ namespace Tensorflow.Keras.Layers | |||||
/// </summary> | /// </summary> | ||||
public class LSTM : RNN | public class LSTM : RNN | ||||
{ | { | ||||
LSTMArgs args; | |||||
LSTMArgs _args; | |||||
InputSpec[] _state_spec; | InputSpec[] _state_spec; | ||||
InputSpec _input_spec; | InputSpec _input_spec; | ||||
bool _could_use_gpu_kernel; | bool _could_use_gpu_kernel; | ||||
public LSTMArgs Args { get => _args; } | |||||
public LSTM(LSTMArgs args) : | public LSTM(LSTMArgs args) : | ||||
base(CreateCell(args), args) | base(CreateCell(args), args) | ||||
{ | { | ||||
this.args = args; | |||||
_args = args; | |||||
_input_spec = new InputSpec(ndim: 3); | _input_spec = new InputSpec(ndim: 3); | ||||
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); | _state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); | ||||
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh | _could_use_gpu_kernel = args.Activation == keras.activations.Tanh | ||||
@@ -71,7 +72,7 @@ namespace Tensorflow.Keras.Layers | |||||
var single_input = inputs.Single; | var single_input = inputs.Single; | ||||
var input_shape = single_input.shape; | var input_shape = single_input.shape; | ||||
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
var timesteps = _args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
_maybe_reset_cell_dropout_mask(Cell); | _maybe_reset_cell_dropout_mask(Cell); | ||||
@@ -87,26 +88,26 @@ namespace Tensorflow.Keras.Layers | |||||
inputs, | inputs, | ||||
initial_state, | initial_state, | ||||
constants: null, | constants: null, | ||||
go_backwards: args.GoBackwards, | |||||
go_backwards: _args.GoBackwards, | |||||
mask: mask, | mask: mask, | ||||
unroll: args.Unroll, | |||||
unroll: _args.Unroll, | |||||
input_length: ops.convert_to_tensor(timesteps), | input_length: ops.convert_to_tensor(timesteps), | ||||
time_major: args.TimeMajor, | |||||
zero_output_for_mask: args.ZeroOutputForMask, | |||||
return_all_outputs: args.ReturnSequences | |||||
time_major: _args.TimeMajor, | |||||
zero_output_for_mask: _args.ZeroOutputForMask, | |||||
return_all_outputs: _args.ReturnSequences | |||||
); | ); | ||||
Tensor output; | Tensor output; | ||||
if (args.ReturnSequences) | |||||
if (_args.ReturnSequences) | |||||
{ | { | ||||
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards); | |||||
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, _args.GoBackwards); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
output = last_output; | output = last_output; | ||||
} | } | ||||
if (args.ReturnState) | |||||
if (_args.ReturnState) | |||||
{ | { | ||||
return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); | return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); | ||||
} | } | ||||
@@ -115,5 +116,11 @@ namespace Tensorflow.Keras.Layers | |||||
return output; | return output; | ||||
} | } | ||||
} | } | ||||
public override IKerasConfig get_config() | |||||
{ | |||||
return _args; | |||||
} | |||||
} | } | ||||
} | } |
@@ -31,7 +31,9 @@ namespace Tensorflow.Keras.Layers | |||||
protected IVariableV1 _kernel; | protected IVariableV1 _kernel; | ||||
protected IVariableV1 _bias; | protected IVariableV1 _bias; | ||||
private IRnnCell _cell; | private IRnnCell _cell; | ||||
protected IRnnCell Cell | |||||
public RNNArgs Args { get => _args; } | |||||
public IRnnCell Cell | |||||
{ | { | ||||
get | get | ||||
{ | { | ||||
@@ -570,10 +572,13 @@ namespace Tensorflow.Keras.Layers | |||||
var input_shape = array_ops.shape(inputs); | var input_shape = array_ops.shape(inputs); | ||||
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; | ||||
var dtype = input.dtype; | var dtype = input.dtype; | ||||
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); | Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); | ||||
return init_state; | return init_state; | ||||
} | } | ||||
public override IKerasConfig get_config() | |||||
{ | |||||
return _args; | |||||
} | |||||
} | } | ||||
} | } |
@@ -5,6 +5,7 @@ using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using System.Threading.Tasks; | using System.Threading.Tasks; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
@@ -38,8 +39,6 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; | var cells = new IRnnCell[] { tf.keras.layers.SimpleRNNCell(4), tf.keras.layers.SimpleRNNCell(5) }; | ||||
var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); | var stackedRNNCell = tf.keras.layers.StackedRNNCells(cells); | ||||
var (output, state) = stackedRNNCell.Apply(inputs, states); | var (output, state) = stackedRNNCell.Apply(inputs, states); | ||||
Console.WriteLine(output); | |||||
Console.WriteLine(state.shape); | |||||
Assert.AreEqual((32, 5), output.shape); | Assert.AreEqual((32, 5), output.shape); | ||||
Assert.AreEqual((32, 4), state[0].shape); | Assert.AreEqual((32, 4), state[0].shape); | ||||
} | } | ||||
@@ -108,6 +107,7 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
var inputs = tf.random.normal((32, 10, 8)); | var inputs = tf.random.normal((32, 10, 8)); | ||||
var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); | var cell = tf.keras.layers.SimpleRNNCell(10, dropout: 0.5f, recurrent_dropout: 0.5f); | ||||
var rnn = tf.keras.layers.RNN(cell: cell); | var rnn = tf.keras.layers.RNN(cell: cell); | ||||
var cgf = rnn.get_config(); | |||||
var output = rnn.Apply(inputs); | var output = rnn.Apply(inputs); | ||||
Assert.AreEqual((32, 10), output.shape); | Assert.AreEqual((32, 10), output.shape); | ||||
@@ -145,5 +145,14 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
Assert.AreEqual((32, 4), output.shape); | Assert.AreEqual((32, 4), output.shape); | ||||
} | } | ||||
[TestMethod] | |||||
public void Bidirectional() | |||||
{ | |||||
var bi = tf.keras.layers.Bidirectional(keras.layers.LSTM(10, return_sequences:true)); | |||||
var inputs = tf.random.normal((32, 10, 8)); | |||||
var outputs = bi.Apply(inputs); | |||||
Assert.AreEqual((32, 10, 20), outputs.shape); | |||||
} | |||||
} | } | ||||
} | } |