@@ -908,12 +908,17 @@ class DataSet(object): | |||||
:param bool shuffle: 在split前是否shuffle一下 | :param bool shuffle: 在split前是否shuffle一下 | ||||
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | :return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | ||||
""" | """ | ||||
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | |||||
assert isinstance(ratio, float) | assert isinstance(ratio, float) | ||||
assert 0 < ratio < 1 | assert 0 < ratio < 1 | ||||
all_indices = [_ for _ in range(len(self))] | all_indices = [_ for _ in range(len(self))] | ||||
if shuffle: | if shuffle: | ||||
np.random.shuffle(all_indices) | np.random.shuffle(all_indices) | ||||
split = int(ratio * len(self)) | split = int(ratio * len(self)) | ||||
if split == 0: | |||||
error_msg = f'Dev DataSet has {split} instance after split.' | |||||
logger.error(error_msg) | |||||
raise IndexError(error_msg) | |||||
dev_indices = all_indices[:split] | dev_indices = all_indices[:split] | ||||
train_indices = all_indices[split:] | train_indices = all_indices[split:] | ||||
dev_set = DataSet() | dev_set = DataSet() | ||||
@@ -209,7 +209,7 @@ class Vocabulary(object): | |||||
self._word2idx = {} | self._word2idx = {} | ||||
if self.padding is not None: | if self.padding is not None: | ||||
self._word2idx[self.padding] = len(self._word2idx) | self._word2idx[self.padding] = len(self._word2idx) | ||||
if self.unknown is not None: | |||||
if (self.unknown is not None) and (self.unknown != self.padding): | |||||
self._word2idx[self.unknown] = len(self._word2idx) | self._word2idx[self.unknown] = len(self._word2idx) | ||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
@@ -26,7 +26,7 @@ class DotAttention(nn.Module): | |||||
self.value_size = value_size | self.value_size = value_size | ||||
self.scale = math.sqrt(key_size) | self.scale = math.sqrt(key_size) | ||||
self.drop = nn.Dropout(dropout) | self.drop = nn.Dropout(dropout) | ||||
self.softmax = nn.Softmax(dim=2) | |||||
self.softmax = nn.Softmax(dim=-1) | |||||
def forward(self, Q, K, V, mask_out=None): | def forward(self, Q, K, V, mask_out=None): | ||||
""" | """ | ||||