Browse Source

BasicLSTMCell

tags/v0.13
Oceania2018 5 years ago
parent
commit
a70077bbb4
12 changed files with 260 additions and 7 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Framework/tensor_shape.cs
  3. +57
    -0
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  5. +41
    -0
      src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs
  6. +0
    -0
      src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  8. +94
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  9. +45
    -0
      src/TensorFlowNET.Core/Operations/clip_ops.cs
  10. +1
    -1
      src/TensorFlowNET.Core/TensorFlow.Binding.csproj
  11. +6
    -0
      src/TensorFlowNET.Core/Tensors/Dimension.cs
  12. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -251,7 +251,7 @@ namespace Tensorflow
/// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>.
/// </remarks>
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue")
=> gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name);
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name);

public Tensor sub(Tensor a, Tensor b)
=> gen_math_ops.sub(a, b);


+ 10
- 0
src/TensorFlowNET.Core/Framework/tensor_shape.cs View File

@@ -24,6 +24,16 @@ namespace Tensorflow.Framework
}
}

public static Dimension dimension_at_index(TensorShape shape, int index)
{
return shape.rank < 0 ?
new Dimension(-1) :
new Dimension(shape.dims[index]);
}

public static int dimension_value(Dimension dimension)
=> dimension.value;

public static TensorShape as_shape(this Shape shape)
=> new TensorShape(shape.Dimensions);
}


+ 57
- 0
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -0,0 +1,57 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;
using Tensorflow.Operations.Activation;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;

namespace Tensorflow
{
/// <summary>
/// Basic LSTM recurrent network cell.
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
/// </summary>
public class BasicLSTMCell : LayerRnnCell
{
int _num_units;
float _forget_bias;
bool _state_is_tuple;
IActivation _activation;

/// <summary>
/// Initialize the basic LSTM cell.
/// </summary>
/// <param name="num_units">The number of units in the LSTM cell.</param>
/// <param name="forget_bias"></param>
/// <param name="state_is_tuple"></param>
/// <param name="activation"></param>
/// <param name="reuse"></param>
/// <param name="name"></param>
/// <param name="dtype"></param>
public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true,
IActivation activation = null, bool? reuse = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype)
{
input_spec = new InputSpec(ndim: 2);
_num_units = num_units;
_forget_bias = forget_bias;
_state_is_tuple = state_is_tuple;
_activation = activation;
if (_activation == null)
_activation = tf.nn.tanh();
}

public LSTMStateTuple state_size
{
get
{
return _state_is_tuple ?
new LSTMStateTuple(_num_units, _num_units) :
(LSTMStateTuple)(2 * _num_units);
}
}
}
}

src/TensorFlowNET.Core/Operations/BasicRNNCell.cs → src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -16,6 +16,7 @@

using System;
using Tensorflow.Keras.Engine;
using Tensorflow.Operations;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -25,7 +26,7 @@ namespace Tensorflow
int _num_units;
Func<Tensor, string, Tensor> _activation;

public override int state_size => _num_units;
public override LSTMStateTuple state_size => _num_units;
public override int output_size => _num_units;
public VariableV1 _kernel;
string _WEIGHTS_VARIABLE_NAME = "kernel";

+ 41
- 0
src/TensorFlowNET.Core/Operations/NnOps/LSTMStateTuple.cs View File

@@ -0,0 +1,41 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations
{
/// <summary>
/// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state.
///
/// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state
/// and `h` is the output.
///
/// Only used when `state_is_tuple=True`.
/// </summary>
public class LSTMStateTuple
{
int c;
int h;

public LSTMStateTuple(int c)
{
this.c = c;
}

public LSTMStateTuple(int c, int h)
{
this.c = c;
this.h = h;
}

public static implicit operator int(LSTMStateTuple tuple)
{
return tuple.c;
}

public static implicit operator LSTMStateTuple(int c)
{
return new LSTMStateTuple(c);
}
}
}

src/TensorFlowNET.Core/Operations/LayerRNNCell.cs → src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs View File


src/TensorFlowNET.Core/Operations/RNNCell.cs → src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -49,7 +49,7 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell.
/// </summary>
protected bool _is_tf_rnn_cell = false;
public virtual int state_size { get; }
public virtual LSTMStateTuple state_size { get; }

public virtual int output_size { get; }


+ 94
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -18,13 +18,106 @@ using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
internal class rnn
public class rnn
{
/// <summary>
/// Creates a bidirectional recurrent neural network.
/// </summary>
public static void static_bidirectional_rnn(BasicLSTMCell cell_fw,
BasicLSTMCell cell_bw,
Tensor[] inputs,
Tensor initial_state_fw = null,
Tensor initial_state_bw = null,
TF_DataType dtype = TF_DataType.DtInvalid,
Tensor sequence_length = null,
string scope = null)
{
if (inputs == null || inputs.Length == 0)
throw new ValueError("inputs must not be empty");

tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate
{
// Forward direction
tf_with(tf.variable_scope("fw"), fw_scope =>
{
static_rnn(
cell_fw,
inputs,
initial_state_fw,
dtype,
sequence_length,
scope: fw_scope);
});
});
}

public static void static_rnn(BasicLSTMCell cell,
Tensor[] inputs,
Tensor initial_state,
TF_DataType dtype = TF_DataType.DtInvalid,
Tensor sequence_length = null,
VariableScope scope = null)
{
// Create a new scope in which the caching device is either
// determined by the parent scope, or is set to place the cached
// Variable using the same placement as for the rest of the RNN.
if (scope == null)
tf_with(tf.variable_scope("rnn"), varscope =>
{
throw new NotImplementedException("static_rnn");
});
else
tf_with(tf.variable_scope(scope), varscope =>
{
Dimension fixed_batch_size = null;
Dimension batch_size = null;
Tensor batch_size_tensor = null;

// Obtain the first sequence of the input
var first_input = inputs[0];
if (first_input.TensorShape.rank != 1)
{
var input_shape = first_input.TensorShape.with_rank_at_least(2);
fixed_batch_size = input_shape.dims[0];
var flat_inputs = nest.flatten2(inputs);
foreach (var flat_input in flat_inputs)
{
input_shape = flat_input.TensorShape.with_rank_at_least(2);
batch_size = tensor_shape.dimension_at_index(input_shape, 0);
var input_size = input_shape[1];
fixed_batch_size.merge_with(batch_size);
foreach (var (i, size) in enumerate(input_size.dims))
{
if (size < 0)
throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " +
"shape inference, but saw value None.");
}
}
}
else
fixed_batch_size = first_input.TensorShape.with_rank_at_least(1).dims[0];

if (tensor_shape.dimension_value(fixed_batch_size) >= 0)
batch_size = tensor_shape.dimension_value(fixed_batch_size);
else
batch_size_tensor = array_ops.shape(first_input)[0];

Tensor state = null;
if (initial_state != null)
state = initial_state;
else
{
cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype);
}
});
}

public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor,
Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid,


+ 45
- 0
src/TensorFlowNET.Core/Operations/clip_ops.cs View File

@@ -0,0 +1,45 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class clip_ops
{
public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null)
{
return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate
{
var values = ops.convert_to_tensor(t, name: "t");
// Go through list of tensors, for each value in each tensor clip
var t_min = math_ops.minimum(values, clip_value_max);
// Assert that the shape is compatible with the initial shape,
// to prevent unintentional broadcasting.
_ = values.TensorShape.merge_with(t_min.shape);
var t_max = math_ops.maximum(t_min, clip_value_min, name: name);
_ = values.TensorShape.merge_with(t_max.shape);

return t_max;
});
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/TensorFlow.Binding.csproj View File

@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net472;netstandard2.0</TargetFrameworks>
<TargetFramework>netstandard2.0</TargetFramework>
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>1.14.1</TargetTensorFlow>


+ 6
- 0
src/TensorFlowNET.Core/Tensors/Dimension.cs View File

@@ -22,6 +22,12 @@ namespace Tensorflow
return new Dimension(_value);
}

public static implicit operator Dimension(int value)
=> new Dimension(value);

public static implicit operator int(Dimension dimension)
=> dimension.value;

public override string ToString() => $"Dimension({_value})";
}
}

+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -162,9 +162,9 @@ namespace Tensorflow
using (var status = new Status())
{
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);

status.Check(true);
}


Loading…
Cancel
Save