diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 5816d2af..901f2dd4 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -110,5 +110,5 @@ class ESIM(BaseModel): def predict(self, words1, words2, seq_len1, seq_len2): prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] - return torch.argmax(prediction, dim=-1) + return {'pred': torch.argmax(prediction, dim=-1)} diff --git a/test/test_tutorials.py b/test/test_tutorials.py index bc0b5d2b..600699a3 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -379,6 +379,14 @@ class TestTutorial(unittest.TestCase): dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') train_data_2[-1], dev_data_2[-1] + for data in [train_data, dev_data, test_data]: + data.rename_field('premise', 'words1') + data.rename_field('hypothesis', 'words2') + data.rename_field('premise_len', 'seq_len1') + data.rename_field('hypothesis_len', 'seq_len2') + data.set_input('words1', 'words2', 'seq_len1', 'seq_len2') + + # step 1:加载模型参数(非必选) from fastNLP.io.config_io import ConfigSection, ConfigLoader args = ConfigSection()