@@ -251,7 +251,7 @@ namespace Tensorflow | |||
/// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>. | |||
/// </remarks> | |||
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") | |||
=> gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | |||
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | |||
public Tensor sub(Tensor a, Tensor b) | |||
=> gen_math_ops.sub(a, b); | |||
@@ -24,6 +24,16 @@ namespace Tensorflow.Framework | |||
} | |||
} | |||
public static Dimension dimension_at_index(TensorShape shape, int index) | |||
{ | |||
return shape.rank < 0 ? | |||
new Dimension(-1) : | |||
new Dimension(shape.dims[index]); | |||
} | |||
public static int dimension_value(Dimension dimension) | |||
=> dimension.value; | |||
public static TensorShape as_shape(this Shape shape) | |||
=> new TensorShape(shape.Dimensions); | |||
} | |||
@@ -0,0 +1,57 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations.Activation; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Operations; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Basic LSTM recurrent network cell. | |||
/// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
/// </summary> | |||
public class BasicLSTMCell : LayerRnnCell | |||
{ | |||
int _num_units; | |||
float _forget_bias; | |||
bool _state_is_tuple; | |||
IActivation _activation; | |||
/// <summary> | |||
/// Initialize the basic LSTM cell. | |||
/// </summary> | |||
/// <param name="num_units">The number of units in the LSTM cell.</param> | |||
/// <param name="forget_bias"></param> | |||
/// <param name="state_is_tuple"></param> | |||
/// <param name="activation"></param> | |||
/// <param name="reuse"></param> | |||
/// <param name="name"></param> | |||
/// <param name="dtype"></param> | |||
public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, | |||
IActivation activation = null, bool? reuse = null, string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) | |||
{ | |||
input_spec = new InputSpec(ndim: 2); | |||
_num_units = num_units; | |||
_forget_bias = forget_bias; | |||
_state_is_tuple = state_is_tuple; | |||
_activation = activation; | |||
if (_activation == null) | |||
_activation = tf.nn.tanh(); | |||
} | |||
public LSTMStateTuple state_size | |||
{ | |||
get | |||
{ | |||
return _state_is_tuple ? | |||
new LSTMStateTuple(_num_units, _num_units) : | |||
(LSTMStateTuple)(2 * _num_units); | |||
} | |||
} | |||
} | |||
} |
@@ -16,6 +16,7 @@ | |||
using System; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -25,7 +26,7 @@ namespace Tensorflow | |||
int _num_units; | |||
Func<Tensor, string, Tensor> _activation; | |||
public override int state_size => _num_units; | |||
public override LSTMStateTuple state_size => _num_units; | |||
public override int output_size => _num_units; | |||
public VariableV1 _kernel; | |||
string _WEIGHTS_VARIABLE_NAME = "kernel"; |
@@ -0,0 +1,41 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
{ | |||
/// <summary> | |||
/// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. | |||
/// | |||
/// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state | |||
/// and `h` is the output. | |||
/// | |||
/// Only used when `state_is_tuple=True`. | |||
/// </summary> | |||
public class LSTMStateTuple | |||
{ | |||
int c; | |||
int h; | |||
public LSTMStateTuple(int c) | |||
{ | |||
this.c = c; | |||
} | |||
public LSTMStateTuple(int c, int h) | |||
{ | |||
this.c = c; | |||
this.h = h; | |||
} | |||
public static implicit operator int(LSTMStateTuple tuple) | |||
{ | |||
return tuple.c; | |||
} | |||
public static implicit operator LSTMStateTuple(int c) | |||
{ | |||
return new LSTMStateTuple(c); | |||
} | |||
} | |||
} |
@@ -49,7 +49,7 @@ namespace Tensorflow | |||
/// difference between TF and Keras RNN cell. | |||
/// </summary> | |||
protected bool _is_tf_rnn_cell = false; | |||
public virtual int state_size { get; } | |||
public virtual LSTMStateTuple state_size { get; } | |||
public virtual int output_size { get; } | |||
@@ -18,13 +18,106 @@ using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
internal class rnn | |||
public class rnn | |||
{ | |||
/// <summary> | |||
/// Creates a bidirectional recurrent neural network. | |||
/// </summary> | |||
public static void static_bidirectional_rnn(BasicLSTMCell cell_fw, | |||
BasicLSTMCell cell_bw, | |||
Tensor[] inputs, | |||
Tensor initial_state_fw = null, | |||
Tensor initial_state_bw = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
Tensor sequence_length = null, | |||
string scope = null) | |||
{ | |||
if (inputs == null || inputs.Length == 0) | |||
throw new ValueError("inputs must not be empty"); | |||
tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate | |||
{ | |||
// Forward direction | |||
tf_with(tf.variable_scope("fw"), fw_scope => | |||
{ | |||
static_rnn( | |||
cell_fw, | |||
inputs, | |||
initial_state_fw, | |||
dtype, | |||
sequence_length, | |||
scope: fw_scope); | |||
}); | |||
}); | |||
} | |||
public static void static_rnn(BasicLSTMCell cell, | |||
Tensor[] inputs, | |||
Tensor initial_state, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
Tensor sequence_length = null, | |||
VariableScope scope = null) | |||
{ | |||
// Create a new scope in which the caching device is either | |||
// determined by the parent scope, or is set to place the cached | |||
// Variable using the same placement as for the rest of the RNN. | |||
if (scope == null) | |||
tf_with(tf.variable_scope("rnn"), varscope => | |||
{ | |||
throw new NotImplementedException("static_rnn"); | |||
}); | |||
else | |||
tf_with(tf.variable_scope(scope), varscope => | |||
{ | |||
Dimension fixed_batch_size = null; | |||
Dimension batch_size = null; | |||
Tensor batch_size_tensor = null; | |||
// Obtain the first sequence of the input | |||
var first_input = inputs[0]; | |||
if (first_input.TensorShape.rank != 1) | |||
{ | |||
var input_shape = first_input.TensorShape.with_rank_at_least(2); | |||
fixed_batch_size = input_shape.dims[0]; | |||
var flat_inputs = nest.flatten2(inputs); | |||
foreach (var flat_input in flat_inputs) | |||
{ | |||
input_shape = flat_input.TensorShape.with_rank_at_least(2); | |||
batch_size = tensor_shape.dimension_at_index(input_shape, 0); | |||
var input_size = input_shape[1]; | |||
fixed_batch_size.merge_with(batch_size); | |||
foreach (var (i, size) in enumerate(input_size.dims)) | |||
{ | |||
if (size < 0) | |||
throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " + | |||
"shape inference, but saw value None."); | |||
} | |||
} | |||
} | |||
else | |||
fixed_batch_size = first_input.TensorShape.with_rank_at_least(1).dims[0]; | |||
if (tensor_shape.dimension_value(fixed_batch_size) >= 0) | |||
batch_size = tensor_shape.dimension_value(fixed_batch_size); | |||
else | |||
batch_size_tensor = array_ops.shape(first_input)[0]; | |||
Tensor state = null; | |||
if (initial_state != null) | |||
state = initial_state; | |||
else | |||
{ | |||
cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); | |||
} | |||
}); | |||
} | |||
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | |||
Tensor sequence_length = null, Tensor initial_state = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
@@ -0,0 +1,45 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class clip_ops | |||
{ | |||
public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate | |||
{ | |||
var values = ops.convert_to_tensor(t, name: "t"); | |||
// Go through list of tensors, for each value in each tensor clip | |||
var t_min = math_ops.minimum(values, clip_value_max); | |||
// Assert that the shape is compatible with the initial shape, | |||
// to prevent unintentional broadcasting. | |||
_ = values.TensorShape.merge_with(t_min.shape); | |||
var t_max = math_ops.maximum(t_min, clip_value_min, name: name); | |||
_ = values.TensorShape.merge_with(t_max.shape); | |||
return t_max; | |||
}); | |||
} | |||
} | |||
} |
@@ -1,7 +1,7 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFrameworks>net472;netstandard2.0</TargetFrameworks> | |||
<TargetFramework>netstandard2.0</TargetFramework> | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>1.14.1</TargetTensorFlow> | |||
@@ -22,6 +22,12 @@ namespace Tensorflow | |||
return new Dimension(_value); | |||
} | |||
public static implicit operator Dimension(int value) | |||
=> new Dimension(value); | |||
public static implicit operator int(Dimension dimension) | |||
=> dimension.value; | |||
public override string ToString() => $"Dimension({_value})"; | |||
} | |||
} |
@@ -162,9 +162,9 @@ namespace Tensorflow | |||
using (var status = new Status()) | |||
{ | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status); | |||
else | |||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||
status.Check(true); | |||
} | |||