@@ -57,6 +57,21 @@ namespace Tensorflow | |||||
new[] { loop_vars }); | new[] { loop_vars }); | ||||
return results[0]; | return results[0]; | ||||
} | } | ||||
public (Tensor, List<TensorArray>, Tensors, Tensors) while_loop(Func<Tensor, Tensor> cond, | |||||
Func<Tensor, List<TensorArray>, Tensors, Tensors, (Tensor, List<TensorArray>, Tensors, Tensors)> body, | |||||
(Tensor, List<TensorArray>, Tensors, Tensors) loop_vars, | |||||
int parallel_iterations = 10) | |||||
=> control_flow_ops.while_loop(cond, | |||||
body, | |||||
loop_vars); | |||||
public (Tensor, List<TensorArray>, Tensors) while_loop(Func<Tensor, Tensor> cond, | |||||
Func<Tensor, List<TensorArray>, Tensors, (Tensor, List<TensorArray>, Tensors)> body, | |||||
(Tensor, List<TensorArray>, Tensors) loop_vars, | |||||
int parallel_iterations = 10) | |||||
=> control_flow_ops.while_loop(cond, | |||||
body, | |||||
loop_vars); | |||||
public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | public Tensor[] while_loop(Func<Tensor[], Tensor> cond, | ||||
Func<Tensor[], Tensor[]> body, | Func<Tensor[], Tensor[]> body, | ||||
@@ -1,5 +1,9 @@ | |||||
using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
using OneOf; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Layers; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | namespace Tensorflow.Keras.ArgsDefinition.Rnn | ||||
{ | { | ||||
@@ -7,11 +11,14 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | { | ||||
public interface IRnnArgCell : ILayer | public interface IRnnArgCell : ILayer | ||||
{ | { | ||||
object state_size { get; } | |||||
public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null); | |||||
public StateSizeWrapper state_size { get; set; } | |||||
public int output_size { get; set; } | |||||
} | } | ||||
[JsonProperty("cell")] | [JsonProperty("cell")] | ||||
// TODO: the cell should be serialized with `serialize_keras_object`. | // TODO: the cell should be serialized with `serialize_keras_object`. | ||||
public IRnnArgCell Cell { get; set; } = null; | |||||
public OneOf<IList<IRnnArgCell>, IRnnArgCell> Cell { get; set; } | |||||
[JsonProperty("return_sequences")] | [JsonProperty("return_sequences")] | ||||
public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
[JsonProperty("return_state")] | [JsonProperty("return_state")] | ||||
@@ -25,6 +32,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
[JsonProperty("time_major")] | [JsonProperty("time_major")] | ||||
public bool TimeMajor { get; set; } = false; | public bool TimeMajor { get; set; } = false; | ||||
// TODO: Add `num_constants` and `zero_output_for_mask`. | // TODO: Add `num_constants` and `zero_output_for_mask`. | ||||
public bool ZeroOutputForMask { get; set; } = false; | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | public Dictionary<string, object> Kwargs { get; set; } = null; | ||||
public int Units { get; set; } | public int Units { get; set; } | ||||
@@ -1,10 +1,11 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; | |||||
namespace Tensorflow.Keras.ArgsDefinition.Rnn | namespace Tensorflow.Keras.ArgsDefinition.Rnn | ||||
{ | { | ||||
public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
{ | { | ||||
public IList<RnnCell> Cells { get; set; } | |||||
public IList<IRnnArgCell> Cells { get; set; } | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | public Dictionary<string, object> Kwargs { get; set; } = null; | ||||
} | } | ||||
} | } |
@@ -0,0 +1,63 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System.Collections; | |||||
namespace Tensorflow.NumPy | |||||
{ | |||||
// Since state_size in RNN is a single integer or array of integer, so use StateSizeWrapper to hold it | |||||
public class StateSizeWrapper : IEnumerable<int> | |||||
{ | |||||
int[] _state_size; | |||||
public int[] state_size => _state_size; | |||||
public StateSizeWrapper(int state_size) | |||||
{ | |||||
_state_size = new int[] { state_size }; | |||||
} | |||||
public StateSizeWrapper(params int[] state_size) | |||||
{ | |||||
_state_size = state_size; | |||||
} | |||||
public StateSizeWrapper(IEnumerable<int> state_size) | |||||
{ | |||||
_state_size = state_size.ToArray(); | |||||
} | |||||
public static implicit operator StateSizeWrapper(int[] state_size) | |||||
=> new StateSizeWrapper(state_size); | |||||
public static implicit operator StateSizeWrapper(int state_size) | |||||
=> new StateSizeWrapper(state_size); | |||||
public static implicit operator StateSizeWrapper((int, int) state_size) | |||||
=> new StateSizeWrapper(state_size.Item1, state_size.Item2); | |||||
public static implicit operator StateSizeWrapper(List<int> v) | |||||
=> new StateSizeWrapper(v); | |||||
public override string ToString() | |||||
{ | |||||
return $"{state_size}"; | |||||
} | |||||
public int this[int n] | |||||
{ | |||||
get => n < 0 ? state_size[state_size.Length + n] : state_size[n]; | |||||
set => state_size[n] = value; | |||||
} | |||||
public IEnumerator<int> GetEnumerator() | |||||
{ | |||||
return state_size.ToList().GetEnumerator(); | |||||
} | |||||
IEnumerator IEnumerable.GetEnumerator() | |||||
{ | |||||
return GetEnumerator(); | |||||
} | |||||
} | |||||
} | |||||
@@ -26,6 +26,7 @@ using Tensorflow.Operations; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -50,7 +51,7 @@ namespace Tensorflow | |||||
/// matching structure of Tensors having shape `[batch_size].concatenate(s)` | /// matching structure of Tensors having shape `[batch_size].concatenate(s)` | ||||
/// for each `s` in `self.batch_size`. | /// for each `s` in `self.batch_size`. | ||||
/// </summary> | /// </summary> | ||||
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell | |||||
public abstract class RnnCell : ILayer | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight | /// Attribute that indicates whether the cell is a TF RNN cell, due the slight | ||||
@@ -698,6 +698,53 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static (Tensor, List<TensorArray>, Tensors, Tensors) while_loop(Func<Tensor, Tensor> cond, | |||||
Func<Tensor, List<TensorArray>, Tensors, Tensors, (Tensor, List<TensorArray>, Tensors, Tensors)> body, | |||||
(Tensor, List<TensorArray>, Tensors, Tensors) loop_vars, | |||||
int parallel_iterations = 10, | |||||
string name = null) | |||||
{ | |||||
var executing_eagerly = tf.Context.executing_eagerly(); | |||||
if (!executing_eagerly) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
return tf_with(ops.name_scope("name", "while"), delegate | |||||
{ | |||||
while ((bool)cond(loop_vars.Item1)) | |||||
{ | |||||
loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3, loop_vars.Item4); | |||||
} | |||||
return loop_vars; | |||||
}); | |||||
} | |||||
public static (Tensor, List<TensorArray>, Tensors) while_loop(Func<Tensor, Tensor> cond, | |||||
Func<Tensor, List<TensorArray>, Tensors, (Tensor, List<TensorArray>, Tensors)> body, | |||||
(Tensor, List<TensorArray>, Tensors) loop_vars, | |||||
int parallel_iterations = 10, | |||||
string name = null) | |||||
{ | |||||
var executing_eagerly = tf.Context.executing_eagerly(); | |||||
if (!executing_eagerly) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
return tf_with(ops.name_scope("name", "while"), delegate | |||||
{ | |||||
while ((bool)cond(loop_vars.Item1)) | |||||
{ | |||||
loop_vars = body(loop_vars.Item1, loop_vars.Item2, loop_vars.Item3); | |||||
} | |||||
return loop_vars; | |||||
}); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Repeat `body` while the condition `cond` is true. | /// Repeat `body` while the condition `cond` is true. | ||||
/// </summary> | /// </summary> | ||||
@@ -211,6 +211,28 @@ namespace Tensorflow.Util | |||||
=> arg is IEnumerable && !(arg is string) && !(arg is NDArray) && | => arg is IEnumerable && !(arg is string) && !(arg is NDArray) && | ||||
!(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>)); | !(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>)); | ||||
public static bool is_nested(object obj) | |||||
{ | |||||
// Check if the object is an IEnumerable | |||||
if (obj is IEnumerable) | |||||
{ | |||||
// If it is, check if it is a nested structure | |||||
foreach (object item in (IEnumerable)obj) | |||||
{ | |||||
if (is_nested(item)) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
else | |||||
{ | |||||
// If it is not, return false | |||||
return false; | |||||
} | |||||
} | |||||
public static bool is_mapping(object arg) => arg is IDictionary; | public static bool is_mapping(object arg) => arg is IDictionary; | ||||
//# See the swig file (util.i) for documentation. | //# See the swig file (util.i) for documentation. | ||||
@@ -263,7 +285,29 @@ namespace Tensorflow.Util | |||||
} | } | ||||
} | } | ||||
public static List<T> FlattenTupple<T>(object tuple) | |||||
{ | |||||
List<T> items = new List<T>(); | |||||
var type = tuple.GetType(); | |||||
if (type.GetInterface("ITuple") == null) | |||||
throw new ArgumentException("This is not a tuple!"); | |||||
foreach (var property in type.GetProperties()) | |||||
{ | |||||
var value = property.GetValue(tuple); | |||||
if (property.PropertyType.GetInterface("ITuple") != null) | |||||
{ | |||||
var subItems = FlattenTupple<T>(value); | |||||
items.AddRange(subItems); | |||||
} | |||||
else | |||||
{ | |||||
items.Add((T)value); | |||||
} | |||||
} | |||||
return items; | |||||
} | |||||
//# See the swig file (util.i) for documentation. | //# See the swig file (util.i) for documentation. | ||||
//_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | ||||
@@ -22,6 +22,9 @@ using Tensorflow.Functions; | |||||
using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using static Tensorflow.Graphs.SubGraphUtility; | using static Tensorflow.Graphs.SubGraphUtility; | ||||
using Tensorflow.Util; | |||||
using Tensorflow.Operations; | |||||
using OneOf; | |||||
namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
{ | { | ||||
@@ -65,7 +68,7 @@ namespace Tensorflow.Keras | |||||
return; | return; | ||||
} | } | ||||
var graph = v.Graph; | var graph = v.Graph; | ||||
if(graph is null) | |||||
if (graph is null) | |||||
{ | { | ||||
graph = get_graph(); | graph = get_graph(); | ||||
} | } | ||||
@@ -95,7 +98,7 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
if (_GRAPH == null) | if (_GRAPH == null) | ||||
_GRAPH = new FuncGraph("keras_graph"); | _GRAPH = new FuncGraph("keras_graph"); | ||||
return _GRAPH; | return _GRAPH; | ||||
} | } | ||||
return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
@@ -105,7 +108,7 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
if (_CURRENT_SCRATCH_GRAPH == null) | if (_CURRENT_SCRATCH_GRAPH == null) | ||||
_CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); | _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); | ||||
return _CURRENT_SCRATCH_GRAPH; | return _CURRENT_SCRATCH_GRAPH; | ||||
} | } | ||||
@@ -230,16 +233,16 @@ namespace Tensorflow.Keras | |||||
{ | { | ||||
if (outputs[0].op.type == "Const") | if (outputs[0].op.type == "Const") | ||||
return tensor_util.constant_value(outputs); | return tensor_util.constant_value(outputs); | ||||
var source_graph = outputs.graph; | var source_graph = outputs.graph; | ||||
var exec_graph = _scratch_graph(); | var exec_graph = _scratch_graph(); | ||||
var global_graph = get_graph(); | var global_graph = get_graph(); | ||||
if (source_graph == global_graph && exec_graph != global_graph) | if (source_graph == global_graph && exec_graph != global_graph) | ||||
{ | { | ||||
var lifted_map = lift_to_graph(outputs, exec_graph, | |||||
new List<Tensor>(), | |||||
add_sources: true, | |||||
handle_captures: true, | |||||
var lifted_map = lift_to_graph(outputs, exec_graph, | |||||
new List<Tensor>(), | |||||
add_sources: true, | |||||
handle_captures: true, | |||||
base_graph: source_graph); | base_graph: source_graph); | ||||
} | } | ||||
if (outputs[0].op.type == "Placeholder" | if (outputs[0].op.type == "Placeholder" | ||||
@@ -250,7 +253,7 @@ namespace Tensorflow.Keras | |||||
exec_graph.as_default(); | exec_graph.as_default(); | ||||
exec_graph.Inputs = exec_graph.internal_captures; | exec_graph.Inputs = exec_graph.internal_captures; | ||||
exec_graph.Outputs = outputs; | exec_graph.Outputs = outputs; | ||||
var graph_fn = new ConcreteFunction(exec_graph); | var graph_fn = new ConcreteFunction(exec_graph); | ||||
_CURRENT_SCRATCH_GRAPH = null; | _CURRENT_SCRATCH_GRAPH = null; | ||||
@@ -370,7 +373,7 @@ namespace Tensorflow.Keras | |||||
/// <param name="data_format"></param> | /// <param name="data_format"></param> | ||||
/// <param name="interpolation"></param> | /// <param name="interpolation"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor resize_images(Tensor x, int height_factor, int width_factor, | |||||
public Tensor resize_images(Tensor x, int height_factor, int width_factor, | |||||
string data_format, string interpolation = "nearest") | string data_format, string interpolation = "nearest") | ||||
{ | { | ||||
var (rows, cols) = (0, 0); | var (rows, cols) = (0, 0); | ||||
@@ -412,7 +415,7 @@ namespace Tensorflow.Keras | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor concatenate(Tensors tensors, int axis = -1) | public Tensor concatenate(Tensors tensors, int axis = -1) | ||||
{ | { | ||||
if(axis < 0) | |||||
if (axis < 0) | |||||
{ | { | ||||
var rank = tensors[0].ndim; | var rank = tensors[0].ndim; | ||||
if (rank > -1) | if (rank > -1) | ||||
@@ -450,5 +453,520 @@ namespace Tensorflow.Keras | |||||
return x; | return x; | ||||
} | } | ||||
public static (Tensors, Tensors) convert_inputs_if_ragged(OneOf<Tensor, RaggedTensor> inputs) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
// | |||||
public static (Tensors, Tensors, Tensors) rnn( | |||||
Func<Tensors, Tensors, (Tensors, Tensors)> step_function, // args:inputs, states, return:output, new_states | |||||
Tensors inputs, // inputs is a tuple of tensors (one per input sequence) | |||||
Tensors initial_states, | |||||
bool go_backwards = false, | |||||
Tensor? mask = null, | |||||
Tensors? constants = null, | |||||
bool unroll = false, | |||||
Tensors? input_length = null, // An integer or a 1-D Tensor,depending on whether the time dimension is fixed-length or not | |||||
bool time_major = false, | |||||
bool zero_output_for_mask = false, | |||||
bool return_all_outputs = true) | |||||
{ | |||||
Tensors swap_batch_timestep(Tensors input_t) | |||||
{ | |||||
var axes = Enumerable.Range(0, input_t.rank).ToArray(); | |||||
axes[0] = 1; | |||||
axes[1] = 0; | |||||
return tf.transpose(input_t, axes); | |||||
} | |||||
if (!time_major) | |||||
{ | |||||
inputs = nest.map_structure(swap_batch_timestep, inputs); | |||||
} | |||||
var flatted_inptus = nest.flatten(inputs); | |||||
var time_steps = flatted_inptus[0].shape[0]; | |||||
var batch = flatted_inptus[0].shape[1]; | |||||
var time_step_t = tf.shape(flatted_inptus[0])[0]; | |||||
foreach (var input_ in flatted_inptus) | |||||
{ | |||||
input_.shape.with_rank_at_least(3); | |||||
} | |||||
if (mask != null) | |||||
{ | |||||
if (mask.dtype != TF_DataType.TF_BOOL) | |||||
{ | |||||
mask = tf.cast(mask, TF_DataType.TF_BOOL); | |||||
} | |||||
if (mask.rank == 2) | |||||
{ | |||||
mask = tf.expand_dims(mask, -1); | |||||
} | |||||
if (!time_major) | |||||
{ | |||||
mask = swap_batch_timestep(mask); | |||||
} | |||||
} | |||||
if (constants == null) | |||||
{ | |||||
constants = new List<Tensor>(); | |||||
} | |||||
// tf.where needs its condition tensor to be the same shape as its two | |||||
// result tensors, but in our case the condition (mask) tensor is | |||||
// (nsamples, 1), and inputs are (nsamples, ndimensions) or even more. | |||||
// So we need to broadcast the mask to match the shape of inputs. | |||||
// That's what the tile call does, it just repeats the mask along its | |||||
// second dimension n times. | |||||
Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1) | |||||
{ | |||||
if (nest.is_nested(mask_t)) | |||||
{ | |||||
throw new ValueError($"mask_t is expected to be tensor, but got {mask_t}"); | |||||
} | |||||
if (nest.is_nested(input_t)) | |||||
{ | |||||
throw new ValueError($"input_t is expected to be tensor, but got {input_t}"); | |||||
} | |||||
var rank_diff = input_t.rank - mask_t.rank; | |||||
for (int i = 0; i < rank_diff; i++) | |||||
{ | |||||
mask_t = tf.expand_dims(mask_t, -1); | |||||
} | |||||
var multiples = Enumerable.Repeat(1, fixed_dim).ToArray().concat(input_t.shape.as_int_list().ToList().GetRange(fixed_dim, input_t.rank)); | |||||
return tf.tile(mask_t, multiples); | |||||
} | |||||
Tensors outputs = new Tensors(); | |||||
Tensors output_time_zero = new Tensors(); | |||||
Tensors last_output = new Tensors(); | |||||
Tensors new_states = new Tensors(); | |||||
if (unroll) | |||||
{ | |||||
if (time_steps == 0) | |||||
{ | |||||
throw new ValueError("Unrolling requires a fixed number of timesteps."); | |||||
} | |||||
// Process the input tensors. The input tensor need to be split on the | |||||
// time_step dim, and reverse if go_backwards is True. In the case of | |||||
// nested input, the input is flattened and then transformed | |||||
// individually. The result of this will be a tuple of lists, each of | |||||
// the item in tuple is list of the tensor with shape (batch, feature) | |||||
// TODO(Wanglongzhi2001),step_func接受的第二个参数为List,但是最后却用的tuple | |||||
//var states = Tuple.Create(initial_states); | |||||
var states = initial_states; | |||||
var successive_states = new Tensors(); | |||||
var successive_outputs = new Tensors(); | |||||
// Process the input tensors. The input tensor need to be split on the | |||||
// time_step dim, and reverse if go_backwards is True. In the case of | |||||
// nested input, the input is flattened and then transformed | |||||
// individually. The result of this will be a tuple of lists, each of | |||||
// the item in tuple is list of the tensor with shape (batch, feature) | |||||
Tensors _process_single_input_t(Tensors input_t) | |||||
{ | |||||
input_t = tf.unstack(input_t); // unstack for time_step dim | |||||
if (go_backwards) | |||||
{ | |||||
input_t.Reverse(); | |||||
} | |||||
return input_t; | |||||
} | |||||
// TODO(Wanglongzhi2001) | |||||
Tensors processed_input; | |||||
if (nest.is_nested(inputs)) | |||||
{ | |||||
processed_input = nest.map_structure(_process_single_input_t, inputs); | |||||
} | |||||
else | |||||
{ | |||||
processed_input = _process_single_input_t(inputs); | |||||
} | |||||
object _get_input_tensor(int time) | |||||
{ | |||||
List<Tensor> inp = new List<Tensor>(); | |||||
foreach (var t_ in processed_input) | |||||
{ | |||||
inp.Add(t_[time]); | |||||
} | |||||
return nest.pack_sequence_as(inputs, inp); | |||||
} | |||||
if (mask != null) | |||||
{ | |||||
var mask_list = tf.unstack(mask); | |||||
if (go_backwards) | |||||
{ | |||||
mask_list.Reverse(); | |||||
} | |||||
for (int i = 0; i < time_steps; i++) | |||||
{ | |||||
// TODO(Wanglongzhi2001),deal with _get_input_tensor | |||||
var inp = _get_input_tensor(i); | |||||
var mask_t = mask_list[i]; | |||||
// TODO | |||||
var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); | |||||
var tiled_mask_t = _expand_mask(mask_t, output); | |||||
Tensors prev_output; | |||||
if (successive_outputs == null) | |||||
{ | |||||
prev_output = tf.zeros_like(output); | |||||
} | |||||
else | |||||
{ | |||||
prev_output = successive_outputs[successive_outputs.Length - 1]; | |||||
} | |||||
output = tf.where(tiled_mask_t, output, prev_output); | |||||
//var flat_states = nest.flatten(states); | |||||
//var flat_new_states = nest.flatten(newStates); | |||||
var flat_states = states.ToList(); | |||||
var flat_new_states = newStates.ToList(); | |||||
var tiledMaskT = flat_states | |||||
.Select(s => _expand_mask(mask_t, s)) | |||||
.ToArray(); | |||||
var tuple = Tuple.Create(tiledMaskT); | |||||
List<Tensor> flat_final_states = new List<Tensor>(); | |||||
foreach (var (m, s, ps) in Enumerable.Zip(tiled_mask_t, flat_new_states, flat_states)) | |||||
{ | |||||
flat_final_states.Add(tf.where(m, s, ps)); | |||||
} | |||||
states = (Tensors)nest.pack_sequence_as(states, flat_final_states); | |||||
if (return_all_outputs) | |||||
{ | |||||
successive_outputs.Add(output); | |||||
successive_states.Add(states); | |||||
} | |||||
else | |||||
{ | |||||
successive_outputs = new Tensors { output }; | |||||
successive_states = new Tensors { states }; | |||||
} | |||||
} | |||||
last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
new_states = successive_states[successive_states.Length - 1]; | |||||
outputs = tf.stack(successive_outputs); | |||||
if (zero_output_for_mask) | |||||
{ | |||||
last_output = tf.where(_expand_mask(mask_list[mask_list.Length - 1], last_output), last_output, tf.zeros_like(last_output)); | |||||
outputs = tf.where(_expand_mask(mask, outputs, fixed_dim: 2), outputs, tf.zeros_like(outputs)); | |||||
} | |||||
else // mask is null | |||||
{ | |||||
for (int i = 0; i < time_steps; i++) | |||||
{ | |||||
var inp = _get_input_tensor(i); | |||||
var (output, newStates) = step_function((Tensors)inp, new Tensors { states, constants }); | |||||
states = newStates; | |||||
if (return_all_outputs) | |||||
{ | |||||
successive_outputs.Add(output); | |||||
successive_states.Add(newStates); | |||||
} | |||||
else | |||||
{ | |||||
successive_outputs = new Tensors { output }; | |||||
successive_states = new Tensors { newStates }; | |||||
} | |||||
} | |||||
last_output = successive_outputs[successive_outputs.Length - 1]; | |||||
new_states = successive_states[successive_states.Length - 1]; | |||||
outputs = tf.stack(successive_outputs); | |||||
} | |||||
} | |||||
} | |||||
else // unroll == false | |||||
{ | |||||
var states = initial_states; | |||||
// Create input tensor array, if the inputs is nested tensors, then it | |||||
// will be flattened first, and tensor array will be created one per | |||||
// flattened tensor. | |||||
var input_ta = new List<TensorArray>(); | |||||
for (int i = 0; i < flatted_inptus.Count; i++) | |||||
{ | |||||
input_ta.Add(tf.TensorArray(dtype: flatted_inptus[i].dtype, size: time_step_t)); | |||||
} | |||||
// Get the time(0) input and compute the output for that, the output will | |||||
// be used to determine the dtype of output tensor array. Don't read from | |||||
// input_ta due to TensorArray clear_after_read default to True. | |||||
var inps = new Tensors(); | |||||
foreach (var inp in flatted_inptus) | |||||
{ | |||||
inps.Add(inp[0]); | |||||
} | |||||
var input_time_zero = nest.pack_sequence_as(inputs, inps); | |||||
// output_time_zero is used to determine the cell output shape and its | |||||
// dtype. the value is discarded. | |||||
(output_time_zero, _) = step_function((Tensor)input_time_zero, new Tensors { initial_states, constants }); | |||||
var output_ta_size = return_all_outputs ? time_step_t : tf.constant(1); | |||||
var output_ta = new List<TensorArray>(); | |||||
for (int i = 0; i < output_time_zero.ToList().Count; i++) | |||||
{ | |||||
var Out = output_time_zero.ToList()[i]; | |||||
output_ta.Add(tf.TensorArray(dtype: Out.dtype, size: output_ta_size, element_shape: Out.shape)); | |||||
} | |||||
var time = tf.constant(0, dtype: TF_DataType.TF_INT32, name: "time"); | |||||
Func<Tensor, Tensor>? masking_fn; | |||||
Func<Tensors, Tensors, Tensors, Tensors>? compute_masked_output = null; | |||||
if (mask != null) | |||||
{ | |||||
if (go_backwards) | |||||
{ | |||||
mask = tf.reverse(mask, axis: new[] { 0 }); | |||||
} | |||||
var mask_ta = tf.TensorArray(dtype: TF_DataType.TF_BOOL, size: time_step_t); | |||||
mask_ta = mask_ta.unstack(mask); | |||||
masking_fn = (time) => | |||||
{ | |||||
return mask_ta.read(time); | |||||
}; | |||||
compute_masked_output = (mask_t, flat_out, flat_mask) => | |||||
{ | |||||
var tiled_mask_t = new Tensors(); | |||||
foreach (var o in flat_out) | |||||
{ | |||||
tiled_mask_t.Add(_expand_mask(mask_t, o, fixed_dim: mask_t.rank)); | |||||
} | |||||
Tensors res = new Tensors(); | |||||
foreach (var (m, o, fm) in Enumerable.Zip(tiled_mask_t, flat_out, flat_mask)) | |||||
{ | |||||
res.Add(tf.where(m, o, fm)); | |||||
} | |||||
return res; | |||||
}; | |||||
} | |||||
// TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)? | |||||
else if (input_length is Tensor) | |||||
{ | |||||
if (go_backwards) | |||||
{ | |||||
var max_len = tf.reduce_max(input_length, axis: 0); | |||||
var rev_input_length = tf.subtract(max_len - 1, input_length); | |||||
masking_fn = (time) => | |||||
{ | |||||
return tf.less(rev_input_length, time); | |||||
}; | |||||
} | |||||
else | |||||
{ | |||||
masking_fn = (time) => | |||||
{ | |||||
return tf.greater(input_length, time); | |||||
}; | |||||
} | |||||
compute_masked_output = (mask_t, flat_out, flat_mask) => | |||||
{ | |||||
var res = new List<Tensor>(); | |||||
foreach (var (o, zo) in zip(flat_out, flat_mask)) | |||||
{ | |||||
res.Add(tf.where(mask_t, o, zo)); | |||||
} | |||||
return res; | |||||
}; | |||||
} | |||||
else | |||||
{ | |||||
masking_fn = null; | |||||
} | |||||
if (masking_fn != null) | |||||
{ | |||||
// Mask for the T output will be base on the output of T - 1. In the | |||||
// case T = 0, a zero filled tensor will be used. | |||||
var flat_zero_output = new Tensors(); | |||||
foreach (var o in nest.flatten(output_time_zero)) | |||||
{ | |||||
flat_zero_output.Add(tf.zeros_like(o)); | |||||
} | |||||
(Tensor, List<TensorArray>, Tensors, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors prev_output, Tensors states) | |||||
{ | |||||
/* | |||||
RNN step function. | |||||
Args: | |||||
time: Current timestep value. | |||||
output_ta_t: TensorArray. | |||||
prev_output: tuple of outputs from time - 1. | |||||
*states: List of states. | |||||
Returns: | |||||
Tuple(todo): `(time + 1, output_ta_t, output) + tuple(new_states)` | |||||
*/ | |||||
var current_input = input_ta.Select(x => x.read(time)).ToList(); | |||||
// maybe set shape | |||||
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||||
current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); | |||||
var mask_t = masking_fn(time); | |||||
var (output, new_states) = step_function(current_input, new Tensors { states, constants }); | |||||
// mask output | |||||
//var flat_output = nest.flatten(output); | |||||
var flat_output = output.ToList(); | |||||
var flat_mask_output = zero_output_for_mask ? flat_zero_output : prev_output.ToList(); | |||||
// TODO(Wanglongzhi2001),deal with compute_masked_output's third parameter's type | |||||
var flat_new_output = compute_masked_output(mask_t, flat_output, flat_mask_output); | |||||
// mask states | |||||
var flat_state = states.ToList(); | |||||
var flat_new_state = new_states.ToList(); | |||||
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||||
{ | |||||
if (new_state is Tensor) | |||||
{ | |||||
new_state.set_shape(state.shape); | |||||
} | |||||
} | |||||
var flat_final_state = compute_masked_output(mask_t, flat_new_state, flat_state); | |||||
new_states = (Tensors)nest.pack_sequence_as(new_states, flat_final_state); | |||||
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||||
var Output_ta_t = new List<TensorArray>(); | |||||
// TODO(Wanglongzhi2001),deal with zip output_ta_t | |||||
foreach (var (ta, Out) in zip(output_ta_t, flat_new_output)) | |||||
{ | |||||
Output_ta_t.Add(ta.write(ta_index_to_write, Out)); | |||||
} | |||||
//new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); | |||||
return (time + 1, Output_ta_t, flat_new_output, new_states); | |||||
} | |||||
Func<Tensor, Tensor> cond = (time) => (time < time_step_t); | |||||
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, flat_zero_output, states)); | |||||
new_states = final_outputs.Item4; | |||||
output_ta = final_outputs.Item2; | |||||
} | |||||
else | |||||
{ | |||||
(Tensor, List<TensorArray>, Tensors) _step(Tensor time, List<TensorArray> output_ta_t, Tensors states) | |||||
{ | |||||
var current_input = input_ta.Select(x => x.read(time)).ToList(); | |||||
// maybe set shape | |||||
// TODO(Wanglongzhi2001),deal with nest.pack_sequence_as's return type | |||||
current_input = (List<Tensor>)nest.pack_sequence_as(inputs, current_input); | |||||
var (output, new_states) = step_function(current_input, new Tensors { states, constants }); | |||||
var flat_state = states.ToList(); | |||||
var flat_new_state = new_states.ToList(); | |||||
foreach (var (state, new_state) in zip(flat_state, flat_new_state)) | |||||
{ | |||||
if (new_state is Tensor) | |||||
{ | |||||
new_state.set_shape(state.shape); | |||||
} | |||||
} | |||||
var flat_output = output.ToList(); | |||||
var ta_index_to_write = return_all_outputs ? time : tf.constant(0); | |||||
var Output_ta_t = new List<TensorArray>(); | |||||
foreach (var (ta, out_) in zip(output_ta_t, flat_output)) | |||||
{ | |||||
Output_ta_t.Add(ta.write(ta_index_to_write, out_)); | |||||
} | |||||
new_states = (Tensors)nest.pack_sequence_as(initial_states, flat_new_state); | |||||
return (time + 1, Output_ta_t, new_states); | |||||
} | |||||
Func<Tensor, Tensor> cond = (time) => (time < time_step_t); | |||||
var final_outputs = tf.while_loop(cond: cond, body: _step, loop_vars: (time, output_ta, states)); | |||||
new_states = final_outputs.Item3; | |||||
output_ta = final_outputs.Item2; | |||||
} | |||||
//Tensors outputs = new Tensors(); | |||||
foreach (var o in output_ta) | |||||
{ | |||||
outputs.Add(o.stack()); | |||||
} | |||||
foreach (var o in outputs) | |||||
{ | |||||
last_output.Add(o[-1]); | |||||
} | |||||
outputs = (Tensors)nest.pack_sequence_as(output_time_zero, outputs); | |||||
last_output = (Tensors)nest.pack_sequence_as(output_time_zero, last_output); | |||||
} | |||||
Func<Tensor, Tensor> set_shape; | |||||
set_shape = (output_) => | |||||
{ | |||||
if (output_ is Tensor) | |||||
{ | |||||
var shape = output_.shape.as_int_list(); | |||||
if (return_all_outputs) | |||||
{ | |||||
shape[0] = (int)time_steps; | |||||
} | |||||
else | |||||
{ | |||||
shape[0] = 1; | |||||
} | |||||
shape[1] = (int)batch; | |||||
output_.set_shape(new Tensor(shape)); | |||||
} | |||||
return output_; | |||||
}; | |||||
var Outputs = (Tensors)nest.map_structure(set_shape, outputs); | |||||
if (!time_major) | |||||
{ | |||||
Outputs = nest.map_structure(swap_batch_timestep, outputs); | |||||
} | |||||
return (last_output, Outputs, new_states); | |||||
} | |||||
} | } | ||||
} | } |
@@ -332,9 +332,9 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="state"></param> | /// <param name="state"></param> | ||||
/// <param name="training"></param> | /// <param name="training"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected virtual Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if(ReplacedCall is not null) | |||||
if (ReplacedCall is not null) | |||||
{ | { | ||||
return ReplacedCall(inputs); | return ReplacedCall(inputs); | ||||
} | } | ||||
@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers { | |||||
base.build(input_shape); | base.build(input_shape); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
output = tf.where(output > 0f, output, | output = tf.where(output > 0f, output, | ||||
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers { | |||||
{ | { | ||||
base.build(input_shape); | base.build(input_shape); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
return tf.exp(output); | return tf.exp(output); | ||||
@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { | |||||
public HardSigmoid ( LayerArgs args ) : base(args) { | public HardSigmoid ( LayerArgs args ) : base(args) { | ||||
// hard sigmoid has no arguments | // hard sigmoid has no arguments | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | |||||
Tensor x = inputs; | Tensor x = inputs; | ||||
return tf.clip_by_value( | return tf.clip_by_value( | ||||
tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); | tf.add(tf.multiply(x, 0.2f), 0.5f), 0f, 1f); | ||||
@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return tf.nn.leaky_relu(inputs, alpha: alpha); | return tf.nn.leaky_relu(inputs, alpha: alpha); | ||||
} | } | ||||
@@ -22,7 +22,8 @@ namespace Tensorflow.Keras.Layers { | |||||
} | } | ||||
base.build(input_shape); | base.build(input_shape); | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | |||||
Tensor output = inputs; | Tensor output = inputs; | ||||
return tf.where(output > 0f, | return tf.where(output > 0f, | ||||
tf.multiply(scale, output), | tf.multiply(scale, output), | ||||
@@ -11,7 +11,8 @@ namespace Tensorflow.Keras.Layers { | |||||
public Softmax ( SoftmaxArgs args ) : base(args) { | public Softmax ( SoftmaxArgs args ) : base(args) { | ||||
axis = args.axis; | axis = args.axis; | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | |||||
Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) | Tensor x = inputs.Length == 2 ? inputs + ((1.0 - tf.cast(inputs[1], inputs.dtype)) * 1e-9) | ||||
: inputs; | : inputs; | ||||
Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | Tensor e = tf.exp(tf.sub(x, tf.reduce_max(x, axis: this.axis, keepdims: true))); | ||||
@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { | |||||
public Softplus ( LayerArgs args ) : base(args) { | public Softplus ( LayerArgs args ) : base(args) { | ||||
// Softplus has no arguments | // Softplus has no arguments | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | |||||
Tensor x = inputs; | Tensor x = inputs; | ||||
return tf.log( | return tf.log( | ||||
tf.add(tf.exp(x), 1f)); | tf.add(tf.exp(x), 1f)); | ||||
@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { | |||||
public Softsign ( LayerArgs args ) : base(args) { | public Softsign ( LayerArgs args ) : base(args) { | ||||
// Softsign has no arguments | // Softsign has no arguments | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | |||||
Tensor x = inputs; | Tensor x = inputs; | ||||
// x / (abs(x) + 1) | // x / (abs(x) + 1) | ||||
return tf.div(x, tf.add(1f, tf.abs(x))); | return tf.div(x, tf.add(1f, tf.abs(x))); | ||||
@@ -10,7 +10,8 @@ namespace Tensorflow.Keras.Layers { | |||||
public Swish ( LayerArgs args ) : base(args) { | public Swish ( LayerArgs args ) : base(args) { | ||||
// Swish has no arguments | // Swish has no arguments | ||||
} | } | ||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | |||||
Tensor x = inputs; | Tensor x = inputs; | ||||
// x / (1 + exp(-x)) | // x / (1 + exp(-x)) | ||||
@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
// Tanh has no arguments | // Tanh has no arguments | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor x = inputs; | Tensor x = inputs; | ||||
@@ -114,7 +114,7 @@ namespace Tensorflow.Keras.Layers | |||||
return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); | return (tf.linalg.einsum("bij,bjk->bik", (weights, value)), weights); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensors _inp; | Tensors _inp; | ||||
Tensors _mask = null; | Tensors _mask = null; | ||||
@@ -252,7 +252,7 @@ namespace Tensorflow.Keras.Layers | |||||
return (attention_output, attention_scores); | return (attention_output, attention_scores); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensors _inp; | Tensors _inp; | ||||
Tensor _mask = null; | Tensor _mask = null; | ||||
@@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Layers | |||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); | var outputs = _convolution_op.Apply(inputs, kernel.AsTensor()); | ||||
if (use_bias) | if (use_bias) | ||||
@@ -69,7 +69,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
var rank = inputs.rank; | var rank = inputs.rank; | ||||
@@ -189,7 +189,7 @@ namespace Tensorflow.Keras.Layers | |||||
// return new dict(base_config.items().ToList() + config.items().ToList()); | // return new dict(base_config.items().ToList() + config.items().ToList()); | ||||
//} | //} | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); | var ret = tf.linalg.einsum(this.equation, (inputs, this.kernel.AsTensor())); | ||||
if (this.bias != null) | if (this.bias != null) | ||||
@@ -66,7 +66,7 @@ namespace Tensorflow.Keras.Layers | |||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
var dtype = inputs.dtype; | var dtype = inputs.dtype; | ||||
if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers | |||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return _merge_function(inputs); | return _merge_function(inputs); | ||||
} | } | ||||
@@ -146,7 +146,7 @@ namespace Tensorflow.Keras.Layers | |||||
return false; | return false; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
var training_tensor = training == null | var training_tensor = training == null | ||||
@@ -101,7 +101,7 @@ namespace Tensorflow.Keras.Layers | |||||
return input_shape; | return input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor outputs = null; | Tensor outputs = null; | ||||
var inputs_dtype = inputs.dtype.as_base_dtype(); | var inputs_dtype = inputs.dtype.as_base_dtype(); | ||||
@@ -157,7 +157,7 @@ namespace Tensorflow.Keras.Layers | |||||
base.adapt(data, batch_size: batch_size, steps: steps); | base.adapt(data, batch_size: batch_size, steps: steps); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (_args.Invert) | if (_args.Invert) | ||||
{ | { | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (data_format == "channels_last") | if (data_format == "channels_last") | ||||
return math_ops.reduce_mean(inputs, 1, false); | return math_ops.reduce_mean(inputs, 1, false); | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (data_format == "channels_last") | if (data_format == "channels_last") | ||||
return math_ops.reduce_mean(inputs, (1, 2), false); | return math_ops.reduce_mean(inputs, (1, 2), false); | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (data_format == "channels_last") | if (data_format == "channels_last") | ||||
return math_ops.reduce_max(inputs, 1, false); | return math_ops.reduce_max(inputs, 1, false); | ||||
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (data_format == "channels_last") | if (data_format == "channels_last") | ||||
return math_ops.reduce_max(inputs, (1, 2), false); | return math_ops.reduce_max(inputs, (1, 2), false); | ||||
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||||
input_spec = new InputSpec(ndim: 3); | input_spec = new InputSpec(ndim: 3); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; | int pad_axis = args.DataFormat == "channels_first" ? 2 : 3; | ||||
inputs = tf.expand_dims(inputs, pad_axis); | inputs = tf.expand_dims(inputs, pad_axis); | ||||
@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||||
input_spec = new InputSpec(ndim: 4); | input_spec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
int[] pool_shape; | int[] pool_shape; | ||||
int[] strides; | int[] strides; | ||||
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
var depth = args.NumTokens; | var depth = args.NumTokens; | ||||
var max_value = tf.reduce_max(inputs); | var max_value = tf.reduce_max(inputs); | ||||
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
scale = constant_op.constant(args.Scale, args.DType); | scale = constant_op.constant(args.Scale, args.DType); | ||||
offset = constant_op.constant(args.Offset, args.DType); | offset = constant_op.constant(args.Offset, args.DType); | ||||
@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); | return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation); | ||||
} | } | ||||
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (training == null) | if (training == null) | ||||
training = false; | training = false; | ||||
@@ -27,7 +27,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
if (output.rank != 3) | if (output.rank != 3) | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||||
built = true; | built = true; | ||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
if (output.rank != 4) | if (output.rank != 4) | ||||
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Layers.Reshaping | |||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor output = inputs; | Tensor output = inputs; | ||||
if (output.rank != 5) | if (output.rank != 5) | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers | |||||
_channels_first = args.DataFormat == "channels_first"; | _channels_first = args.DataFormat == "channels_first"; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (_channels_first) | if (_channels_first) | ||||
{ | { | ||||
@@ -28,7 +28,7 @@ namespace Tensorflow.Keras.Layers { | |||||
built = true; | built = true; | ||||
_buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
Tensor outputs = inputs; | Tensor outputs = inputs; | ||||
return tf.transpose(outputs, new Axis(permute)); | return tf.transpose(outputs, new Axis(permute)); | ||||
@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.args = args; | this.args = args; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
var shapes = new List<Tensor>(); | var shapes = new List<Tensor>(); | ||||
shapes.Add(array_ops.shape(inputs)[0]); | shapes.Add(array_ops.shape(inputs)[0]); | ||||
@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers | |||||
inputSpec = new InputSpec(ndim: 4); | inputSpec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return keras.backend.resize_images(inputs, | return keras.backend.resize_images(inputs, | ||||
size[0], size[1], | size[0], size[1], | ||||
@@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Layers | |||||
this.input_spec = new InputSpec(ndim: 4); | this.input_spec = new InputSpec(ndim: 4); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return keras.backend.spatial_2d_padding(inputs, | return keras.backend.spatial_2d_padding(inputs, | ||||
padding: padding, | padding: padding, | ||||
@@ -26,9 +26,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
.ToArray(); | .ToArray(); | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return base.Call(inputs, state: state, training: training); | |||||
return base.Call(inputs, initial_state: initial_state, training: training); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,9 +1,15 @@ | |||||
using System; | using System; | ||||
using System.Collections; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using System.Reflection; | |||||
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Util; | |||||
using OneOf; | |||||
using OneOf.Types; | |||||
using Tensorflow.Common.Extensions; | |||||
// from tensorflow.python.distribute import distribution_strategy_context as ds_context; | // from tensorflow.python.distribute import distribution_strategy_context as ds_context; | ||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
@@ -19,11 +25,46 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
protected IVariableV1 kernel; | protected IVariableV1 kernel; | ||||
protected IVariableV1 bias; | protected IVariableV1 bias; | ||||
protected ILayer cell; | protected ILayer cell; | ||||
public RNN(RNNArgs args) : base(PreConstruct(args)) | public RNN(RNNArgs args) : base(PreConstruct(args)) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
SupportsMasking = true; | SupportsMasking = true; | ||||
// if is StackedRnncell | |||||
if (args.Cell.IsT0) | |||||
{ | |||||
cell = new StackedRNNCells(new StackedRNNCellsArgs | |||||
{ | |||||
Cells = args.Cell.AsT0, | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
cell = args.Cell.AsT1; | |||||
} | |||||
Type type = cell.GetType(); | |||||
MethodInfo methodInfo = type.GetMethod("Call"); | |||||
if (methodInfo == null) | |||||
{ | |||||
throw new ValueError(@"Argument `cell` or `cells`should have a `call` method. "); | |||||
} | |||||
PropertyInfo propertyInfo = type.GetProperty("state_size"); | |||||
if (propertyInfo == null) | |||||
{ | |||||
throw new ValueError(@"The RNN cell should have a `state_size` attribute"); | |||||
} | |||||
// get input_shape | |||||
this.args = PreConstruct(args); | |||||
// The input shape is unknown yet, it could have nested tensor inputs, and | // The input shape is unknown yet, it could have nested tensor inputs, and | ||||
// the input spec will be the list of specs for nested inputs, the structure | // the input spec will be the list of specs for nested inputs, the structure | ||||
// of the input_spec will be the same as the input. | // of the input_spec will be the same as the input. | ||||
@@ -37,17 +78,384 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
//} | //} | ||||
} | } | ||||
// States is a tuple consist of cell states_size, like (cell1.state_size, cell2.state_size,...) | |||||
// state_size can be a single integer, can also be a list/tuple of integers, can also be TensorShape or a list/tuple of TensorShape | |||||
public object States | |||||
{ | |||||
get | |||||
{ | |||||
if (_states == null) | |||||
{ | |||||
var state = nest.map_structure(x => null, cell.state_size); | |||||
return nest.is_nested(state) ? state : new Tensors { state }; | |||||
} | |||||
return _states; | |||||
} | |||||
set { _states = value; } | |||||
} | |||||
private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape) | |||||
{ | |||||
var batch = input_shape[0]; | |||||
var time_step = input_shape[1]; | |||||
if (args.TimeMajor) | |||||
{ | |||||
(batch, time_step) = (time_step, batch); | |||||
} | |||||
// state_size is a array of ints or a positive integer | |||||
var state_size = cell.state_size; | |||||
// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor | |||||
Func<Shape, Shape> _get_output_shape; | |||||
_get_output_shape = (flat_output_size) => | |||||
{ | |||||
var output_dim = flat_output_size.as_int_list(); | |||||
Shape output_shape; | |||||
if (args.ReturnSequences) | |||||
{ | |||||
if (args.TimeMajor) | |||||
{ | |||||
output_shape = new Shape(new int[] { (int)time_step, (int)batch }.concat(output_dim)); | |||||
} | |||||
else | |||||
{ | |||||
output_shape = new Shape(new int[] { (int)batch, (int)time_step }.concat(output_dim)); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
output_shape = new Shape(new int[] { (int)batch }.concat(output_dim)); | |||||
} | |||||
return output_shape; | |||||
}; | |||||
Shape output_shape; | |||||
if (cell.output_size != 0) | |||||
{ | |||||
output_shape = nest.map_structure(_get_output_shape, cell.output_size); | |||||
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 | |||||
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); | |||||
} | |||||
else | |||||
{ | |||||
output_shape = _get_output_shape(state_size[0]); | |||||
} | |||||
if (args.ReturnState) | |||||
{ | |||||
Func<Shape, Shape> _get_state_shape; | |||||
_get_state_shape = (flat_state) => | |||||
{ | |||||
var state_shape = new int[] { (int)batch }.concat(flat_state.as_int_list()); | |||||
return new Shape(state_shape); | |||||
}; | |||||
var state_shape = _get_state_shape(new Shape(state_size.ToArray())); | |||||
return new List<Shape> { output_shape, state_shape }; | |||||
} | |||||
else | |||||
{ | |||||
return output_shape; | |||||
} | |||||
} | |||||
private Tensors compute_mask(Tensors inputs, Tensors mask) | |||||
{ | |||||
// Time step masks must be the same for each input. | |||||
// This is because the mask for an RNN is of size [batch, time_steps, 1], | |||||
// and specifies which time steps should be skipped, and a time step | |||||
// must be skipped for all inputs. | |||||
mask = nest.flatten(mask)[0]; | |||||
var output_mask = args.ReturnSequences ? mask : null; | |||||
if (args.ReturnState) | |||||
{ | |||||
var state_mask = new List<Tensor>(); | |||||
for (int i = 0; i < len(States); i++) | |||||
{ | |||||
state_mask.Add(null); | |||||
} | |||||
return new List<Tensor> { output_mask }.concat(state_mask); | |||||
} | |||||
else | |||||
{ | |||||
return output_mask; | |||||
} | |||||
} | |||||
public override void build(KerasShapesWrapper input_shape) | public override void build(KerasShapesWrapper input_shape) | ||||
{ | { | ||||
object get_input_spec(Shape shape) | |||||
{ | |||||
var input_spec_shape = shape.as_int_list(); | |||||
var (batch_index, time_step_index) = args.TimeMajor ? (1, 0) : (0, 1); | |||||
if (!args.Stateful) | |||||
{ | |||||
input_spec_shape[batch_index] = -1; | |||||
} | |||||
input_spec_shape[time_step_index] = -1; | |||||
return new InputSpec(shape: input_spec_shape); | |||||
} | |||||
Shape get_step_input_shape(Shape shape) | |||||
{ | |||||
// return shape[1:] if self.time_major else (shape[0],) + shape[2:] | |||||
if (args.TimeMajor) | |||||
{ | |||||
return shape.as_int_list().ToList().GetRange(1, shape.Length - 1).ToArray(); | |||||
} | |||||
else | |||||
{ | |||||
return new int[] { shape.as_int_list()[0] }.concat(shape.as_int_list().ToList().GetRange(2, shape.Length - 2).ToArray()); | |||||
} | |||||
} | |||||
object get_state_spec(Shape shape) | |||||
{ | |||||
var state_spec_shape = shape.as_int_list(); | |||||
// append bacth dim | |||||
state_spec_shape = new int[] { -1 }.concat(state_spec_shape); | |||||
return new InputSpec(shape: state_spec_shape); | |||||
} | |||||
// Check whether the input shape contains any nested shapes. It could be | |||||
// (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from | |||||
// numpy inputs. | |||||
if (!cell.Built) | if (!cell.Built) | ||||
{ | { | ||||
cell.build(input_shape); | cell.build(input_shape); | ||||
} | } | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
// inputs: Tensors | |||||
// mask: Binary tensor of shape [batch_size, timesteps] indicating whether a given timestep should be masked | |||||
// training: bool | |||||
// initial_state: List of initial state tensors to be passed to the first call of the cell | |||||
// constants: List of constant tensors to be passed to the cell at each timestep | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return base.Call(inputs, state, training); | |||||
//var (inputs_padded, row_length) = BackendImpl.convert_inputs_if_ragged(inputs); | |||||
//bool is_ragged_input = row_length != null; | |||||
//_validate_args_if_ragged(is_ragged_input, mask); | |||||
var (inputs_processed, initial_state_processed, constants_processed) = _process_inputs(inputs, initial_state, constants); | |||||
_maybe_reset_cell_dropout_mask(cell); | |||||
if (cell is StackedRNNCells) | |||||
{ | |||||
foreach (var cell in ((StackedRNNCells)cell).Cells) | |||||
{ | |||||
_maybe_reset_cell_dropout_mask(cell); | |||||
} | |||||
} | |||||
if (mask != null) | |||||
{ | |||||
// Time step masks must be the same for each input. | |||||
//mask = nest.flatten(mask)[0]; | |||||
mask = mask[0]; | |||||
} | |||||
Shape input_shape; | |||||
if (nest.is_nested(initial_state_processed)) | |||||
{ | |||||
// In the case of nested input, use the first element for shape check | |||||
// input_shape = nest.flatten(inputs)[0].shape; | |||||
input_shape = inputs[0].shape; | |||||
} | |||||
else | |||||
{ | |||||
input_shape = inputs.shape; | |||||
} | |||||
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
if (args.Unroll && timesteps != null) | |||||
{ | |||||
throw new ValueError( | |||||
"Cannot unroll a RNN if the " + | |||||
"time dimension is undefined. \n" + | |||||
"- If using a Sequential model, " + | |||||
"specify the time dimension by passing " + | |||||
"an `input_shape` or `batch_input_shape` " + | |||||
"argument to your first layer. If your " + | |||||
"first layer is an Embedding, you can " + | |||||
"also use the `input_length` argument.\n" + | |||||
"- If using the functional API, specify " + | |||||
"the time dimension by passing a `shape` " + | |||||
"or `batch_shape` argument to your Input layer." | |||||
); | |||||
} | |||||
// cell_call_fn = (self.cell.__call__ if callable(self.cell) else self.cell.call) | |||||
var cell_call_fn = cell.Call; | |||||
Func<Tensors, Tensors, (Tensors, Tensors)> step; | |||||
if (constants != null) | |||||
{ | |||||
ParameterInfo[] parameters = cell_call_fn.GetMethodInfo().GetParameters(); | |||||
bool hasParam = parameters.Any(p => p.Name == "constants"); | |||||
if (!hasParam) | |||||
{ | |||||
throw new ValueError( | |||||
$"RNN cell {cell} does not support constants." + | |||||
$"Received: constants={constants}"); | |||||
} | |||||
step = (inputs, states) => | |||||
{ | |||||
// constants = states[-self._num_constants :] | |||||
constants = states.numpy()[new Slice(states.Length - _num_constants, states.Length)]; | |||||
// states = states[: -self._num_constants] | |||||
states = states.numpy()[new Slice(0, states.Length - _num_constants)]; | |||||
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) | |||||
states = states.Length == 1 ? states[0] : states; | |||||
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); | |||||
if (!nest.is_nested(new_states)) | |||||
{ | |||||
return (output, new Tensors { new_states }); | |||||
} | |||||
return (output, new_states); | |||||
}; | |||||
} | |||||
else | |||||
{ | |||||
step = (inputs, states) => | |||||
{ | |||||
// states = (states[0] if len(states) == 1 and is_tf_rnn_cell else states) | |||||
states = states.Length == 1 ? states[0] : states; | |||||
var (output, new_states) = cell_call_fn(inputs, null, null, states, constants); | |||||
if (!nest.is_nested(new_states)) | |||||
{ | |||||
return (output, new Tensors { new_states }); | |||||
} | |||||
return (output, new_states); | |||||
}; | |||||
} | |||||
var (last_output, outputs, states) = BackendImpl.rnn(step, | |||||
inputs, | |||||
initial_state, | |||||
constants: constants, | |||||
go_backwards: args.GoBackwards, | |||||
mask: mask, | |||||
unroll: args.Unroll, | |||||
input_length: row_length != null ? row_length : new Tensor(timesteps), | |||||
time_major: args.TimeMajor, | |||||
zero_output_for_mask: args.ZeroOutputForMask, | |||||
return_all_outputs: args.ReturnSequences); | |||||
if (args.Stateful) | |||||
{ | |||||
throw new NotImplementedException("this argument havn't been developed!"); | |||||
} | |||||
Tensors output = new Tensors(); | |||||
if (args.ReturnSequences) | |||||
{ | |||||
throw new NotImplementedException("this argument havn't been developed!"); | |||||
} | |||||
else | |||||
{ | |||||
output = last_output; | |||||
} | |||||
if (args.ReturnState) | |||||
{ | |||||
foreach (var state in states) | |||||
{ | |||||
output.Add(state); | |||||
} | |||||
return output; | |||||
} | |||||
else | |||||
{ | |||||
return output; | |||||
} | |||||
} | |||||
private (Tensors, Tensors, Tensors) _process_inputs(Tensor inputs, Tensors initial_state, Tensors constants) | |||||
{ | |||||
bool IsSequence(object obj) | |||||
{ | |||||
// Check if the object is an IEnumerable | |||||
if (obj is IEnumerable) | |||||
{ | |||||
// If it is, check if it is a tuple | |||||
if (!(obj is Tuple)) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
// If it is not, return false | |||||
return false; | |||||
} | |||||
if (IsSequence(input)) | |||||
{ | |||||
if (_num_constants != 0) | |||||
{ | |||||
initial_state = inputs[new Slice(1, len(inputs))]; | |||||
} | |||||
else | |||||
{ | |||||
initial_state = inputs[new Slice(1, len(inputs) - _num_constants)]; | |||||
} | |||||
if (len(initial_state) == 0) | |||||
initial_state = null; | |||||
inputs = inputs[0]; | |||||
} | |||||
if (args.Stateful) | |||||
{ | |||||
throw new NotImplementedException("argument stateful has not been implemented!"); | |||||
} | |||||
return (inputs, initial_state, constants); | |||||
} | |||||
private void _validate_args_if_ragged(bool is_ragged_input, Tensors mask) | |||||
{ | |||||
if (is_ragged_input) | |||||
{ | |||||
if (args.Unroll) | |||||
{ | |||||
throw new ValueError("The input received contains RaggedTensors and does " + | |||||
"not support unrolling. Disable unrolling by passing " + | |||||
"`unroll=False` in the RNN Layer constructor."); | |||||
} | |||||
if (mask != null) | |||||
{ | |||||
throw new ValueError($"The mask that was passed in was {mask}, which " + | |||||
"cannot be applied to RaggedTensor inputs. Please " + | |||||
"make sure that there is no mask injected by upstream " + | |||||
"layers."); | |||||
} | |||||
} | |||||
} | |||||
void _maybe_reset_cell_dropout_mask(ILayer cell) | |||||
{ | |||||
//if (cell is DropoutRNNCellMixin) | |||||
//{ | |||||
// cell.reset_dropout_mask(); | |||||
// cell.reset_recurrent_dropout_mask(); | |||||
//} | |||||
} | } | ||||
private static RNNArgs PreConstruct(RNNArgs args) | private static RNNArgs PreConstruct(RNNArgs args) | ||||
@@ -77,6 +485,10 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
return args; | return args; | ||||
} | } | ||||
public Tensors __call__(Tensors inputs, Tensor state = null, Tensor training = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public RNN New(LayerRnnCell cell, | public RNN New(LayerRnnCell cell, | ||||
bool return_sequences = false, | bool return_sequences = false, | ||||
bool return_state = false, | bool return_state = false, | ||||
@@ -95,7 +507,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
TimeMajor = time_major | TimeMajor = time_major | ||||
}); | }); | ||||
public RNN New(IList<RnnCell> cell, | |||||
public RNN New(IList<IRnnArgCell> cell, | |||||
bool return_sequences = false, | bool return_sequences = false, | ||||
bool return_state = false, | bool return_state = false, | ||||
bool go_backwards = false, | bool go_backwards = false, | ||||
@@ -125,7 +537,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
// Check whether the state_size contains multiple states. | // Check whether the state_size contains multiple states. | ||||
public static bool _is_multiple_state(object state_size) | |||||
public static bool is_multiple_state(object state_size) | |||||
{ | { | ||||
var myIndexerProperty = state_size.GetType().GetProperty("Item"); | var myIndexerProperty = state_size.GetType().GetProperty("Item"); | ||||
return myIndexerProperty != null | return myIndexerProperty != null | ||||
@@ -42,9 +42,9 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
return base.Call(inputs, state, training); | |||||
return base.Call(inputs, initial_state, training); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -2,15 +2,16 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.ComponentModel; | using System.ComponentModel; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
using static Tensorflow.Keras.ArgsDefinition.Rnn.RNNArgs; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
public class StackedRNNCells : Layer, RNNArgs.IRnnArgCell | |||||
public class StackedRNNCells : Layer | |||||
{ | { | ||||
public IList<RnnCell> Cells { get; set; } | |||||
public IList<IRnnArgCell> Cells { get; set; } | |||||
public bool reverse_state_order; | public bool reverse_state_order; | ||||
public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | ||||
@@ -51,7 +52,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
{ | { | ||||
return lastCell.output_size; | return lastCell.output_size; | ||||
} | } | ||||
else if (RNN._is_multiple_state(lastCell.state_size)) | |||||
else if (RNN.is_multiple_state(lastCell.state_size)) | |||||
{ | { | ||||
// return ((dynamic)Cells[-1].state_size)[0]; | // return ((dynamic)Cells[-1].state_size)[0]; | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
@@ -63,6 +64,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
} | } | ||||
public object get_initial_state() | public object get_initial_state() | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -80,7 +82,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
// return tuple(initial_states) | // return tuple(initial_states) | ||||
} | } | ||||
public object call() | |||||
public Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
// def call(self, inputs, states, constants= None, training= None, ** kwargs): | // def call(self, inputs, states, constants= None, training= None, ** kwargs): | ||||
@@ -34,7 +34,7 @@ namespace Tensorflow.Keras.Layers | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
protected override Tensors Call(Tensors inputs, Tensor mask = null, bool? training = null, Tensors initial_state = null, Tensors constants = null) | |||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
return DeFunCall(inputs); | return DeFunCall(inputs); | ||||