Browse Source

fix: make the initialization of the layer's name correct

tags/v0.110.4-Transformer-Model
“Wanglongzhi2001” 2 years ago
parent
commit
3bef87aefc
2 changed files with 42 additions and 5 deletions
  1. +9
    -5
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  2. +33
    -0
      test/TensorFlowNET.Keras.UnitTest/InitLayerNameTest.cs

+ 9
- 5
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -29,6 +29,7 @@ using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Saving;
using Tensorflow.Train;
using System.Text.RegularExpressions;

namespace Tensorflow.Keras.Utils
{
@@ -126,12 +127,15 @@ namespace Tensorflow.Keras.Utils

public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>
string intermediate = Regex.Replace(name, "(.)([A-Z][a-z0-9]+)", "$1_$2");
string insecure = Regex.Replace(intermediate, "([a-z])([A-Z])", "$1_$2").ToLower();

if (insecure[0] != '_')
{
return i > 0 && char.IsUpper(x) && !Char.IsDigit(name[i - 1]) ?
"_" + x.ToString() :
x.ToString();
})).ToLower();
return insecure;
}
return "private" + insecure;
}

/// <summary>


+ 33
- 0
test/TensorFlowNET.Keras.UnitTest/InitLayerNameTest.cs View File

@@ -0,0 +1,33 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.Layers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class InitLayerNameTest
{
[TestMethod]
public void RNNLayerNameTest()
{
var simpleRnnCell = keras.layers.SimpleRNNCell(1);
Assert.AreEqual("simple_rnn_cell", simpleRnnCell.Name);
var simpleRnn = keras.layers.SimpleRNN(2);
Assert.AreEqual("simple_rnn", simpleRnn.Name);
var lstmCell = keras.layers.LSTMCell(2);
Assert.AreEqual("lstm_cell", lstmCell.Name);
var lstm = keras.layers.LSTM(3);
Assert.AreEqual("lstm", lstm.Name);
}

[TestMethod]
public void ConvLayerNameTest()
{
var conv2d = keras.layers.Conv2D(8, activation: "linear");
Assert.AreEqual("conv2d", conv2d.Name);
var conv2dTranspose = keras.layers.Conv2DTranspose(8);
Assert.AreEqual("conv2d_transpose", conv2dTranspose.Name);
}
}
}

Loading…
Cancel
Save