@@ -0,0 +1,284 @@ | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from fastNLP.io.dataset_loader import ConllLoader | |||
import numpy as np | |||
from itertools import chain | |||
from fastNLP import DataSet, Vocabulary | |||
from functools import partial | |||
import os | |||
from typing import Union, Dict | |||
from reproduction.utils import check_dataloader_paths | |||
class CTBxJointLoader(DataSetLoader): | |||
""" | |||
文件夹下应该具有以下的文件结构 | |||
-train.conllx | |||
-dev.conllx | |||
-test.conllx | |||
每个文件中的内容如下(空格隔开不同的句子, 共有) | |||
1 费孝通 _ NR NR _ 3 nsubjpass _ _ | |||
2 被 _ SB SB _ 3 pass _ _ | |||
3 授予 _ VV VV _ 0 root _ _ | |||
4 麦格赛赛 _ NR NR _ 5 nn _ _ | |||
5 奖 _ NN NN _ 3 dobj _ _ | |||
1 新华社 _ NR NR _ 7 dep _ _ | |||
2 马尼拉 _ NR NR _ 7 dep _ _ | |||
3 8月 _ NT NT _ 7 dep _ _ | |||
4 31日 _ NT NT _ 7 dep _ _ | |||
... | |||
""" | |||
def __init__(self): | |||
self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) | |||
def load(self, path:str): | |||
""" | |||
给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 | |||
words: list[str] | |||
pos_tags: list[str] | |||
heads: list[int] | |||
labels: list[str] | |||
:param path: | |||
:return: | |||
""" | |||
dataset = self._loader.load(path) | |||
dataset.heads.int() | |||
return dataset | |||
def process(self, paths): | |||
""" | |||
:param paths: | |||
:return: | |||
Dataset包含以下的field | |||
chars: | |||
bigrams: | |||
trigrams: | |||
pre_chars: | |||
pre_bigrams: | |||
pre_trigrams: | |||
seg_targets: | |||
seg_masks: | |||
seq_lens: | |||
char_labels: | |||
char_heads: | |||
gold_word_pairs: | |||
seg_targets: | |||
seg_masks: | |||
char_labels: | |||
char_heads: | |||
pun_masks: | |||
gold_label_word_pairs: | |||
""" | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
data.datasets[name] = dataset | |||
char_labels_vocab = Vocabulary(padding=None, unknown=None) | |||
def process(dataset, char_label_vocab): | |||
dataset.apply(add_word_lst, new_field_name='word_lst') | |||
dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') | |||
dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') | |||
dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') | |||
dataset.apply(add_char_heads, new_field_name='char_heads') | |||
dataset.apply(add_char_labels, new_field_name='char_labels') | |||
dataset.apply(add_segs, new_field_name='seg_targets') | |||
dataset.apply(add_mask, new_field_name='seg_masks') | |||
dataset.add_seq_len('chars', new_field_name='seq_lens') | |||
dataset.apply(add_pun_masks, new_field_name='pun_masks') | |||
if len(char_label_vocab.word_count)==0: | |||
char_label_vocab.from_dataset(dataset, field_name='char_labels') | |||
char_label_vocab.index_dataset(dataset, field_name='char_labels') | |||
new_dataset = add_root(dataset) | |||
new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) | |||
global add_label_word_pairs | |||
add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) | |||
new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) | |||
new_dataset.set_pad_val('char_labels', -1) | |||
new_dataset.set_pad_val('char_heads', -1) | |||
return new_dataset | |||
for name in list(paths.keys()): | |||
dataset = data.datasets[name] | |||
dataset = process(dataset, char_labels_vocab) | |||
data.datasets[name] = dataset | |||
data.vocabs['char_labels'] = char_labels_vocab | |||
char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') | |||
bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') | |||
trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') | |||
for name in ['chars', 'bigrams', 'trigrams']: | |||
vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) | |||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) | |||
data.vocabs['pre_{}'.format(name)] = vocab | |||
for name, vocab in zip(['chars', 'bigrams', 'trigrams'], | |||
[char_vocab, bigram_vocab, trigram_vocab]): | |||
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) | |||
data.vocabs[name] = vocab | |||
for name, dataset in data.datasets.items(): | |||
dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', | |||
'pre_bigrams', 'pre_trigrams') | |||
dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', | |||
'char_heads', | |||
'pun_masks', 'gold_label_word_pairs') | |||
return data | |||
def add_label_word_pairs(instance, label_vocab): | |||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||
word_end_indexes.insert(0, 0) | |||
word_pairs = [] | |||
labels = instance['labels'] | |||
pos_tags = instance['pos_tags'] | |||
for idx, head in enumerate(instance['heads']): | |||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||
continue | |||
label = label_vocab.to_index(labels[idx]) | |||
if head==0: | |||
word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||
else: | |||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, | |||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||
return word_pairs | |||
def add_word_pairs(instance): | |||
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]] | |||
word_end_indexes = np.array(list(map(len, instance['word_lst']))) | |||
word_end_indexes = np.cumsum(word_end_indexes).tolist() | |||
word_end_indexes.insert(0, 0) | |||
word_pairs = [] | |||
pos_tags = instance['pos_tags'] | |||
for idx, head in enumerate(instance['heads']): | |||
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 | |||
continue | |||
if head==0: | |||
word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) | |||
else: | |||
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), | |||
(word_end_indexes[idx], word_end_indexes[idx + 1]))) | |||
return word_pairs | |||
def add_root(dataset): | |||
new_dataset = DataSet() | |||
for sample in dataset: | |||
chars = ['char_root'] + sample['chars'] | |||
bigrams = ['bigram_root'] + sample['bigrams'] | |||
trigrams = ['trigram_root'] + sample['trigrams'] | |||
seq_lens = sample['seq_lens']+1 | |||
char_labels = [0] + sample['char_labels'] | |||
char_heads = [0] + sample['char_heads'] | |||
sample['chars'] = chars | |||
sample['bigrams'] = bigrams | |||
sample['trigrams'] = trigrams | |||
sample['seq_lens'] = seq_lens | |||
sample['char_labels'] = char_labels | |||
sample['char_heads'] = char_heads | |||
new_dataset.append(sample) | |||
return new_dataset | |||
def add_pun_masks(instance): | |||
tags = instance['pos_tags'] | |||
pun_masks = [] | |||
for word, tag in zip(instance['words'], tags): | |||
if tag=='PU': | |||
pun_masks.extend([1]*len(word)) | |||
else: | |||
pun_masks.extend([0]*len(word)) | |||
return pun_masks | |||
def add_word_lst(instance): | |||
words = instance['words'] | |||
word_lst = [list(word) for word in words] | |||
return word_lst | |||
def add_bigram(instance): | |||
chars = instance['chars'] | |||
length = len(chars) | |||
chars = chars + ['<eos>'] | |||
bigrams = [] | |||
for i in range(length): | |||
bigrams.append(''.join(chars[i:i + 2])) | |||
return bigrams | |||
def add_trigram(instance): | |||
chars = instance['chars'] | |||
length = len(chars) | |||
chars = chars + ['<eos>'] * 2 | |||
trigrams = [] | |||
for i in range(length): | |||
trigrams.append(''.join(chars[i:i + 3])) | |||
return trigrams | |||
def add_char_heads(instance): | |||
words = instance['word_lst'] | |||
heads = instance['heads'] | |||
char_heads = [] | |||
char_index = 1 # 因此存在root节点所以需要从1开始 | |||
head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 | |||
for word, head in zip(words, heads): | |||
char_head = [] | |||
if len(word)>1: | |||
char_head.append(char_index+1) | |||
char_index += 1 | |||
for _ in range(len(word)-2): | |||
char_index += 1 | |||
char_head.append(char_index) | |||
char_index += 1 | |||
char_head.append(head_end_indexes[head-1]) | |||
char_heads.extend(char_head) | |||
return char_heads | |||
def add_char_labels(instance): | |||
""" | |||
将word_lst中的数据按照下面的方式设置label | |||
比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. | |||
对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label | |||
:param instance: | |||
:return: | |||
""" | |||
words = instance['word_lst'] | |||
labels = instance['labels'] | |||
char_labels = [] | |||
for word, label in zip(words, labels): | |||
for _ in range(len(word)-1): | |||
char_labels.append('APP') | |||
char_labels.append(label) | |||
return char_labels | |||
# add seg_targets | |||
def add_segs(instance): | |||
words = instance['word_lst'] | |||
segs = [0]*len(instance['chars']) | |||
index = 0 | |||
for word in words: | |||
index = index + len(word) - 1 | |||
segs[index] = len(word)-1 | |||
index = index + 1 | |||
return segs | |||
# add target_masks | |||
def add_mask(instance): | |||
words = instance['word_lst'] | |||
mask = [] | |||
for word in words: | |||
mask.extend([0] * (len(word) - 1)) | |||
mask.append(1) | |||
return mask |
@@ -0,0 +1,311 @@ | |||
from fastNLP.models.biaffine_parser import BiaffineParser | |||
from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear | |||
import numpy as np | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
from fastNLP.modules.dropout import TimestepDropout | |||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
from fastNLP import seq_len_to_mask | |||
from fastNLP.modules import Embedding | |||
def drop_input_independent(word_embeddings, dropout_emb): | |||
batch_size, seq_length, _ = word_embeddings.size() | |||
word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) | |||
word_masks = torch.bernoulli(word_masks) | |||
word_masks = word_masks.unsqueeze(dim=2) | |||
word_embeddings = word_embeddings * word_masks | |||
return word_embeddings | |||
class CharBiaffineParser(BiaffineParser): | |||
def __init__(self, char_vocab_size, | |||
emb_dim, | |||
bigram_vocab_size, | |||
trigram_vocab_size, | |||
num_label, | |||
rnn_layers=3, | |||
rnn_hidden_size=800, #单向的数量 | |||
arc_mlp_size=500, | |||
label_mlp_size=100, | |||
dropout=0.3, | |||
encoder='lstm', | |||
use_greedy_infer=False, | |||
app_index = 0, | |||
pre_chars_embed=None, | |||
pre_bigrams_embed=None, | |||
pre_trigrams_embed=None): | |||
super(BiaffineParser, self).__init__() | |||
rnn_out_size = 2 * rnn_hidden_size | |||
self.char_embed = Embedding((char_vocab_size, emb_dim)) | |||
self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) | |||
self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) | |||
if pre_chars_embed: | |||
self.pre_char_embed = Embedding(pre_chars_embed) | |||
self.pre_char_embed.requires_grad = False | |||
if pre_bigrams_embed: | |||
self.pre_bigram_embed = Embedding(pre_bigrams_embed) | |||
self.pre_bigram_embed.requires_grad = False | |||
if pre_trigrams_embed: | |||
self.pre_trigram_embed = Embedding(pre_trigrams_embed) | |||
self.pre_trigram_embed.requires_grad = False | |||
self.timestep_drop = TimestepDropout(dropout) | |||
self.encoder_name = encoder | |||
if encoder == 'var-lstm': | |||
self.encoder = VarLSTM(input_size=emb_dim*3, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
batch_first=True, | |||
input_dropout=dropout, | |||
hidden_dropout=dropout, | |||
bidirectional=True) | |||
elif encoder == 'lstm': | |||
self.encoder = nn.LSTM(input_size=emb_dim*3, | |||
hidden_size=rnn_hidden_size, | |||
num_layers=rnn_layers, | |||
bias=True, | |||
batch_first=True, | |||
dropout=dropout, | |||
bidirectional=True) | |||
else: | |||
raise ValueError('unsupported encoder type: {}'.format(encoder)) | |||
self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), | |||
nn.LeakyReLU(0.1), | |||
TimestepDropout(p=dropout),) | |||
self.arc_mlp_size = arc_mlp_size | |||
self.label_mlp_size = label_mlp_size | |||
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) | |||
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) | |||
self.use_greedy_infer = use_greedy_infer | |||
self.reset_parameters() | |||
self.dropout = dropout | |||
self.app_index = app_index | |||
self.num_label = num_label | |||
if self.app_index != 0: | |||
raise ValueError("现在app_index必须等于0") | |||
def reset_parameters(self): | |||
for name, m in self.named_modules(): | |||
if 'embed' in name: | |||
pass | |||
elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): | |||
pass | |||
else: | |||
for p in m.parameters(): | |||
if len(p.size())>1: | |||
nn.init.xavier_normal_(p, gain=0.1) | |||
else: | |||
nn.init.uniform_(p, -0.1, 0.1) | |||
def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, | |||
pre_trigrams=None): | |||
""" | |||
max_len是包含root的 | |||
:param chars: batch_size x max_len | |||
:param ngrams: batch_size x max_len*ngram_per_char | |||
:param seq_lens: batch_size | |||
:param gold_heads: batch_size x max_len | |||
:param pre_chars: batch_size x max_len | |||
:param pre_ngrams: batch_size x max_len*ngram_per_char | |||
:return dict: parsing results | |||
arc_pred: [batch_size, seq_len, seq_len] | |||
label_pred: [batch_size, seq_len, seq_len] | |||
mask: [batch_size, seq_len] | |||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||
""" | |||
# prepare embeddings | |||
batch_size, seq_len = chars.shape | |||
# print('forward {} {}'.format(batch_size, seq_len)) | |||
# get sequence mask | |||
mask = seq_len_to_mask(seq_lens).long() | |||
chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] | |||
bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] | |||
trigrams = self.trigram_embed(trigrams) | |||
if pre_chars is not None: | |||
pre_chars = self.pre_char_embed(pre_chars) | |||
# pre_chars = self.pre_char_fc(pre_chars) | |||
chars = pre_chars + chars | |||
if pre_bigrams is not None: | |||
pre_bigrams = self.pre_bigram_embed(pre_bigrams) | |||
# pre_bigrams = self.pre_bigram_fc(pre_bigrams) | |||
bigrams = bigrams + pre_bigrams | |||
if pre_trigrams is not None: | |||
pre_trigrams = self.pre_trigram_embed(pre_trigrams) | |||
# pre_trigrams = self.pre_trigram_fc(pre_trigrams) | |||
trigrams = trigrams + pre_trigrams | |||
x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] | |||
# encoder, extract features | |||
if self.training: | |||
x = drop_input_independent(x, self.dropout) | |||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||
x = x[sort_idx] | |||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | |||
feat, _ = self.encoder(x) # -> [N,L,C] | |||
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
feat = feat[unsort_idx] | |||
feat = self.timestep_drop(feat) | |||
# for arc biaffine | |||
# mlp, reduce dim | |||
feat = self.mlp(feat) | |||
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size | |||
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] | |||
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] | |||
# biaffine arc classifier | |||
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] | |||
# use gold or predicted arc to predict label | |||
if gold_heads is None or not self.training: | |||
# use greedy decoding in training | |||
if self.training or self.use_greedy_infer: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
else: | |||
heads = self.mst_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
assert self.training # must be training mode | |||
if gold_heads is None: | |||
heads = self.greedy_decoder(arc_pred, mask) | |||
head_pred = heads | |||
else: | |||
head_pred = None | |||
heads = gold_heads | |||
# heads: batch_size x max_len | |||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) | |||
label_head = label_head[batch_range, heads].contiguous() | |||
label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] | |||
# 这里限制一下,只有当head为下一个时,才能预测app这个label | |||
arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ | |||
.repeat(batch_size, 1) # batch_size x max_len | |||
app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app | |||
app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) | |||
app_masks[:, :, 1:] = 0 | |||
label_pred = label_pred.masked_fill(app_masks, -np.inf) | |||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | |||
if head_pred is not None: | |||
res_dict['head_pred'] = head_pred | |||
return res_dict | |||
@staticmethod | |||
def loss(arc_pred, label_pred, arc_true, label_true, mask): | |||
""" | |||
Compute loss. | |||
:param arc_pred: [batch_size, seq_len, seq_len] | |||
:param label_pred: [batch_size, seq_len, n_tags] | |||
:param arc_true: [batch_size, seq_len] | |||
:param label_true: [batch_size, seq_len] | |||
:param mask: [batch_size, seq_len] | |||
:return: loss value | |||
""" | |||
batch_size, seq_len, _ = arc_pred.shape | |||
flip_mask = (mask == 0) | |||
_arc_pred = arc_pred.clone() | |||
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) | |||
arc_true[:, 0].fill_(-1) | |||
label_true[:, 0].fill_(-1) | |||
arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) | |||
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) | |||
return arc_nll + label_nll | |||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): | |||
""" | |||
max_len是包含root的 | |||
:param chars: batch_size x max_len | |||
:param ngrams: batch_size x max_len*ngram_per_char | |||
:param seq_lens: batch_size | |||
:param pre_chars: batch_size x max_len | |||
:param pre_ngrams: batch_size x max_len*ngram_per_cha | |||
:return: | |||
""" | |||
res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, | |||
pre_trigrams=pre_trigrams, gold_heads=None) | |||
output = {} | |||
output['arc_pred'] = res.pop('head_pred') | |||
_, label_pred = res.pop('label_pred').max(2) | |||
output['label_pred'] = label_pred | |||
return output | |||
class CharParser(nn.Module): | |||
def __init__(self, char_vocab_size, | |||
emb_dim, | |||
bigram_vocab_size, | |||
trigram_vocab_size, | |||
num_label, | |||
rnn_layers=3, | |||
rnn_hidden_size=400, #单向的数量 | |||
arc_mlp_size=500, | |||
label_mlp_size=100, | |||
dropout=0.3, | |||
encoder='var-lstm', | |||
use_greedy_infer=False, | |||
app_index = 0, | |||
pre_chars_embed=None, | |||
pre_bigrams_embed=None, | |||
pre_trigrams_embed=None): | |||
super().__init__() | |||
self.parser = CharBiaffineParser(char_vocab_size, | |||
emb_dim, | |||
bigram_vocab_size, | |||
trigram_vocab_size, | |||
num_label, | |||
rnn_layers, | |||
rnn_hidden_size, #单向的数量 | |||
arc_mlp_size, | |||
label_mlp_size, | |||
dropout, | |||
encoder, | |||
use_greedy_infer, | |||
app_index, | |||
pre_chars_embed=pre_chars_embed, | |||
pre_bigrams_embed=pre_bigrams_embed, | |||
pre_trigrams_embed=pre_trigrams_embed) | |||
def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, | |||
pre_trigrams=None): | |||
res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, | |||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||
arc_pred = res_dict['arc_pred'] | |||
label_pred = res_dict['label_pred'] | |||
masks = res_dict['mask'] | |||
loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) | |||
return {'loss': loss} | |||
def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): | |||
res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, | |||
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) | |||
output = {} | |||
output['head_preds'] = res.pop('head_pred') | |||
_, label_pred = res.pop('label_pred').max(2) | |||
output['label_preds'] = label_pred | |||
return output |
@@ -0,0 +1,65 @@ | |||
from fastNLP.core.callback import Callback | |||
import torch | |||
from torch import nn | |||
class OptimizerCallback(Callback): | |||
def __init__(self, optimizer, scheduler, update_every=4): | |||
super().__init__() | |||
self._optimizer = optimizer | |||
self.scheduler = scheduler | |||
self._update_every = update_every | |||
def on_backward_end(self): | |||
if self.step % self._update_every==0: | |||
# nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) | |||
# self._optimizer.step() | |||
self.scheduler.step() | |||
# self.model.zero_grad() | |||
class DevCallback(Callback): | |||
def __init__(self, tester, metric_key='u_f1'): | |||
super().__init__() | |||
self.tester = tester | |||
setattr(tester, 'verbose', 0) | |||
self.metric_key = metric_key | |||
self.record_best = False | |||
self.best_eval_value = 0 | |||
self.best_eval_res = None | |||
self.best_dev_res = None # 存取dev的表现 | |||
def on_valid_begin(self): | |||
eval_res = self.tester.test() | |||
metric_name = self.tester.metrics[0].__class__.__name__ | |||
metric_value = eval_res[metric_name][self.metric_key] | |||
if metric_value>self.best_eval_value: | |||
self.best_eval_value = metric_value | |||
self.best_epoch = self.trainer.epoch | |||
self.record_best = True | |||
self.best_eval_res = eval_res | |||
self.test_eval_res = eval_res | |||
eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ | |||
self.tester._format_eval_results(eval_res) | |||
self.pbar.write(eval_str) | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
if self.record_best: | |||
self.best_dev_res = eval_result | |||
self.record_best = False | |||
if is_better_eval: | |||
self.best_dev_res_on_dev = eval_result | |||
self.best_test_res_on_dev = self.test_eval_res | |||
self.dev_epoch = self.epoch | |||
def on_train_end(self): | |||
print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, | |||
self.tester._format_eval_results(self.best_eval_res), | |||
self.tester._format_eval_results(self.best_dev_res))) | |||
print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, | |||
self.tester._format_eval_results(self.best_test_res_on_dev), | |||
self.tester._format_eval_results(self.best_dev_res_on_dev))) |
@@ -0,0 +1,184 @@ | |||
from fastNLP.core.metrics import MetricBase | |||
from fastNLP.core.utils import seq_len_to_mask | |||
import torch | |||
class SegAppCharParseF1Metric(MetricBase): | |||
# | |||
def __init__(self, app_index): | |||
super().__init__() | |||
self.app_index = app_index | |||
self.parse_head_tp = 0 | |||
self.parse_label_tp = 0 | |||
self.rec_tol = 0 | |||
self.pre_tol = 0 | |||
def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, | |||
pun_masks): | |||
""" | |||
max_len是不包含root的character的长度 | |||
:param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size | |||
:param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size | |||
:param head_preds: batch_size x max_len | |||
:param label_preds: batch_size x max_len | |||
:param seq_lens: | |||
:param pun_masks: batch_size x | |||
:return: | |||
""" | |||
# 去掉root | |||
head_preds = head_preds[:, 1:].tolist() | |||
label_preds = label_preds[:, 1:].tolist() | |||
seq_lens = (seq_lens - 1).tolist() | |||
# 先解码出words,POS,heads, labels, 对应的character范围 | |||
for b in range(len(head_preds)): | |||
seq_len = seq_lens[b] | |||
head_pred = head_preds[b][:seq_len] | |||
label_pred = label_preds[b][:seq_len] | |||
words = [] # 存放[word_start, word_end),相对起始位置,不考虑root | |||
heads = [] | |||
labels = [] | |||
ranges = [] # 对应该char是第几个word,长度是seq_len+1 | |||
word_idx = 0 | |||
word_start_idx = 0 | |||
for idx, (label, head) in enumerate(zip(label_pred, head_pred)): | |||
ranges.append(word_idx) | |||
if label == self.app_index: | |||
pass | |||
else: | |||
labels.append(label) | |||
heads.append(head) | |||
words.append((word_start_idx, idx+1)) | |||
word_start_idx = idx+1 | |||
word_idx += 1 | |||
head_dep_tuple = [] # head在前面 | |||
head_label_dep_tuple = [] | |||
for idx, head in enumerate(heads): | |||
span = words[idx] | |||
if span[0]==span[1]-1 and pun_masks[b, span[0]]: | |||
continue # exclude punctuations | |||
if head == 0: | |||
head_dep_tuple.append((('root', words[idx]))) | |||
head_label_dep_tuple.append(('root', labels[idx], words[idx])) | |||
else: | |||
head_word_idx = ranges[head-1] | |||
head_word_span = words[head_word_idx] | |||
head_dep_tuple.append(((head_word_span, words[idx]))) | |||
head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) | |||
gold_head_dep_tuple = set(gold_word_pairs[b]) | |||
gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) | |||
for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): | |||
if head_dep in gold_head_dep_tuple: | |||
self.parse_head_tp += 1 | |||
if head_label_dep in gold_head_label_dep_tuple: | |||
self.parse_label_tp += 1 | |||
self.pre_tol += len(head_dep_tuple) | |||
self.rec_tol += len(gold_head_dep_tuple) | |||
def get_metric(self, reset=True): | |||
u_p = self.parse_head_tp / self.pre_tol | |||
u_r = self.parse_head_tp / self.rec_tol | |||
u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) | |||
l_p = self.parse_label_tp / self.pre_tol | |||
l_r = self.parse_label_tp / self.rec_tol | |||
l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) | |||
if reset: | |||
self.parse_head_tp = 0 | |||
self.parse_label_tp = 0 | |||
self.rec_tol = 0 | |||
self.pre_tol = 0 | |||
return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), | |||
'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} | |||
class CWSMetric(MetricBase): | |||
def __init__(self, app_index): | |||
super().__init__() | |||
self.app_index = app_index | |||
self.pre = 0 | |||
self.rec = 0 | |||
self.tp = 0 | |||
def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): | |||
""" | |||
:param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 | |||
:param seg_masks: batch_size x max_len,只有在word结束的地方为1 | |||
:param label_preds: batch_size x max_len | |||
:param seq_lens: batch_size | |||
:return: | |||
""" | |||
pred_masks = torch.zeros_like(seg_masks) | |||
pred_segs = torch.zeros_like(seg_targets) | |||
seq_lens = (seq_lens - 1).tolist() | |||
for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): | |||
seq_len = seq_lens[idx] | |||
label_pred = label_pred[:seq_len] | |||
word_len = 0 | |||
for l_i, label in enumerate(label_pred): | |||
if label==self.app_index and l_i!=len(label_pred)-1: | |||
word_len += 1 | |||
else: | |||
pred_segs[idx, l_i] = word_len # 这个词的长度为word_len | |||
pred_masks[idx, l_i] = 1 | |||
word_len = 0 | |||
right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 | |||
self.rec += seg_masks.sum().item() | |||
self.pre += pred_masks.sum().item() | |||
# 且pred和target在同一个地方有值 | |||
self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() | |||
def get_metric(self, reset=True): | |||
res = {} | |||
res['rec'] = round(self.tp/(self.rec+1e-6), 4) | |||
res['pre'] = round(self.tp/(self.pre+1e-6), 4) | |||
res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) | |||
if reset: | |||
self.pre = 0 | |||
self.rec = 0 | |||
self.tp = 0 | |||
return res | |||
class ParserMetric(MetricBase): | |||
def __init__(self, ): | |||
super().__init__() | |||
self.num_arc = 0 | |||
self.num_label = 0 | |||
self.num_sample = 0 | |||
def get_metric(self, reset=True): | |||
res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), | |||
'LAS': round(self.num_label*1.0 / self.num_sample, 4)} | |||
if reset: | |||
self.num_sample = self.num_label = self.num_arc = 0 | |||
return res | |||
def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): | |||
"""Evaluate the performance of prediction. | |||
""" | |||
if seq_lens is None: | |||
seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) | |||
else: | |||
seq_mask = seq_len_to_mask(seq_lens.long(), float=False) | |||
# mask out <root> tag | |||
seq_mask[:, 0] = 0 | |||
head_pred_correct = (head_preds == heads).__and__(seq_mask) | |||
label_pred_correct = (label_preds == labels).__and__(head_pred_correct) | |||
self.num_arc += head_pred_correct.float().sum().item() | |||
self.num_label += label_pred_correct.float().sum().item() | |||
self.num_sample += seq_mask.sum().item() | |||
@@ -0,0 +1,16 @@ | |||
Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) | |||
### 准备数据 | |||
1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. | |||
2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 | |||
3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 | |||
### 运行代码 | |||
``` | |||
python train.py | |||
``` | |||
### 其它 | |||
ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 | |||
learning rate scheduler. |
@@ -0,0 +1,124 @@ | |||
import sys | |||
sys.path.append('../..') | |||
from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader | |||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||
from torch import nn | |||
from functools import partial | |||
from reproduction.joint_cws_parse.models.CharParser import CharParser | |||
from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric | |||
from fastNLP import cache_results, BucketSampler, Trainer | |||
from torch import optim | |||
from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback | |||
from torch.optim.lr_scheduler import LambdaLR, StepLR | |||
from fastNLP import Tester | |||
from fastNLP import GradientClipCallback, LRScheduler | |||
import os | |||
def set_random_seed(random_seed=666): | |||
import random, numpy, torch | |||
random.seed(random_seed) | |||
numpy.random.seed(random_seed) | |||
torch.cuda.manual_seed(random_seed) | |||
torch.random.manual_seed(random_seed) | |||
uniform_init = partial(nn.init.normal_, std=0.02) | |||
################################################### | |||
# 需要变动的超参放到这里 | |||
lr = 0.002 # 0.01~0.001 | |||
dropout = 0.33 # 0.3~0.6 | |||
weight_decay = 0 # 1e-5, 1e-6, 0 | |||
arc_mlp_size = 500 # 200, 300 | |||
rnn_hidden_size = 400 # 200, 300, 400 | |||
rnn_layers = 3 # 2, 3 | |||
encoder = 'var-lstm' # var-lstm, lstm | |||
emb_size = 100 # 64 , 100 | |||
label_mlp_size = 100 | |||
batch_size = 32 | |||
update_every = 4 | |||
n_epochs = 100 | |||
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 | |||
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt | |||
#################################################### | |||
set_random_seed(1234) | |||
device = 0 | |||
# @cache_results('caches/{}.pkl'.format(data_name)) | |||
# def get_data(): | |||
data = CTBxJointLoader().process(data_folder) | |||
char_labels_vocab = data.vocabs['char_labels'] | |||
pre_chars_vocab = data.vocabs['pre_chars'] | |||
pre_bigrams_vocab = data.vocabs['pre_bigrams'] | |||
pre_trigrams_vocab = data.vocabs['pre_trigrams'] | |||
chars_vocab = data.vocabs['chars'] | |||
bigrams_vocab = data.vocabs['bigrams'] | |||
trigrams_vocab = data.vocabs['trigrams'] | |||
pre_chars_embed = StaticEmbedding(pre_chars_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() | |||
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() | |||
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, | |||
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), | |||
init_method=uniform_init, normalize=False) | |||
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() | |||
# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data | |||
# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() | |||
print(data) | |||
model = CharParser(char_vocab_size=len(chars_vocab), | |||
emb_dim=emb_size, | |||
bigram_vocab_size=len(bigrams_vocab), | |||
trigram_vocab_size=len(trigrams_vocab), | |||
num_label=len(char_labels_vocab), | |||
rnn_layers=rnn_layers, | |||
rnn_hidden_size=rnn_hidden_size, | |||
arc_mlp_size=arc_mlp_size, | |||
label_mlp_size=label_mlp_size, | |||
dropout=dropout, | |||
encoder=encoder, | |||
use_greedy_infer=False, | |||
app_index=char_labels_vocab['APP'], | |||
pre_chars_embed=pre_chars_embed, | |||
pre_bigrams_embed=pre_bigrams_embed, | |||
pre_trigrams_embed=pre_trigrams_embed) | |||
metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) | |||
metric2 = CWSMetric(char_labels_vocab['APP']) | |||
metrics = [metric1, metric2] | |||
optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, | |||
weight_decay=weight_decay, betas=[0.9, 0.9]) | |||
sampler = BucketSampler(seq_len_field_name='seq_lens') | |||
callbacks = [] | |||
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) | |||
scheduler = StepLR(optimizer, step_size=18, gamma=0.75) | |||
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) | |||
# callbacks.append(optim_callback) | |||
scheduler_callback = LRScheduler(scheduler) | |||
callbacks.append(scheduler_callback) | |||
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) | |||
tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, | |||
batch_size=64, device=device, verbose=0) | |||
dev_callback = DevCallback(tester) | |||
callbacks.append(dev_callback) | |||
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, | |||
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, | |||
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, | |||
device=device, callbacks=callbacks, update_every=update_every) | |||
trainer.train() |