@@ -68,7 +68,11 @@ class DataSetGetter: | |||||
else: | else: | ||||
data = f.pad(vlist) | data = f.pad(vlist) | ||||
if not self.as_numpy: | if not self.as_numpy: | ||||
data, flag = _to_tensor(data, f.dtype) | |||||
try: | |||||
data, flag = _to_tensor(data, f.dtype) | |||||
except TypeError as e: | |||||
print(f"Field {n} cannot be converted to torch.tensor.") | |||||
raise e | |||||
batch_dict[n] = data | batch_dict[n] = data | ||||
return batch_dict | return batch_dict | ||||
@@ -173,15 +177,17 @@ class OnlineDataIter(BatchIter): | |||||
def _to_tensor(batch, field_dtype): | def _to_tensor(batch, field_dtype): | ||||
try: | try: | ||||
if field_dtype is not None \ | |||||
if field_dtype is not None and isinstance(field_dtype, type)\ | |||||
and issubclass(field_dtype, Number) \ | and issubclass(field_dtype, Number) \ | ||||
and not isinstance(batch, torch.Tensor): | and not isinstance(batch, torch.Tensor): | ||||
if issubclass(batch.dtype.type, np.floating): | if issubclass(batch.dtype.type, np.floating): | ||||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | new_batch = torch.as_tensor(batch).float() # 默认使用float32 | ||||
elif issubclass(batch.dtype.type, np.integer): | |||||
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 | |||||
else: | else: | ||||
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||||
new_batch = torch.as_tensor(batch) | |||||
return new_batch, True | return new_batch, True | ||||
else: | else: | ||||
return batch, False | return batch, False | ||||
except: | |||||
return batch, False | |||||
except Exception as e: | |||||
raise e |
@@ -395,6 +395,8 @@ def _get_ele_type_and_dim(cell:Any, dim=0): | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(cell, (str, Number, np.bool_)): | if isinstance(cell, (str, Number, np.bool_)): | ||||
if hasattr(cell, 'dtype'): | |||||
return cell.dtype.type, dim | |||||
return type(cell), dim | return type(cell), dim | ||||
elif isinstance(cell, list): | elif isinstance(cell, list): | ||||
dim += 1 | dim += 1 | ||||
@@ -412,7 +414,7 @@ def _get_ele_type_and_dim(cell:Any, dim=0): | |||||
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 | return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 | ||||
elif isinstance(cell, np.ndarray): | elif isinstance(cell, np.ndarray): | ||||
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 | if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 | ||||
return cell.dtype.type, cell.ndim + dim | |||||
return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等 | |||||
# 否则需要继续往下iterate | # 否则需要继续往下iterate | ||||
dim += 1 | dim += 1 | ||||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | ||||
@@ -537,31 +539,30 @@ class AutoPadder(Padder): | |||||
if field_ele_dtype: | if field_ele_dtype: | ||||
if dim>3: | if dim>3: | ||||
return np.array(contents) | return np.array(contents) | ||||
if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str): | |||||
if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool): | |||||
if dim==0: | |||||
if isinstance(field_ele_dtype, type) and \ | |||||
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): | |||||
if dim==0: | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
elif dim==1: | |||||
max_len = max(map(len, contents)) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
array[i, :len(content_i)] = content_i | |||||
elif dim==2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
else: | |||||
shape = np.shape(contents) | |||||
if len(shape)==4: # 说明各dimension是相同的大小 | |||||
array = np.array(contents, dtype=field_ele_dtype) | array = np.array(contents, dtype=field_ele_dtype) | ||||
elif dim==1: | |||||
max_len = max(map(len, contents)) | |||||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
array[i, :len(content_i)] = content_i | |||||
elif dim==2: | |||||
max_len = max(map(len, contents)) | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
else: | else: | ||||
shape = np.shape(contents) | |||||
if len(shape)==4: # 说明各dimension是相同的大小 | |||||
array = np.array(contents, dtype=field_ele_dtype) | |||||
else: | |||||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return array | |||||
return np.array(contents) | |||||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||||
return array | |||||
elif str(field_ele_dtype).startswith('torch'): | elif str(field_ele_dtype).startswith('torch'): | ||||
if dim==0: | if dim==0: | ||||
tensor = torch.tensor(contents).to(field_ele_dtype) | tensor = torch.tensor(contents).to(field_ele_dtype) | ||||
@@ -99,7 +99,7 @@ class Tester(object): | |||||
if isinstance(data, DataSet): | if isinstance(data, DataSet): | ||||
self.data_iterator = DataSetIter( | self.data_iterator = DataSetIter( | ||||
dataset=data, batch_size=batch_size, num_workers=num_workers) | |||||
dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler()) | |||||
elif isinstance(data, BatchIter): | elif isinstance(data, BatchIter): | ||||
self.data_iterator = data | self.data_iterator = data | ||||
else: | else: | ||||
@@ -831,7 +831,8 @@ class _WordBertModel(nn.Module): | |||||
# +2是由于需要加入[CLS]与[SEP] | # +2是由于需要加入[CLS]与[SEP] | ||||
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | ||||
word_pieces[:, 0].fill_(self._cls_index) | word_pieces[:, 0].fill_(self._cls_index) | ||||
word_pieces[torch.arange(batch_size).to(words), word_pieces_lengths+1] = self._sep_index | |||||
batch_indexes = torch.arange(batch_size).to(words) | |||||
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index | |||||
attn_masks = torch.zeros_like(word_pieces) | attn_masks = torch.zeros_like(word_pieces) | ||||
# 1. 获取words的word_pieces的id,以及对应的span范围 | # 1. 获取words的word_pieces的id,以及对应的span范围 | ||||
word_indexes = words.tolist() | word_indexes = words.tolist() | ||||
@@ -879,8 +880,8 @@ class _WordBertModel(nn.Module): | |||||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | ||||
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | ||||
if self.include_cls_sep: | if self.include_cls_sep: | ||||
outputs[:, :, 0] = output_layer[:, 0] | |||||
outputs[:, :, seq_len+s_shift] = output_layer[:, seq_len+s_shift] | |||||
outputs[l_index, :, 0] = output_layer[:, 0] | |||||
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift] | |||||
# 3. 最终的embedding结果 | # 3. 最终的embedding结果 | ||||
return outputs | return outputs | ||||
@@ -73,12 +73,11 @@ class BertWordPieceEncoder(nn.Module): | |||||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | ||||
:param datasets: DataSet对象 | :param datasets: DataSet对象 | ||||
:param field_name: str基于哪一列index | |||||
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||||
:return: | :return: | ||||
""" | """ | ||||
self.model.index_dataset(*datasets, field_name=field_name) | self.model.index_dataset(*datasets, field_name=field_name) | ||||
def forward(self, word_pieces, token_type_ids=None): | def forward(self, word_pieces, token_type_ids=None): | ||||
""" | """ | ||||
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | ||||
@@ -51,7 +51,7 @@ class Embedding(nn.Module): | |||||
self.dropout = nn.Dropout(dropout) | self.dropout = nn.Dropout(dropout) | ||||
if not isinstance(self.embed, TokenEmbedding): | if not isinstance(self.embed, TokenEmbedding): | ||||
self._embed_size = self.embed.weight.size(1) | self._embed_size = self.embed.weight.size(1) | ||||
if dropout_word>0 and isinstance(unk_index, int): | |||||
if dropout_word>0 and not isinstance(unk_index, int): | |||||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | raise ValueError("When drop word is set, you need to pass in the unk_index.") | ||||
else: | else: | ||||
self._embed_size = self.embed.embed_size | self._embed_size = self.embed.embed_size | ||||
@@ -512,7 +512,8 @@ class BertEmbedding(ContextualEmbedding): | |||||
""" | """ | ||||
别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding` | 别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding` | ||||
使用BERT对words进行encode的Embedding。 | |||||
使用BERT对words进行encode的Embedding。建议将输入的words长度限制在450以内,而不要使用512。这是由于预训练的bert模型长 | |||||
度限制为512个token,而因为输入的word是未进行word piece分割的,在分割之后长度可能会超过最大长度限制。 | |||||
Example:: | Example:: | ||||
@@ -523,7 +524,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | ||||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | ||||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | ||||
中计算得到他对应的表示。支持``last``, ``first``, ``avg``, ``max``. | |||||
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 | |||||
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | ||||
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 | 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 | ||||
:param bool requires_grad: 是否需要gradient。 | :param bool requires_grad: 是否需要gradient。 | ||||
@@ -673,8 +674,8 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
self.char_pad_index = self.char_vocab.padding_idx | self.char_pad_index = self.char_vocab.padding_idx | ||||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | print(f"In total, there are {len(self.char_vocab)} distinct characters.") | ||||
# 对vocab进行index | # 对vocab进行index | ||||
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | |||||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | |||||
fill_value=self.char_pad_index, dtype=torch.long), | fill_value=self.char_pad_index, dtype=torch.long), | ||||
requires_grad=False) | requires_grad=False) | ||||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | ||||
@@ -707,7 +708,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
# 为1的地方为mask | # 为1的地方为mask | ||||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | ||||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | ||||
chars = self.dropout(chars) | |||||
self.dropout(chars) | |||||
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) | reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) | ||||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | ||||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | ||||