From 07ea65683362cc2a633e9de0a7e0b550794d2474 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Fri, 16 Jun 2023 16:15:01 +0800 Subject: [PATCH] fix: error when training SimpleRNN. --- .../Exceptions/NotOkStatusException.cs | 19 +++++++ .../Operations/Operation.cs | 11 +++- .../Operations/gen_math_ops.cs | 3 +- src/TensorFlowNET.Core/Status/Status.cs | 3 +- src/TensorFlowNET.Keras/IsExternalInit.cs | 4 ++ src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs | 54 ++++++++++++------- .../Layers/Rnn/SimpleRNN.cs | 14 ----- .../Layers/Rnn.Test.cs | 5 ++ 8 files changed, 78 insertions(+), 35 deletions(-) create mode 100644 src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs create mode 100644 src/TensorFlowNET.Keras/IsExternalInit.cs diff --git a/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs b/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs new file mode 100644 index 00000000..c283c1a4 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Exceptions +{ + public class NotOkStatusException : TensorflowException + { + public NotOkStatusException() : base() + { + + } + + public NotOkStatusException(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 5e689c65..d31b26d4 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -186,7 +186,16 @@ namespace Tensorflow } public virtual T get_attr(string name) - => (T)get_attr(name); + { + if (typeof(T).IsValueType) + { + return (T)Convert.ChangeType(get_attr(name), typeof(T)); + } + else + { + return (T)get_attr(name); + } + } internal unsafe TF_DataType _get_attr_type(string name) { diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 3456d9b3..6eb7a411 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -4633,8 +4633,9 @@ public static class gen_math_ops var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } }); return _fast_path_result[0]; } - catch (Exception) + catch (Exception ex) { + Console.WriteLine(); } try { diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index a890c2ae..12b6fba2 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -17,6 +17,7 @@ using System; using System.Diagnostics; using System.Runtime.CompilerServices; +using Tensorflow.Exceptions; using Tensorflow.Util; using static Tensorflow.c_api; @@ -88,7 +89,7 @@ namespace Tensorflow case TF_Code.TF_INVALID_ARGUMENT: throw new InvalidArgumentError(message); default: - throw new TensorflowException(message); + throw new NotOkStatusException(message); } } } diff --git a/src/TensorFlowNET.Keras/IsExternalInit.cs b/src/TensorFlowNET.Keras/IsExternalInit.cs new file mode 100644 index 00000000..11f062fa --- /dev/null +++ b/src/TensorFlowNET.Keras/IsExternalInit.cs @@ -0,0 +1,4 @@ +namespace System.Runtime.CompilerServices +{ + internal static class IsExternalInit { } +} diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs index 77f7d927..f99bc23a 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs @@ -11,6 +11,7 @@ using Tensorflow.Common.Extensions; using System.Linq.Expressions; using Tensorflow.Keras.Utils; using Tensorflow.Common.Types; +using System.Runtime.CompilerServices; // from tensorflow.python.distribute import distribution_strategy_context as ds_context; namespace Tensorflow.Keras.Layers.Rnn @@ -30,7 +31,19 @@ namespace Tensorflow.Keras.Layers.Rnn private int _num_constants; protected IVariableV1 _kernel; protected IVariableV1 _bias; - protected IRnnCell _cell; + private IRnnCell _cell; + protected IRnnCell Cell + { + get + { + return _cell; + } + init + { + _cell = value; + _self_tracked_trackables.Add(_cell); + } + } public RNN(RNNArgs args) : base(PreConstruct(args)) { @@ -40,14 +53,14 @@ namespace Tensorflow.Keras.Layers.Rnn // if is StackedRnncell if (args.Cells != null) { - _cell = new StackedRNNCells(new StackedRNNCellsArgs + Cell = new StackedRNNCells(new StackedRNNCellsArgs { Cells = args.Cells }); } else { - _cell = args.Cell; + Cell = args.Cell; } // get input_shape @@ -65,7 +78,7 @@ namespace Tensorflow.Keras.Layers.Rnn if (_states == null) { // CHECK(Rinne): check if this is correct. - var nested = _cell.StateSize.MapStructure(x => null); + var nested = Cell.StateSize.MapStructure(x => null); _states = nested.AsNest().ToTensors(); } return _states; @@ -83,7 +96,7 @@ namespace Tensorflow.Keras.Layers.Rnn } // state_size is a array of ints or a positive integer - var state_size = _cell.StateSize.ToSingleShape(); + var state_size = Cell.StateSize.ToSingleShape(); // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor Func _get_output_shape; @@ -110,12 +123,12 @@ namespace Tensorflow.Keras.Layers.Rnn return output_shape; }; - Type type = _cell.GetType(); + Type type = Cell.GetType(); PropertyInfo output_size_info = type.GetProperty("output_size"); Shape output_shape; if (output_size_info != null) { - output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape()); + output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape()); // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型 output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape); } @@ -171,7 +184,9 @@ namespace Tensorflow.Keras.Layers.Rnn public override void build(KerasShapesWrapper input_shape) { - object get_input_spec(Shape shape) + input_shape = new KerasShapesWrapper(input_shape.Shapes[0]); + + InputSpec get_input_spec(Shape shape) { var input_spec_shape = shape.as_int_list(); @@ -213,10 +228,13 @@ namespace Tensorflow.Keras.Layers.Rnn // numpy inputs. - if (!_cell.Built) + if (Cell is Layer layer && !layer.Built) { - _cell.build(input_shape); + layer.build(input_shape); + layer.Built = true; } + + this.built = true; } /// @@ -247,10 +265,10 @@ namespace Tensorflow.Keras.Layers.Rnn (inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants); - _maybe_reset_cell_dropout_mask(_cell); - if (_cell is StackedRNNCells) + _maybe_reset_cell_dropout_mask(Cell); + if (Cell is StackedRNNCells) { - var stack_cell = _cell as StackedRNNCells; + var stack_cell = Cell as StackedRNNCells; foreach (IRnnCell cell in stack_cell.Cells) { _maybe_reset_cell_dropout_mask(cell); @@ -300,10 +318,10 @@ namespace Tensorflow.Keras.Layers.Rnn bool is_tf_rnn_cell = false; if (constants is not null) { - if (!_cell.SupportOptionalArgs) + if (!Cell.SupportOptionalArgs) { throw new ValueError( - $"RNN cell {_cell} does not support constants." + + $"RNN cell {Cell} does not support constants." + $"Received: constants={constants}"); } @@ -312,7 +330,7 @@ namespace Tensorflow.Keras.Layers.Rnn constants = new Tensors(states.TakeLast(_num_constants).ToArray()); states = new Tensors(states.SkipLast(_num_constants).ToArray()); states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states; - var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); + var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants }); return (output, new_states.Single); }; } @@ -321,7 +339,7 @@ namespace Tensorflow.Keras.Layers.Rnn step = (inputs, states) => { states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states; - var (output, new_states) = _cell.Apply(inputs, states); + var (output, new_states) = Cell.Apply(inputs, states); return (output, new_states); }; } @@ -562,7 +580,7 @@ namespace Tensorflow.Keras.Layers.Rnn var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0]; var dtype = input.dtype; - Tensors init_state = _cell.GetInitialState(null, batch_size, dtype); + Tensors init_state = Cell.GetInitialState(null, batch_size, dtype); return init_state; } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index 22d0e277..551c20cd 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -32,19 +32,5 @@ namespace Tensorflow.Keras.Layers.Rnn }); return args; } - - public override void build(KerasShapesWrapper input_shape) - { - var single_shape = input_shape.ToSingleShape(); - var input_dim = single_shape[-1]; - _buildInputShape = input_shape; - - _kernel = add_weight("kernel", (single_shape[-1], args.Units), - initializer: args.KernelInitializer - //regularizer = self.kernel_regularizer, - //constraint = self.kernel_constraint, - //caching_device = default_caching_device, - ); - } } } \ No newline at end of file diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs index 28a16ad4..fcb9ad1d 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs @@ -77,6 +77,11 @@ namespace Tensorflow.Keras.UnitTest.Layers var output = keras.layers.Dense(10).Apply(x); var model = keras.Model(inputs, output); model.summary(); + + model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy()); + var datax = np.ones((16, 10, 8), dtype: dtypes.float32); + var datay = np.ones((16)); + model.fit(datax, datay, epochs: 20); } [TestMethod] public void RNNForSimpleRNNCell()