Browse Source

fix: error when training SimpleRNN.

tags/v0.110.0-LSTM-Model
Yaohui Liu 2 years ago
parent
commit
07ea656833
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
8 changed files with 78 additions and 35 deletions
  1. +19
    -0
      src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs
  2. +10
    -1
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +2
    -1
      src/TensorFlowNET.Core/Status/Status.cs
  5. +4
    -0
      src/TensorFlowNET.Keras/IsExternalInit.cs
  6. +36
    -18
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  7. +0
    -14
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  8. +5
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 19
- 0
src/TensorFlowNET.Core/Exceptions/NotOkStatusException.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Exceptions
{
public class NotOkStatusException : TensorflowException
{
public NotOkStatusException() : base()
{

}

public NotOkStatusException(string message) : base(message)
{

}
}
}

+ 10
- 1
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -186,7 +186,16 @@ namespace Tensorflow
}

public virtual T get_attr<T>(string name)
=> (T)get_attr(name);
{
if (typeof(T).IsValueType)
{
return (T)Convert.ChangeType(get_attr(name), typeof(T));
}
else
{
return (T)get_attr(name);
}
}

internal unsafe TF_DataType _get_attr_type(string name)
{


+ 2
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -4633,8 +4633,9 @@ public static class gen_math_ops
var _fast_path_result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(_ctx, "MatMul", name) { args = new object[] { a, b }, attrs = new Dictionary<string, object>() { ["transpose_a"] = transpose_a, ["transpose_b"] = transpose_b } });
return _fast_path_result[0];
}
catch (Exception)
catch (Exception ex)
{
Console.WriteLine();
}
try
{


+ 2
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using Tensorflow.Exceptions;
using Tensorflow.Util;
using static Tensorflow.c_api;

@@ -88,7 +89,7 @@ namespace Tensorflow
case TF_Code.TF_INVALID_ARGUMENT:
throw new InvalidArgumentError(message);
default:
throw new TensorflowException(message);
throw new NotOkStatusException(message);
}
}
}


+ 4
- 0
src/TensorFlowNET.Keras/IsExternalInit.cs View File

@@ -0,0 +1,4 @@
namespace System.Runtime.CompilerServices
{
internal static class IsExternalInit { }
}

+ 36
- 18
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -11,6 +11,7 @@ using Tensorflow.Common.Extensions;
using System.Linq.Expressions;
using Tensorflow.Keras.Utils;
using Tensorflow.Common.Types;
using System.Runtime.CompilerServices;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers.Rnn
@@ -30,7 +31,19 @@ namespace Tensorflow.Keras.Layers.Rnn
private int _num_constants;
protected IVariableV1 _kernel;
protected IVariableV1 _bias;
protected IRnnCell _cell;
private IRnnCell _cell;
protected IRnnCell Cell
{
get
{
return _cell;
}
init
{
_cell = value;
_self_tracked_trackables.Add(_cell);
}
}

public RNN(RNNArgs args) : base(PreConstruct(args))
{
@@ -40,14 +53,14 @@ namespace Tensorflow.Keras.Layers.Rnn
// if is StackedRnncell
if (args.Cells != null)
{
_cell = new StackedRNNCells(new StackedRNNCellsArgs
Cell = new StackedRNNCells(new StackedRNNCellsArgs
{
Cells = args.Cells
});
}
else
{
_cell = args.Cell;
Cell = args.Cell;
}

// get input_shape
@@ -65,7 +78,7 @@ namespace Tensorflow.Keras.Layers.Rnn
if (_states == null)
{
// CHECK(Rinne): check if this is correct.
var nested = _cell.StateSize.MapStructure<Tensor?>(x => null);
var nested = Cell.StateSize.MapStructure<Tensor?>(x => null);
_states = nested.AsNest().ToTensors();
}
return _states;
@@ -83,7 +96,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}

// state_size is a array of ints or a positive integer
var state_size = _cell.StateSize.ToSingleShape();
var state_size = Cell.StateSize.ToSingleShape();

// TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
Func<Shape, Shape> _get_output_shape;
@@ -110,12 +123,12 @@ namespace Tensorflow.Keras.Layers.Rnn
return output_shape;
};

Type type = _cell.GetType();
Type type = Cell.GetType();
PropertyInfo output_size_info = type.GetProperty("output_size");
Shape output_shape;
if (output_size_info != null)
{
output_shape = nest.map_structure(_get_output_shape, _cell.OutputSize.ToSingleShape());
output_shape = nest.map_structure(_get_output_shape, Cell.OutputSize.ToSingleShape());
// TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
output_shape = (output_shape.Length == 1 ? (int)output_shape[0] : output_shape);
}
@@ -171,7 +184,9 @@ namespace Tensorflow.Keras.Layers.Rnn

public override void build(KerasShapesWrapper input_shape)
{
object get_input_spec(Shape shape)
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);

InputSpec get_input_spec(Shape shape)
{
var input_spec_shape = shape.as_int_list();

@@ -213,10 +228,13 @@ namespace Tensorflow.Keras.Layers.Rnn
// numpy inputs.


if (!_cell.Built)
if (Cell is Layer layer && !layer.Built)
{
_cell.build(input_shape);
layer.build(input_shape);
layer.Built = true;
}

this.built = true;
}

/// <summary>
@@ -247,10 +265,10 @@ namespace Tensorflow.Keras.Layers.Rnn

(inputs, initial_state, constants) = _process_inputs(inputs, initial_state, constants);

_maybe_reset_cell_dropout_mask(_cell);
if (_cell is StackedRNNCells)
_maybe_reset_cell_dropout_mask(Cell);
if (Cell is StackedRNNCells)
{
var stack_cell = _cell as StackedRNNCells;
var stack_cell = Cell as StackedRNNCells;
foreach (IRnnCell cell in stack_cell.Cells)
{
_maybe_reset_cell_dropout_mask(cell);
@@ -300,10 +318,10 @@ namespace Tensorflow.Keras.Layers.Rnn
bool is_tf_rnn_cell = false;
if (constants is not null)
{
if (!_cell.SupportOptionalArgs)
if (!Cell.SupportOptionalArgs)
{
throw new ValueError(
$"RNN cell {_cell} does not support constants." +
$"RNN cell {Cell} does not support constants." +
$"Received: constants={constants}");
}

@@ -312,7 +330,7 @@ namespace Tensorflow.Keras.Layers.Rnn
constants = new Tensors(states.TakeLast(_num_constants).ToArray());
states = new Tensors(states.SkipLast(_num_constants).ToArray());
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states[0]) : states;
var (output, new_states) = _cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
var (output, new_states) = Cell.Apply(inputs, states, optional_args: new RnnOptionalArgs() { Constants = constants });
return (output, new_states.Single);
};
}
@@ -321,7 +339,7 @@ namespace Tensorflow.Keras.Layers.Rnn
step = (inputs, states) =>
{
states = len(states) == 1 && is_tf_rnn_cell ? new Tensors(states.First()) : states;
var (output, new_states) = _cell.Apply(inputs, states);
var (output, new_states) = Cell.Apply(inputs, states);
return (output, new_states);
};
}
@@ -562,7 +580,7 @@ namespace Tensorflow.Keras.Layers.Rnn
var batch_size = _args.TimeMajor ? input_shape[1] : input_shape[0];
var dtype = input.dtype;

Tensors init_state = _cell.GetInitialState(null, batch_size, dtype);
Tensors init_state = Cell.GetInitialState(null, batch_size, dtype);

return init_state;
}


+ 0
- 14
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -32,19 +32,5 @@ namespace Tensorflow.Keras.Layers.Rnn
});
return args;
}

public override void build(KerasShapesWrapper input_shape)
{
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_buildInputShape = input_shape;

_kernel = add_weight("kernel", (single_shape[-1], args.Units),
initializer: args.KernelInitializer
//regularizer = self.kernel_regularizer,
//constraint = self.kernel_constraint,
//caching_device = default_caching_device,
);
}
}
}

+ 5
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -77,6 +77,11 @@ namespace Tensorflow.Keras.UnitTest.Layers
var output = keras.layers.Dense(10).Apply(x);
var model = keras.Model(inputs, output);
model.summary();

model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy());
var datax = np.ones((16, 10, 8), dtype: dtypes.float32);
var datay = np.ones((16));
model.fit(datax, datay, epochs: 20);
}
[TestMethod]
public void RNNForSimpleRNNCell()


Loading…
Cancel
Save