Browse Source

增加joint_cws_parse的代码

tags/v0.4.10
yh_cc 6 years ago
parent
commit
999f8ac33f
9 changed files with 984 additions and 0 deletions
  1. +0
    -0
      reproduction/joint_cws_parse/__init__.py
  2. +0
    -0
      reproduction/joint_cws_parse/data/__init__.py
  3. +284
    -0
      reproduction/joint_cws_parse/data/data_loader.py
  4. +311
    -0
      reproduction/joint_cws_parse/models/CharParser.py
  5. +0
    -0
      reproduction/joint_cws_parse/models/__init__.py
  6. +65
    -0
      reproduction/joint_cws_parse/models/callbacks.py
  7. +184
    -0
      reproduction/joint_cws_parse/models/metrics.py
  8. +16
    -0
      reproduction/joint_cws_parse/readme.md
  9. +124
    -0
      reproduction/joint_cws_parse/train.py

+ 0
- 0
reproduction/joint_cws_parse/__init__.py View File


+ 0
- 0
reproduction/joint_cws_parse/data/__init__.py View File


+ 284
- 0
reproduction/joint_cws_parse/data/data_loader.py View File

@@ -0,0 +1,284 @@


from fastNLP.io.base_loader import DataSetLoader, DataInfo
from fastNLP.io.dataset_loader import ConllLoader
import numpy as np

from itertools import chain
from fastNLP import DataSet, Vocabulary
from functools import partial
import os
from typing import Union, Dict
from reproduction.utils import check_dataloader_paths


class CTBxJointLoader(DataSetLoader):
"""
文件夹下应该具有以下的文件结构
-train.conllx
-dev.conllx
-test.conllx
每个文件中的内容如下(空格隔开不同的句子, 共有)
1 费孝通 _ NR NR _ 3 nsubjpass _ _
2 被 _ SB SB _ 3 pass _ _
3 授予 _ VV VV _ 0 root _ _
4 麦格赛赛 _ NR NR _ 5 nn _ _
5 奖 _ NN NN _ 3 dobj _ _

1 新华社 _ NR NR _ 7 dep _ _
2 马尼拉 _ NR NR _ 7 dep _ _
3 8月 _ NT NT _ 7 dep _ _
4 31日 _ NT NT _ 7 dep _ _
...

"""
def __init__(self):
self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7])

def load(self, path:str):
"""
给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容
words: list[str]
pos_tags: list[str]
heads: list[int]
labels: list[str]

:param path:
:return:
"""
dataset = self._loader.load(path)
dataset.heads.int()
return dataset
def process(self, paths):
"""
:param paths:
:return:
Dataset包含以下的field
chars:
bigrams:
trigrams:
pre_chars:
pre_bigrams:
pre_trigrams:
seg_targets:
seg_masks:
seq_lens:
char_labels:
char_heads:
gold_word_pairs:
seg_targets:
seg_masks:
char_labels:
char_heads:
pun_masks:
gold_label_word_pairs:
"""
paths = check_dataloader_paths(paths)
data = DataInfo()

for name, path in paths.items():
dataset = self.load(path)
data.datasets[name] = dataset

char_labels_vocab = Vocabulary(padding=None, unknown=None)

def process(dataset, char_label_vocab):
dataset.apply(add_word_lst, new_field_name='word_lst')
dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars')
dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams')
dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams')
dataset.apply(add_char_heads, new_field_name='char_heads')
dataset.apply(add_char_labels, new_field_name='char_labels')
dataset.apply(add_segs, new_field_name='seg_targets')
dataset.apply(add_mask, new_field_name='seg_masks')
dataset.add_seq_len('chars', new_field_name='seq_lens')
dataset.apply(add_pun_masks, new_field_name='pun_masks')
if len(char_label_vocab.word_count)==0:
char_label_vocab.from_dataset(dataset, field_name='char_labels')
char_label_vocab.index_dataset(dataset, field_name='char_labels')
new_dataset = add_root(dataset)
new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True)
global add_label_word_pairs
add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab)
new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True)

new_dataset.set_pad_val('char_labels', -1)
new_dataset.set_pad_val('char_heads', -1)

return new_dataset

for name in list(paths.keys()):
dataset = data.datasets[name]
dataset = process(dataset, char_labels_vocab)
data.datasets[name] = dataset

data.vocabs['char_labels'] = char_labels_vocab

char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars')
bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams')
trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams')

for name in ['chars', 'bigrams', 'trigrams']:
vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values()))
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name)
data.vocabs['pre_{}'.format(name)] = vocab

for name, vocab in zip(['chars', 'bigrams', 'trigrams'],
[char_vocab, bigram_vocab, trigram_vocab]):
vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name)
data.vocabs[name] = vocab

for name, dataset in data.datasets.items():
dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars',
'pre_bigrams', 'pre_trigrams')
dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels',
'char_heads',
'pun_masks', 'gold_label_word_pairs')

return data


def add_label_word_pairs(instance, label_vocab):
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]]
word_end_indexes = np.array(list(map(len, instance['word_lst'])))
word_end_indexes = np.cumsum(word_end_indexes).tolist()
word_end_indexes.insert(0, 0)
word_pairs = []
labels = instance['labels']
pos_tags = instance['pos_tags']
for idx, head in enumerate(instance['heads']):
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录
continue
label = label_vocab.to_index(labels[idx])
if head==0:
word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1]))))
else:
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label,
(word_end_indexes[idx], word_end_indexes[idx + 1])))
return word_pairs

def add_word_pairs(instance):
# List[List[((head_start, head_end], (dep_start, dep_end]), ...]]
word_end_indexes = np.array(list(map(len, instance['word_lst'])))
word_end_indexes = np.cumsum(word_end_indexes).tolist()
word_end_indexes.insert(0, 0)
word_pairs = []
pos_tags = instance['pos_tags']
for idx, head in enumerate(instance['heads']):
if pos_tags[idx]=='PU': # 如果是标点符号,就不记录
continue
if head==0:
word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1]))))
else:
word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]),
(word_end_indexes[idx], word_end_indexes[idx + 1])))
return word_pairs

def add_root(dataset):
new_dataset = DataSet()
for sample in dataset:
chars = ['char_root'] + sample['chars']
bigrams = ['bigram_root'] + sample['bigrams']
trigrams = ['trigram_root'] + sample['trigrams']
seq_lens = sample['seq_lens']+1
char_labels = [0] + sample['char_labels']
char_heads = [0] + sample['char_heads']
sample['chars'] = chars
sample['bigrams'] = bigrams
sample['trigrams'] = trigrams
sample['seq_lens'] = seq_lens
sample['char_labels'] = char_labels
sample['char_heads'] = char_heads
new_dataset.append(sample)
return new_dataset

def add_pun_masks(instance):
tags = instance['pos_tags']
pun_masks = []
for word, tag in zip(instance['words'], tags):
if tag=='PU':
pun_masks.extend([1]*len(word))
else:
pun_masks.extend([0]*len(word))
return pun_masks

def add_word_lst(instance):
words = instance['words']
word_lst = [list(word) for word in words]
return word_lst

def add_bigram(instance):
chars = instance['chars']
length = len(chars)
chars = chars + ['<eos>']
bigrams = []
for i in range(length):
bigrams.append(''.join(chars[i:i + 2]))
return bigrams

def add_trigram(instance):
chars = instance['chars']
length = len(chars)
chars = chars + ['<eos>'] * 2
trigrams = []
for i in range(length):
trigrams.append(''.join(chars[i:i + 3]))
return trigrams

def add_char_heads(instance):
words = instance['word_lst']
heads = instance['heads']
char_heads = []
char_index = 1 # 因此存在root节点所以需要从1开始
head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1
for word, head in zip(words, heads):
char_head = []
if len(word)>1:
char_head.append(char_index+1)
char_index += 1
for _ in range(len(word)-2):
char_index += 1
char_head.append(char_index)
char_index += 1
char_head.append(head_end_indexes[head-1])
char_heads.extend(char_head)
return char_heads

def add_char_labels(instance):
"""
将word_lst中的数据按照下面的方式设置label
比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)"..
对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label
:param instance:
:return:
"""
words = instance['word_lst']
labels = instance['labels']
char_labels = []
for word, label in zip(words, labels):
for _ in range(len(word)-1):
char_labels.append('APP')
char_labels.append(label)
return char_labels

# add seg_targets
def add_segs(instance):
words = instance['word_lst']
segs = [0]*len(instance['chars'])
index = 0
for word in words:
index = index + len(word) - 1
segs[index] = len(word)-1
index = index + 1
return segs

# add target_masks
def add_mask(instance):
words = instance['word_lst']
mask = []
for word in words:
mask.extend([0] * (len(word) - 1))
mask.append(1)
return mask

+ 311
- 0
reproduction/joint_cws_parse/models/CharParser.py View File

@@ -0,0 +1,311 @@



from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from fastNLP.modules.dropout import TimestepDropout
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP import seq_len_to_mask
from fastNLP.modules import Embedding


def drop_input_independent(word_embeddings, dropout_emb):
batch_size, seq_length, _ = word_embeddings.size()
word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb)
word_masks = torch.bernoulli(word_masks)
word_masks = word_masks.unsqueeze(dim=2)
word_embeddings = word_embeddings * word_masks

return word_embeddings


class CharBiaffineParser(BiaffineParser):
def __init__(self, char_vocab_size,
emb_dim,
bigram_vocab_size,
trigram_vocab_size,
num_label,
rnn_layers=3,
rnn_hidden_size=800, #单向的数量
arc_mlp_size=500,
label_mlp_size=100,
dropout=0.3,
encoder='lstm',
use_greedy_infer=False,
app_index = 0,
pre_chars_embed=None,
pre_bigrams_embed=None,
pre_trigrams_embed=None):


super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
self.char_embed = Embedding((char_vocab_size, emb_dim))
self.bigram_embed = Embedding((bigram_vocab_size, emb_dim))
self.trigram_embed = Embedding((trigram_vocab_size, emb_dim))
if pre_chars_embed:
self.pre_char_embed = Embedding(pre_chars_embed)
self.pre_char_embed.requires_grad = False
if pre_bigrams_embed:
self.pre_bigram_embed = Embedding(pre_bigrams_embed)
self.pre_bigram_embed.requires_grad = False
if pre_trigrams_embed:
self.pre_trigram_embed = Embedding(pre_trigrams_embed)
self.pre_trigram_embed.requires_grad = False
self.timestep_drop = TimestepDropout(dropout)
self.encoder_name = encoder

if encoder == 'var-lstm':
self.encoder = VarLSTM(input_size=emb_dim*3,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
input_dropout=dropout,
hidden_dropout=dropout,
bidirectional=True)
elif encoder == 'lstm':
self.encoder = nn.LSTM(input_size=emb_dim*3,
hidden_size=rnn_hidden_size,
num_layers=rnn_layers,
bias=True,
batch_first=True,
dropout=dropout,
bidirectional=True)

else:
raise ValueError('unsupported encoder type: {}'.format(encoder))

self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2),
nn.LeakyReLU(0.1),
TimestepDropout(p=dropout),)
self.arc_mlp_size = arc_mlp_size
self.label_mlp_size = label_mlp_size
self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True)
self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True)
self.use_greedy_infer = use_greedy_infer
self.reset_parameters()
self.dropout = dropout

self.app_index = app_index
self.num_label = num_label
if self.app_index != 0:
raise ValueError("现在app_index必须等于0")

def reset_parameters(self):
for name, m in self.named_modules():
if 'embed' in name:
pass
elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'):
pass
else:
for p in m.parameters():
if len(p.size())>1:
nn.init.xavier_normal_(p, gain=0.1)
else:
nn.init.uniform_(p, -0.1, 0.1)

def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None,
pre_trigrams=None):
"""
max_len是包含root的
:param chars: batch_size x max_len
:param ngrams: batch_size x max_len*ngram_per_char
:param seq_lens: batch_size
:param gold_heads: batch_size x max_len
:param pre_chars: batch_size x max_len
:param pre_ngrams: batch_size x max_len*ngram_per_char
:return dict: parsing results
arc_pred: [batch_size, seq_len, seq_len]
label_pred: [batch_size, seq_len, seq_len]
mask: [batch_size, seq_len]
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads
"""
# prepare embeddings
batch_size, seq_len = chars.shape
# print('forward {} {}'.format(batch_size, seq_len))

# get sequence mask
mask = seq_len_to_mask(seq_lens).long()

chars = self.char_embed(chars) # [N,L] -> [N,L,C_0]
bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1]
trigrams = self.trigram_embed(trigrams)

if pre_chars is not None:
pre_chars = self.pre_char_embed(pre_chars)
# pre_chars = self.pre_char_fc(pre_chars)
chars = pre_chars + chars
if pre_bigrams is not None:
pre_bigrams = self.pre_bigram_embed(pre_bigrams)
# pre_bigrams = self.pre_bigram_fc(pre_bigrams)
bigrams = bigrams + pre_bigrams
if pre_trigrams is not None:
pre_trigrams = self.pre_trigram_embed(pre_trigrams)
# pre_trigrams = self.pre_trigram_fc(pre_trigrams)
trigrams = trigrams + pre_trigrams

x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C]

# encoder, extract features
if self.training:
x = drop_input_independent(x, self.dropout)
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
x = x[sort_idx]
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
feat, _ = self.encoder(x) # -> [N,L,C]
feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True)
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
feat = feat[unsort_idx]
feat = self.timestep_drop(feat)

# for arc biaffine
# mlp, reduce dim
feat = self.mlp(feat)
arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size
arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz]
label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:]

# biaffine arc classifier
arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L]

# use gold or predicted arc to predict label
if gold_heads is None or not self.training:
# use greedy decoding in training
if self.training or self.use_greedy_infer:
heads = self.greedy_decoder(arc_pred, mask)
else:
heads = self.mst_decoder(arc_pred, mask)
head_pred = heads
else:
assert self.training # must be training mode
if gold_heads is None:
heads = self.greedy_decoder(arc_pred, mask)
head_pred = heads
else:
head_pred = None
heads = gold_heads
# heads: batch_size x max_len

batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1)
label_head = label_head[batch_range, heads].contiguous()
label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label]
# 这里限制一下,只有当head为下一个时,才能预测app这个label
arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\
.repeat(batch_size, 1) # batch_size x max_len
app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app
app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label)
app_masks[:, :, 1:] = 0
label_pred = label_pred.masked_fill(app_masks, -np.inf)

res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask}
if head_pred is not None:
res_dict['head_pred'] = head_pred
return res_dict

@staticmethod
def loss(arc_pred, label_pred, arc_true, label_true, mask):
"""
Compute loss.

:param arc_pred: [batch_size, seq_len, seq_len]
:param label_pred: [batch_size, seq_len, n_tags]
:param arc_true: [batch_size, seq_len]
:param label_true: [batch_size, seq_len]
:param mask: [batch_size, seq_len]
:return: loss value
"""

batch_size, seq_len, _ = arc_pred.shape
flip_mask = (mask == 0)
_arc_pred = arc_pred.clone()
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))

arc_true[:, 0].fill_(-1)
label_true[:, 0].fill_(-1)

arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1)
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1)

return arc_nll + label_nll

def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams):
"""

max_len是包含root的

:param chars: batch_size x max_len
:param ngrams: batch_size x max_len*ngram_per_char
:param seq_lens: batch_size
:param pre_chars: batch_size x max_len
:param pre_ngrams: batch_size x max_len*ngram_per_cha
:return:
"""
res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams,
pre_trigrams=pre_trigrams, gold_heads=None)
output = {}
output['arc_pred'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_pred'] = label_pred
return output

class CharParser(nn.Module):
def __init__(self, char_vocab_size,
emb_dim,
bigram_vocab_size,
trigram_vocab_size,
num_label,
rnn_layers=3,
rnn_hidden_size=400, #单向的数量
arc_mlp_size=500,
label_mlp_size=100,
dropout=0.3,
encoder='var-lstm',
use_greedy_infer=False,
app_index = 0,
pre_chars_embed=None,
pre_bigrams_embed=None,
pre_trigrams_embed=None):
super().__init__()

self.parser = CharBiaffineParser(char_vocab_size,
emb_dim,
bigram_vocab_size,
trigram_vocab_size,
num_label,
rnn_layers,
rnn_hidden_size, #单向的数量
arc_mlp_size,
label_mlp_size,
dropout,
encoder,
use_greedy_infer,
app_index,
pre_chars_embed=pre_chars_embed,
pre_bigrams_embed=pre_bigrams_embed,
pre_trigrams_embed=pre_trigrams_embed)

def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None,
pre_trigrams=None):
res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars,
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams)
arc_pred = res_dict['arc_pred']
label_pred = res_dict['label_pred']
masks = res_dict['mask']
loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks)
return {'loss': loss}

def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None):
res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars,
pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams)
output = {}
output['head_preds'] = res.pop('head_pred')
_, label_pred = res.pop('label_pred').max(2)
output['label_preds'] = label_pred
return output

+ 0
- 0
reproduction/joint_cws_parse/models/__init__.py View File


+ 65
- 0
reproduction/joint_cws_parse/models/callbacks.py View File

@@ -0,0 +1,65 @@

from fastNLP.core.callback import Callback
import torch
from torch import nn

class OptimizerCallback(Callback):
def __init__(self, optimizer, scheduler, update_every=4):
super().__init__()

self._optimizer = optimizer
self.scheduler = scheduler
self._update_every = update_every

def on_backward_end(self):
if self.step % self._update_every==0:
# nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5)
# self._optimizer.step()
self.scheduler.step()
# self.model.zero_grad()


class DevCallback(Callback):
def __init__(self, tester, metric_key='u_f1'):
super().__init__()
self.tester = tester
setattr(tester, 'verbose', 0)

self.metric_key = metric_key

self.record_best = False
self.best_eval_value = 0
self.best_eval_res = None

self.best_dev_res = None # 存取dev的表现

def on_valid_begin(self):
eval_res = self.tester.test()
metric_name = self.tester.metrics[0].__class__.__name__
metric_value = eval_res[metric_name][self.metric_key]
if metric_value>self.best_eval_value:
self.best_eval_value = metric_value
self.best_epoch = self.trainer.epoch
self.record_best = True
self.best_eval_res = eval_res
self.test_eval_res = eval_res
eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \
self.tester._format_eval_results(eval_res)
self.pbar.write(eval_str)

def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
if self.record_best:
self.best_dev_res = eval_result
self.record_best = False
if is_better_eval:
self.best_dev_res_on_dev = eval_result
self.best_test_res_on_dev = self.test_eval_res
self.dev_epoch = self.epoch

def on_train_end(self):
print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch,
self.tester._format_eval_results(self.best_eval_res),
self.tester._format_eval_results(self.best_dev_res)))
print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch,
self.tester._format_eval_results(self.best_test_res_on_dev),
self.tester._format_eval_results(self.best_dev_res_on_dev)))

+ 184
- 0
reproduction/joint_cws_parse/models/metrics.py View File

@@ -0,0 +1,184 @@
from fastNLP.core.metrics import MetricBase
from fastNLP.core.utils import seq_len_to_mask
import torch


class SegAppCharParseF1Metric(MetricBase):
#
def __init__(self, app_index):
super().__init__()
self.app_index = app_index

self.parse_head_tp = 0
self.parse_label_tp = 0
self.rec_tol = 0
self.pre_tol = 0

def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens,
pun_masks):
"""

max_len是不包含root的character的长度
:param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size
:param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size
:param head_preds: batch_size x max_len
:param label_preds: batch_size x max_len
:param seq_lens:
:param pun_masks: batch_size x
:return:
"""
# 去掉root
head_preds = head_preds[:, 1:].tolist()
label_preds = label_preds[:, 1:].tolist()
seq_lens = (seq_lens - 1).tolist()

# 先解码出words,POS,heads, labels, 对应的character范围
for b in range(len(head_preds)):
seq_len = seq_lens[b]
head_pred = head_preds[b][:seq_len]
label_pred = label_preds[b][:seq_len]

words = [] # 存放[word_start, word_end),相对起始位置,不考虑root
heads = []
labels = []
ranges = [] # 对应该char是第几个word,长度是seq_len+1
word_idx = 0
word_start_idx = 0
for idx, (label, head) in enumerate(zip(label_pred, head_pred)):
ranges.append(word_idx)
if label == self.app_index:
pass
else:
labels.append(label)
heads.append(head)
words.append((word_start_idx, idx+1))
word_start_idx = idx+1
word_idx += 1

head_dep_tuple = [] # head在前面
head_label_dep_tuple = []
for idx, head in enumerate(heads):
span = words[idx]
if span[0]==span[1]-1 and pun_masks[b, span[0]]:
continue # exclude punctuations
if head == 0:
head_dep_tuple.append((('root', words[idx])))
head_label_dep_tuple.append(('root', labels[idx], words[idx]))
else:
head_word_idx = ranges[head-1]
head_word_span = words[head_word_idx]
head_dep_tuple.append(((head_word_span, words[idx])))
head_label_dep_tuple.append((head_word_span, labels[idx], words[idx]))

gold_head_dep_tuple = set(gold_word_pairs[b])
gold_head_label_dep_tuple = set(gold_label_word_pairs[b])

for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple):
if head_dep in gold_head_dep_tuple:
self.parse_head_tp += 1
if head_label_dep in gold_head_label_dep_tuple:
self.parse_label_tp += 1
self.pre_tol += len(head_dep_tuple)
self.rec_tol += len(gold_head_dep_tuple)

def get_metric(self, reset=True):
u_p = self.parse_head_tp / self.pre_tol
u_r = self.parse_head_tp / self.rec_tol
u_f = 2*u_p*u_r/(1e-6 + u_p + u_r)
l_p = self.parse_label_tp / self.pre_tol
l_r = self.parse_label_tp / self.rec_tol
l_f = 2*l_p*l_r/(1e-6 + l_p + l_r)

if reset:
self.parse_head_tp = 0
self.parse_label_tp = 0
self.rec_tol = 0
self.pre_tol = 0

return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4),
'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)}


class CWSMetric(MetricBase):
def __init__(self, app_index):
super().__init__()
self.app_index = app_index
self.pre = 0
self.rec = 0
self.tp = 0

def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens):
"""

:param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。
:param seg_masks: batch_size x max_len,只有在word结束的地方为1
:param label_preds: batch_size x max_len
:param seq_lens: batch_size
:return:
"""

pred_masks = torch.zeros_like(seg_masks)
pred_segs = torch.zeros_like(seg_targets)

seq_lens = (seq_lens - 1).tolist()
for idx, label_pred in enumerate(label_preds[:, 1:].tolist()):
seq_len = seq_lens[idx]
label_pred = label_pred[:seq_len]
word_len = 0
for l_i, label in enumerate(label_pred):
if label==self.app_index and l_i!=len(label_pred)-1:
word_len += 1
else:
pred_segs[idx, l_i] = word_len # 这个词的长度为word_len
pred_masks[idx, l_i] = 1
word_len = 0

right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致
self.rec += seg_masks.sum().item()
self.pre += pred_masks.sum().item()
# 且pred和target在同一个地方有值
self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item()

def get_metric(self, reset=True):
res = {}
res['rec'] = round(self.tp/(self.rec+1e-6), 4)
res['pre'] = round(self.tp/(self.pre+1e-6), 4)
res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4)

if reset:
self.pre = 0
self.rec = 0
self.tp = 0

return res


class ParserMetric(MetricBase):
def __init__(self, ):
super().__init__()
self.num_arc = 0
self.num_label = 0
self.num_sample = 0

def get_metric(self, reset=True):
res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4),
'LAS': round(self.num_label*1.0 / self.num_sample, 4)}
if reset:
self.num_sample = self.num_label = self.num_arc = 0
return res

def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None):
"""Evaluate the performance of prediction.
"""
if seq_lens is None:
seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte)
else:
seq_mask = seq_len_to_mask(seq_lens.long(), float=False)
# mask out <root> tag
seq_mask[:, 0] = 0
head_pred_correct = (head_preds == heads).__and__(seq_mask)
label_pred_correct = (label_preds == labels).__and__(head_pred_correct)
self.num_arc += head_pred_correct.float().sum().item()
self.num_label += label_pred_correct.float().sum().item()
self.num_sample += seq_mask.sum().item()


+ 16
- 0
reproduction/joint_cws_parse/readme.md View File

@@ -0,0 +1,16 @@
Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697)

### 准备数据
1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'.
2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。
3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。


### 运行代码
```
python train.py
```

### 其它
ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节
learning rate scheduler.

+ 124
- 0
reproduction/joint_cws_parse/train.py View File

@@ -0,0 +1,124 @@
import sys
sys.path.append('../..')

from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader
from fastNLP.modules.encoder.embedding import StaticEmbedding
from torch import nn
from functools import partial
from reproduction.joint_cws_parse.models.CharParser import CharParser
from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric
from fastNLP import cache_results, BucketSampler, Trainer
from torch import optim
from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback
from torch.optim.lr_scheduler import LambdaLR, StepLR
from fastNLP import Tester
from fastNLP import GradientClipCallback, LRScheduler
import os

def set_random_seed(random_seed=666):
import random, numpy, torch
random.seed(random_seed)
numpy.random.seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.random.manual_seed(random_seed)

uniform_init = partial(nn.init.normal_, std=0.02)

###################################################
# 需要变动的超参放到这里
lr = 0.002 # 0.01~0.001
dropout = 0.33 # 0.3~0.6
weight_decay = 0 # 1e-5, 1e-6, 0
arc_mlp_size = 500 # 200, 300
rnn_hidden_size = 400 # 200, 300, 400
rnn_layers = 3 # 2, 3
encoder = 'var-lstm' # var-lstm, lstm
emb_size = 100 # 64 , 100
label_mlp_size = 100

batch_size = 32
update_every = 4
n_epochs = 100
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt
####################################################

set_random_seed(1234)
device = 0

# @cache_results('caches/{}.pkl'.format(data_name))
# def get_data():
data = CTBxJointLoader().process(data_folder)

char_labels_vocab = data.vocabs['char_labels']

pre_chars_vocab = data.vocabs['pre_chars']
pre_bigrams_vocab = data.vocabs['pre_bigrams']
pre_trigrams_vocab = data.vocabs['pre_trigrams']

chars_vocab = data.vocabs['chars']
bigrams_vocab = data.vocabs['bigrams']
trigrams_vocab = data.vocabs['trigrams']

pre_chars_embed = StaticEmbedding(pre_chars_vocab,
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std()
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std()
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
init_method=uniform_init, normalize=False)
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std()

# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data

# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()

print(data)
model = CharParser(char_vocab_size=len(chars_vocab),
emb_dim=emb_size,
bigram_vocab_size=len(bigrams_vocab),
trigram_vocab_size=len(trigrams_vocab),
num_label=len(char_labels_vocab),
rnn_layers=rnn_layers,
rnn_hidden_size=rnn_hidden_size,
arc_mlp_size=arc_mlp_size,
label_mlp_size=label_mlp_size,
dropout=dropout,
encoder=encoder,
use_greedy_infer=False,
app_index=char_labels_vocab['APP'],
pre_chars_embed=pre_chars_embed,
pre_bigrams_embed=pre_bigrams_embed,
pre_trigrams_embed=pre_trigrams_embed)

metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP'])
metric2 = CWSMetric(char_labels_vocab['APP'])
metrics = [metric1, metric2]

optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr,
weight_decay=weight_decay, betas=[0.9, 0.9])

sampler = BucketSampler(seq_len_field_name='seq_lens')
callbacks = []
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every)
# callbacks.append(optim_callback)
scheduler_callback = LRScheduler(scheduler)
callbacks.append(scheduler_callback)
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))

tester = Tester(data=data.datasets['test'], model=model, metrics=metrics,
batch_size=64, device=device, verbose=0)
dev_callback = DevCallback(tester)
callbacks.append(dev_callback)

trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3,
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer,
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True,
device=device, callbacks=callbacks, update_every=update_every)
trainer.train()

Loading…
Cancel
Save