@@ -65,7 +65,9 @@ namespace Tensorflow.Layers | |||
variable_scope scope_context_manager = null; | |||
if (built) | |||
{ | |||
scope_context_manager = tf.variable_scope(_scope, | |||
reuse: true, | |||
auxiliary_name_scope: false); | |||
} | |||
else | |||
{ | |||
@@ -181,7 +183,7 @@ namespace Tensorflow.Layers | |||
return _current_scope.original_name_scope; | |||
} | |||
private void _set_scope(VariableScope scope = null) | |||
protected void _set_scope(VariableScope scope = null) | |||
{ | |||
if (_scope == null) | |||
{ | |||
@@ -14,12 +14,17 @@ namespace Tensorflow | |||
/// Basic LSTM recurrent network cell. | |||
/// The implementation is based on: http://arxiv.org/abs/1409.2329. | |||
/// </summary> | |||
public class BasicLSTMCell : LayerRnnCell | |||
public class BasicLstmCell : LayerRnnCell | |||
{ | |||
int _num_units; | |||
float _forget_bias; | |||
bool _state_is_tuple; | |||
IActivation _activation; | |||
LSTMStateTuple _state; | |||
VariableV1 _kernel; | |||
VariableV1 _bias; | |||
string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
string _BIAS_VARIABLE_NAME = "bias"; | |||
/// <summary> | |||
/// Initialize the basic LSTM cell. | |||
@@ -31,7 +36,7 @@ namespace Tensorflow | |||
/// <param name="reuse"></param> | |||
/// <param name="name"></param> | |||
/// <param name="dtype"></param> | |||
public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, | |||
public BasicLstmCell(int num_units, float forget_bias = 1.0f, bool state_is_tuple = true, | |||
IActivation activation = null, bool? reuse = null, string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) : base(_reuse: reuse, name: name, dtype: dtype) | |||
{ | |||
@@ -44,13 +49,123 @@ namespace Tensorflow | |||
_activation = tf.nn.tanh(); | |||
} | |||
public LSTMStateTuple state_size | |||
protected override void build(TensorShape input_shape) | |||
{ | |||
var input_depth = input_shape.dims.Last(); | |||
var h_depth = _num_units; | |||
_kernel = add_weight(_WEIGHTS_VARIABLE_NAME, | |||
shape: new[] { input_depth + h_depth, 4 * _num_units }); | |||
_bias = add_weight(_BIAS_VARIABLE_NAME, | |||
shape: new[] { 4 * _num_units }, | |||
initializer: tf.zeros_initializer); | |||
built = true; | |||
} | |||
public Tensor[] __call__(Tensor inputs, LSTMStateTuple state) | |||
{ | |||
_state = state; | |||
return base.__call__(inputs); | |||
} | |||
/// <summary> | |||
/// Long short-term memory cell (LSTM). | |||
/// </summary> | |||
/// <param name="inputs"></param> | |||
/// <param name="training"></param> | |||
/// <param name="state"></param> | |||
/// <returns></returns> | |||
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null) | |||
{ | |||
var one = constant_op.constant(1, dtype: dtypes.int32); | |||
// Parameters of gates are concatenated into one multiply for efficiency. | |||
Tensor c = null; | |||
Tensor h = null; | |||
if(_state_is_tuple) | |||
(c, h) = ((Tensor)_state.c, (Tensor)_state.h); | |||
else | |||
{ | |||
// array_ops.split(value: state, num_or_size_splits: 2, axis: one); | |||
throw new NotImplementedException("BasicLstmCell call"); | |||
} | |||
var gate_inputs = math_ops.matmul(array_ops.concat(new[] { inputs, h }, 1), _kernel as RefVariable); | |||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | |||
// i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | |||
var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | |||
// Note that using `add` and `multiply` instead of `+` and `*` gives a | |||
// performance improvement. So using those at the cost of readability. | |||
var new_c = gen_math_ops.add( | |||
math_ops.multiply(c, math_ops.sigmoid(gen_math_ops.add(f, forget_bias_tensor))), | |||
math_ops.multiply(math_ops.sigmoid(i), _activation.Activate(j))); | |||
var new_h = math_ops.multiply(_activation.Activate(new_c), math_ops.sigmoid(o)); | |||
if (_state_is_tuple) | |||
return new[] { new_c, new_h }; | |||
else | |||
return new[] { array_ops.concat(new[] { new_c, new_h }, 1) }; | |||
} | |||
public override object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
if (inputs != null) | |||
throw new NotImplementedException("get_initial_state input is not null"); | |||
return zero_state(batch_size, dtype); | |||
} | |||
/// <summary> | |||
/// Return zero-filled state tensor(s). | |||
/// </summary> | |||
/// <param name="batch_size"></param> | |||
/// <param name="dtype"></param> | |||
/// <returns></returns> | |||
private LSTMStateTuple zero_state(Tensor batch_size, TF_DataType dtype) | |||
{ | |||
LSTMStateTuple output = null; | |||
tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
{ | |||
output = _zero_state_tensors(state_size, batch_size, dtype); | |||
}); | |||
return output; | |||
} | |||
private LSTMStateTuple _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) | |||
{ | |||
if (state_size is LSTMStateTuple state_size_tuple) | |||
{ | |||
var outputs = state_size_tuple.Flatten() | |||
.Select(x => (int)x) | |||
.Select(s => | |||
{ | |||
var c = rnn_cell_impl._concat(batch_size, s); | |||
var size = array_ops.zeros(c, dtype: dtype); | |||
var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||
size.set_shape(c_static); | |||
return size; | |||
}).ToArray(); | |||
return new LSTMStateTuple(outputs[0], outputs[1]); | |||
} | |||
throw new NotImplementedException("_zero_state_tensors"); | |||
} | |||
public override object state_size | |||
{ | |||
get | |||
{ | |||
return _state_is_tuple ? | |||
new LSTMStateTuple(_num_units, _num_units) : | |||
(LSTMStateTuple)(2 * _num_units); | |||
if (_state_is_tuple) | |||
return new LSTMStateTuple(_num_units, _num_units); | |||
else | |||
return 2 * _num_units; | |||
} | |||
} | |||
} | |||
@@ -26,7 +26,7 @@ namespace Tensorflow | |||
int _num_units; | |||
Func<Tensor, string, Tensor> _activation; | |||
public override LSTMStateTuple state_size => _num_units; | |||
public override object state_size => _num_units; | |||
public override int output_size => _num_units; | |||
public VariableV1 _kernel; | |||
string _WEIGHTS_VARIABLE_NAME = "kernel"; | |||
@@ -12,15 +12,10 @@ namespace Tensorflow.Operations | |||
/// | |||
/// Only used when `state_is_tuple=True`. | |||
/// </summary> | |||
public class LSTMStateTuple | |||
public class LSTMStateTuple : ICanBeFlattened | |||
{ | |||
int c; | |||
int h; | |||
public LSTMStateTuple(int c) | |||
{ | |||
this.c = c; | |||
} | |||
public object c; | |||
public object h; | |||
public LSTMStateTuple(int c, int h) | |||
{ | |||
@@ -28,14 +23,13 @@ namespace Tensorflow.Operations | |||
this.h = h; | |||
} | |||
public static implicit operator int(LSTMStateTuple tuple) | |||
public LSTMStateTuple(Tensor c, Tensor h) | |||
{ | |||
return tuple.c; | |||
this.c = c; | |||
this.h = h; | |||
} | |||
public static implicit operator LSTMStateTuple(int c) | |||
{ | |||
return new LSTMStateTuple(c); | |||
} | |||
public object[] Flatten() | |||
=> new[] { c, h }; | |||
} | |||
} |
@@ -49,7 +49,7 @@ namespace Tensorflow | |||
/// difference between TF and Keras RNN cell. | |||
/// </summary> | |||
protected bool _is_tf_rnn_cell = false; | |||
public virtual LSTMStateTuple state_size { get; } | |||
public virtual object state_size { get; } | |||
public virtual int output_size { get; } | |||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||
_is_tf_rnn_cell = true; | |||
} | |||
public virtual Tensor get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
public virtual object get_initial_state(Tensor inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
if (inputs != null) | |||
throw new NotImplementedException("get_initial_state input is not null"); | |||
@@ -78,11 +78,10 @@ namespace Tensorflow | |||
/// <param name="batch_size"></param> | |||
/// <param name="dtype"></param> | |||
/// <returns></returns> | |||
public Tensor zero_state(Tensor batch_size, TF_DataType dtype) | |||
private Tensor zero_state(Tensor batch_size, TF_DataType dtype) | |||
{ | |||
Tensor output = null; | |||
var state_size = this.state_size; | |||
tf_with(ops.name_scope($"{this.GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
tf_with(ops.name_scope($"{GetType().Name}ZeroState", values: new { batch_size }), delegate | |||
{ | |||
output = _zero_state_tensors(state_size, batch_size, dtype); | |||
}); | |||
@@ -90,20 +89,25 @@ namespace Tensorflow | |||
return output; | |||
} | |||
private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype) | |||
private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_DataType dtype) | |||
{ | |||
var output = nest.map_structure(s => | |||
if(state_size is int state_size_int) | |||
{ | |||
var c = rnn_cell_impl._concat(batch_size, s); | |||
var size = array_ops.zeros(c, dtype: dtype); | |||
var output = nest.map_structure(s => | |||
{ | |||
var c = rnn_cell_impl._concat(batch_size, s); | |||
var size = array_ops.zeros(c, dtype: dtype); | |||
var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||
size.set_shape(c_static); | |||
var c_static = rnn_cell_impl._concat(batch_size, s, @static: true); | |||
size.set_shape(c_static); | |||
return size; | |||
}, state_size); | |||
return size; | |||
}, state_size_int); | |||
return output; | |||
return output; | |||
} | |||
throw new NotImplementedException("_zero_state_tensors"); | |||
} | |||
} | |||
} |
@@ -29,8 +29,8 @@ namespace Tensorflow.Operations | |||
/// <summary> | |||
/// Creates a bidirectional recurrent neural network. | |||
/// </summary> | |||
public static void static_bidirectional_rnn(BasicLSTMCell cell_fw, | |||
BasicLSTMCell cell_bw, | |||
public static (Tensor[], LSTMStateTuple, LSTMStateTuple) static_bidirectional_rnn(BasicLstmCell cell_fw, | |||
BasicLstmCell cell_bw, | |||
Tensor[] inputs, | |||
Tensor initial_state_fw = null, | |||
Tensor initial_state_bw = null, | |||
@@ -41,12 +41,17 @@ namespace Tensorflow.Operations | |||
if (inputs == null || inputs.Length == 0) | |||
throw new ValueError("inputs must not be empty"); | |||
Tensor[] output_fw = null; | |||
Tensor[] output_bw = null; | |||
LSTMStateTuple output_state_fw = null; | |||
LSTMStateTuple output_state_bw = null; | |||
tf_with(tf.variable_scope(scope ?? "bidirectional_rnn"), delegate | |||
{ | |||
// Forward direction | |||
tf_with(tf.variable_scope("fw"), fw_scope => | |||
{ | |||
static_rnn( | |||
(output_fw, output_state_fw) = static_rnn( | |||
cell_fw, | |||
inputs, | |||
initial_state_fw, | |||
@@ -54,16 +59,48 @@ namespace Tensorflow.Operations | |||
sequence_length, | |||
scope: fw_scope); | |||
}); | |||
// backward direction | |||
tf_with(tf.variable_scope("bw"), bw_scope => | |||
{ | |||
var reversed_inputs = _reverse_seq(inputs, sequence_length); | |||
(output_bw, output_state_bw) = static_rnn( | |||
cell_bw, | |||
reversed_inputs, | |||
initial_state_bw, | |||
dtype, | |||
sequence_length, | |||
scope: bw_scope); | |||
}); | |||
}); | |||
output_bw = _reverse_seq(output_bw, sequence_length); | |||
var flat_outputs = zip(output_fw, output_bw) | |||
.Select(x => array_ops.concat(new[] { x.Item1, x.Item2 }, 1)) | |||
.ToArray(); | |||
return (flat_outputs, output_state_fw, output_state_bw); | |||
} | |||
public static void static_rnn(BasicLSTMCell cell, | |||
private static Tensor[] _reverse_seq(Tensor[] input_seq, Tensor lengths) | |||
{ | |||
if (lengths == null) | |||
return input_seq.Reverse().ToArray(); | |||
throw new NotImplementedException("_reverse_seq"); | |||
} | |||
public static (Tensor[], LSTMStateTuple) static_rnn(BasicLstmCell cell, | |||
Tensor[] inputs, | |||
Tensor initial_state, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
Tensor sequence_length = null, | |||
VariableScope scope = null) | |||
{ | |||
List<Tensor> outputs = new List<Tensor>(); | |||
object state = null; | |||
// Create a new scope in which the caching device is either | |||
// determined by the parent scope, or is set to place the cached | |||
// Variable using the same placement as for the rest of the RNN. | |||
@@ -73,12 +110,12 @@ namespace Tensorflow.Operations | |||
throw new NotImplementedException("static_rnn"); | |||
}); | |||
else | |||
tf_with(tf.variable_scope(scope), varscope => | |||
tf_with(tf.variable_scope(scope), scope1 => | |||
{ | |||
Dimension fixed_batch_size = null; | |||
Dimension batch_size = null; | |||
Tensor batch_size_tensor = null; | |||
VariableScope varscope = scope1; | |||
// Obtain the first sequence of the input | |||
var first_input = inputs[0]; | |||
if (first_input.TensorShape.rank != 1) | |||
@@ -108,14 +145,31 @@ namespace Tensorflow.Operations | |||
else | |||
batch_size_tensor = array_ops.shape(first_input)[0]; | |||
Tensor state = null; | |||
if (initial_state != null) | |||
state = initial_state; | |||
else | |||
{ | |||
cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); | |||
state = cell.get_initial_state(batch_size: batch_size_tensor, dtype: dtype); | |||
} | |||
Tensor output = null; | |||
if (state is LSTMStateTuple state_tuple) | |||
{ | |||
foreach (var (time, input_) in enumerate(inputs)) | |||
{ | |||
if (time > 0) | |||
varscope.reuse_variables(); | |||
if (sequence_length != null) | |||
throw new NotImplementedException("static_rnn"); | |||
var results = cell.__call__(input_, state_tuple); | |||
(output, state_tuple) = (results[1], new LSTMStateTuple(results[0], results[1])); | |||
outputs.Add(output); | |||
} | |||
} | |||
}); | |||
return (outputs.ToArray(), state as LSTMStateTuple); | |||
} | |||
public static (Tensor, Tensor) dynamic_rnn(RnnCell cell, Tensor inputs_tensor, | |||
@@ -145,7 +199,7 @@ namespace Tensorflow.Operations | |||
if (initial_state != null) | |||
state = initial_state; | |||
else | |||
state = cell.get_initial_state(batch_size: batch_size, dtype: dtype); | |||
state = cell.get_initial_state(batch_size: batch_size, dtype: dtype) as Tensor; | |||
var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input); | |||
@@ -604,6 +604,11 @@ namespace Tensorflow | |||
return gen_array_ops.concat_v2(values, axis, name: name); | |||
} | |||
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") | |||
{ | |||
return gen_array_ops.concat_v2(values, axis, name: name); | |||
} | |||
public static Tensor concat(object[] values, int axis, string name = "concat") | |||
{ | |||
return gen_array_ops.concat_v2(values, axis, name: name); | |||
@@ -629,6 +634,16 @@ namespace Tensorflow | |||
}); | |||
} | |||
public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, | |||
string name = "split") | |||
{ | |||
var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||
return gen_array_ops.split(axis: axis, | |||
num_split: num_or_size_splits, | |||
value: value, | |||
name: name); | |||
} | |||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
@@ -47,7 +47,7 @@ namespace Tensorflow | |||
/// <param name="axis"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor concat_v2<T>(T[] values, int axis, string name = null) | |||
public static Tensor concat_v2<T, Ta>(T[] values, Ta axis, string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); | |||
@@ -5,7 +5,7 @@ | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>1.14.1</TargetTensorFlow> | |||
<Version>0.12.1</Version> | |||
<Version>0.13.0</Version> | |||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
<Company>SciSharp STACK</Company> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
@@ -18,14 +18,16 @@ | |||
<Description>Google's TensorFlow full binding in .NET Standard. | |||
Building, training and infering deep learning models. | |||
https://tensorflownet.readthedocs.io</Description> | |||
<AssemblyVersion>0.12.1.0</AssemblyVersion> | |||
<PackageReleaseNotes>Changes since v0.11.0: | |||
<AssemblyVersion>0.13.0.0</AssemblyVersion> | |||
<PackageReleaseNotes>Changes since v0.12.0: | |||
1: Add ICanBeFlattened for nest.flatten2. | |||
2: Complete the WhileContext. | |||
3: Add tf.nn.rnn_cell.BasicRNNCell and tf.nn.dynamic_rnn. | |||
4: Add EstimatorSpec.</PackageReleaseNotes> | |||
4: Add EstimatorSpec. | |||
5: Add rnn.static_rnn. | |||
6: Add array_grad._SplitGrad().</PackageReleaseNotes> | |||
<LangVersion>7.3</LangVersion> | |||
<FileVersion>0.12.1.0</FileVersion> | |||
<FileVersion>0.13.0.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||
@@ -7,20 +7,6 @@ namespace Tensorflow | |||
{ | |||
public partial class Tensor | |||
{ | |||
/// <summary> | |||
/// Issue unresolved, will cause name_scope problem. | |||
/// </summary> | |||
/// <param name="scalar"></param> | |||
/*public static implicit operator Tensor(double scalar) | |||
{ | |||
return constant_op.constant(scalar); | |||
}*/ | |||
/*public static implicit operator Tensor(int scalar) | |||
{ | |||
return constant_op.constant(scalar); | |||
}*/ | |||
public static implicit operator IntPtr(Tensor tensor) | |||
{ | |||
if (tensor._handle == IntPtr.Zero) | |||
@@ -526,14 +526,6 @@ namespace Tensorflow.Util | |||
return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
} | |||
public static Tensor map_structure2<T>(Func<T, Tensor> func, T structure) | |||
{ | |||
var flat_structure = flatten(structure); | |||
var mapped_flat_structure = flat_structure.Select(func).ToList(); | |||
return pack_sequence_as(structure, mapped_flat_structure) as Tensor; | |||
} | |||
/// <summary> | |||
/// Same as map_structure, but with only one structure (no combining of multiple structures) | |||
/// </summary> | |||
@@ -74,5 +74,10 @@ namespace Tensorflow | |||
aggregation: aggregation) as RefVariable; | |||
}); | |||
} | |||
public void reuse_variables() | |||
{ | |||
_reuse = _ReuseMode.AUTO_REUSE; | |||
} | |||
} | |||
} |
@@ -5,6 +5,7 @@ | |||
/// </summary> | |||
public enum _ReuseMode | |||
{ | |||
NOT_REUSE = 0, | |||
// Indicates that variables are to be fetched if they already exist or | |||
// otherwise created. | |||
AUTO_REUSE = 1 | |||