Browse Source

1.修复BertEmbedding中的bug; 2. 修复Batch, Field在进行类型转换时的bug

tags/v0.4.10
yh_cc 6 years ago
parent
commit
4b0c26d338
6 changed files with 49 additions and 41 deletions
  1. +11
    -5
      fastNLP/core/batch.py
  2. +25
    -24
      fastNLP/core/field.py
  3. +1
    -1
      fastNLP/core/tester.py
  4. +4
    -3
      fastNLP/modules/encoder/_bert.py
  5. +1
    -2
      fastNLP/modules/encoder/bert.py
  6. +7
    -6
      fastNLP/modules/encoder/embedding.py

+ 11
- 5
fastNLP/core/batch.py View File

@@ -68,7 +68,11 @@ class DataSetGetter:
else: else:
data = f.pad(vlist) data = f.pad(vlist)
if not self.as_numpy: if not self.as_numpy:
data, flag = _to_tensor(data, f.dtype)
try:
data, flag = _to_tensor(data, f.dtype)
except TypeError as e:
print(f"Field {n} cannot be converted to torch.tensor.")
raise e
batch_dict[n] = data batch_dict[n] = data
return batch_dict return batch_dict


@@ -173,15 +177,17 @@ class OnlineDataIter(BatchIter):


def _to_tensor(batch, field_dtype): def _to_tensor(batch, field_dtype):
try: try:
if field_dtype is not None \
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \ and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor): and not isinstance(batch, torch.Tensor):
if issubclass(batch.dtype.type, np.floating): if issubclass(batch.dtype.type, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32 new_batch = torch.as_tensor(batch).float() # 默认使用float32
elif issubclass(batch.dtype.type, np.integer):
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
else: else:
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制
new_batch = torch.as_tensor(batch)
return new_batch, True return new_batch, True
else: else:
return batch, False return batch, False
except:
return batch, False
except Exception as e:
raise e

+ 25
- 24
fastNLP/core/field.py View File

@@ -395,6 +395,8 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
:return: :return:
""" """
if isinstance(cell, (str, Number, np.bool_)): if isinstance(cell, (str, Number, np.bool_)):
if hasattr(cell, 'dtype'):
return cell.dtype.type, dim
return type(cell), dim return type(cell), dim
elif isinstance(cell, list): elif isinstance(cell, list):
dim += 1 dim += 1
@@ -412,7 +414,7 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0
elif isinstance(cell, np.ndarray): elif isinstance(cell, np.ndarray):
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了
return cell.dtype.type, cell.ndim + dim
return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等
# 否则需要继续往下iterate # 否则需要继续往下iterate
dim += 1 dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
@@ -537,31 +539,30 @@ class AutoPadder(Padder):
if field_ele_dtype: if field_ele_dtype:
if dim>3: if dim>3:
return np.array(contents) return np.array(contents)
if isinstance(field_ele_dtype, np.dtype) or field_ele_dtype in (float, int, bool, str):
if isinstance(field_ele_dtype, np.number) or field_ele_dtype in (float, int, bool):
if dim==0:
if isinstance(field_ele_dtype, type) and \
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)):
if dim==0:
array = np.array(contents, dtype=field_ele_dtype)
elif dim==1:
max_len = max(map(len, contents))
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
array[i, :len(content_i)] = content_i
elif dim==2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
for j, content_ii in enumerate(content_i):
array[i, j, :len(content_ii)] = content_ii
else:
shape = np.shape(contents)
if len(shape)==4: # 说明各dimension是相同的大小
array = np.array(contents, dtype=field_ele_dtype) array = np.array(contents, dtype=field_ele_dtype)
elif dim==1:
max_len = max(map(len, contents))
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
array[i, :len(content_i)] = content_i
elif dim==2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
for j, content_ii in enumerate(content_i):
array[i, j, :len(content_ii)] = content_ii
else: else:
shape = np.shape(contents)
if len(shape)==4: # 说明各dimension是相同的大小
array = np.array(contents, dtype=field_ele_dtype)
else:
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return array
return np.array(contents)
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return array
elif str(field_ele_dtype).startswith('torch'): elif str(field_ele_dtype).startswith('torch'):
if dim==0: if dim==0:
tensor = torch.tensor(contents).to(field_ele_dtype) tensor = torch.tensor(contents).to(field_ele_dtype)


+ 1
- 1
fastNLP/core/tester.py View File

@@ -99,7 +99,7 @@ class Tester(object):


if isinstance(data, DataSet): if isinstance(data, DataSet):
self.data_iterator = DataSetIter( self.data_iterator = DataSetIter(
dataset=data, batch_size=batch_size, num_workers=num_workers)
dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler())
elif isinstance(data, BatchIter): elif isinstance(data, BatchIter):
self.data_iterator = data self.data_iterator = data
else: else:


+ 4
- 3
fastNLP/modules/encoder/_bert.py View File

@@ -831,7 +831,8 @@ class _WordBertModel(nn.Module):
# +2是由于需要加入[CLS]与[SEP] # +2是由于需要加入[CLS]与[SEP]
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index)
word_pieces[:, 0].fill_(self._cls_index) word_pieces[:, 0].fill_(self._cls_index)
word_pieces[torch.arange(batch_size).to(words), word_pieces_lengths+1] = self._sep_index
batch_indexes = torch.arange(batch_size).to(words)
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index
attn_masks = torch.zeros_like(word_pieces) attn_masks = torch.zeros_like(word_pieces)
# 1. 获取words的word_pieces的id,以及对应的span范围 # 1. 获取words的word_pieces的id,以及对应的span范围
word_indexes = words.tolist() word_indexes = words.tolist()
@@ -879,8 +880,8 @@ class _WordBertModel(nn.Module):
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1]
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
if self.include_cls_sep: if self.include_cls_sep:
outputs[:, :, 0] = output_layer[:, 0]
outputs[:, :, seq_len+s_shift] = output_layer[:, seq_len+s_shift]
outputs[l_index, :, 0] = output_layer[:, 0]
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift]
# 3. 最终的embedding结果 # 3. 最终的embedding结果
return outputs return outputs




+ 1
- 2
fastNLP/modules/encoder/bert.py View File

@@ -73,12 +73,11 @@ class BertWordPieceEncoder(nn.Module):
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 [CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。


:param datasets: DataSet对象 :param datasets: DataSet对象
:param field_name: str基于哪一列index
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
:return: :return:
""" """
self.model.index_dataset(*datasets, field_name=field_name) self.model.index_dataset(*datasets, field_name=field_name)



def forward(self, word_pieces, token_type_ids=None): def forward(self, word_pieces, token_type_ids=None):
""" """
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。


+ 7
- 6
fastNLP/modules/encoder/embedding.py View File

@@ -51,7 +51,7 @@ class Embedding(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
if not isinstance(self.embed, TokenEmbedding): if not isinstance(self.embed, TokenEmbedding):
self._embed_size = self.embed.weight.size(1) self._embed_size = self.embed.weight.size(1)
if dropout_word>0 and isinstance(unk_index, int):
if dropout_word>0 and not isinstance(unk_index, int):
raise ValueError("When drop word is set, you need to pass in the unk_index.") raise ValueError("When drop word is set, you need to pass in the unk_index.")
else: else:
self._embed_size = self.embed.embed_size self._embed_size = self.embed.embed_size
@@ -512,7 +512,8 @@ class BertEmbedding(ContextualEmbedding):
""" """
别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding` 别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding`


使用BERT对words进行encode的Embedding。
使用BERT对words进行encode的Embedding。建议将输入的words长度限制在450以内,而不要使用512。这是由于预训练的bert模型长
度限制为512个token,而因为输入的word是未进行word piece分割的,在分割之后长度可能会超过最大长度限制。


Example:: Example::


@@ -523,7 +524,7 @@ class BertEmbedding(ContextualEmbedding):
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased``
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
中计算得到他对应的表示。支持``last``, ``first``, ``avg``, ``max``.
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。
:param bool requires_grad: 是否需要gradient。 :param bool requires_grad: 是否需要gradient。
@@ -673,8 +674,8 @@ class CNNCharEmbedding(TokenEmbedding):
self.char_pad_index = self.char_vocab.padding_idx self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.") print(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
self.max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),
max_word_len = max(map(lambda x: len(x[0]), vocab))
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len),
fill_value=self.char_pad_index, dtype=torch.long), fill_value=self.char_pad_index, dtype=torch.long),
requires_grad=False) requires_grad=False)
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
@@ -707,7 +708,7 @@ class CNNCharEmbedding(TokenEmbedding):
# 为1的地方为mask # 为1的地方为mask
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
chars = self.dropout(chars)
self.dropout(chars)
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1)
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)


Loading…
Cancel
Save