@@ -71,15 +71,15 @@ namespace Tensorflow | |||||
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | ||||
=> array_ops.split( | => array_ops.split( | ||||
value: value, | value: value, | ||||
num_split: num_split, | |||||
num_or_size_splits: num_split, | |||||
axis: axis, | axis: axis, | ||||
name: name); | name: name); | ||||
public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | ||||
=> array_ops.split( | => array_ops.split( | ||||
value: value, | value: value, | ||||
num_split: num_split, | |||||
axis: axis, | |||||
num_or_size_splits: num_split, | |||||
axis: ops.convert_to_tensor(axis), | |||||
name: name); | name: name); | ||||
public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | public Tensor ensure_shape(Tensor x, Shape shape, string name = null) | ||||
@@ -197,25 +197,11 @@ namespace Tensorflow.Common.Types | |||||
} | } | ||||
else if(NestType is NestType.List) | else if(NestType is NestType.List) | ||||
{ | { | ||||
foreach(var item in ListValue!) | |||||
{ | |||||
if(item.NestType is NestType.List or NestType.Dictionary) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
return ListValue!.Count > 0; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
foreach (var item in DictValue!.Values) | |||||
{ | |||||
if (item.NestType is NestType.List or NestType.Dictionary) | |||||
{ | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
return DictValue!.Count > 0; | |||||
} | } | ||||
} | } | ||||
@@ -352,7 +352,11 @@ namespace Tensorflow.Eager | |||||
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); | ||||
break; | break; | ||||
case TF_AttrType.TF_ATTR_SHAPE: | case TF_AttrType.TF_ATTR_SHAPE: | ||||
var dims = (value as long[]).ToArray(); | |||||
long[] dims; | |||||
if (value is Shape shape) dims = shape.dims.ToArray(); | |||||
else if (value is long[] longs) dims = longs.ToArray(); | |||||
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray(); | |||||
else dims = ((long[])value).ToArray(); | |||||
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); | ||||
status.Check(true); | status.Check(true); | ||||
break; | break; | ||||
@@ -137,7 +137,6 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); | ||||
} | } | ||||
Shape tensor_shape = new(dims); | |||||
if(status.Code != TF_Code.TF_OK) | if(status.Code != TF_Code.TF_OK) | ||||
{ | { | ||||
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
Shape tensor_shape = new(dims); | |||||
return new TapeTensor(id, dtype, tensor_shape); | return new TapeTensor(id, dtype, tensor_shape); | ||||
} | } | ||||
} | } | ||||
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager | |||||
return dtype == dtypes.variant || dtype == dtypes.resource; | return dtype == dtypes.variant || dtype == dtypes.resource; | ||||
} | } | ||||
bool ListContainNone(long[] list) | |||||
bool ListContainNone(long[]? list) | |||||
{ | { | ||||
if(list is null) | |||||
{ | |||||
return true; | |||||
} | |||||
int len = list.Length; | int len = list.Length; | ||||
if(len == 0) | if(len == 0) | ||||
{ | { | ||||
@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients | |||||
? input_values[0].rank + dim_int | ? input_values[0].rank + dim_int | ||||
: dim_int % input_values[0].rank; | : dim_int % input_values[0].rank; | ||||
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); | ||||
var sizes_tensor = constant_op.constant(sizes); | |||||
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList(); | |||||
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList(); | |||||
} | } | ||||
else if (constant_op.is_constant(concat_dim)) | else if (constant_op.is_constant(concat_dim)) | ||||
{ | { | ||||
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients | |||||
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | ||||
new Tensor[] { tf.constant(1), tf.constant(-1) }); | new Tensor[] { tf.constant(1), tf.constant(-1) }); | ||||
var squeeze_sizes = array_ops.squeeze(slice); | var squeeze_sizes = array_ops.squeeze(slice); | ||||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); | |||||
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList(); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -4,8 +4,6 @@ | |||||
{ | { | ||||
// TODO: maybe change the `RNNArgs` and implement this class. | // TODO: maybe change the `RNNArgs` and implement this class. | ||||
public bool UnitForgetBias { get; set; } | public bool UnitForgetBias { get; set; } | ||||
public float Dropout { get; set; } | |||||
public float RecurrentDropout { get; set; } | |||||
public int Implementation { get; set; } | public int Implementation { get; set; } | ||||
} | } | ||||
} | } |
@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
[JsonProperty("unit_forget_bias")] | [JsonProperty("unit_forget_bias")] | ||||
public bool UnitForgetBias { get; set; } = true; | public bool UnitForgetBias { get; set; } = true; | ||||
[JsonProperty("implementation")] | [JsonProperty("implementation")] | ||||
public int Implementation { get; set; } = 2; | |||||
public int Implementation { get; set; } = 1; | |||||
} | } | ||||
} | } |
@@ -7,12 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
// TODO(Rinne): add regularizers. | // TODO(Rinne): add regularizers. | ||||
public class RNNArgs : AutoSerializeLayerArgs | public class RNNArgs : AutoSerializeLayerArgs | ||||
{ | { | ||||
[JsonProperty("cell")] | |||||
// TODO: the cell should be serialized with `serialize_keras_object`. | |||||
public IRnnCell Cell { get; set; } = null; | |||||
[JsonProperty("cells")] | |||||
public IList<IRnnCell> Cells { get; set; } = null; | |||||
[JsonProperty("return_sequences")] | [JsonProperty("return_sequences")] | ||||
public bool ReturnSequences { get; set; } = false; | public bool ReturnSequences { get; set; } = false; | ||||
[JsonProperty("return_state")] | [JsonProperty("return_state")] | ||||
@@ -25,8 +19,10 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public bool Unroll { get; set; } = false; | public bool Unroll { get; set; } = false; | ||||
[JsonProperty("time_major")] | [JsonProperty("time_major")] | ||||
public bool TimeMajor { get; set; } = false; | public bool TimeMajor { get; set; } = false; | ||||
public int? InputDim { get; set; } | |||||
public int? InputLength { get; set; } | |||||
// TODO: Add `num_constants` and `zero_output_for_mask`. | // TODO: Add `num_constants` and `zero_output_for_mask`. | ||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public int Units { get; set; } | public int Units { get; set; } | ||||
public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
@@ -38,21 +34,5 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
public float Dropout { get; set; } = .0f; | public float Dropout { get; set; } = .0f; | ||||
public bool ZeroOutputForMask { get; set; } = false; | public bool ZeroOutputForMask { get; set; } = false; | ||||
public float RecurrentDropout { get; set; } = .0f; | public float RecurrentDropout { get; set; } = .0f; | ||||
// kernel_regularizer=None, | |||||
// recurrent_regularizer=None, | |||||
// bias_regularizer=None, | |||||
// activity_regularizer=None, | |||||
// kernel_constraint=None, | |||||
// recurrent_constraint=None, | |||||
// bias_constraint=None, | |||||
// dropout=0., | |||||
// recurrent_dropout=0., | |||||
// return_sequences=False, | |||||
// return_state=False, | |||||
// go_backwards=False, | |||||
// stateful=False, | |||||
// unroll=False, | |||||
// **kwargs): | |||||
} | } | ||||
} | } |
@@ -5,7 +5,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn | |||||
{ | { | ||||
public class StackedRNNCellsArgs : LayerArgs | public class StackedRNNCellsArgs : LayerArgs | ||||
{ | { | ||||
public IList<IRnnCell> Cells { get; set; } | |||||
public Dictionary<string, object> Kwargs { get; set; } = null; | |||||
public bool ReverseStateOrder = false; | |||||
} | } | ||||
} | } |
@@ -182,7 +182,7 @@ namespace Tensorflow.Keras.Layers | |||||
bool unit_forget_bias = true, | bool unit_forget_bias = true, | ||||
float dropout = 0f, | float dropout = 0f, | ||||
float recurrent_dropout = 0f, | float recurrent_dropout = 0f, | ||||
int implementation = 2, | |||||
int implementation = 1, | |||||
bool return_sequences = false, | bool return_sequences = false, | ||||
bool return_state = false, | bool return_state = false, | ||||
bool go_backwards = false, | bool go_backwards = false, | ||||
@@ -89,7 +89,7 @@ namespace Tensorflow | |||||
gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | gate_inputs = nn_ops.bias_add(gate_inputs, _bias); | ||||
// i = input_gate, j = new_input, f = forget_gate, o = output_gate | // i = input_gate, j = new_input, f = forget_gate, o = output_gate | ||||
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||||
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||||
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | ||||
var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | ||||
@@ -389,9 +389,13 @@ namespace Tensorflow | |||||
case "list(type)": | case "list(type)": | ||||
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); | ||||
break; | break; | ||||
case "list(float)": | |||||
if (value != null) | |||||
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray()); | |||||
break; | |||||
case "list(int)": | case "list(int)": | ||||
if (value != null) | if (value != null) | ||||
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x))); | |||||
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x))); | |||||
break; | break; | ||||
case "bool": | case "bool": | ||||
attr_value.B = (bool)value; | attr_value.B = (bool)value; | ||||
@@ -428,6 +432,9 @@ namespace Tensorflow | |||||
case "list(func)": | case "list(func)": | ||||
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); | ||||
break; | break; | ||||
case "list(string)": | |||||
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x))); | |||||
break; | |||||
default: | default: | ||||
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); | ||||
} | } | ||||
@@ -390,7 +390,8 @@ namespace Tensorflow.Operations | |||||
int ta_size; | int ta_size; | ||||
if(!_dynamic_size && (_size is not null)) | if(!_dynamic_size && (_size is not null)) | ||||
{ | { | ||||
ta_size = (int)tensor_util.constant_value(_size); | |||||
var size_tensor = tensor_util.constant_value(_size); | |||||
ta_size = size_tensor is null ? -1 : (int)size_tensor; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -1014,38 +1014,27 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1, | |||||
public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null, | |||||
string name = "split") | string name = "split") | ||||
{ | { | ||||
if (num == -1) | |||||
num = (int)size_splits.shape[0]; | |||||
return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name); | |||||
return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name); | |||||
} | } | ||||
public static Tensor[] split<T>(Tensor value, int num_split, T axis, | |||||
public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1, | |||||
string name = "split") | string name = "split") | ||||
{ | { | ||||
var size_splits = ops.convert_to_tensor(num_split); | |||||
if (tf.Context.executing_eagerly()) | |||||
if(num_or_size_splits.Length == 0) | |||||
{ | { | ||||
return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.Context); | |||||
throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split."); | |||||
} | } | ||||
var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||||
var _op = tf.OpDefLib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||||
return _op.outputs; | |||||
} | |||||
private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null) | |||||
{ | |||||
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { value }); | |||||
var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); | |||||
var _inputs_flat = new List<Tensor> { axis_tensor }; | |||||
_inputs_flat.AddRange(input); | |||||
var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; | |||||
if(num == -1) | |||||
{ | |||||
num = (int)size_splits.shape[0]; | |||||
} | |||||
return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); | |||||
return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name); | |||||
} | } | ||||
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) | public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) | ||||
@@ -1778,10 +1778,10 @@ new_height, new_width"); | |||||
{ | { | ||||
// a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] | // a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] | ||||
var a_xy_minmax = array_ops.split( | var a_xy_minmax = array_ops.split( | ||||
value: boxes_a, num_split: 4, axis: 2); | |||||
value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||||
// b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] | // b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] | ||||
var b_xy_minmax = array_ops.split( | var b_xy_minmax = array_ops.split( | ||||
value: boxes_b, num_split: 4, axis: 2); | |||||
value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||||
var i_xmin = math_ops.maximum( | var i_xmin = math_ops.maximum( | ||||
a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); | a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); | ||||
@@ -1943,7 +1943,7 @@ new_height, new_width"); | |||||
using (ops.name_scope("canonicalize_coordinates")) | using (ops.name_scope("canonicalize_coordinates")) | ||||
{ | { | ||||
// y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] | // y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] | ||||
var yx = array_ops.split(value: boxes, num_split: 4, axis: 2); | |||||
var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2)); | |||||
var y_1_is_min = math_ops.reduce_all( | var y_1_is_min = math_ops.reduce_all( | ||||
gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); | gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); | ||||
var y_minmax = control_flow_ops.cond( | var y_minmax = control_flow_ops.cond( | ||||
@@ -86,7 +86,7 @@ namespace Tensorflow.Operations | |||||
} | } | ||||
} | } | ||||
var cond_graph = FuncGraph.func_graph_from_func("cond", wrapped_cond, null, | |||||
var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null, | |||||
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); | null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); | ||||
bool stateful_parallelism = false; | bool stateful_parallelism = false; | ||||
@@ -111,7 +111,7 @@ namespace Tensorflow.Operations | |||||
return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); | return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); | ||||
} | } | ||||
var body_graph = FuncGraph.func_graph_from_func("body", wrapped_body, null, null, func_graph_signature, | |||||
var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature, | |||||
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); | add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); | ||||
// TODO(Rinne): possible wrong implementation here. | // TODO(Rinne): possible wrong implementation here. | ||||
@@ -170,11 +170,28 @@ namespace Tensorflow | |||||
public Tensor value() | public Tensor value() | ||||
=> GraphElement ?? _read_variable_op(); | => GraphElement ?? _read_variable_op(); | ||||
protected Tensor _read_variable_op() | |||||
protected Tensor _read_variable_op(bool no_copy = false) | |||||
{ | { | ||||
variable_accessed(this); | variable_accessed(this); | ||||
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||||
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||||
Tensor read_and_set_handle(bool no_copy) | |||||
{ | |||||
if (no_copy) | |||||
{ | |||||
gen_resource_variable_ops.disable_copy_on_read(handle); | |||||
} | |||||
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | |||||
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result); | |||||
return result; | |||||
} | |||||
// TODO(Rinne): deal with caching device. | |||||
var result = read_and_set_handle(no_copy); | |||||
if (!tf.Context.executing_eagerly()) | |||||
{ | |||||
tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle }, | |||||
backward_function: (x, _) => x); | |||||
} | |||||
// have to set shape when converting to substituent placeholder | // have to set shape when converting to substituent placeholder | ||||
if (result.shape.ndim == -1) | if (result.shape.ndim == -1) | ||||
@@ -38,6 +38,8 @@ namespace Tensorflow.Keras.Engine | |||||
_handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
_set_mask_metadata(inputs, outputs, null); | _set_mask_metadata(inputs, outputs, null); | ||||
// TODO(Rinne): set save spec if null | |||||
scope.__exit__(); | scope.__exit__(); | ||||
return outputs; | return outputs; | ||||
@@ -709,10 +709,7 @@ namespace Tensorflow.Keras.Layers | |||||
public IRnnCell StackedRNNCells( | public IRnnCell StackedRNNCells( | ||||
IEnumerable<IRnnCell> cells) | IEnumerable<IRnnCell> cells) | ||||
=> new StackedRNNCells(new StackedRNNCellsArgs | |||||
{ | |||||
Cells = cells.ToList() | |||||
}); | |||||
=> new StackedRNNCells(cells.ToList(), new StackedRNNCellsArgs()); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -757,9 +754,8 @@ namespace Tensorflow.Keras.Layers | |||||
bool stateful = false, | bool stateful = false, | ||||
bool unroll = false, | bool unroll = false, | ||||
bool time_major = false) | bool time_major = false) | ||||
=> new RNN(new RNNArgs | |||||
=> new RNN(cell, new RNNArgs | |||||
{ | { | ||||
Cell = cell, | |||||
ReturnSequences = return_sequences, | ReturnSequences = return_sequences, | ||||
ReturnState = return_state, | ReturnState = return_state, | ||||
GoBackwards = go_backwards, | GoBackwards = go_backwards, | ||||
@@ -776,9 +772,8 @@ namespace Tensorflow.Keras.Layers | |||||
bool stateful = false, | bool stateful = false, | ||||
bool unroll = false, | bool unroll = false, | ||||
bool time_major = false) | bool time_major = false) | ||||
=> new RNN(new RNNArgs | |||||
=> new RNN(cell, new RNNArgs | |||||
{ | { | ||||
Cells = cell.ToList(), | |||||
ReturnSequences = return_sequences, | ReturnSequences = return_sequences, | ||||
ReturnState = return_state, | ReturnState = return_state, | ||||
GoBackwards = go_backwards, | GoBackwards = go_backwards, | ||||
@@ -798,7 +793,7 @@ namespace Tensorflow.Keras.Layers | |||||
bool unit_forget_bias = true, | bool unit_forget_bias = true, | ||||
float dropout = 0f, | float dropout = 0f, | ||||
float recurrent_dropout = 0f, | float recurrent_dropout = 0f, | ||||
int implementation = 2) | |||||
int implementation = 1) | |||||
=> new LSTMCell(new LSTMCellArgs | => new LSTMCell(new LSTMCellArgs | ||||
{ | { | ||||
Units = uints, | Units = uints, | ||||
@@ -851,7 +846,7 @@ namespace Tensorflow.Keras.Layers | |||||
bool unit_forget_bias = true, | bool unit_forget_bias = true, | ||||
float dropout = 0f, | float dropout = 0f, | ||||
float recurrent_dropout = 0f, | float recurrent_dropout = 0f, | ||||
int implementation = 2, | |||||
int implementation = 1, | |||||
bool return_sequences = false, | bool return_sequences = false, | ||||
bool return_state = false, | bool return_state = false, | ||||
bool go_backwards = false, | bool go_backwards = false, | ||||
@@ -2,6 +2,7 @@ | |||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Common.Extensions; | |||||
namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
{ | { | ||||
@@ -14,22 +15,105 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public class LSTM : RNN | public class LSTM : RNN | ||||
{ | { | ||||
LSTMArgs args; | LSTMArgs args; | ||||
InputSpec[] state_spec; | |||||
int units => args.Units; | |||||
InputSpec[] _state_spec; | |||||
InputSpec _input_spec; | |||||
bool _could_use_gpu_kernel; | |||||
public LSTM(LSTMArgs args) : | public LSTM(LSTMArgs args) : | ||||
base(args) | |||||
base(CreateCell(args), args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
state_spec = new[] { units, units } | |||||
.Select(dim => new InputSpec(shape: (-1, dim))) | |||||
.ToArray(); | |||||
_input_spec = new InputSpec(ndim: 3); | |||||
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray(); | |||||
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh | |||||
&& args.RecurrentActivation == keras.activations.Sigmoid | |||||
&& args.RecurrentDropout == 0 && !args.Unroll && args.UseBias | |||||
&& ops.executing_eagerly_outside_functions(); | |||||
} | |||||
private static IRnnCell CreateCell(LSTMArgs lstmArgs) | |||||
{ | |||||
return new LSTMCell(new LSTMCellArgs() | |||||
{ | |||||
Units = lstmArgs.Units, | |||||
Activation = lstmArgs.Activation, | |||||
RecurrentActivation = lstmArgs.RecurrentActivation, | |||||
UseBias = lstmArgs.UseBias, | |||||
KernelInitializer = lstmArgs.KernelInitializer, | |||||
RecurrentInitializer = lstmArgs.RecurrentInitializer, | |||||
UnitForgetBias = lstmArgs.UnitForgetBias, | |||||
BiasInitializer = lstmArgs.BiasInitializer, | |||||
// TODO(Rinne): kernel_regularizer | |||||
// TODO(Rinne): recurrent_regularizer | |||||
// TODO(Rinne): bias_regularizer | |||||
// TODO(Rinne): kernel_constriant | |||||
// TODO(Rinne): recurrent_constriant | |||||
// TODO(Rinne): bias_constriant | |||||
Dropout = lstmArgs.Dropout, | |||||
RecurrentDropout = lstmArgs.RecurrentDropout, | |||||
Implementation = lstmArgs.Implementation, | |||||
DType = lstmArgs.DType, | |||||
Trainable = lstmArgs.Trainable | |||||
}); | |||||
} | } | ||||
protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null) | |||||
{ | { | ||||
return base.Call(inputs, initial_state: state, training: training); | |||||
// skip the condition of ragged input | |||||
(inputs, initial_state, _) = _process_inputs(inputs, initial_state, null); | |||||
Tensor mask = null; | |||||
if(optional_args is RnnOptionalArgs rnnArgs) | |||||
{ | |||||
mask = rnnArgs.Mask; | |||||
} | |||||
var single_input = inputs.Single; | |||||
var input_shape = single_input.shape; | |||||
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1]; | |||||
_maybe_reset_cell_dropout_mask(Cell); | |||||
Func<Tensors, Tensors, (Tensors, Tensors)> step = (inputs, states) => | |||||
{ | |||||
var res = Cell.Apply(inputs, states, training is null ? true : training.Value); | |||||
var (output, state) = res; | |||||
return (output, state); | |||||
}; | |||||
var (last_output, outputs, states) = keras.backend.rnn( | |||||
step, | |||||
inputs, | |||||
initial_state, | |||||
constants: null, | |||||
go_backwards: args.GoBackwards, | |||||
mask: mask, | |||||
unroll: args.Unroll, | |||||
input_length: ops.convert_to_tensor(timesteps), | |||||
time_major: args.TimeMajor, | |||||
zero_output_for_mask: args.ZeroOutputForMask, | |||||
return_all_outputs: args.ReturnSequences | |||||
); | |||||
Tensor output; | |||||
if (args.ReturnSequences) | |||||
{ | |||||
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards); | |||||
} | |||||
else | |||||
{ | |||||
output = last_output; | |||||
} | |||||
if (args.ReturnState) | |||||
{ | |||||
return new Tensor[] { output }.Concat(states).ToArray().ToTensors(); | |||||
} | |||||
else | |||||
{ | |||||
return output; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,5 +1,6 @@ | |||||
using Serilog.Core; | using Serilog.Core; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using Tensorflow.Common.Extensions; | |||||
using Tensorflow.Common.Types; | using Tensorflow.Common.Types; | ||||
using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
@@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
_bias_initializer = _args.BiasInitializer; | _bias_initializer = _args.BiasInitializer; | ||||
} | } | ||||
_bias = add_weight("bias", (_args.Units * 4), | _bias = add_weight("bias", (_args.Units * 4), | ||||
initializer: _args.BiasInitializer); | |||||
initializer: _bias_initializer); | |||||
} | } | ||||
built = true; | built = true; | ||||
} | } | ||||
@@ -94,7 +95,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
var rec_dp_mask = get_recurrent_dropout_mask_for_cell( | var rec_dp_mask = get_recurrent_dropout_mask_for_cell( | ||||
h_tm1, training.Value, count: 4); | h_tm1, training.Value, count: 4); | ||||
Tensor c; | Tensor c; | ||||
Tensor o; | Tensor o; | ||||
if (_args.Implementation == 1) | if (_args.Implementation == 1) | ||||
@@ -123,7 +123,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
var x_f = math_ops.matmul(inputs_f, k_f); | var x_f = math_ops.matmul(inputs_f, k_f); | ||||
var x_c = math_ops.matmul(inputs_c, k_c); | var x_c = math_ops.matmul(inputs_c, k_c); | ||||
var x_o = math_ops.matmul(inputs_o, k_o); | var x_o = math_ops.matmul(inputs_o, k_o); | ||||
if(_args.UseBias) | |||||
if (_args.UseBias) | |||||
{ | { | ||||
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); | var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); | ||||
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; | Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; | ||||
@@ -170,7 +170,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
var h = o * _args.Activation.Apply(c); | var h = o * _args.Activation.Apply(c); | ||||
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 | // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 | ||||
return new Tensors(h, h, c); | |||||
return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestList<Tensor>(h, c) }).ToTensors(); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -188,22 +188,21 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
h_tm1_o = h_tm1[3]; | h_tm1_o = h_tm1[3]; | ||||
var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); | var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); | ||||
var startIndex = _recurrent_kernel_tensor.shape[0]; | |||||
var endIndex = _recurrent_kernel_tensor.shape[1]; | |||||
int startIndex = (int)_recurrent_kernel_tensor.shape[0]; | |||||
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | ||||
new[] { 0, 0 }, new[] { startIndex, _args.Units }); | new[] { 0, 0 }, new[] { startIndex, _args.Units }); | ||||
var i = _args.RecurrentActivation.Apply( | var i = _args.RecurrentActivation.Apply( | ||||
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); | x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); | ||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | ||||
new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2}); | |||||
new[] { 0, _args.Units }, new[] { startIndex, _args.Units}); | |||||
var f = _args.RecurrentActivation.Apply( | var f = _args.RecurrentActivation.Apply( | ||||
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); | x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); | ||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | ||||
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 }); | |||||
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units }); | |||||
var c = f * c_tm1 + i * _args.Activation.Apply( | var c = f * c_tm1 + i * _args.Activation.Apply( | ||||
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); | x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); | ||||
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, | ||||
new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex }); | |||||
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units }); | |||||
var o = _args.RecurrentActivation.Apply( | var o = _args.RecurrentActivation.Apply( | ||||
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); | x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); | ||||
@@ -45,23 +45,25 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
} | } | ||||
public RNN(RNNArgs args) : base(PreConstruct(args)) | |||||
public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args)) | |||||
{ | { | ||||
_args = args; | _args = args; | ||||
SupportsMasking = true; | SupportsMasking = true; | ||||
// if is StackedRnncell | |||||
if (args.Cells != null) | |||||
{ | |||||
Cell = new StackedRNNCells(new StackedRNNCellsArgs | |||||
{ | |||||
Cells = args.Cells | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
Cell = args.Cell; | |||||
} | |||||
Cell = cell; | |||||
// get input_shape | |||||
_args = PreConstruct(args); | |||||
_num_constants = 0; | |||||
} | |||||
public RNN(IEnumerable<IRnnCell> cells, RNNArgs args) : base(PreConstruct(args)) | |||||
{ | |||||
_args = args; | |||||
SupportsMasking = true; | |||||
Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs()); | |||||
// get input_shape | // get input_shape | ||||
_args = PreConstruct(args); | _args = PreConstruct(args); | ||||
@@ -330,7 +332,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
states = new Tensors(states.SkipLast(_num_constants).ToArray()); | states = new Tensors(states.SkipLast(_num_constants).ToArray()); | ||||
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; | ||||
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); | ||||
return (output, new_states.Single); | |||||
return (output, new_states); | |||||
}; | }; | ||||
} | } | ||||
else | else | ||||
@@ -382,6 +384,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
//var tapeSet = tf.GetTapeSet(); | |||||
//foreach(var tape in tapeSet) | |||||
//{ | |||||
// tape.Watch(output); | |||||
//} | |||||
return output; | return output; | ||||
} | } | ||||
} | } | ||||
@@ -405,7 +412,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
} | } | ||||
private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) | |||||
protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants) | |||||
{ | { | ||||
if (inputs.Length > 1) | if (inputs.Length > 1) | ||||
{ | { | ||||
@@ -484,7 +491,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
} | } | ||||
void _maybe_reset_cell_dropout_mask(ILayer cell) | |||||
protected void _maybe_reset_cell_dropout_mask(ILayer cell) | |||||
{ | { | ||||
if (cell is DropoutRNNCellMixin CellDRCMixin) | if (cell is DropoutRNNCellMixin CellDRCMixin) | ||||
{ | { | ||||
@@ -495,26 +502,21 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
private static RNNArgs PreConstruct(RNNArgs args) | private static RNNArgs PreConstruct(RNNArgs args) | ||||
{ | { | ||||
if (args.Kwargs == null) | |||||
{ | |||||
args.Kwargs = new Dictionary<string, object>(); | |||||
} | |||||
// If true, the output for masked timestep will be zeros, whereas in the | // If true, the output for masked timestep will be zeros, whereas in the | ||||
// false case, output from previous timestep is returned for masked timestep. | // false case, output from previous timestep is returned for masked timestep. | ||||
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false); | |||||
var zeroOutputForMask = args.ZeroOutputForMask; | |||||
Shape input_shape; | Shape input_shape; | ||||
var propIS = (Shape)args.Kwargs.Get("input_shape", null); | |||||
var propID = (int?)args.Kwargs.Get("input_dim", null); | |||||
var propIL = (int?)args.Kwargs.Get("input_length", null); | |||||
var propIS = args.InputShape; | |||||
var propID = args.InputDim; | |||||
var propIL = args.InputLength; | |||||
if (propIS == null && (propID != null || propIL != null)) | if (propIS == null && (propID != null || propIL != null)) | ||||
{ | { | ||||
input_shape = new Shape( | input_shape = new Shape( | ||||
propIL ?? -1, | propIL ?? -1, | ||||
propID ?? -1); | propID ?? -1); | ||||
args.Kwargs["input_shape"] = input_shape; | |||||
args.InputShape = input_shape; | |||||
} | } | ||||
return args; | return args; | ||||
@@ -10,14 +10,14 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public class SimpleRNN : RNN | public class SimpleRNN : RNN | ||||
{ | { | ||||
SimpleRNNArgs args; | SimpleRNNArgs args; | ||||
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args)) | |||||
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args), args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
} | } | ||||
private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args) | |||||
private static SimpleRNNCell CreateCellForArgs(SimpleRNNArgs args) | |||||
{ | { | ||||
args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs() | |||||
return new SimpleRNNCell(new SimpleRNNCellArgs() | |||||
{ | { | ||||
Units = args.Units, | Units = args.Units, | ||||
Activation = args.Activation, | Activation = args.Activation, | ||||
@@ -30,7 +30,6 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
DType = args.DType, | DType = args.DType, | ||||
Trainable = args.Trainable, | Trainable = args.Trainable, | ||||
}); | }); | ||||
return args; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -115,10 +115,5 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
return new Tensors(output, output); | return new Tensors(output, output); | ||||
} | } | ||||
} | } | ||||
public Tensors get_initial_state(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype); | |||||
} | |||||
} | } | ||||
} | } |
@@ -15,15 +15,11 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
public IList<IRnnCell> Cells { get; set; } | public IList<IRnnCell> Cells { get; set; } | ||||
public bool _reverse_state_order; | public bool _reverse_state_order; | ||||
public StackedRNNCells(StackedRNNCellsArgs args) : base(args) | |||||
public StackedRNNCells(IEnumerable<IRnnCell> cells, StackedRNNCellsArgs args) : base(args) | |||||
{ | { | ||||
if (args.Kwargs == null) | |||||
{ | |||||
args.Kwargs = new Dictionary<string, object>(); | |||||
} | |||||
Cells = args.Cells; | |||||
_reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false); | |||||
Cells = cells.ToList(); | |||||
_reverse_state_order = args.ReverseStateOrder; | |||||
if (_reverse_state_order) | if (_reverse_state_order) | ||||
{ | { | ||||
@@ -55,30 +55,56 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
Assert.AreEqual((2, 4), new_states[0].shape); | Assert.AreEqual((2, 4), new_states[0].shape); | ||||
} | } | ||||
[TestMethod] | |||||
public void TrainLSTMWithMnist() | |||||
{ | |||||
var input = keras.Input((784)); | |||||
var x = keras.layers.Reshape((28, 28)).Apply(input); | |||||
//x = keras.layers.LSTM(50, return_sequences: true).Apply(x); | |||||
//x = keras.layers.LSTM(100, return_sequences: true).Apply(x); | |||||
//x = keras.layers.LSTM(150, return_sequences: true).Apply(x); | |||||
x = keras.layers.LSTM(4, implementation: 2).Apply(x); | |||||
//x = keras.layers.Dense(100).Apply(x); | |||||
var output = keras.layers.Dense(10, activation: "softmax").Apply(x); | |||||
var model = keras.Model(input, output); | |||||
model.summary(); | |||||
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
var data_loader = new MnistModelLoader(); | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void SimpleRNN() | public void SimpleRNN() | ||||
{ | { | ||||
//var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32); | |||||
///*var simple_rnn = keras.layers.SimpleRNN(4); | |||||
//var output = simple_rnn.Apply(inputs); | |||||
//Assert.AreEqual((32, 4), output.shape);*/ | |||||
//var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true); | |||||
//var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs); | |||||
//Assert.AreEqual((6, 10, 4), whole_sequence_output.shape); | |||||
//Assert.AreEqual((6, 4), final_state.shape); | |||||
var input = keras.Input((784)); | |||||
var x = keras.layers.Reshape((28, 28)).Apply(input); | |||||
x = keras.layers.SimpleRNN(10).Apply(x); | |||||
var output = keras.layers.Dense(10, activation: "softmax").Apply(x); | |||||
var inputs = keras.Input(shape: (10, 8)); | |||||
var x = keras.layers.SimpleRNN(4).Apply(inputs); | |||||
var output = keras.layers.Dense(10).Apply(x); | |||||
var model = keras.Model(inputs, output); | |||||
var model = keras.Model(input, output); | |||||
model.summary(); | model.summary(); | ||||
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy()); | |||||
var datax = np.ones((16, 10, 8), dtype: dtypes.float32); | |||||
var datay = np.ones((16)); | |||||
model.fit(datax, datay, epochs: 20); | |||||
var data_loader = new MnistModelLoader(); | |||||
var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||||
{ | |||||
TrainDir = "mnist", | |||||
OneHot = false, | |||||
ValidationSize = 58000, | |||||
}).Result; | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
public void RNNForSimpleRNNCell() | public void RNNForSimpleRNNCell() | ||||
{ | { | ||||
@@ -109,19 +135,5 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
Console.WriteLine($"output: {output}"); | Console.WriteLine($"output: {output}"); | ||||
Assert.AreEqual((5, 4), output.shape); | Assert.AreEqual((5, 4), output.shape); | ||||
} | } | ||||
[TestMethod] | |||||
public void MyTest() | |||||
{ | |||||
var a = tf.zeros((2, 3)); | |||||
var b = tf.ones_like(a); | |||||
var c = tf.ones((3,4)); | |||||
var d = new Tensors { a, b, c }; | |||||
var (A, BC) = d; | |||||
Console.WriteLine($"A:{A}"); | |||||
Console.WriteLine($"BC:{BC}"); | |||||
} | |||||
} | } | ||||
} | } |
@@ -9,7 +9,7 @@ namespace Tensorflow.CodeGen | |||||
{ | { | ||||
public class OpClassifier | public class OpClassifier | ||||
{ | { | ||||
private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$"; | |||||
private static readonly string _filenamePattern = @"^gen_[a-z_]*_ops.py$"; | |||||
private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; | private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; | ||||
private Dictionary<string, HashSet<string>> _opSet = new(); | private Dictionary<string, HashSet<string>> _opSet = new(); | ||||
public Dictionary<string, HashSet<string>> OpSet => _opSet; | public Dictionary<string, HashSet<string>> OpSet => _opSet; | ||||
@@ -178,10 +178,25 @@ namespace Tensorflow.CodeGen | |||||
else if (attr.Type == "list(shape)") | else if (attr.Type == "list(shape)") | ||||
{ | { | ||||
res.Add((attr.Name, "Shape[]", "NOVALUE")); | res.Add((attr.Name, "Shape[]", "NOVALUE")); | ||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | |||||
List<string> exps = new(); | |||||
foreach (var value in attr.DefaultValue.List.Shape) | |||||
{ | |||||
exps.Add($"new Shape({string.Join(", ", value.Dim.Select(x => x.Size))})"); | |||||
} | |||||
string expression = "new Shape[]{" + $"{string.Join(", ", exps)}" + "}"; | |||||
dynamicDefaultValues[attr.Name] = expression; | |||||
res.Add((attr.Name, "string[]", $"null")); | |||||
} | |||||
else | |||||
{ | |||||
res.Add((attr.Name, "string[]", "NOVALUE")); | |||||
} | |||||
} | } | ||||
else if (attr.Type == "list(string)") | else if (attr.Type == "list(string)") | ||||
{ | { | ||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S) | |||||
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List) | |||||
{ | { | ||||
List<string> values = new(); | List<string> values = new(); | ||||
foreach (var value in attr.DefaultValue.List.S) | foreach (var value in attr.DefaultValue.List.S) | ||||