Browse Source

1. 增强BertEmbedding使其可以自动判断token_type_ids; 2.增加CrossEntropyLoss中对label dimension的报错处理

tags/v0.4.10
yh 6 years ago
parent
commit
7a21c2a587
5 changed files with 72 additions and 44 deletions
  1. +3
    -1
      fastNLP/core/losses.py
  2. +1
    -1
      fastNLP/core/trainer.py
  3. +15
    -2
      fastNLP/embeddings/bert_embedding.py
  4. +4
    -4
      reproduction/joint_cws_parse/models/CharParser.py
  5. +49
    -36
      reproduction/joint_cws_parse/train.py

+ 3
- 1
fastNLP/core/losses.py View File

@@ -28,6 +28,7 @@ from .utils import _check_arg_dict_list
from .utils import _check_function_or_method
from .utils import _get_func_signature
from .utils import seq_len_to_mask
import warnings


class LossBase(object):
@@ -226,7 +227,8 @@ class CrossEntropyLoss(LossBase):
def get_loss(self, pred, target, seq_len=None):
if pred.dim() > 2:
if pred.size(1) != target.size(1): # 有可能顺序替换了
pred = pred.transpose(1, 2)
raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)."
" It should be (batch_size, max_len, num_labels).")
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
if seq_len is not None:


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -942,7 +942,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL
if dev_data is not None:
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
batch_size=batch_size, verbose=-1, use_tqdm=False)
evaluate_results = tester.test()
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)



+ 15
- 2
fastNLP/embeddings/bert_embedding.py View File

@@ -11,7 +11,7 @@ from ..core.vocabulary import Vocabulary
from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
from .contextual_embedding import ContextualEmbedding
import warnings

class BertEmbedding(ContextualEmbedding):
"""
@@ -229,6 +229,10 @@ class _WordBertModel(nn.Module):
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的
found_count = 0
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
if "[CLS]" in vocab:
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin "
"and end of the sentence automatically.")
for word, index in vocab:
if index == vocab.padding_idx: # pad是个特殊的符号
word = '[PAD]'
@@ -316,9 +320,18 @@ class _WordBertModel(nn.Module):
word_pieces[:, 0].fill_(self._cls_index)
batch_indexes = torch.arange(batch_size).to(words)
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index
if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids
with torch.no_grad():
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
sep_mask_cumsum = sep_mask.flip(dim=-1).cumsum(dim=-1).flip(dim=-1)
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).float()
else:
token_type_ids = torch.zeros_like(word_pieces)
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks,
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size



+ 4
- 4
reproduction/joint_cws_parse/models/CharParser.py View File

@@ -224,11 +224,11 @@ class CharBiaffineParser(BiaffineParser):

batch_size, seq_len, _ = arc_pred.shape
flip_mask = (mask == 0)
_arc_pred = arc_pred.clone()
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
# _arc_pred = arc_pred.clone()
_arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf'))

arc_true[:, 0].fill_(-1)
label_true[:, 0].fill_(-1)
arc_true.data[:, 0].fill_(-1)
label_true.data[:, 0].fill_(-1)

arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1)
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1)


+ 49
- 36
reproduction/joint_cws_parse/train.py View File

@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import StepLR
from fastNLP import Tester
from fastNLP import GradientClipCallback, LRScheduler
import os
from fastNLP import cache_results

def set_random_seed(random_seed=666):
import random, numpy, torch
@@ -39,43 +40,42 @@ label_mlp_size = 100
batch_size = 32
update_every = 4
n_epochs = 100
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt
data_name = 'new_ctb7'
####################################################
data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt

set_random_seed(1234)
device = 0

# @cache_results('caches/{}.pkl'.format(data_name))
# def get_data():
data = CTBxJointLoader().process(data_folder)

char_labels_vocab = data.vocabs['char_labels']

pre_chars_vocab = data.vocabs['pre_chars']
pre_bigrams_vocab = data.vocabs['pre_bigrams']
pre_trigrams_vocab = data.vocabs['pre_trigrams']

chars_vocab = data.vocabs['chars']
bigrams_vocab = data.vocabs['bigrams']
trigrams_vocab = data.vocabs['trigrams']

pre_chars_embed = StaticEmbedding(pre_chars_vocab,
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std()
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std()
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std()

# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data

# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()
@cache_results('caches/{}.pkl'.format(data_name))
def get_data():
data = CTBxJointLoader().process(data_folder)
char_labels_vocab = data.vocabs['char_labels']

pre_chars_vocab = data.vocabs['pre_chars']
pre_bigrams_vocab = data.vocabs['pre_bigrams']
pre_trigrams_vocab = data.vocabs['pre_trigrams']

chars_vocab = data.vocabs['chars']
bigrams_vocab = data.vocabs['bigrams']
trigrams_vocab = data.vocabs['trigrams']
pre_chars_embed = StaticEmbedding(pre_chars_vocab,
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std()
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std()
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std()

return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data

chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()

print(data)
model = CharParser(char_vocab_size=len(chars_vocab),
@@ -104,11 +104,24 @@ optimizer = optim.Adam([param for param in model.parameters() if param.requires_

sampler = BucketSampler(seq_len_field_name='seq_lens')
callbacks = []

from fastNLP.core.callback import Callback
from torch.optim.lr_scheduler import LambdaLR
class SchedulerCallback(Callback):
def __init__(self, scheduler):
super().__init__()
self.scheduler = scheduler

def on_backward_end(self):
if self.step % self.update_every==0:
self.scheduler.step()

scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every)
# scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
scheduler_callback = SchedulerCallback(scheduler)
# callbacks.append(optim_callback)
scheduler_callback = LRScheduler(scheduler)
# scheduler_callback = LRScheduler(scheduler)
callbacks.append(scheduler_callback)
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))

@@ -119,6 +132,6 @@ callbacks.append(dev_callback)

trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3,
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer,
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True,
check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True,
device=device, callbacks=callbacks, update_every=update_every)
trainer.train()

Loading…
Cancel
Save