@@ -251,7 +251,7 @@ namespace Tensorflow | |||||
/// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>. | /// greater than <c>clip_value_max</c> are set to <c>clip_value_max</c>. | ||||
/// </remarks> | /// </remarks> | ||||
public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") | public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") | ||||
=> gen_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | |||||
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | |||||
public Tensor sub(Tensor a, Tensor b) | public Tensor sub(Tensor a, Tensor b) | ||||
=> gen_math_ops.sub(a, b); | => gen_math_ops.sub(a, b); | ||||
@@ -24,6 +24,16 @@ namespace Tensorflow.Framework | |||||
} | } | ||||
} | } | ||||
public static Dimension dimension_at_index(TensorShape shape, int index) | |||||
{ | |||||
return shape.rank < 0 ? | |||||
new Dimension(-1) : | |||||
new Dimension(shape.dims[index]); | |||||
} | |||||
public static int dimension_value(Dimension dimension) | |||||
=> dimension.value; | |||||
public static TensorShape as_shape(this Shape shape) | public static TensorShape as_shape(this Shape shape) | ||||
=> new TensorShape(shape.Dimensions); | => new TensorShape(shape.Dimensions); | ||||
} | } | ||||
@@ -0,0 +1,57 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow.Operations.Activation; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | |||||
{ | |||||
/// <summary> | |||||
/// Basic LSTM recurrent network cell. | |||||
/// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||||
/// </summary> | |||||
public class BasicLSTMCell : LayerRnnCell | |||||
{ | |||||
int _num_units; | |||||
float _forget_bias; | |||||
bool _state_is_tuple; | |||||
IActivation _activation; | |||||
/// <summary> | |||||
/// Initialize the basic LSTM cell. | |||||
/// </summary> | |||||
/// <param name="num_units">The number of units in the LSTM cell.</param> | |||||
/// <param name="forget_bias"></param> | |||||
/// <param name="state_is_tuple"></param> | |||||
/// <param name="activation"></param> | |||||
/// <param name="reuse"></param> | |||||
/// <param name="name"></param> | |||||
/// <param name="dtype"></param> | |||||
public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, | |||||
IActivation activation = null, bool? reuse = null, string name = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) | |||||
{ | |||||
input_spec = new InputSpec(ndim: 2); | |||||
_num_units = num_units; | |||||
_forget_bias = forget_bias; | |||||
_state_is_tuple = state_is_tuple; | |||||
_activation = activation; | |||||
if (_activation == null) | |||||
_activation = tf.nn.tanh(); | |||||
} | |||||
public LSTMStateTuple state_size | |||||
{ | |||||
get | |||||
{ | |||||
return _state_is_tuple ? | |||||
new LSTMStateTuple(_num_units, _num_units) : | |||||
(LSTMStateTuple)(2 * _num_units); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Operations; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -25,7 +26,7 @@ namespace Tensorflow | |||||
int _num_units; | int _num_units; | ||||
Func<Tensor, string, Tensor> _activation; | Func<Tensor, string, Tensor> _activation; | ||||
public override int state_size => _num_units; | |||||
public override LSTMStateTuple state_size => _num_units; | |||||
public override int output_size => _num_units; | public override int output_size => _num_units; | ||||
public VariableV1 _kernel; | public VariableV1 _kernel; | ||||
string _WEIGHTS_VARIABLE_NAME = "kernel"; | string _WEIGHTS_VARIABLE_NAME = "kernel"; |
@@ -0,0 +1,41 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
/// <summary> | |||||
/// Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. | |||||
/// | |||||
/// Stores two elements: `(c, h)`, in that order. Where `c` is the hidden state | |||||
/// and `h` is the output. | |||||
/// | |||||
/// Only used when `state_is_tuple=True`. | |||||
/// </summary> | |||||
public class LSTMStateTuple | |||||
{ | |||||
int c; | |||||
int h; | |||||
public LSTMStateTuple(int c) | |||||
{ | |||||
this.c = c; | |||||
} | |||||
public LSTMStateTuple(int c, int h) | |||||
{ | |||||
this.c = c; | |||||
this.h = h; | |||||
} | |||||
public static implicit operator int(LSTMStateTuple tuple) | |||||
{ | |||||
return tuple.c; | |||||
} | |||||
public static implicit operator LSTMStateTuple(int c) | |||||
{ | |||||
return new LSTMStateTuple(c); | |||||
} | |||||
} | |||||
} |
@@ -49,7 +49,7 @@ namespace Tensorflow | |||||
/// difference between TF and Keras RNN cell. | /// difference between TF and Keras RNN cell. | ||||
/// </summary> | /// </summary> | ||||
protected bool _is_tf_rnn_cell = false; | protected bool _is_tf_rnn_cell = false; | ||||
public virtual int state_size { get; } | |||||
public virtual LSTMStateTuple state_size { get; } | |||||
public virtual int output_size { get; } | public virtual int output_size { get; } | ||||
@@ -18,13 +18,106 @@ using NumSharp; | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Framework; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Operations | namespace Tensorflow.Operations | ||||
{ | { | ||||
internal class rnn | |||||
public class rnn | |||||
{ | { | ||||
/// <summary> | |||||
/// Creates a bidirectional recurrent neural network. | |||||
/// </summary> | |||||
public static void static_bidirectional_rnn(BasicLSTMCell cell_fw, | |||||
BasicLSTMCell cell_bw, | |||||
Tensor[] inputs, | |||||
Tensor initial_state_fw = null, | |||||
Tensor initial_state_bw = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
Tensor sequence_length = null, | |||||
string scope = null) | |||||
{ | |||||
if (inputs == null || inputs.Length == 0) | |||||
throw new ValueError("inputs must not be empty"); | |||||
tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate | |||||
{ | |||||
// Forward direction | |||||
tf_with(tf.variable_scope("fw"), fw_scope => | |||||
{ | |||||
static_rnn( | |||||
cell_fw, | |||||
inputs, | |||||
initial_state_fw, | |||||
dtype, | |||||
sequence_length, | |||||
scope: fw_scope); | |||||
}); | |||||
}); | |||||
} | |||||
public static void static_rnn(BasicLSTMCell cell, | |||||
Tensor[] inputs, | |||||
Tensor initial_state, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
Tensor sequence_length = null, | |||||
VariableScope scope = null) | |||||
{ | |||||
// Create a new scope in which the caching device is either | |||||
// determined by the parent scope, or is set to place the cached | |||||
// Variable using the same placement as for the rest of the RNN. | |||||
if (scope == null) | |||||
tf_with(tf.variable_scope("rnn"), varscope => | |||||
{ | |||||
throw new NotImplementedException("static_rnn"); | |||||
}); | |||||
else | |||||
tf_with(tf.variable_scope(scope), varscope => | |||||
{ | |||||
Dimension fixed_batch_size = null; | |||||
Dimension batch_size = null; | |||||
Tensor batch_size_tensor = null; | |||||
// Obtain the first sequence of the input | |||||
var first_input = inputs[0]; | |||||
if (first_input.TensorShape.rank != 1) | |||||
{ | |||||
var input_shape = first_input.TensorShape.with_rank_at_least(2); | |||||
fixed_batch_size = input_shape.dims[0]; | |||||
var flat_inputs = nest.flatten2(inputs); | |||||
foreach (var flat_input in flat_inputs) | |||||
{ | |||||
input_shape = flat_input.TensorShape.with_rank_at_least(2); | |||||
batch_size = tensor_shape.dimension_at_index(input_shape, 0); | |||||
var input_size = input_shape[1]; | |||||
fixed_batch_size.merge_with(batch_size); | |||||
foreach (var (i, size) in enumerate(input_size.dims)) | |||||
{ | |||||
if (size < 0) | |||||
throw new ValueError($"Input size (dimension {i} of inputs) must be accessible via " + | |||||
"shape inference, but saw value None."); | |||||
} | |||||
} | |||||
} | |||||
else | |||||
fixed_batch_size = first_input.TensorShape.with_rank_at_least(1).dims[0]; | |||||
if (tensor_shape.dimension_value(fixed_batch_size) >= 0) | |||||
batch_size = tensor_shape.dimension_value(fixed_batch_size); | |||||
else | |||||
batch_size_tensor = array_ops.shape(first_input)[0]; | |||||
Tensor state = null; | |||||
if (initial_state != null) | |||||
state = initial_state; | |||||
else | |||||
{ | |||||
cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); | |||||
} | |||||
}); | |||||
} | |||||
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | ||||
Tensor sequence_length = null, Tensor initial_state = null, | Tensor sequence_length = null, Tensor initial_state = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
@@ -0,0 +1,45 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public class clip_ops | |||||
{ | |||||
public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate | |||||
{ | |||||
var values = ops.convert_to_tensor(t, name: "t"); | |||||
// Go through list of tensors, for each value in each tensor clip | |||||
var t_min = math_ops.minimum(values, clip_value_max); | |||||
// Assert that the shape is compatible with the initial shape, | |||||
// to prevent unintentional broadcasting. | |||||
_ = values.TensorShape.merge_with(t_min.shape); | |||||
var t_max = math_ops.maximum(t_min, clip_value_min, name: name); | |||||
_ = values.TensorShape.merge_with(t_max.shape); | |||||
return t_max; | |||||
}); | |||||
} | |||||
} | |||||
} |
@@ -1,7 +1,7 @@ | |||||
<Project Sdk="Microsoft.NET.Sdk"> | <Project Sdk="Microsoft.NET.Sdk"> | ||||
<PropertyGroup> | <PropertyGroup> | ||||
<TargetFrameworks>net472;netstandard2.0</TargetFrameworks> | |||||
<TargetFramework>netstandard2.0</TargetFramework> | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>1.14.1</TargetTensorFlow> | <TargetTensorFlow>1.14.1</TargetTensorFlow> | ||||
@@ -22,6 +22,12 @@ namespace Tensorflow | |||||
return new Dimension(_value); | return new Dimension(_value); | ||||
} | } | ||||
public static implicit operator Dimension(int value) | |||||
=> new Dimension(value); | |||||
public static implicit operator int(Dimension dimension) | |||||
=> dimension.value; | |||||
public override string ToString() => $"Dimension({_value})"; | public override string ToString() => $"Dimension({_value})"; | ||||
} | } | ||||
} | } |
@@ -162,9 +162,9 @@ namespace Tensorflow | |||||
using (var status = new Status()) | using (var status = new Status()) | ||||
{ | { | ||||
if (value == null) | if (value == null) | ||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), null, -1, status); | |||||
else | else | ||||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
c_api.TF_GraphSetTensorShape(graph, _as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||||
status.Check(true); | status.Check(true); | ||||
} | } | ||||