Browse Source

fix:fix the bug of load LSTM model

tags/v0.110.4-Transformer-Model
“Wanglongzhi2001” 2 years ago
parent
commit
b27ccca84f
28 changed files with 71 additions and 40 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  4. +9
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs
  12. +1
    -2
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  13. +3
    -1
      src/TensorFlowNET.Core/ops.cs
  14. +1
    -2
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  15. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs
  16. +1
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs
  17. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
  18. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  19. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
  20. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs
  21. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  22. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
  23. +2
    -2
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  24. +0
    -1
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  25. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs
  26. +1
    -1
      src/TensorFlowNET.Keras/Utils/RnnUtils.cs
  27. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
  28. +28
    -1
      test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUCellArgs.cs View File

@@ -3,7 +3,7 @@ using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUCellArgs : AutoSerializeLayerArgs
{


+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs View File

@@ -1,4 +1,4 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class LSTMArgs : RNNArgs
{


+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -1,7 +1,7 @@
using Newtonsoft.Json;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
// TODO: complete the implementation
public class LSTMCellArgs : AutoSerializeLayerArgs


+ 9
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs View File

@@ -1,8 +1,8 @@
using Newtonsoft.Json;
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
// TODO(Rinne): add regularizers.
public class RNNArgs : AutoSerializeLayerArgs
@@ -23,16 +23,22 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
public int? InputDim { get; set; }
public int? InputLength { get; set; }
// TODO: Add `num_constants` and `zero_output_for_mask`.
[JsonProperty("units")]
public int Units { get; set; }
[JsonProperty("activation")]
public Activation Activation { get; set; }
[JsonProperty("recurrent_activation")]
public Activation RecurrentActivation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
public IInitializer KernelInitializer { get; set; }
public IInitializer RecurrentInitializer { get; set; }
public IInitializer BiasInitializer { get; set; }
[JsonProperty("dropout")]
public float Dropout { get; set; } = .0f;
[JsonProperty("zero_output_for_mask")]
public bool ZeroOutputForMask { get; set; } = false;
[JsonProperty("recurrent_dropout")]
public float RecurrentDropout { get; set; } = .0f;
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RnnOptionalArgs.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class RnnOptionalArgs: IOptionalArgs
{


+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNArgs.cs View File

@@ -1,4 +1,4 @@
namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNArgs : RNNArgs
{


+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNCellArgs.cs View File

@@ -1,6 +1,6 @@
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
{


+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/StackedRNNCellsArgs.cs View File

@@ -1,7 +1,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
namespace Tensorflow.Keras.ArgsDefinition
{
public class StackedRNNCellsArgs : LayerArgs
{


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -1,7 +1,7 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;



+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rnn/IRnnCell.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Common.Types;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
public interface IRnnCell: ILayer
{


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rnn/IStackedRnnCells.cs View File

@@ -2,7 +2,7 @@
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
public interface IStackedRnnCells : IRnnCell
{


+ 1
- 2
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -19,9 +19,8 @@ using System.Collections.Generic;
using Tensorflow.Common.Types;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Operations;


+ 3
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -571,7 +571,9 @@ namespace Tensorflow
if (tf.Context.executing_eagerly())
return true;
else
throw new NotImplementedException("");
// TODO(Wanglongzhi2001), implement the false case
return true;
//throw new NotImplementedException("");
}

public static bool inside_function()


+ 1
- 2
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -2,9 +2,8 @@
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rnn/DropoutRNNCellMixin.cs View File

@@ -6,7 +6,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
public abstract class DropoutRNNCellMixin: Layer, IRnnCell
{


+ 1
- 2
src/TensorFlowNET.Keras/Layers/Rnn/GRUCell.cs View File

@@ -3,12 +3,11 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Cell class for the GRU layer.


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs View File

@@ -1,10 +1,10 @@
using System.Linq;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Common.Types;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Long Short-Term Memory layer - Hochreiter 1997.


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -3,12 +3,12 @@ using Serilog.Core;
using System.Diagnostics;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Cell class for the LSTM layer.


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs View File

@@ -3,7 +3,6 @@ using System;
using System.Collections.Generic;
using System.Reflection;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Util;
@@ -14,7 +13,7 @@ using Tensorflow.Common.Types;
using System.Runtime.CompilerServices;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Base class for recurrent layers.
@@ -185,6 +184,7 @@ namespace Tensorflow.Keras.Layers.Rnn

public override void build(KerasShapesWrapper input_shape)
{
_buildInputShape = input_shape;
input_shape = new KerasShapesWrapper(input_shape.Shapes[0]);

InputSpec get_input_spec(Shape shape)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rnn/RnnBase.cs View File

@@ -4,7 +4,7 @@ using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
public abstract class RnnBase: Layer
{


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -1,11 +1,11 @@
using System.Data;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Operations.Activation;
using static HDF.PInvoke.H5Z;
using static Tensorflow.ApiDef.Types;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
public class SimpleRNN : RNN
{


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs View File

@@ -1,7 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
@@ -9,7 +9,7 @@ using Tensorflow.Common.Extensions;
using Tensorflow.Keras.Utils;
using Tensorflow.Graphs;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Cell class for SimpleRNN.


+ 2
- 2
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -3,12 +3,12 @@ using System.ComponentModel;
using System.Linq;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Layers.Rnn
namespace Tensorflow.Keras.Layers
{
public class StackedRNNCells : Layer, IRnnCell
{


+ 0
- 1
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -13,7 +13,6 @@ using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving.SavedModel;


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/serialized_attributes.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Metrics;
using Tensorflow.Train;



+ 1
- 1
src/TensorFlowNET.Keras/Utils/RnnUtils.cs View File

@@ -3,7 +3,7 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;
using Tensorflow.Common.Extensions;

namespace Tensorflow.Keras.Utils


+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -6,7 +6,7 @@ using System.Text;
using System.Threading.Tasks;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers.Rnn;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Train;


+ 28
- 1
test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs View File

@@ -1,5 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
@@ -79,6 +81,31 @@ public class ModelLoadTest
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}

[TestMethod]
public void LSTMLoad()
{
var inputs = np.random.randn(10, 5, 3);
var outputs = np.random.randn(10, 1);
var model = keras.Sequential();
model.add(keras.Input(shape: (5, 3)));
var lstm = keras.layers.LSTM(32);

model.add(lstm);

model.add(keras.layers.Dense(1, keras.activations.Sigmoid));

model.compile(optimizer: keras.optimizers.Adam(),
loss: keras.losses.MeanSquaredError(),
new[] { "accuracy" });

var result = model.fit(inputs.numpy(), outputs.numpy(), batch_size: 10, epochs: 3, workers: 16, use_multiprocessing: true);

model.save("LSTM_Random");

var model_loaded = keras.models.load_model("LSTM_Random");
model_loaded.summary();
}

[Ignore]
[TestMethod]
public void VGG19()


Loading…
Cancel
Save