@@ -0,0 +1,284 @@ | |||||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||||
from fastNLP.io.dataset_loader import ConllLoader | |||||
import numpy as np | |||||
from itertools import chain | |||||
from fastNLP import DataSet, Vocabulary | |||||
from functools import partial | |||||
import os | |||||
from typing import Union, Dict | |||||
from reproduction.utils import check_dataloader_paths | |||||
class CTBxJointLoader(DataSetLoader): | |||||
""" | |||||
文件夹下应该具有以下的文件结构 | |||||
-train.conllx | |||||
-dev.conllx | |||||
-test.conllx | |||||
每个文件中的内容如下(空格隔开不同的句子, 共有) | |||||
1 费孝通 _ NR NR _ 3 nsubjpass _ _ | |||||
2 被 _ SB SB _ 3 pass _ _ | |||||
3 授予 _ VV VV _ 0 root _ _ | |||||
4 麦格赛赛 _ NR NR _ 5 nn _ _ | |||||
5 奖 _ NN NN _ 3 dobj _ _ | |||||
1 新华社 _ NR NR _ 7 dep _ _ | |||||
2 马尼拉 _ NR NR _ 7 dep _ _ | |||||
3 8月 _ NT NT _ 7 dep _ _ | |||||
4 31日 _ NT NT _ 7 dep _ _ | |||||
... | |||||
""" | |||||
def __init__(self): | |||||
self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) | |||||
def load(self, path:str): | |||||
""" | |||||
给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 | |||||
words: list[str] | |||||
pos_tags: list[str] | |||||
heads: list[int] | |||||
labels: list[str] | |||||
:param path: | |||||
:return: | |||||
""" | |||||
dataset = self._loader.load(path) | |||||
dataset.heads.int() | |||||
return dataset | |||||
def process(self, paths): | |||||
""" | |||||
:param paths: | |||||
:return: | |||||
Dataset包含以下的field | |||||
chars: | |||||
bigrams: | |||||
trigrams: | |||||
pre_chars: | |||||
pre_bigrams: | |||||
pre_trigrams: | |||||
seg_targets: | |||||
seg_masks: | |||||
seq_lens: | |||||
char_labels: | |||||
char_heads: | |||||
gold_word_pairs: | |||||
seg_targets: | |||||
seg_masks: | |||||
char_labels: | |||||
char_heads: | |||||
pun_masks: | |||||
gold_label_word_pairs: | |||||
""" | |||||
paths = check_dataloader_paths(paths) | |||||
data = DataInfo() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
data.datasets[name] = dataset | |||||
char_labels_vocab = Vocabulary(padding=None, unknown=None) | |||||
def process(dataset, char_label_vocab): | |||||
dataset.apply(add_word_lst, new_field_name='word_lst') | |||||
dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') | |||||
dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') | |||||
dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') | |||||
dataset.apply(add_char_heads, new_field_name='char_heads') | |||||
dataset.apply(add_char_labels, new_field_name='char_labels') | |||||
dataset.apply(add_segs, new_field_name='seg_targets') | |||||
dataset.apply(add_mask, new_field_name='seg_masks') | |||||
dataset.add_seq_len('chars', new_field_name='seq_lens') | |||||
dataset.apply(add_pun_masks, new_field_name='pun_masks') | |||||
if len(char_label_vocab.word_count)==0: | |||||
char_label_vocab.from_dataset(dataset, field_name='char_labels') | |||||
char_label_vocab.index_dataset(dataset, field_name='char_labels') | |||||
new_dataset = add_root(dataset) | |||||
new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) | |||||
global add_label_word_pairs | |||||
add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) | |||||
new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) | |||||
new_dataset.set_pad_val('char_labels', -1) | |||||
new_dataset.set_pad_val('char_heads', -1) | |||||
return new_dataset | |||||
for name in list(paths.keys()): | |||||
dataset = data.datasets[name] | |||||
dataset = process(dataset, char_labels_vocab) | |||||
data.datasets[name] = dataset | |||||
data.vocabs['char_labels'] = char_labels_vocab | |||||
char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') | |||||
bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') | |||||
trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') | |||||
for name in ['chars', 'bigrams', 'trigrams']: | |||||
vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) | |||||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) | |||||
data.vocabs['pre_{}'.format(name)] = vocab | |||||
for name, vocab in zip(['chars', 'bigrams', 'trigrams'], | |||||
[char_vocab, bigram_vocab, trigram_vocab]): | |||||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) | |||||
data.vocabs[name] = vocab | |||||
for name, dataset in data.datasets.items(): | |||||
dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', | |||||
'pre_bigrams', 'pre_trigrams') | |||||
dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', | |||||
'char_heads', | |||||
'pun_masks', 'gold_label_word_pairs') | |||||
return data | |||||
def add_label_word_pairs(instance, label_vocab): | |||||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||||
word_end_indexes.insert(0, 0) | |||||
word_pairs = [] | |||||
labels = instance['labels'] | |||||
pos_tags = instance['pos_tags'] | |||||
for idx, head in enumerate(instance['heads']): | |||||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||||
continue | |||||
label = label_vocab.to_index(labels[idx]) | |||||
if head==0: | |||||
word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||||
else: | |||||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, | |||||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||||
return word_pairs | |||||
def add_word_pairs(instance): | |||||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||||
word_end_indexes.insert(0, 0) | |||||
word_pairs = [] | |||||
pos_tags = instance['pos_tags'] | |||||
for idx, head in enumerate(instance['heads']): | |||||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||||
continue | |||||
if head==0: | |||||
word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||||
else: | |||||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), | |||||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||||
return word_pairs | |||||
def add_root(dataset): | |||||
new_dataset = DataSet() | |||||
for sample in dataset: | |||||
chars = ['char_root'] + sample['chars'] | |||||
bigrams = ['bigram_root'] + sample['bigrams'] | |||||
trigrams = ['trigram_root'] + sample['trigrams'] | |||||
seq_lens = sample['seq_lens']+1 | |||||
char_labels = [0] + sample['char_labels'] | |||||
char_heads = [0] + sample['char_heads'] | |||||
sample['chars'] = chars | |||||
sample['bigrams'] = bigrams | |||||
sample['trigrams'] = trigrams | |||||
sample['seq_lens'] = seq_lens | |||||
sample['char_labels'] = char_labels | |||||
sample['char_heads'] = char_heads | |||||
new_dataset.append(sample) | |||||
return new_dataset | |||||
def add_pun_masks(instance): | |||||
tags = instance['pos_tags'] | |||||
pun_masks = [] | |||||
for word, tag in zip(instance['words'], tags): | |||||
if tag=='PU': | |||||
pun_masks.extend([1]*len(word)) | |||||
else: | |||||
pun_masks.extend([0]*len(word)) | |||||
return pun_masks | |||||
def add_word_lst(instance): | |||||
words = instance['words'] | |||||
word_lst = [list(word) for word in words] | |||||
return word_lst | |||||
def add_bigram(instance): | |||||
chars = instance['chars'] | |||||
length = len(chars) | |||||
chars = chars + ['<eos>'] | |||||
bigrams = [] | |||||
for i in range(length): | |||||
bigrams.append(''.join(chars[i:i + 2])) | |||||
return bigrams | |||||
def add_trigram(instance): | |||||
chars = instance['chars'] | |||||
length = len(chars) | |||||
chars = chars + ['<eos>'] * 2 | |||||
trigrams = [] | |||||
for i in range(length): | |||||
trigrams.append(''.join(chars[i:i + 3])) | |||||
return trigrams | |||||
def add_char_heads(instance): | |||||
words = instance['word_lst'] | |||||
heads = instance['heads'] | |||||
char_heads = [] | |||||
char_index = 1 # 因此存在root节点所以需要从1开始 | |||||
head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 | |||||
for word, head in zip(words, heads): | |||||
char_head = [] | |||||
if len(word)>1: | |||||
char_head.append(char_index+1) | |||||
char_index += 1 | |||||
for _ in range(len(word)-2): | |||||
char_index += 1 | |||||
char_head.append(char_index) | |||||
char_index += 1 | |||||
char_head.append(head_end_indexes[head-1]) | |||||
char_heads.extend(char_head) | |||||
return char_heads | |||||
def add_char_labels(instance): | |||||
""" | |||||
将word_lst中的数据按照下面的方式设置label | |||||
比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. | |||||
对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label | |||||
:param instance: | |||||
:return: | |||||
""" | |||||
words = instance['word_lst'] | |||||
labels = instance['labels'] | |||||
char_labels = [] | |||||
for word, label in zip(words, labels): | |||||
for _ in range(len(word)-1): | |||||
char_labels.append('APP') | |||||
char_labels.append(label) | |||||
return char_labels | |||||
# add seg_targets | |||||
def add_segs(instance): | |||||
words = instance['word_lst'] | |||||
segs = [0]*len(instance['chars']) | |||||
index = 0 | |||||
for word in words: | |||||
index = index + len(word) - 1 | |||||
segs[index] = len(word)-1 | |||||
index = index + 1 | |||||
return segs | |||||
# add target_masks | |||||
def add_mask(instance): | |||||
words = instance['word_lst'] | |||||
mask = [] | |||||
for word in words: | |||||
mask.extend([0] * (len(word) - 1)) | |||||
mask.append(1) | |||||
return mask |
@@ -0,0 +1,311 @@ | |||||
from fastNLP.models.biaffine_parser import BiaffineParser | |||||
from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear | |||||
import numpy as np | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
from fastNLP.modules.dropout import TimestepDropout | |||||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||||
from fastNLP import seq_len_to_mask | |||||
from fastNLP.modules import Embedding | |||||
def drop_input_independent(word_embeddings, dropout_emb): | |||||
batch_size, seq_length, _ = word_embeddings.size() | |||||
word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) | |||||
word_masks = torch.bernoulli(word_masks) | |||||
word_masks = word_masks.unsqueeze(dim=2) | |||||
word_embeddings = word_embeddings * word_masks | |||||
return word_embeddings | |||||
class CharBiaffineParser(BiaffineParser): | |||||
def __init__(self, char_vocab_size, | |||||
emb_dim, | |||||
bigram_vocab_size, | |||||
trigram_vocab_size, | |||||
num_label, | |||||
rnn_layers=3, | |||||
rnn_hidden_size=800, #单向的数量 | |||||
arc_mlp_size=500, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='lstm', | |||||
use_greedy_infer=False, | |||||
app_index = 0, | |||||
pre_chars_embed=None, | |||||
pre_bigrams_embed=None, | |||||
pre_trigrams_embed=None): | |||||
super(BiaffineParser, self).__init__() | |||||
rnn_out_size = 2 * rnn_hidden_size | |||||
self.char_embed = Embedding((char_vocab_size, emb_dim)) | |||||
self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) | |||||
self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) | |||||
if pre_chars_embed: | |||||
self.pre_char_embed = Embedding(pre_chars_embed) | |||||
self.pre_char_embed.requires_grad = False | |||||
if pre_bigrams_embed: | |||||
self.pre_bigram_embed = Embedding(pre_bigrams_embed) | |||||
self.pre_bigram_embed.requires_grad = False | |||||
if pre_trigrams_embed: | |||||
self.pre_trigram_embed = Embedding(pre_trigrams_embed) | |||||
self.pre_trigram_embed.requires_grad = False | |||||
self.timestep_drop = TimestepDropout(dropout) | |||||
self.encoder_name = encoder | |||||
if encoder == 'var-lstm': | |||||
self.encoder = VarLSTM(input_size=emb_dim*3, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
input_dropout=dropout, | |||||
hidden_dropout=dropout, | |||||
bidirectional=True) | |||||
elif encoder == 'lstm': | |||||
self.encoder = nn.LSTM(input_size=emb_dim*3, | |||||
hidden_size=rnn_hidden_size, | |||||
num_layers=rnn_layers, | |||||
bias=True, | |||||
batch_first=True, | |||||
dropout=dropout, | |||||
bidirectional=True) | |||||
else: | |||||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||||
nn.LeakyReLU(0.1), | |||||
TimestepDropout(p=dropout),) | |||||
self.arc_mlp_size = arc_mlp_size | |||||
self.label_mlp_size = label_mlp_size | |||||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||||
self.use_greedy_infer = use_greedy_infer | |||||
self.reset_parameters() | |||||
self.dropout = dropout | |||||
self.app_index = app_index | |||||
self.num_label = num_label | |||||
if self.app_index != 0: | |||||
raise ValueError("现在app_index必须等于0") | |||||
def reset_parameters(self): | |||||
for name, m in self.named_modules(): | |||||
if 'embed' in name: | |||||
pass | |||||
elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): | |||||
pass | |||||
else: | |||||
for p in m.parameters(): | |||||
if len(p.size())>1: | |||||
nn.init.xavier_normal_(p, gain=0.1) | |||||
else: | |||||
nn.init.uniform_(p, -0.1, 0.1) | |||||
def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, | |||||
pre_trigrams=None): | |||||
""" | |||||
max_len是包含root的 | |||||
:param chars: batch_size x max_len | |||||
:param ngrams: batch_size x max_len*ngram_per_char | |||||
:param seq_lens: batch_size | |||||
:param gold_heads: batch_size x max_len | |||||
:param pre_chars: batch_size x max_len | |||||
:param pre_ngrams: batch_size x max_len*ngram_per_char | |||||
:return dict: parsing results | |||||
arc_pred: [batch_size, seq_len, seq_len] | |||||
label_pred: [batch_size, seq_len, seq_len] | |||||
mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||||
""" | |||||
# prepare embeddings | |||||
batch_size, seq_len = chars.shape | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | |||||
# get sequence mask | |||||
mask = seq_len_to_mask(seq_lens).long() | |||||
chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] | |||||
bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] | |||||
trigrams = self.trigram_embed(trigrams) | |||||
if pre_chars is not None: | |||||
pre_chars = self.pre_char_embed(pre_chars) | |||||
# pre_chars = self.pre_char_fc(pre_chars) | |||||
chars = pre_chars + chars | |||||
if pre_bigrams is not None: | |||||
pre_bigrams = self.pre_bigram_embed(pre_bigrams) | |||||
# pre_bigrams = self.pre_bigram_fc(pre_bigrams) | |||||
bigrams = bigrams + pre_bigrams | |||||
if pre_trigrams is not None: | |||||
pre_trigrams = self.pre_trigram_embed(pre_trigrams) | |||||
# pre_trigrams = self.pre_trigram_fc(pre_trigrams) | |||||
trigrams = trigrams + pre_trigrams | |||||
x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] | |||||
# encoder, extract features | |||||
if self.training: | |||||
x = drop_input_independent(x, self.dropout) | |||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||||
x = x[sort_idx] | |||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||||
feat, _ = self.encoder(x) # -> [N,L,C] | |||||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
feat = feat[unsort_idx] | |||||
feat = self.timestep_drop(feat) | |||||
# for arc biaffine | |||||
# mlp, reduce dim | |||||
feat = self.mlp(feat) | |||||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||||
# biaffine arc classifier | |||||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||||
# use gold or predicted arc to predict label | |||||
if gold_heads is None or not self.training: | |||||
# use greedy decoding in training | |||||
if self.training or self.use_greedy_infer: | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
else: | |||||
heads = self.mst_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
assert self.training # must be training mode | |||||
if gold_heads is None: | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | |||||
else: | |||||
head_pred = None | |||||
heads = gold_heads | |||||
# heads: batch_size x max_len | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | |||||
label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] | |||||
# 这里限制一下,只有当head为下一个时,才能预测app这个label | |||||
arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ | |||||
.repeat(batch_size, 1) # batch_size x max_len | |||||
app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app | |||||
app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) | |||||
app_masks[:, :, 1:] = 0 | |||||
label_pred = label_pred.masked_fill(app_masks, -np.inf) | |||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||||
if head_pred is not None: | |||||
res_dict['head_pred'] = head_pred | |||||
return res_dict | |||||
@staticmethod | |||||
def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||||
""" | |||||
Compute loss. | |||||
:param arc_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param arc_true: [batch_size, seq_len] | |||||
:param label_true: [batch_size, seq_len] | |||||
:param mask: [batch_size, seq_len] | |||||
:return: loss value | |||||
""" | |||||
batch_size, seq_len, _ = arc_pred.shape | |||||
flip_mask = (mask == 0) | |||||
_arc_pred = arc_pred.clone() | |||||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||||
arc_true[:, 0].fill_(-1) | |||||
label_true[:, 0].fill_(-1) | |||||
arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) | |||||
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) | |||||
return arc_nll + label_nll | |||||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): | |||||
""" | |||||
max_len是包含root的 | |||||
:param chars: batch_size x max_len | |||||
:param ngrams: batch_size x max_len*ngram_per_char | |||||
:param seq_lens: batch_size | |||||
:param pre_chars: batch_size x max_len | |||||
:param pre_ngrams: batch_size x max_len*ngram_per_cha | |||||
:return: | |||||
""" | |||||
res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, | |||||
pre_trigrams=pre_trigrams, gold_heads=None) | |||||
output = {} | |||||
output['arc_pred'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_pred'] = label_pred | |||||
return output | |||||
class CharParser(nn.Module): | |||||
def __init__(self, char_vocab_size, | |||||
emb_dim, | |||||
bigram_vocab_size, | |||||
trigram_vocab_size, | |||||
num_label, | |||||
rnn_layers=3, | |||||
rnn_hidden_size=400, #单向的数量 | |||||
arc_mlp_size=500, | |||||
label_mlp_size=100, | |||||
dropout=0.3, | |||||
encoder='var-lstm', | |||||
use_greedy_infer=False, | |||||
app_index = 0, | |||||
pre_chars_embed=None, | |||||
pre_bigrams_embed=None, | |||||
pre_trigrams_embed=None): | |||||
super().__init__() | |||||
self.parser = CharBiaffineParser(char_vocab_size, | |||||
emb_dim, | |||||
bigram_vocab_size, | |||||
trigram_vocab_size, | |||||
num_label, | |||||
rnn_layers, | |||||
rnn_hidden_size, #单向的数量 | |||||
arc_mlp_size, | |||||
label_mlp_size, | |||||
dropout, | |||||
encoder, | |||||
use_greedy_infer, | |||||
app_index, | |||||
pre_chars_embed=pre_chars_embed, | |||||
pre_bigrams_embed=pre_bigrams_embed, | |||||
pre_trigrams_embed=pre_trigrams_embed) | |||||
def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, | |||||
pre_trigrams=None): | |||||
res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, | |||||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||||
arc_pred = res_dict['arc_pred'] | |||||
label_pred = res_dict['label_pred'] | |||||
masks = res_dict['mask'] | |||||
loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) | |||||
return {'loss': loss} | |||||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): | |||||
res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, | |||||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||||
output = {} | |||||
output['head_preds'] = res.pop('head_pred') | |||||
_, label_pred = res.pop('label_pred').max(2) | |||||
output['label_preds'] = label_pred | |||||
return output |
@@ -0,0 +1,65 @@ | |||||
from fastNLP.core.callback import Callback | |||||
import torch | |||||
from torch import nn | |||||
class OptimizerCallback(Callback): | |||||
def __init__(self, optimizer, scheduler, update_every=4): | |||||
super().__init__() | |||||
self._optimizer = optimizer | |||||
self.scheduler = scheduler | |||||
self._update_every = update_every | |||||
def on_backward_end(self): | |||||
if self.step % self._update_every==0: | |||||
# nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) | |||||
# self._optimizer.step() | |||||
self.scheduler.step() | |||||
# self.model.zero_grad() | |||||
class DevCallback(Callback): | |||||
def __init__(self, tester, metric_key='u_f1'): | |||||
super().__init__() | |||||
self.tester = tester | |||||
setattr(tester, 'verbose', 0) | |||||
self.metric_key = metric_key | |||||
self.record_best = False | |||||
self.best_eval_value = 0 | |||||
self.best_eval_res = None | |||||
self.best_dev_res = None # 存取dev的表现 | |||||
def on_valid_begin(self): | |||||
eval_res = self.tester.test() | |||||
metric_name = self.tester.metrics[0].__class__.__name__ | |||||
metric_value = eval_res[metric_name][self.metric_key] | |||||
if metric_value>self.best_eval_value: | |||||
self.best_eval_value = metric_value | |||||
self.best_epoch = self.trainer.epoch | |||||
self.record_best = True | |||||
self.best_eval_res = eval_res | |||||
self.test_eval_res = eval_res | |||||
eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
self.pbar.write(eval_str) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if self.record_best: | |||||
self.best_dev_res = eval_result | |||||
self.record_best = False | |||||
if is_better_eval: | |||||
self.best_dev_res_on_dev = eval_result | |||||
self.best_test_res_on_dev = self.test_eval_res | |||||
self.dev_epoch = self.epoch | |||||
def on_train_end(self): | |||||
print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, | |||||
self.tester._format_eval_results(self.best_eval_res), | |||||
self.tester._format_eval_results(self.best_dev_res))) | |||||
print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, | |||||
self.tester._format_eval_results(self.best_test_res_on_dev), | |||||
self.tester._format_eval_results(self.best_dev_res_on_dev))) |
@@ -0,0 +1,184 @@ | |||||
from fastNLP.core.metrics import MetricBase | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
import torch | |||||
class SegAppCharParseF1Metric(MetricBase): | |||||
# | |||||
def __init__(self, app_index): | |||||
super().__init__() | |||||
self.app_index = app_index | |||||
self.parse_head_tp = 0 | |||||
self.parse_label_tp = 0 | |||||
self.rec_tol = 0 | |||||
self.pre_tol = 0 | |||||
def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, | |||||
pun_masks): | |||||
""" | |||||
max_len是不包含root的character的长度 | |||||
:param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size | |||||
:param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size | |||||
:param head_preds: batch_size x max_len | |||||
:param label_preds: batch_size x max_len | |||||
:param seq_lens: | |||||
:param pun_masks: batch_size x | |||||
:return: | |||||
""" | |||||
# 去掉root | |||||
head_preds = head_preds[:, 1:].tolist() | |||||
label_preds = label_preds[:, 1:].tolist() | |||||
seq_lens = (seq_lens - 1).tolist() | |||||
# 先解码出words,POS,heads, labels, 对应的character范围 | |||||
for b in range(len(head_preds)): | |||||
seq_len = seq_lens[b] | |||||
head_pred = head_preds[b][:seq_len] | |||||
label_pred = label_preds[b][:seq_len] | |||||
words = [] # 存放[word_start, word_end),相对起始位置,不考虑root | |||||
heads = [] | |||||
labels = [] | |||||
ranges = [] # 对应该char是第几个word,长度是seq_len+1 | |||||
word_idx = 0 | |||||
word_start_idx = 0 | |||||
for idx, (label, head) in enumerate(zip(label_pred, head_pred)): | |||||
ranges.append(word_idx) | |||||
if label == self.app_index: | |||||
pass | |||||
else: | |||||
labels.append(label) | |||||
heads.append(head) | |||||
words.append((word_start_idx, idx+1)) | |||||
word_start_idx = idx+1 | |||||
word_idx += 1 | |||||
head_dep_tuple = [] # head在前面 | |||||
head_label_dep_tuple = [] | |||||
for idx, head in enumerate(heads): | |||||
span = words[idx] | |||||
if span[0]==span[1]-1 and pun_masks[b, span[0]]: | |||||
continue # exclude punctuations | |||||
if head == 0: | |||||
head_dep_tuple.append((('root', words[idx]))) | |||||
head_label_dep_tuple.append(('root', labels[idx], words[idx])) | |||||
else: | |||||
head_word_idx = ranges[head-1] | |||||
head_word_span = words[head_word_idx] | |||||
head_dep_tuple.append(((head_word_span, words[idx]))) | |||||
head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) | |||||
gold_head_dep_tuple = set(gold_word_pairs[b]) | |||||
gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) | |||||
for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): | |||||
if head_dep in gold_head_dep_tuple: | |||||
self.parse_head_tp += 1 | |||||
if head_label_dep in gold_head_label_dep_tuple: | |||||
self.parse_label_tp += 1 | |||||
self.pre_tol += len(head_dep_tuple) | |||||
self.rec_tol += len(gold_head_dep_tuple) | |||||
def get_metric(self, reset=True): | |||||
u_p = self.parse_head_tp / self.pre_tol | |||||
u_r = self.parse_head_tp / self.rec_tol | |||||
u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) | |||||
l_p = self.parse_label_tp / self.pre_tol | |||||
l_r = self.parse_label_tp / self.rec_tol | |||||
l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) | |||||
if reset: | |||||
self.parse_head_tp = 0 | |||||
self.parse_label_tp = 0 | |||||
self.rec_tol = 0 | |||||
self.pre_tol = 0 | |||||
return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), | |||||
'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} | |||||
class CWSMetric(MetricBase): | |||||
def __init__(self, app_index): | |||||
super().__init__() | |||||
self.app_index = app_index | |||||
self.pre = 0 | |||||
self.rec = 0 | |||||
self.tp = 0 | |||||
def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): | |||||
""" | |||||
:param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 | |||||
:param seg_masks: batch_size x max_len,只有在word结束的地方为1 | |||||
:param label_preds: batch_size x max_len | |||||
:param seq_lens: batch_size | |||||
:return: | |||||
""" | |||||
pred_masks = torch.zeros_like(seg_masks) | |||||
pred_segs = torch.zeros_like(seg_targets) | |||||
seq_lens = (seq_lens - 1).tolist() | |||||
for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): | |||||
seq_len = seq_lens[idx] | |||||
label_pred = label_pred[:seq_len] | |||||
word_len = 0 | |||||
for l_i, label in enumerate(label_pred): | |||||
if label==self.app_index and l_i!=len(label_pred)-1: | |||||
word_len += 1 | |||||
else: | |||||
pred_segs[idx, l_i] = word_len # 这个词的长度为word_len | |||||
pred_masks[idx, l_i] = 1 | |||||
word_len = 0 | |||||
right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 | |||||
self.rec += seg_masks.sum().item() | |||||
self.pre += pred_masks.sum().item() | |||||
# 且pred和target在同一个地方有值 | |||||
self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() | |||||
def get_metric(self, reset=True): | |||||
res = {} | |||||
res['rec'] = round(self.tp/(self.rec+1e-6), 4) | |||||
res['pre'] = round(self.tp/(self.pre+1e-6), 4) | |||||
res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) | |||||
if reset: | |||||
self.pre = 0 | |||||
self.rec = 0 | |||||
self.tp = 0 | |||||
return res | |||||
class ParserMetric(MetricBase): | |||||
def __init__(self, ): | |||||
super().__init__() | |||||
self.num_arc = 0 | |||||
self.num_label = 0 | |||||
self.num_sample = 0 | |||||
def get_metric(self, reset=True): | |||||
res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), | |||||
'LAS': round(self.num_label*1.0 / self.num_sample, 4)} | |||||
if reset: | |||||
self.num_sample = self.num_label = self.num_arc = 0 | |||||
return res | |||||
def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): | |||||
"""Evaluate the performance of prediction. | |||||
""" | |||||
if seq_lens is None: | |||||
seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) | |||||
else: | |||||
seq_mask = seq_len_to_mask(seq_lens.long(), float=False) | |||||
# mask out <root> tag | |||||
seq_mask[:, 0] = 0 | |||||
head_pred_correct = (head_preds == heads).__and__(seq_mask) | |||||
label_pred_correct = (label_preds == labels).__and__(head_pred_correct) | |||||
self.num_arc += head_pred_correct.float().sum().item() | |||||
self.num_label += label_pred_correct.float().sum().item() | |||||
self.num_sample += seq_mask.sum().item() | |||||
@@ -0,0 +1,16 @@ | |||||
Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) | |||||
### 准备数据 | |||||
1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. | |||||
2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 | |||||
3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 | |||||
### 运行代码 | |||||
``` | |||||
python train.py | |||||
``` | |||||
### 其它 | |||||
ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 | |||||
learning rate scheduler. |
@@ -0,0 +1,124 @@ | |||||
import sys | |||||
sys.path.append('../..') | |||||
from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader | |||||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
from torch import nn | |||||
from functools import partial | |||||
from reproduction.joint_cws_parse.models.CharParser import CharParser | |||||
from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric | |||||
from fastNLP import cache_results, BucketSampler, Trainer | |||||
from torch import optim | |||||
from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback | |||||
from torch.optim.lr_scheduler import LambdaLR, StepLR | |||||
from fastNLP import Tester | |||||
from fastNLP import GradientClipCallback, LRScheduler | |||||
import os | |||||
def set_random_seed(random_seed=666): | |||||
import random, numpy, torch | |||||
random.seed(random_seed) | |||||
numpy.random.seed(random_seed) | |||||
torch.cuda.manual_seed(random_seed) | |||||
torch.random.manual_seed(random_seed) | |||||
uniform_init = partial(nn.init.normal_, std=0.02) | |||||
################################################### | |||||
# 需要变动的超参放到这里 | |||||
lr = 0.002 # 0.01~0.001 | |||||
dropout = 0.33 # 0.3~0.6 | |||||
weight_decay = 0 # 1e-5, 1e-6, 0 | |||||
arc_mlp_size = 500 # 200, 300 | |||||
rnn_hidden_size = 400 # 200, 300, 400 | |||||
rnn_layers = 3 # 2, 3 | |||||
encoder = 'var-lstm' # var-lstm, lstm | |||||
emb_size = 100 # 64 , 100 | |||||
label_mlp_size = 100 | |||||
batch_size = 32 | |||||
update_every = 4 | |||||
n_epochs = 100 | |||||
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||||
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||||
#################################################### | |||||
set_random_seed(1234) | |||||
device = 0 | |||||
# @cache_results('caches/{}.pkl'.format(data_name)) | |||||
# def get_data(): | |||||
data = CTBxJointLoader().process(data_folder) | |||||
char_labels_vocab = data.vocabs['char_labels'] | |||||
pre_chars_vocab = data.vocabs['pre_chars'] | |||||
pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||||
pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||||
chars_vocab = data.vocabs['chars'] | |||||
bigrams_vocab = data.vocabs['bigrams'] | |||||
trigrams_vocab = data.vocabs['trigrams'] | |||||
pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||||
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||||
init_method=uniform_init, normalize=False) | |||||
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() | |||||
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||||
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||||
init_method=uniform_init, normalize=False) | |||||
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() | |||||
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||||
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||||
init_method=uniform_init, normalize=False) | |||||
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() | |||||
# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||||
# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||||
print(data) | |||||
model = CharParser(char_vocab_size=len(chars_vocab), | |||||
emb_dim=emb_size, | |||||
bigram_vocab_size=len(bigrams_vocab), | |||||
trigram_vocab_size=len(trigrams_vocab), | |||||
num_label=len(char_labels_vocab), | |||||
rnn_layers=rnn_layers, | |||||
rnn_hidden_size=rnn_hidden_size, | |||||
arc_mlp_size=arc_mlp_size, | |||||
label_mlp_size=label_mlp_size, | |||||
dropout=dropout, | |||||
encoder=encoder, | |||||
use_greedy_infer=False, | |||||
app_index=char_labels_vocab['APP'], | |||||
pre_chars_embed=pre_chars_embed, | |||||
pre_bigrams_embed=pre_bigrams_embed, | |||||
pre_trigrams_embed=pre_trigrams_embed) | |||||
metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) | |||||
metric2 = CWSMetric(char_labels_vocab['APP']) | |||||
metrics = [metric1, metric2] | |||||
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, | |||||
weight_decay=weight_decay, betas=[0.9, 0.9]) | |||||
sampler = BucketSampler(seq_len_field_name='seq_lens') | |||||
callbacks = [] | |||||
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||||
scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||||
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) | |||||
# callbacks.append(optim_callback) | |||||
scheduler_callback = LRScheduler(scheduler) | |||||
callbacks.append(scheduler_callback) | |||||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||||
tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, | |||||
batch_size=64, device=device, verbose=0) | |||||
dev_callback = DevCallback(tester) | |||||
callbacks.append(dev_callback) | |||||
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, | |||||
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, | |||||
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, | |||||
device=device, callbacks=callbacks, update_every=update_every) | |||||
trainer.train() |