Browse Source

- fix test_tutorial

tags/v0.4.10
yunfan 6 years ago
parent
commit
abff8d9daa
2 changed files with 9 additions and 1 deletions
  1. +1
    -1
      fastNLP/models/snli.py
  2. +8
    -0
      test/test_tutorials.py

+ 1
- 1
fastNLP/models/snli.py View File

@@ -110,5 +110,5 @@ class ESIM(BaseModel):

def predict(self, words1, words2, seq_len1, seq_len2):
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred']
return torch.argmax(prediction, dim=-1)
return {'pred': torch.argmax(prediction, dim=-1)}


+ 8
- 0
test/test_tutorials.py View File

@@ -379,6 +379,14 @@ class TestTutorial(unittest.TestCase):
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')
train_data_2[-1], dev_data_2[-1]

for data in [train_data, dev_data, test_data]:
data.rename_field('premise', 'words1')
data.rename_field('hypothesis', 'words2')
data.rename_field('premise_len', 'seq_len1')
data.rename_field('hypothesis_len', 'seq_len2')
data.set_input('words1', 'words2', 'seq_len1', 'seq_len2')


# step 1:加载模型参数(非必选)
from fastNLP.io.config_io import ConfigSection, ConfigLoader
args = ConfigSection()


Loading…
Cancel
Save