@@ -122,7 +122,7 @@ class SamplerAdapter(torch.utils.data.Sampler): | |||
class BatchIter: | |||
""" | |||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及属性。 | |||
Trainer用于迭代数据的类。继承该类,并实现get_num_batches(), get_batch_indices(), num_batches(), __iter__()方法以及dataset属性。 | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, | |||
@@ -259,8 +259,8 @@ class DataSetIter(BatchIter): | |||
class TorchLoaderIter(BatchIter): | |||
""" | |||
与DataSetIter类似,但可以用于pytorch的DataSet对象。可以通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 | |||
或者也可以传入任何实现了类似以下方法的对象 | |||
与DataSetIter类似,但可以用于非fastNLP的数据容器对象,然后将其传入到Trainer中。 | |||
只需要保证数据容器实现了实现了以下的方法 | |||
Example:: | |||
@@ -287,7 +287,7 @@ class TorchLoaderIter(BatchIter): | |||
x, y = l | |||
xs.append(x) | |||
ys.append(y) | |||
# 不需要转移到gpu,Trainer和Tester会将其转移到model所在的device | |||
# 不需要转移到gpu,Trainer或Tester会将其转移到model所在的device | |||
x,y = torch.FloatTensor(xs), torch.FloatTensor(ys) | |||
return {'x':x, 'y':y}, {'y':y} | |||
@@ -620,7 +620,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
:param args: | |||
:return: | |||
""" | |||
if not torch.cuda.is_available(): | |||
if not torch.cuda.is_available() or device is None: | |||
return | |||
if not isinstance(device, torch.device): | |||
@@ -56,7 +56,7 @@ class BertEmbedding(ContextualEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False): | |||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False, **kwargs): | |||
""" | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
@@ -77,6 +77,9 @@ class BertEmbedding(ContextualEmbedding): | |||
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 | |||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | |||
来进行分类的任务将auto_truncate置为True。 | |||
:param kwargs: | |||
bool only_use_pretrain_bpe: 仅使用出现在pretrain词表中的bpe,如果该词没法tokenize则使用unk。如果embedding不需要更新 | |||
建议设置为True。 | |||
""" | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
@@ -93,10 +96,13 @@ class BertEmbedding(ContextualEmbedding): | |||
self._word_sep_index = None | |||
if '[SEP]' in vocab: | |||
self._word_sep_index = vocab['[SEP]'] | |||
only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False) | |||
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2, | |||
only_use_pretrain_bpe=only_use_pretrain_bpe) | |||
self._sep_index = self.model._sep_index | |||
self._cls_index = self.model._cls_index | |||
self.requires_grad = requires_grad | |||
@@ -254,7 +260,8 @@ class BertWordPieceEncoder(nn.Module): | |||
class _WordBertModel(nn.Module): | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2, | |||
only_use_pretrain_bpe=False): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
@@ -302,7 +309,7 @@ class _WordBertModel(nn.Module): | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | |||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
word): # 出现次数大于这个次数才新增 | |||
word) and not only_use_pretrain_bpe: # 出现次数大于这个次数才新增 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
continue | |||
for word_piece in word_pieces: | |||
@@ -88,8 +88,8 @@ class StaticEmbedding(TokenEmbedding): | |||
:param dict kwargs: | |||
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | |||
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | |||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 | |||
不需要更新设置为True。 | |||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果 | |||
embedding不需要更新建议设置为True。 | |||
""" | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if embedding_dim > 0: | |||
@@ -605,7 +605,7 @@ class BertModel(nn.Module): | |||
logger.warning("Weights of {} not initialized from pretrained model: {}".format( | |||
model.__class__.__name__, missing_keys)) | |||
if len(unexpected_keys) > 0: | |||
logger.warning("Weights from pretrained model not used in {}: {}".format( | |||
logger.debug("Weights from pretrained model not used in {}: {}".format( | |||
model.__class__.__name__, unexpected_keys)) | |||
logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") | |||
@@ -29,7 +29,7 @@ class TestDownload(unittest.TestCase): | |||
class TestBertEmbedding(unittest.TestCase): | |||
def test_bert_embedding_1(self): | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP]".split()) | |||
vocab = Vocabulary().add_word_lst("this is a test . [SEP] NotInBERT".split()) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1) | |||
requires_grad = embed.requires_grad | |||
embed.requires_grad = not requires_grad | |||
@@ -38,6 +38,13 @@ class TestBertEmbedding(unittest.TestCase): | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
embed = BertEmbedding(vocab, model_dir_or_name='test/data_for_tests/embedding/small_bert', word_dropout=0.1, | |||
only_use_pretrain_bpe=True) | |||
embed.eval() | |||
words = torch.LongTensor([[2, 3, 4, 0]]) | |||
result = embed(words) | |||
self.assertEqual(result.size(), (1, 4, 16)) | |||
class TestBertWordPieceEncoder(unittest.TestCase): | |||
def test_bert_word_piece_encoder(self): | |||