Browse Source

[update]对最近几个issue的更新

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
28ddc0c17e
3 changed files with 7 additions and 2 deletions
  1. +5
    -0
      fastNLP/core/dataset.py
  2. +1
    -1
      fastNLP/core/vocabulary.py
  3. +1
    -1
      fastNLP/modules/encoder/attention.py

+ 5
- 0
fastNLP/core/dataset.py View File

@@ -908,12 +908,17 @@ class DataSet(object):
:param bool shuffle: 在split前是否shuffle一下 :param bool shuffle: 在split前是否shuffle一下
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] :return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ]
""" """
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.'
assert isinstance(ratio, float) assert isinstance(ratio, float)
assert 0 < ratio < 1 assert 0 < ratio < 1
all_indices = [_ for _ in range(len(self))] all_indices = [_ for _ in range(len(self))]
if shuffle: if shuffle:
np.random.shuffle(all_indices) np.random.shuffle(all_indices)
split = int(ratio * len(self)) split = int(ratio * len(self))
if split == 0:
error_msg = f'Dev DataSet has {split} instance after split.'
logger.error(error_msg)
raise IndexError(error_msg)
dev_indices = all_indices[:split] dev_indices = all_indices[:split]
train_indices = all_indices[split:] train_indices = all_indices[split:]
dev_set = DataSet() dev_set = DataSet()


+ 1
- 1
fastNLP/core/vocabulary.py View File

@@ -209,7 +209,7 @@ class Vocabulary(object):
self._word2idx = {} self._word2idx = {}
if self.padding is not None: if self.padding is not None:
self._word2idx[self.padding] = len(self._word2idx) self._word2idx[self.padding] = len(self._word2idx)
if self.unknown is not None:
if (self.unknown is not None) and (self.unknown != self.padding):
self._word2idx[self.unknown] = len(self._word2idx) self._word2idx[self.unknown] = len(self._word2idx)
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None max_size = min(self.max_size, len(self.word_count)) if self.max_size else None


+ 1
- 1
fastNLP/modules/encoder/attention.py View File

@@ -26,7 +26,7 @@ class DotAttention(nn.Module):
self.value_size = value_size self.value_size = value_size
self.scale = math.sqrt(key_size) self.scale = math.sqrt(key_size)
self.drop = nn.Dropout(dropout) self.drop = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=2)
self.softmax = nn.Softmax(dim=-1)


def forward(self, Q, K, V, mask_out=None): def forward(self, Q, K, V, mask_out=None):
""" """


Loading…
Cancel
Save