From 28ddc0c17edce74ece3f6e2deee1db074c74c59e Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Thu, 12 Dec 2019 17:00:48 +0800 Subject: [PATCH] =?UTF-8?q?[update]=E5=AF=B9=E6=9C=80=E8=BF=91=E5=87=A0?= =?UTF-8?q?=E4=B8=AAissue=E7=9A=84=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 5 +++++ fastNLP/core/vocabulary.py | 2 +- fastNLP/modules/encoder/attention.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 12963616..0a24ab22 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -908,12 +908,17 @@ class DataSet(object): :param bool shuffle: 在split前是否shuffle一下 :return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] """ + assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' assert isinstance(ratio, float) assert 0 < ratio < 1 all_indices = [_ for _ in range(len(self))] if shuffle: np.random.shuffle(all_indices) split = int(ratio * len(self)) + if split == 0: + error_msg = f'Dev DataSet has {split} instance after split.' + logger.error(error_msg) + raise IndexError(error_msg) dev_indices = all_indices[:split] train_indices = all_indices[split:] dev_set = DataSet() diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 6d530eb6..3456061f 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -209,7 +209,7 @@ class Vocabulary(object): self._word2idx = {} if self.padding is not None: self._word2idx[self.padding] = len(self._word2idx) - if self.unknown is not None: + if (self.unknown is not None) and (self.unknown != self.padding): self._word2idx[self.unknown] = len(self._word2idx) max_size = min(self.max_size, len(self.word_count)) if self.max_size else None diff --git a/fastNLP/modules/encoder/attention.py b/fastNLP/modules/encoder/attention.py index b48be579..49a860bd 100644 --- a/fastNLP/modules/encoder/attention.py +++ b/fastNLP/modules/encoder/attention.py @@ -26,7 +26,7 @@ class DotAttention(nn.Module): self.value_size = value_size self.scale = math.sqrt(key_size) self.drop = nn.Dropout(dropout) - self.softmax = nn.Softmax(dim=2) + self.softmax = nn.Softmax(dim=-1) def forward(self, Q, K, V, mask_out=None): """