Browse Source

fix: none gradient error when training LSTM.

tags/v0.110.0-LSTM-Model
Yaohui Liu 2 years ago
parent
commit
675b93a9d7
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
29 changed files with 1743 additions and 295 deletions
  1. +3
    -3
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  2. +2
    -16
      src/TensorFlowNET.Core/Common/Types/Nest.cs
  3. +5
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +6
    -2
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  5. +2
    -3
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  6. +0
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  8. +3
    -23
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  9. +1
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  12. +8
    -1
      src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
  13. +2
    -1
      src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
  14. +11
    -22
      src/TensorFlowNET.Core/Operations/array_ops.cs
  15. +1469
    -104
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  16. +3
    -3
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  17. +2
    -2
      src/TensorFlowNET.Core/Operations/while_v2.cs
  18. +20
    -3
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  19. +2
    -0
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  20. +5
    -10
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  21. +93
    -9
      src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
  22. +8
    -9
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  23. +28
    -26
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  24. +3
    -4
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  25. +0
    -5
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  26. +4
    -8
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  27. +43
    -31
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
  28. +1
    -1
      tools/Tensorflow.CodeGen/OpClassifier.cs
  29. +16
    -1
      tools/Tensorflow.CodeGen/Utils.cs

+ 3
- 3
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -71,15 +71,15 @@ namespace Tensorflow
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split( => array_ops.split(
value: value, value: value,
num_split: num_split,
num_or_size_splits: num_split,
axis: axis, axis: axis,
name: name); name: name);


public Tensor[] split(Tensor value, int num_split, int axis, string name = null) public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split( => array_ops.split(
value: value, value: value,
num_split: num_split,
axis: axis,
num_or_size_splits: num_split,
axis: ops.convert_to_tensor(axis),
name: name); name: name);


public Tensor ensure_shape(Tensor x, Shape shape, string name = null) public Tensor ensure_shape(Tensor x, Shape shape, string name = null)


+ 2
- 16
src/TensorFlowNET.Core/Common/Types/Nest.cs View File

@@ -197,25 +197,11 @@ namespace Tensorflow.Common.Types
} }
else if(NestType is NestType.List) else if(NestType is NestType.List)
{ {
foreach(var item in ListValue!)
{
if(item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
return ListValue!.Count > 0;
} }
else else
{ {
foreach (var item in DictValue!.Values)
{
if (item.NestType is NestType.List or NestType.Dictionary)
{
return true;
}
}
return false;
return DictValue!.Count > 0;
} }
} }




+ 5
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -352,7 +352,11 @@ namespace Tensorflow.Eager
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value)); c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
break; break;
case TF_AttrType.TF_ATTR_SHAPE: case TF_AttrType.TF_ATTR_SHAPE:
var dims = (value as long[]).ToArray();
long[] dims;
if (value is Shape shape) dims = shape.dims.ToArray();
else if (value is long[] longs) dims = longs.ToArray();
else if (value is int[] ints) dims = ints.Select(x => (long)x).ToArray();
else dims = ((long[])value).ToArray();
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status); c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status);
status.Check(true); status.Check(true);
break; break;


+ 6
- 2
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -137,7 +137,6 @@ namespace Tensorflow.Eager
{ {
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status); dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
} }
Shape tensor_shape = new(dims);


if(status.Code != TF_Code.TF_OK) if(status.Code != TF_Code.TF_OK)
{ {
@@ -145,6 +144,7 @@ namespace Tensorflow.Eager
} }
else else
{ {
Shape tensor_shape = new(dims);
return new TapeTensor(id, dtype, tensor_shape); return new TapeTensor(id, dtype, tensor_shape);
} }
} }
@@ -173,8 +173,12 @@ namespace Tensorflow.Eager
return dtype == dtypes.variant || dtype == dtypes.resource; return dtype == dtypes.variant || dtype == dtypes.resource;
} }


bool ListContainNone(long[] list)
bool ListContainNone(long[]? list)
{ {
if(list is null)
{
return true;
}
int len = list.Length; int len = list.Length;
if(len == 0) if(len == 0)
{ {


+ 2
- 3
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -90,8 +90,7 @@ namespace Tensorflow.Gradients
? input_values[0].rank + dim_int ? input_values[0].rank + dim_int
: dim_int % input_values[0].rank; : dim_int % input_values[0].rank;
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
var sizes_tensor = constant_op.constant(sizes);
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
out_grads = array_ops.split(grad, sizes.Select(x => (int)x).ToArray(), ops.convert_to_tensor(non_neg_concat_dim)).ToList();
} }
else if (constant_op.is_constant(concat_dim)) else if (constant_op.is_constant(concat_dim))
{ {
@@ -127,7 +126,7 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) }); new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice); var squeeze_sizes = array_ops.squeeze(slice);
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList();
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_or_size_splits: (int)non_neg_concat_dim).ToList();
} }
else else
{ {


+ 0
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -4,8 +4,6 @@
{ {
// TODO: maybe change the `RNNArgs` and implement this class. // TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; } public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
public int Implementation { get; set; } public int Implementation { get; set; }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
[JsonProperty("unit_forget_bias")] [JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true; public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")] [JsonProperty("implementation")]
public int Implementation { get; set; } = 2;
public int Implementation { get; set; } = 1;


} }
} }

+ 3
- 23
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -7,12 +7,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
// TODO(Rinne): add regularizers. // TODO(Rinne): add regularizers.
public class RNNArgs : AutoSerializeLayerArgs public class RNNArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("cell")]
// TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnCell Cell { get; set; } = null;
[JsonProperty("cells")]
public IList<IRnnCell> Cells { get; set; } = null;

[JsonProperty("return_sequences")] [JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false; public bool ReturnSequences { get; set; } = false;
[JsonProperty("return_state")] [JsonProperty("return_state")]
@@ -25,8 +19,10 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public bool Unroll { get; set; } = false; public bool Unroll { get; set; } = false;
[JsonProperty("time_major")] [JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false; public bool TimeMajor { get; set; } = false;

public int? InputDim { get; set; }
public int? InputLength { get; set; }
// TODO: Add `num_constants` and `zero_output_for_mask`. // TODO: Add `num_constants` and `zero_output_for_mask`.
public Dictionary<string, object> Kwargs { get; set; } = null;


public int Units { get; set; } public int Units { get; set; }
public Activation Activation { get; set; } public Activation Activation { get; set; }
@@ -38,21 +34,5 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public float Dropout { get; set; } = .0f; public float Dropout { get; set; } = .0f;
public bool ZeroOutputForMask { get; set; } = false; public bool ZeroOutputForMask { get; set; } = false;
public float RecurrentDropout { get; set; } = .0f; public float RecurrentDropout { get; set; } = .0f;

// kernel_regularizer=None,
// recurrent_regularizer=None,
// bias_regularizer=None,
// activity_regularizer=None,
// kernel_constraint=None,
// recurrent_constraint=None,
// bias_constraint=None,
// dropout=0.,
// recurrent_dropout=0.,
// return_sequences=False,
// return_state=False,
// go_backwards=False,
// stateful=False,
// unroll=False,
// **kwargs):
} }
} }

+ 1
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -5,7 +5,6 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
{ {
public class StackedRNNCellsArgs : LayerArgs public class StackedRNNCellsArgs : LayerArgs
{ {
public IList<IRnnCell> Cells { get; set; }
public Dictionary<string, object> Kwargs { get; set; } = null;
public bool ReverseStateOrder = false;
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -182,7 +182,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true, bool unit_forget_bias = true,
float dropout = 0f, float dropout = 0f,
float recurrent_dropout = 0f, float recurrent_dropout = 0f,
int implementation = 2,
int implementation = 1,
bool return_sequences = false, bool return_sequences = false,
bool return_state = false, bool return_state = false,
bool go_backwards = false, bool go_backwards = false,


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -89,7 +89,7 @@ namespace Tensorflow
gate_inputs = nn_ops.bias_add(gate_inputs, _bias); gate_inputs = nn_ops.bias_add(gate_inputs, _bias);


// i = input_gate, j = new_input, f = forget_gate, o = output_gate // i = input_gate, j = new_input, f = forget_gate, o = output_gate
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);


var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);


+ 8
- 1
src/TensorFlowNET.Core/Operations/OpDefLibrary.cs View File

@@ -389,9 +389,13 @@ namespace Tensorflow
case "list(type)": case "list(type)":
attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def))); attr_value.List.Type.AddRange((value as IList<TF_DataType>).Select(x => _MakeType(x, attr_def)));
break; break;
case "list(float)":
if (value != null)
attr_value.List.F.AddRange((value as IEnumerable<float>).ToArray());
break;
case "list(int)": case "list(int)":
if (value != null) if (value != null)
attr_value.List.I.AddRange((value as int[]).Select(x => Convert.ToInt64(x)));
attr_value.List.I.AddRange((value as IEnumerable<int>).Select(x => Convert.ToInt64(x)));
break; break;
case "bool": case "bool":
attr_value.B = (bool)value; attr_value.B = (bool)value;
@@ -428,6 +432,9 @@ namespace Tensorflow
case "list(func)": case "list(func)":
attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name)); attr_value.List.Func.AddRange(_MakeFuncList(value, attr_def.Name));
break; break;
case "list(string)":
attr_value.List.S.AddRange((value as IEnumerable<string>).Select(x => ByteString.CopyFromUtf8(x)));
break;
default: default:
throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos."); throw new TypeError($"SetAttrValue: can't not convert attr_def.Type '{attr_def.Type}' to protos.");
} }


+ 2
- 1
src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs View File

@@ -390,7 +390,8 @@ namespace Tensorflow.Operations
int ta_size; int ta_size;
if(!_dynamic_size && (_size is not null)) if(!_dynamic_size && (_size is not null))
{ {
ta_size = (int)tensor_util.constant_value(_size);
var size_tensor = tensor_util.constant_value(_size);
ta_size = size_tensor is null ? -1 : (int)size_tensor;
} }
else else
{ {


+ 11
- 22
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -1014,38 +1014,27 @@ namespace Tensorflow
}); });
} }


public static Tensor[] split(Tensor value, Tensor size_splits, int axis, int num = -1,
public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis = null,
string name = "split") string name = "split")
{ {
if (num == -1)
num = (int)size_splits.shape[0];

return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name);
return gen_array_ops.split(split_dim: axis, value: value, num_split: num_or_size_splits, name);
} }


public static Tensor[] split<T>(Tensor value, int num_split, T axis,
public static Tensor[] split(Tensor value, int[] num_or_size_splits, Tensor axis = null, int num = -1,
string name = "split") string name = "split")
{ {
var size_splits = ops.convert_to_tensor(num_split);

if (tf.Context.executing_eagerly())
if(num_or_size_splits.Length == 0)
{ {
return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.Context);
throw new ValueError("Rank-0 tensors are not supported as the num_or_size_splits argument to split.");
} }
var size_splits = ops.convert_to_tensor(num_or_size_splits);


var _op = tf.OpDefLib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}

private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null)
{
var (_attr_T, input) = tf.Runner.ArgsToMatchingEager(ctx, args: new object[] { value });
var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32);
var _inputs_flat = new List<Tensor> { axis_tensor };
_inputs_flat.AddRange(input);
var _attrs = new object[] { "num_split", num_split, "T", _attr_T };
if(num == -1)
{
num = (int)size_splits.shape[0];
}


return tf.Runner.Execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
return gen_array_ops.split_v(value: value, size_splits: size_splits, split_dim: axis, num_split: num, name: name);
} }


public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)


+ 1469
- 104
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
File diff suppressed because it is too large
View File


+ 3
- 3
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -1778,10 +1778,10 @@ new_height, new_width");
{ {
// a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3] // a_y_min: [0], a_x_min: [1], a_y_max: [2], a_x_max[3]
var a_xy_minmax = array_ops.split( var a_xy_minmax = array_ops.split(
value: boxes_a, num_split: 4, axis: 2);
value: boxes_a, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));
// b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3] // b_y_min: [0], b_x_min: [1], b_y_max: [2], b_x_max[3]
var b_xy_minmax = array_ops.split( var b_xy_minmax = array_ops.split(
value: boxes_b, num_split: 4, axis: 2);
value: boxes_b, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));


var i_xmin = math_ops.maximum( var i_xmin = math_ops.maximum(
a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 })); a_xy_minmax[1], array_ops.transpose(b_xy_minmax[1], new[] { 0, 2, 1 }));
@@ -1943,7 +1943,7 @@ new_height, new_width");
using (ops.name_scope("canonicalize_coordinates")) using (ops.name_scope("canonicalize_coordinates"))
{ {
// y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3] // y_1 = [0], x_1 = [1], y_2 = [2], x_2 = [3]
var yx = array_ops.split(value: boxes, num_split: 4, axis: 2);
var yx = array_ops.split(value: boxes, num_or_size_splits: 4, axis: ops.convert_to_tensor(2));
var y_1_is_min = math_ops.reduce_all( var y_1_is_min = math_ops.reduce_all(
gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0])); gen_math_ops.less_equal(yx[0][0, 0, 0], yx[2][0, 0, 0]));
var y_minmax = control_flow_ops.cond( var y_minmax = control_flow_ops.cond(


+ 2
- 2
src/TensorFlowNET.Core/Operations/while_v2.cs View File

@@ -86,7 +86,7 @@ namespace Tensorflow.Operations
} }
} }


var cond_graph = FuncGraph.func_graph_from_func("cond", wrapped_cond, null,
var cond_graph = FuncGraph.func_graph_from_func(cond_name, wrapped_cond, null,
null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies); null, signature: func_graph_signature, add_control_dependencies: add_control_dependencies);


bool stateful_parallelism = false; bool stateful_parallelism = false;
@@ -111,7 +111,7 @@ namespace Tensorflow.Operations
return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray(); return new object[] { loop_counter + 1, maximum_iterations_arg }.Concat(outputs).ToArray();
} }


var body_graph = FuncGraph.func_graph_from_func("body", wrapped_body, null, null, func_graph_signature,
var body_graph = FuncGraph.func_graph_from_func(body_name, wrapped_body, null, null, func_graph_signature,
add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism); add_control_dependencies: add_control_dependencies, acd_record_initial_resource_uses: stateful_parallelism);


// TODO(Rinne): possible wrong implementation here. // TODO(Rinne): possible wrong implementation here.


+ 20
- 3
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -170,11 +170,28 @@ namespace Tensorflow
public Tensor value() public Tensor value()
=> GraphElement ?? _read_variable_op(); => GraphElement ?? _read_variable_op();


protected Tensor _read_variable_op()
protected Tensor _read_variable_op(bool no_copy = false)
{ {
variable_accessed(this); variable_accessed(this);
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);

Tensor read_and_set_handle(bool no_copy)
{
if (no_copy)
{
gen_resource_variable_ops.disable_copy_on_read(handle);
}
var result = gen_resource_variable_ops.read_variable_op(handle, _dtype);
resource_variable_ops._maybe_set_handle_data(_dtype, handle, result);
return result;
}

// TODO(Rinne): deal with caching device.
var result = read_and_set_handle(no_copy);
if (!tf.Context.executing_eagerly())
{
tf.Runner.TFE_TapeSetRecordOperation("ReadVariableOp", new Tensor[] { result }, new Tensor[] { handle },
backward_function: (x, _) => x);
}


// have to set shape when converting to substituent placeholder // have to set shape when converting to substituent placeholder
if (result.shape.ndim == -1) if (result.shape.ndim == -1)


+ 2
- 0
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -38,6 +38,8 @@ namespace Tensorflow.Keras.Engine
_handle_activity_regularization(inputs, outputs); _handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null); _set_mask_metadata(inputs, outputs, null);


// TODO(Rinne): set save spec if null

scope.__exit__(); scope.__exit__();


return outputs; return outputs;


+ 5
- 10
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -709,10 +709,7 @@ namespace Tensorflow.Keras.Layers


public IRnnCell StackedRNNCells( public IRnnCell StackedRNNCells(
IEnumerable<IRnnCell> cells) IEnumerable<IRnnCell> cells)
=> new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = cells.ToList()
});
=> new StackedRNNCells(cells.ToList(), new StackedRNNCellsArgs());


/// <summary> /// <summary>
/// ///
@@ -757,9 +754,8 @@ namespace Tensorflow.Keras.Layers
bool stateful = false, bool stateful = false,
bool unroll = false, bool unroll = false,
bool time_major = false) bool time_major = false)
=> new RNN(new RNNArgs
=> new RNN(cell, new RNNArgs
{ {
Cell = cell,
ReturnSequences = return_sequences, ReturnSequences = return_sequences,
ReturnState = return_state, ReturnState = return_state,
GoBackwards = go_backwards, GoBackwards = go_backwards,
@@ -776,9 +772,8 @@ namespace Tensorflow.Keras.Layers
bool stateful = false, bool stateful = false,
bool unroll = false, bool unroll = false,
bool time_major = false) bool time_major = false)
=> new RNN(new RNNArgs
=> new RNN(cell, new RNNArgs
{ {
Cells = cell.ToList(),
ReturnSequences = return_sequences, ReturnSequences = return_sequences,
ReturnState = return_state, ReturnState = return_state,
GoBackwards = go_backwards, GoBackwards = go_backwards,
@@ -798,7 +793,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true, bool unit_forget_bias = true,
float dropout = 0f, float dropout = 0f,
float recurrent_dropout = 0f, float recurrent_dropout = 0f,
int implementation = 2)
int implementation = 1)
=> new LSTMCell(new LSTMCellArgs => new LSTMCell(new LSTMCellArgs
{ {
Units = uints, Units = uints,
@@ -851,7 +846,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true, bool unit_forget_bias = true,
float dropout = 0f, float dropout = 0f,
float recurrent_dropout = 0f, float recurrent_dropout = 0f,
int implementation = 2,
int implementation = 1,
bool return_sequences = false, bool return_sequences = false,
bool return_state = false, bool return_state = false,
bool go_backwards = false, bool go_backwards = false,


+ 93
- 9
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs View File

@@ -2,6 +2,7 @@
using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types; using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;


namespace Tensorflow.Keras.Layers.Rnn namespace Tensorflow.Keras.Layers.Rnn
{ {
@@ -14,22 +15,105 @@ namespace Tensorflow.Keras.Layers.Rnn
public class LSTM : RNN public class LSTM : RNN
{ {
LSTMArgs args; LSTMArgs args;
InputSpec[] state_spec;
int units => args.Units;
InputSpec[] _state_spec;
InputSpec _input_spec;
bool _could_use_gpu_kernel;


public LSTM(LSTMArgs args) : public LSTM(LSTMArgs args) :
base(args)
base(CreateCell(args), args)
{ {
this.args = args; this.args = args;
state_spec = new[] { units, units }
.Select(dim => new InputSpec(shape: (-1, dim)))
.ToArray();
_input_spec = new InputSpec(ndim: 3);
_state_spec = new[] { args.Units, args.Units }.Select(dim => new InputSpec(shape: (-1, dim))).ToArray();
_could_use_gpu_kernel = args.Activation == keras.activations.Tanh
&& args.RecurrentActivation == keras.activations.Sigmoid
&& args.RecurrentDropout == 0 && !args.Unroll && args.UseBias
&& ops.executing_eagerly_outside_functions();
}

private static IRnnCell CreateCell(LSTMArgs lstmArgs)
{
return new LSTMCell(new LSTMCellArgs()
{
Units = lstmArgs.Units,
Activation = lstmArgs.Activation,
RecurrentActivation = lstmArgs.RecurrentActivation,
UseBias = lstmArgs.UseBias,
KernelInitializer = lstmArgs.KernelInitializer,
RecurrentInitializer = lstmArgs.RecurrentInitializer,
UnitForgetBias = lstmArgs.UnitForgetBias,
BiasInitializer = lstmArgs.BiasInitializer,
// TODO(Rinne): kernel_regularizer
// TODO(Rinne): recurrent_regularizer
// TODO(Rinne): bias_regularizer
// TODO(Rinne): kernel_constriant
// TODO(Rinne): recurrent_constriant
// TODO(Rinne): bias_constriant
Dropout = lstmArgs.Dropout,
RecurrentDropout = lstmArgs.RecurrentDropout,
Implementation = lstmArgs.Implementation,
DType = lstmArgs.DType,
Trainable = lstmArgs.Trainable
});
} }


protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bool? training = null, IOptionalArgs? optional_args = null)
{ {
return base.Call(inputs, initial_state: state, training: training);
// skip the condition of ragged input

(inputs, initial_state, _) = _process_inputs(inputs, initial_state, null);

Tensor mask = null;
if(optional_args is RnnOptionalArgs rnnArgs)
{
mask = rnnArgs.Mask;
}

var single_input = inputs.Single;
var input_shape = single_input.shape;
var timesteps = args.TimeMajor ? input_shape[0] : input_shape[1];

_maybe_reset_cell_dropout_mask(Cell);

Func<Tensors, Tensors, (Tensors, Tensors)> step = (inputs, states) =>
{
var res = Cell.Apply(inputs, states, training is null ? true : training.Value);
var (output, state) = res;
return (output, state);
};

var (last_output, outputs, states) = keras.backend.rnn(
step,
inputs,
initial_state,
constants: null,
go_backwards: args.GoBackwards,
mask: mask,
unroll: args.Unroll,
input_length: ops.convert_to_tensor(timesteps),
time_major: args.TimeMajor,
zero_output_for_mask: args.ZeroOutputForMask,
return_all_outputs: args.ReturnSequences
);

Tensor output;
if (args.ReturnSequences)
{
output = keras.backend.maybe_convert_to_ragged(false, outputs, (int)timesteps, args.GoBackwards);
}
else
{
output = last_output;
}

if (args.ReturnState)
{
return new Tensor[] { output }.Concat(states).ToArray().ToTensors();
}
else
{
return output;
}
} }
} }
} }

+ 8
- 9
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -1,5 +1,6 @@
using Serilog.Core; using Serilog.Core;
using System.Diagnostics; using System.Diagnostics;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types; using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
@@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Layers.Rnn
_bias_initializer = _args.BiasInitializer; _bias_initializer = _args.BiasInitializer;
} }
_bias = add_weight("bias", (_args.Units * 4), _bias = add_weight("bias", (_args.Units * 4),
initializer: _args.BiasInitializer);
initializer: _bias_initializer);
} }
built = true; built = true;
} }
@@ -94,7 +95,6 @@ namespace Tensorflow.Keras.Layers.Rnn
var rec_dp_mask = get_recurrent_dropout_mask_for_cell( var rec_dp_mask = get_recurrent_dropout_mask_for_cell(
h_tm1, training.Value, count: 4); h_tm1, training.Value, count: 4);



Tensor c; Tensor c;
Tensor o; Tensor o;
if (_args.Implementation == 1) if (_args.Implementation == 1)
@@ -123,7 +123,7 @@ namespace Tensorflow.Keras.Layers.Rnn
var x_f = math_ops.matmul(inputs_f, k_f); var x_f = math_ops.matmul(inputs_f, k_f);
var x_c = math_ops.matmul(inputs_c, k_c); var x_c = math_ops.matmul(inputs_c, k_c);
var x_o = math_ops.matmul(inputs_o, k_o); var x_o = math_ops.matmul(inputs_o, k_o);
if(_args.UseBias)
if (_args.UseBias)
{ {
var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0); var b = tf.split(_bias.AsTensor(), num_split: 4, axis: 0);
Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3]; Tensor b_i = b[0], b_f = b[1], b_c = b[2], b_o = b[3];
@@ -170,7 +170,7 @@ namespace Tensorflow.Keras.Layers.Rnn
} }
var h = o * _args.Activation.Apply(c); var h = o * _args.Activation.Apply(c);
// 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组 // 这里是因为 Tensors 类初始化的时候会把第一个元素之后的元素打包成一个数组
return new Tensors(h, h, c);
return new Nest<Tensor>(new INestStructure<Tensor>[] { new NestNode<Tensor>(h), new NestList<Tensor>(h, c) }).ToTensors();
} }


/// <summary> /// <summary>
@@ -188,22 +188,21 @@ namespace Tensorflow.Keras.Layers.Rnn
h_tm1_o = h_tm1[3]; h_tm1_o = h_tm1[3];


var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor(); var _recurrent_kernel_tensor = _recurrent_kernel.AsTensor();
var startIndex = _recurrent_kernel_tensor.shape[0];
var endIndex = _recurrent_kernel_tensor.shape[1];
int startIndex = (int)_recurrent_kernel_tensor.shape[0];
var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, var _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, 0 }, new[] { startIndex, _args.Units }); new[] { 0, 0 }, new[] { startIndex, _args.Units });
var i = _args.RecurrentActivation.Apply( var i = _args.RecurrentActivation.Apply(
x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice)); x_i + math_ops.matmul(h_tm1_i, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units }, new[] { startIndex, _args.Units * 2});
new[] { 0, _args.Units }, new[] { startIndex, _args.Units});
var f = _args.RecurrentActivation.Apply( var f = _args.RecurrentActivation.Apply(
x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice)); x_f + math_ops.matmul(h_tm1_f, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units * 3 });
new[] { 0, _args.Units * 2 }, new[] { startIndex, _args.Units });
var c = f * c_tm1 + i * _args.Activation.Apply( var c = f * c_tm1 + i * _args.Activation.Apply(
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 3 }, new[] { startIndex, endIndex });
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units });
var o = _args.RecurrentActivation.Apply( var o = _args.RecurrentActivation.Apply(
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));




+ 28
- 26
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -45,23 +45,25 @@ namespace Tensorflow.Keras.Layers.Rnn
} }
} }


public RNN(RNNArgs args) : base(PreConstruct(args))
public RNN(IRnnCell cell, RNNArgs args) : base(PreConstruct(args))
{ {
_args = args; _args = args;
SupportsMasking = true; SupportsMasking = true;


// if is StackedRnncell
if (args.Cells != null)
{
Cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cells
});
}
else
{
Cell = args.Cell;
}
Cell = cell;

// get input_shape
_args = PreConstruct(args);

_num_constants = 0;
}

public RNN(IEnumerable<IRnnCell> cells, RNNArgs args) : base(PreConstruct(args))
{
_args = args;
SupportsMasking = true;

Cell = new StackedRNNCells(cells, new StackedRNNCellsArgs());


// get input_shape // get input_shape
_args = PreConstruct(args); _args = PreConstruct(args);
@@ -330,7 +332,7 @@ namespace Tensorflow.Keras.Layers.Rnn
states = new Tensors(states.SkipLast(_num_constants).ToArray()); states = new Tensors(states.SkipLast(_num_constants).ToArray());
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states.Single);
return (output, new_states);
}; };
} }
else else
@@ -382,6 +384,11 @@ namespace Tensorflow.Keras.Layers.Rnn
} }
else else
{ {
//var tapeSet = tf.GetTapeSet();
//foreach(var tape in tapeSet)
//{
// tape.Watch(output);
//}
return output; return output;
} }
} }
@@ -405,7 +412,7 @@ namespace Tensorflow.Keras.Layers.Rnn
throw new NotImplementedException(); throw new NotImplementedException();
} }


private (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
protected (Tensors inputs, Tensors initial_state, Tensors constants) _process_inputs(Tensors inputs, Tensors initial_state, Tensors constants)
{ {
if (inputs.Length > 1) if (inputs.Length > 1)
{ {
@@ -484,7 +491,7 @@ namespace Tensorflow.Keras.Layers.Rnn


} }


void _maybe_reset_cell_dropout_mask(ILayer cell)
protected void _maybe_reset_cell_dropout_mask(ILayer cell)
{ {
if (cell is DropoutRNNCellMixin CellDRCMixin) if (cell is DropoutRNNCellMixin CellDRCMixin)
{ {
@@ -495,26 +502,21 @@ namespace Tensorflow.Keras.Layers.Rnn


private static RNNArgs PreConstruct(RNNArgs args) private static RNNArgs PreConstruct(RNNArgs args)
{ {
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary<string, object>();
}

// If true, the output for masked timestep will be zeros, whereas in the // If true, the output for masked timestep will be zeros, whereas in the
// false case, output from previous timestep is returned for masked timestep. // false case, output from previous timestep is returned for masked timestep.
var zeroOutputForMask = (bool)args.Kwargs.Get("zero_output_for_mask", false);
var zeroOutputForMask = args.ZeroOutputForMask;


Shape input_shape; Shape input_shape;
var propIS = (Shape)args.Kwargs.Get("input_shape", null);
var propID = (int?)args.Kwargs.Get("input_dim", null);
var propIL = (int?)args.Kwargs.Get("input_length", null);
var propIS = args.InputShape;
var propID = args.InputDim;
var propIL = args.InputLength;


if (propIS == null && (propID != null || propIL != null)) if (propIS == null && (propID != null || propIL != null))
{ {
input_shape = new Shape( input_shape = new Shape(
propIL ?? -1, propIL ?? -1,
propID ?? -1); propID ?? -1);
args.Kwargs["input_shape"] = input_shape;
args.InputShape = input_shape;
} }


return args; return args;


+ 3
- 4
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -10,14 +10,14 @@ namespace Tensorflow.Keras.Layers.Rnn
public class SimpleRNN : RNN public class SimpleRNN : RNN
{ {
SimpleRNNArgs args; SimpleRNNArgs args;
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args))
public SimpleRNN(SimpleRNNArgs args) : base(CreateCellForArgs(args), args)
{ {
this.args = args; this.args = args;
} }


private static SimpleRNNArgs CreateCellForArgs(SimpleRNNArgs args)
private static SimpleRNNCell CreateCellForArgs(SimpleRNNArgs args)
{ {
args.Cell = new SimpleRNNCell(new SimpleRNNCellArgs()
return new SimpleRNNCell(new SimpleRNNCellArgs()
{ {
Units = args.Units, Units = args.Units,
Activation = args.Activation, Activation = args.Activation,
@@ -30,7 +30,6 @@ namespace Tensorflow.Keras.Layers.Rnn
DType = args.DType, DType = args.DType,
Trainable = args.Trainable, Trainable = args.Trainable,
}); });
return args;
} }
} }
} }

+ 0
- 5
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -115,10 +115,5 @@ namespace Tensorflow.Keras.Layers.Rnn
return new Tensors(output, output); return new Tensors(output, output);
} }
} }

public Tensors get_initial_state(Tensors inputs = null, Tensor batch_size = null, TF_DataType dtype = TF_DataType.DtInvalid)
{
return RnnUtils.generate_zero_filled_state_for_cell(this, inputs, batch_size, dtype);
}
} }
} }

+ 4
- 8
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -15,15 +15,11 @@ namespace Tensorflow.Keras.Layers.Rnn
public IList<IRnnCell> Cells { get; set; } public IList<IRnnCell> Cells { get; set; }
public bool _reverse_state_order; public bool _reverse_state_order;


public StackedRNNCells(StackedRNNCellsArgs args) : base(args)
public StackedRNNCells(IEnumerable<IRnnCell> cells, StackedRNNCellsArgs args) : base(args)
{ {
if (args.Kwargs == null)
{
args.Kwargs = new Dictionary<string, object>();
}
Cells = args.Cells;
_reverse_state_order = (bool)args.Kwargs.Get("reverse_state_order", false);
Cells = cells.ToList();

_reverse_state_order = args.ReverseStateOrder;


if (_reverse_state_order) if (_reverse_state_order)
{ {


+ 43
- 31
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -55,30 +55,56 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual((2, 4), new_states[0].shape); Assert.AreEqual((2, 4), new_states[0].shape);
} }


[TestMethod]
public void TrainLSTMWithMnist()
{
var input = keras.Input((784));
var x = keras.layers.Reshape((28, 28)).Apply(input);
//x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
//x = keras.layers.LSTM(100, return_sequences: true).Apply(x);
//x = keras.layers.LSTM(150, return_sequences: true).Apply(x);
x = keras.layers.LSTM(4, implementation: 2).Apply(x);
//x = keras.layers.Dense(100).Apply(x);
var output = keras.layers.Dense(10, activation: "softmax").Apply(x);

var model = keras.Model(input, output);
model.summary();
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30);
}

[TestMethod] [TestMethod]
public void SimpleRNN() public void SimpleRNN()
{ {
//var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
///*var simple_rnn = keras.layers.SimpleRNN(4);
//var output = simple_rnn.Apply(inputs);
//Assert.AreEqual((32, 4), output.shape);*/

//var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
//var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
//Assert.AreEqual((6, 10, 4), whole_sequence_output.shape);
//Assert.AreEqual((6, 4), final_state.shape);
var input = keras.Input((784));
var x = keras.layers.Reshape((28, 28)).Apply(input);
x = keras.layers.SimpleRNN(10).Apply(x);
var output = keras.layers.Dense(10, activation: "softmax").Apply(x);


var inputs = keras.Input(shape: (10, 8));
var x = keras.layers.SimpleRNN(4).Apply(inputs);
var output = keras.layers.Dense(10).Apply(x);
var model = keras.Model(inputs, output);
var model = keras.Model(input, output);
model.summary(); model.summary();
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });


model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy());
var datax = np.ones((16, 10, 8), dtype: dtypes.float32);
var datay = np.ones((16));
model.fit(datax, datay, epochs: 20);
var data_loader = new MnistModelLoader();
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10);
} }

[TestMethod] [TestMethod]
public void RNNForSimpleRNNCell() public void RNNForSimpleRNNCell()
{ {
@@ -109,19 +135,5 @@ namespace Tensorflow.Keras.UnitTest.Layers
Console.WriteLine($"output: {output}"); Console.WriteLine($"output: {output}");
Assert.AreEqual((5, 4), output.shape); Assert.AreEqual((5, 4), output.shape);
} }

[TestMethod]
public void MyTest()
{
var a = tf.zeros((2, 3));
var b = tf.ones_like(a);
var c = tf.ones((3,4));

var d = new Tensors { a, b, c };
var (A, BC) = d;
Console.WriteLine($"A:{A}");
Console.WriteLine($"BC:{BC}");
}

} }
} }

+ 1
- 1
tools/Tensorflow.CodeGen/OpClassifier.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow.CodeGen
{ {
public class OpClassifier public class OpClassifier
{ {
private static readonly string _filenamePattern = @"^gen_[a-z]*_ops.py$";
private static readonly string _filenamePattern = @"^gen_[a-z_]*_ops.py$";
private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):"; private static readonly string _pythonFunctionPattern = @"def\s+(\w+\d*\w*)\((?:\s*\w+\s*(?:=\s*[\S]*)*,\s*)*\s*name=None\):";
private Dictionary<string, HashSet<string>> _opSet = new(); private Dictionary<string, HashSet<string>> _opSet = new();
public Dictionary<string, HashSet<string>> OpSet => _opSet; public Dictionary<string, HashSet<string>> OpSet => _opSet;


+ 16
- 1
tools/Tensorflow.CodeGen/Utils.cs View File

@@ -178,10 +178,25 @@ namespace Tensorflow.CodeGen
else if (attr.Type == "list(shape)") else if (attr.Type == "list(shape)")
{ {
res.Add((attr.Name, "Shape[]", "NOVALUE")); res.Add((attr.Name, "Shape[]", "NOVALUE"));
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{
List<string> exps = new();
foreach (var value in attr.DefaultValue.List.Shape)
{
exps.Add($"new Shape({string.Join(", ", value.Dim.Select(x => x.Size))})");
}
string expression = "new Shape[]{" + $"{string.Join(", ", exps)}" + "}";
dynamicDefaultValues[attr.Name] = expression;
res.Add((attr.Name, "string[]", $"null"));
}
else
{
res.Add((attr.Name, "string[]", "NOVALUE"));
}
} }
else if (attr.Type == "list(string)") else if (attr.Type == "list(string)")
{ {
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.S)
if (attr.DefaultValue is not null && attr.DefaultValue.ValueCase == AttrValue.ValueOneofCase.List)
{ {
List<string> values = new(); List<string> values = new();
foreach (var value in attr.DefaultValue.List.S) foreach (var value in attr.DefaultValue.List.S)


Loading…
Cancel
Save