Browse Source

fix some bugs in test

tags/v0.4.10
xuyige 5 years ago
parent
commit
1a4c3c2d20
4 changed files with 7 additions and 7 deletions
  1. +1
    -1
      fastNLP/models/bert.py
  2. +3
    -2
      test/core/test_vocabulary.py
  3. +1
    -2
      test/models/test_cnn_text_classification.py
  4. +2
    -2
      test/test_tutorials.py

+ 1
- 1
fastNLP/models/bert.py View File

@@ -30,7 +30,7 @@ class BertConfig:
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate = intermediate_size
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob


+ 3
- 2
test/core/test_vocabulary.py View File

@@ -100,13 +100,14 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
def test_iteration(self):
vocab = Vocabulary()
vocab = Vocabulary(padding=None, unknown=None)
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
"works", "well", "in", "most", "cases", "scales", "well"]
vocab.update(text)
text = set(text)
for word in vocab:
for word, idx in vocab:
self.assertTrue(word in text)
self.assertTrue(idx < len(vocab))


class TestOther(unittest.TestCase):


+ 1
- 2
test/models/test_cnn_text_classification.py View File

@@ -12,7 +12,6 @@ class TestCNNText(unittest.TestCase):
model = CNNText(init_emb,
NUM_CLS,
kernel_nums=(1, 3, 5),
kernel_sizes=(2, 2, 2),
padding=0,
kernel_sizes=(1, 3, 5),
dropout=0.5)
RUNNER.run_model_with_task(TEXT_CLS, model)

+ 2
- 2
test/test_tutorials.py View File

@@ -70,7 +70,7 @@ class TestTutorial(unittest.TestCase):
break

from fastNLP.models import CNNText
model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1)
model = CNNText((len(vocab), 50), num_classes=5, dropout=0.1)

from fastNLP import Trainer
from copy import deepcopy
@@ -143,7 +143,7 @@ class TestTutorial(unittest.TestCase):
is_input=True)

from fastNLP.models import CNNText
model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1)
model = CNNText((len(vocab), 50), num_classes=5, dropout=0.1)

from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam



Loading…
Cancel
Save