From a70077bbb479c24fdb27f9ae5d2131ad0d3782b6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 23 Nov 2019 09:21:12 -0600 Subject: [PATCH] BasicLSTMCell --- src/TensorFlowNET.Core/APIs/tf.math.cs | 2 +- .../Framework/tensor_shape.cs | 10 ++ .../Operations/NnOps/BasicLSTMCell.cs | 57 +++++++++++ .../Operations/{ => NnOps}/BasicRNNCell.cs | 3 +- .../Operations/NnOps/LSTMStateTuple.cs | 41 ++++++++ .../Operations/{ => NnOps}/LayerRNNCell.cs | 0 .../Operations/{ => NnOps}/RNNCell.cs | 2 +- .../Operations/NnOps/rnn.cs | 95 ++++++++++++++++++- src/TensorFlowNET.Core/Operations/clip_ops.cs | 45 +++++++++ .../TensorFlow.Binding.csproj | 2 +- src/TensorFlowNET.Core/Tensors/Dimension.cs | 6 ++ src/TensorFlowNET.Core/Tensors/Tensor.cs | 4 +- 12 files changed, 260 insertions(+), 7 deletions(-) create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs rename src/TensorFlowNET.Core/Operations/{ => NnOps}/BasicRNNCell.cs (96%) create mode 100644 src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs rename src/TensorFlowNET.Core/Operations/{ => NnOps}/LayerRNNCell.cs (100%) rename src/TensorFlowNET.Core/Operations/{ => NnOps}/RNNCell.cs (98%) create mode 100644 src/TensorFlowNET.Core/Operations/clip_ops.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 790e391e..9f2b493c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -251,7 +251,7 @@ namespace Tensorflow /// greater than clip_value_max are set to clip_value_max. /// public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") - => gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name); + => clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); public Tensor sub(Tensor a, Tensor b) => gen_math_ops.sub(a, b); diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs index d4e2f6cd..06d80972 100644 --- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs +++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs @@ -24,6 +24,16 @@ namespace Tensorflow.Framework } } + public static Dimension dimension_at_index(TensorShape shape, int index) + { + return shape.rank < 0 ? + new Dimension(-1) : + new Dimension(shape.dims[index]); + } + + public static int dimension_value(Dimension dimension) + => dimension.value; + public static TensorShape as_shape(this Shape shape) => new TensorShape(shape.Dimensions); } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs new file mode 100644 index 00000000..ab19a271 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static Tensorflow.Binding; +using Tensorflow.Operations.Activation; +using Tensorflow.Keras.Engine; +using Tensorflow.Operations; + +namespace Tensorflow +{ + /// + /// Basic LSTM recurrent network cell. + /// The implementation is based on: http://arxiv.org/abs/1409.2329. + /// + public class BasicLSTMCell : LayerRnnCell + { + int _num_units; + float _forget_bias; + bool _state_is_tuple; + IActivation _activation; + + /// + /// Initialize the basic LSTM cell. + /// + /// The number of units in the LSTM cell. + /// + /// + /// + /// + /// + /// + public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, + IActivation activation = null, bool? reuse = null, string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) + { + input_spec = new InputSpec(ndim: 2); + _num_units = num_units; + _forget_bias = forget_bias; + _state_is_tuple = state_is_tuple; + _activation = activation; + if (_activation == null) + _activation = tf.nn.tanh(); + } + + public LSTMStateTuple state_size + { + get + { + return _state_is_tuple ? + new LSTMStateTuple(_num_units, _num_units) : + (LSTMStateTuple)(2 * _num_units); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs similarity index 96% rename from src/TensorFlowNET.Core/Operations/BasicRNNCell.cs rename to src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs index 69f86349..da528982 100644 --- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs @@ -16,6 +16,7 @@ using System; using Tensorflow.Keras.Engine; +using Tensorflow.Operations; using static Tensorflow.Binding; namespace Tensorflow @@ -25,7 +26,7 @@ namespace Tensorflow int _num_units; Func _activation; - public override int state_size => _num_units; + public override LSTMStateTuple state_size => _num_units; public override int output_size => _num_units; public VariableV1 _kernel; string _WEIGHTS_VARIABLE_NAME = "kernel"; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs new file mode 100644 index 00000000..7539021b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations +{ + /// + /// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. + /// + /// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state + /// and `h` is the output. + /// + /// Only used when `state_is_tuple=True`. + /// + public class LSTMStateTuple + { + int c; + int h; + + public LSTMStateTuple(int c) + { + this.c = c; + } + + public LSTMStateTuple(int c, int h) + { + this.c = c; + this.h = h; + } + + public static implicit operator int(LSTMStateTuple tuple) + { + return tuple.c; + } + + public static implicit operator LSTMStateTuple(int c) + { + return new LSTMStateTuple(c); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/LayerRNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs similarity index 100% rename from src/TensorFlowNET.Core/Operations/LayerRNNCell.cs rename to src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs similarity index 98% rename from src/TensorFlowNET.Core/Operations/RNNCell.cs rename to src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 9902cd41..4d277082 100644 --- a/src/TensorFlowNET.Core/Operations/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -49,7 +49,7 @@ namespace Tensorflow /// difference between TF and Keras RNN cell. /// protected bool _is_tf_rnn_cell = false; - public virtual int state_size { get; } + public virtual LSTMStateTuple state_size { get; } public virtual int output_size { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 48af7d58..a71d035a 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -18,13 +18,106 @@ using NumSharp; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Framework; using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Operations { - internal class rnn + public class rnn { + /// + /// Creates a bidirectional recurrent neural network. + /// + public static void static_bidirectional_rnn(BasicLSTMCell cell_fw, + BasicLSTMCell cell_bw, + Tensor[] inputs, + Tensor initial_state_fw = null, + Tensor initial_state_bw = null, + TF_DataType dtype = TF_DataType.DtInvalid, + Tensor sequence_length = null, + string scope = null) + { + if (inputs == null || inputs.Length == 0) + throw new ValueError("inputs must not be empty"); + + tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate + { + // Forward direction + tf_with(tf.variable_scope("fw"), fw_scope => + { + static_rnn( + cell_fw, + inputs, + initial_state_fw, + dtype, + sequence_length, + scope: fw_scope); + }); + }); + } + + public static void static_rnn(BasicLSTMCell cell, + Tensor[] inputs, + Tensor initial_state, + TF_DataType dtype = TF_DataType.DtInvalid, + Tensor sequence_length = null, + VariableScope scope = null) + { + // Create a new scope in which the caching device is either + // determined by the parent scope, or is set to place the cached + // Variable using the same placement as for the rest of the RNN. + if (scope == null) + tf_with(tf.variable_scope("rnn"), varscope => + { + throw new NotImplementedException("static_rnn"); + }); + else + tf_with(tf.variable_scope(scope), varscope => + { + Dimension fixed_batch_size = null; + Dimension batch_size = null; + Tensor batch_size_tensor = null; + + // Obtain the first sequence of the input + var first_input = inputs[0]; + if (first_input.TensorShape.rank != 1) + { + var input_shape = first_input.TensorShape.with_rank_at_least(2); + fixed_batch_size = input_shape.dims[0]; + var flat_inputs = nest.flatten2(inputs); + foreach (var flat_input in flat_inputs) + { + input_shape = flat_input.TensorShape.with_rank_at_least(2); + batch_size = tensor_shape.dimension_at_index(input_shape, 0); + var input_size = input_shape[1]; + fixed_batch_size.merge_with(batch_size); + foreach (var (i, size) in enumerate(input_size.dims)) + { + if (size < 0) + throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " + + "shape inference, but saw value None."); + } + } + } + else + fixed_batch_size = first_input.TensorShape.with_rank_at_least(1).dims[0]; + + if (tensor_shape.dimension_value(fixed_batch_size) >= 0) + batch_size = tensor_shape.dimension_value(fixed_batch_size); + else + batch_size_tensor = array_ops.shape(first_input)[0]; + + Tensor state = null; + if (initial_state != null) + state = initial_state; + else + { + cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); + } + }); + } + public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, Tensor sequence_length = null, Tensor initial_state = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Operations/clip_ops.cs b/src/TensorFlowNET.Core/Operations/clip_ops.cs new file mode 100644 index 00000000..701664f4 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/clip_ops.cs @@ -0,0 +1,45 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class clip_ops + { + public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) + { + return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate + { + var values = ops.convert_to_tensor(t, name: "t"); + // Go through list of tensors, for each value in each tensor clip + var t_min = math_ops.minimum(values, clip_value_max); + // Assert that the shape is compatible with the initial shape, + // to prevent unintentional broadcasting. + _ = values.TensorShape.merge_with(t_min.shape); + var t_max = math_ops.maximum(t_min, clip_value_min, name: name); + _ = values.TensorShape.merge_with(t_max.shape); + + return t_max; + }); + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 39279808..bf508a78 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -1,7 +1,7 @@  - net472;netstandard2.0 + netstandard2.0 TensorFlow.NET Tensorflow 1.14.1 diff --git a/src/TensorFlowNET.Core/Tensors/Dimension.cs b/src/TensorFlowNET.Core/Tensors/Dimension.cs index 58520270..878ba5ae 100644 --- a/src/TensorFlowNET.Core/Tensors/Dimension.cs +++ b/src/TensorFlowNET.Core/Tensors/Dimension.cs @@ -22,6 +22,12 @@ namespace Tensorflow return new Dimension(_value); } + public static implicit operator Dimension(int value) + => new Dimension(value); + + public static implicit operator int(Dimension dimension) + => dimension.value; + public override string ToString() => $"Dimension({_value})"; } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 67474eb9..99fba404 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -162,9 +162,9 @@ namespace Tensorflow using (var status = new Status()) { if (value == null) - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status); else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); + c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); status.Check(true); }