diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 89b55a25..ca48a8e1 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -68,7 +68,11 @@ class DataSetGetter: else: data = f.pad(vlist) if not self.as_numpy: - data, flag = _to_tensor(data, f.dtype) + try: + data, flag = _to_tensor(data, f.dtype) + except TypeError as e: + print(f"Field {n} cannot be converted to torch.tensor.") + raise e batch_dict[n] = data return batch_dict @@ -173,15 +177,17 @@ class OnlineDataIter(BatchIter): def _to_tensor(batch, field_dtype): try: - if field_dtype is not None \ + if field_dtype is not None and isinstance(field_dtype, type)\ and issubclass(field_dtype, Number) \ and not isinstance(batch, torch.Tensor): if issubclass(batch.dtype.type, np.floating): new_batch = torch.as_tensor(batch).float() # 默认使用float32 + elif issubclass(batch.dtype.type, np.integer): + new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 else: - new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制 + new_batch = torch.as_tensor(batch) return new_batch, True else: return batch, False - except: - return batch, False + except Exception as e: + raise e diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 1c0ad235..65eb0194 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -395,6 +395,8 @@ def _get_ele_type_and_dim(cell:Any, dim=0): :return: """ if isinstance(cell, (str, Number, np.bool_)): + if hasattr(cell, 'dtype'): + return cell.dtype.type, dim return type(cell), dim elif isinstance(cell, list): dim += 1 @@ -412,7 +414,7 @@ def _get_ele_type_and_dim(cell:Any, dim=0): return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 elif isinstance(cell, np.ndarray): if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 - return cell.dtype.type, cell.ndim + dim + return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等 # 否则需要继续往下iterate dim += 1 res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] @@ -537,31 +539,30 @@ class AutoPadder(Padder): if field_ele_dtype: if dim>3: return np.array(contents) - if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str): - if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool): - if dim==0: + if isinstance(field_ele_dtype, type) and \ + (issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): + if dim==0: + array = np.array(contents, dtype=field_ele_dtype) + elif dim==1: + max_len = max(map(len, contents)) + array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + array[i, :len(content_i)] = content_i + elif dim==2: + max_len = max(map(len, contents)) + max_word_len = max([max([len(content_ii) for content_ii in content_i]) for + content_i in contents]) + array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) + for i, content_i in enumerate(contents): + for j, content_ii in enumerate(content_i): + array[i, j, :len(content_ii)] = content_ii + else: + shape = np.shape(contents) + if len(shape)==4: # 说明各dimension是相同的大小 array = np.array(contents, dtype=field_ele_dtype) - elif dim==1: - max_len = max(map(len, contents)) - array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) - for i, content_i in enumerate(contents): - array[i, :len(content_i)] = content_i - elif dim==2: - max_len = max(map(len, contents)) - max_word_len = max([max([len(content_ii) for content_ii in content_i]) for - content_i in contents]) - array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) - for i, content_i in enumerate(contents): - for j, content_ii in enumerate(content_i): - array[i, j, :len(content_ii)] = content_ii else: - shape = np.shape(contents) - if len(shape)==4: # 说明各dimension是相同的大小 - array = np.array(contents, dtype=field_ele_dtype) - else: - raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") - return array - return np.array(contents) + raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") + return array elif str(field_ele_dtype).startswith('torch'): if dim==0: tensor = torch.tensor(contents).to(field_ele_dtype) diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 398afe6b..536279de 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -99,7 +99,7 @@ class Tester(object): if isinstance(data, DataSet): self.data_iterator = DataSetIter( - dataset=data, batch_size=batch_size, num_workers=num_workers) + dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler()) elif isinstance(data, BatchIter): self.data_iterator = data else: diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index a0353279..254917e5 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -831,7 +831,8 @@ class _WordBertModel(nn.Module): # +2是由于需要加入[CLS]与[SEP] word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) word_pieces[:, 0].fill_(self._cls_index) - word_pieces[torch.arange(batch_size).to(words), word_pieces_lengths+1] = self._sep_index + batch_indexes = torch.arange(batch_size).to(words) + word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index attn_masks = torch.zeros_like(word_pieces) # 1. 获取words的word_pieces的id,以及对应的span范围 word_indexes = words.tolist() @@ -879,8 +880,8 @@ class _WordBertModel(nn.Module): start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) if self.include_cls_sep: - outputs[:, :, 0] = output_layer[:, 0] - outputs[:, :, seq_len+s_shift] = output_layer[:, seq_len+s_shift] + outputs[l_index, :, 0] = output_layer[:, 0] + outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift] # 3. 最终的embedding结果 return outputs diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 2ddb37ff..757973fe 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -73,12 +73,11 @@ class BertWordPieceEncoder(nn.Module): [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 :param datasets: DataSet对象 - :param field_name: str基于哪一列index + :param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 :return: """ self.model.index_dataset(*datasets, field_name=field_name) - def forward(self, word_pieces, token_type_ids=None): """ 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index 46e393b1..810b909f 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -51,7 +51,7 @@ class Embedding(nn.Module): self.dropout = nn.Dropout(dropout) if not isinstance(self.embed, TokenEmbedding): self._embed_size = self.embed.weight.size(1) - if dropout_word>0 and isinstance(unk_index, int): + if dropout_word>0 and not isinstance(unk_index, int): raise ValueError("When drop word is set, you need to pass in the unk_index.") else: self._embed_size = self.embed.embed_size @@ -512,7 +512,8 @@ class BertEmbedding(ContextualEmbedding): """ 别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding` - 使用BERT对words进行encode的Embedding。 + 使用BERT对words进行encode的Embedding。建议将输入的words长度限制在450以内,而不要使用512。这是由于预训练的bert模型长 + 度限制为512个token,而因为输入的word是未进行word piece分割的,在分割之后长度可能会超过最大长度限制。 Example:: @@ -523,7 +524,7 @@ class BertEmbedding(ContextualEmbedding): :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces - 中计算得到他对应的表示。支持``last``, ``first``, ``avg``, ``max``. + 中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 :param bool requires_grad: 是否需要gradient。 @@ -673,8 +674,8 @@ class CNNCharEmbedding(TokenEmbedding): self.char_pad_index = self.char_vocab.padding_idx print(f"In total, there are {len(self.char_vocab)} distinct characters.") # 对vocab进行index - self.max_word_len = max(map(lambda x: len(x[0]), vocab)) - self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), + max_word_len = max(map(lambda x: len(x[0]), vocab)) + self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), fill_value=self.char_pad_index, dtype=torch.long), requires_grad=False) self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) @@ -707,7 +708,7 @@ class CNNCharEmbedding(TokenEmbedding): # 为1的地方为mask chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size - chars = self.dropout(chars) + self.dropout(chars) reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)