diff --git a/reproduction/joint_cws_parse/__init__.py b/reproduction/joint_cws_parse/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/joint_cws_parse/data/__init__.py b/reproduction/joint_cws_parse/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/joint_cws_parse/data/data_loader.py b/reproduction/joint_cws_parse/data/data_loader.py new file mode 100644 index 00000000..7802ea09 --- /dev/null +++ b/reproduction/joint_cws_parse/data/data_loader.py @@ -0,0 +1,284 @@ + + +from fastNLP.io.base_loader import DataSetLoader, DataInfo +from fastNLP.io.dataset_loader import ConllLoader +import numpy as np + +from itertools import chain +from fastNLP import DataSet, Vocabulary +from functools import partial +import os +from typing import Union, Dict +from reproduction.utils import check_dataloader_paths + + +class CTBxJointLoader(DataSetLoader): + """ + 文件夹下应该具有以下的文件结构 + -train.conllx + -dev.conllx + -test.conllx + 每个文件中的内容如下(空格隔开不同的句子, 共有) + 1 费孝通 _ NR NR _ 3 nsubjpass _ _ + 2 被 _ SB SB _ 3 pass _ _ + 3 授予 _ VV VV _ 0 root _ _ + 4 麦格赛赛 _ NR NR _ 5 nn _ _ + 5 奖 _ NN NN _ 3 dobj _ _ + + 1 新华社 _ NR NR _ 7 dep _ _ + 2 马尼拉 _ NR NR _ 7 dep _ _ + 3 8月 _ NT NT _ 7 dep _ _ + 4 31日 _ NT NT _ 7 dep _ _ + ... + + """ + def __init__(self): + self._loader = ConllLoader(headers=['words', 'pos_tags', 'heads', 'labels'], indexes=[1, 3, 6, 7]) + + def load(self, path:str): + """ + 给定一个文件路径,将数据读取为DataSet格式。DataSet中包含以下的内容 + words: list[str] + pos_tags: list[str] + heads: list[int] + labels: list[str] + + :param path: + :return: + """ + dataset = self._loader.load(path) + dataset.heads.int() + return dataset + + def process(self, paths): + """ + + :param paths: + :return: + Dataset包含以下的field + chars: + bigrams: + trigrams: + pre_chars: + pre_bigrams: + pre_trigrams: + seg_targets: + seg_masks: + seq_lens: + char_labels: + char_heads: + gold_word_pairs: + seg_targets: + seg_masks: + char_labels: + char_heads: + pun_masks: + gold_label_word_pairs: + """ + paths = check_dataloader_paths(paths) + data = DataInfo() + + for name, path in paths.items(): + dataset = self.load(path) + data.datasets[name] = dataset + + char_labels_vocab = Vocabulary(padding=None, unknown=None) + + def process(dataset, char_label_vocab): + dataset.apply(add_word_lst, new_field_name='word_lst') + dataset.apply(lambda x: list(chain(*x['word_lst'])), new_field_name='chars') + dataset.apply(add_bigram, field_name='chars', new_field_name='bigrams') + dataset.apply(add_trigram, field_name='chars', new_field_name='trigrams') + dataset.apply(add_char_heads, new_field_name='char_heads') + dataset.apply(add_char_labels, new_field_name='char_labels') + dataset.apply(add_segs, new_field_name='seg_targets') + dataset.apply(add_mask, new_field_name='seg_masks') + dataset.add_seq_len('chars', new_field_name='seq_lens') + dataset.apply(add_pun_masks, new_field_name='pun_masks') + if len(char_label_vocab.word_count)==0: + char_label_vocab.from_dataset(dataset, field_name='char_labels') + char_label_vocab.index_dataset(dataset, field_name='char_labels') + new_dataset = add_root(dataset) + new_dataset.apply(add_word_pairs, new_field_name='gold_word_pairs', ignore_type=True) + global add_label_word_pairs + add_label_word_pairs = partial(add_label_word_pairs, label_vocab=char_label_vocab) + new_dataset.apply(add_label_word_pairs, new_field_name='gold_label_word_pairs', ignore_type=True) + + new_dataset.set_pad_val('char_labels', -1) + new_dataset.set_pad_val('char_heads', -1) + + return new_dataset + + for name in list(paths.keys()): + dataset = data.datasets[name] + dataset = process(dataset, char_labels_vocab) + data.datasets[name] = dataset + + data.vocabs['char_labels'] = char_labels_vocab + + char_vocab = Vocabulary(min_freq=2).from_dataset(data.datasets['train'], field_name='chars') + bigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='bigrams') + trigram_vocab = Vocabulary(min_freq=5).from_dataset(data.datasets['train'], field_name='trigrams') + + for name in ['chars', 'bigrams', 'trigrams']: + vocab = Vocabulary().from_dataset(field_name=name, no_create_entry_dataset=list(data.datasets.values())) + vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name='pre_' + name) + data.vocabs['pre_{}'.format(name)] = vocab + + for name, vocab in zip(['chars', 'bigrams', 'trigrams'], + [char_vocab, bigram_vocab, trigram_vocab]): + vocab.index_dataset(*data.datasets.values(), field_name=name, new_field_name=name) + data.vocabs[name] = vocab + + for name, dataset in data.datasets.items(): + dataset.set_input('chars', 'bigrams', 'trigrams', 'seq_lens', 'char_labels', 'char_heads', 'pre_chars', + 'pre_bigrams', 'pre_trigrams') + dataset.set_target('gold_word_pairs', 'seq_lens', 'seg_targets', 'seg_masks', 'char_labels', + 'char_heads', + 'pun_masks', 'gold_label_word_pairs') + + return data + + +def add_label_word_pairs(instance, label_vocab): + # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] + word_end_indexes = np.array(list(map(len, instance['word_lst']))) + word_end_indexes = np.cumsum(word_end_indexes).tolist() + word_end_indexes.insert(0, 0) + word_pairs = [] + labels = instance['labels'] + pos_tags = instance['pos_tags'] + for idx, head in enumerate(instance['heads']): + if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 + continue + label = label_vocab.to_index(labels[idx]) + if head==0: + word_pairs.append((('root', label, (word_end_indexes[idx], word_end_indexes[idx+1])))) + else: + word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), label, + (word_end_indexes[idx], word_end_indexes[idx + 1]))) + return word_pairs + +def add_word_pairs(instance): + # List[List[((head_start, head_end], (dep_start, dep_end]), ...]] + word_end_indexes = np.array(list(map(len, instance['word_lst']))) + word_end_indexes = np.cumsum(word_end_indexes).tolist() + word_end_indexes.insert(0, 0) + word_pairs = [] + pos_tags = instance['pos_tags'] + for idx, head in enumerate(instance['heads']): + if pos_tags[idx]=='PU': # 如果是标点符号,就不记录 + continue + if head==0: + word_pairs.append((('root', (word_end_indexes[idx], word_end_indexes[idx+1])))) + else: + word_pairs.append(((word_end_indexes[head-1], word_end_indexes[head]), + (word_end_indexes[idx], word_end_indexes[idx + 1]))) + return word_pairs + +def add_root(dataset): + new_dataset = DataSet() + for sample in dataset: + chars = ['char_root'] + sample['chars'] + bigrams = ['bigram_root'] + sample['bigrams'] + trigrams = ['trigram_root'] + sample['trigrams'] + seq_lens = sample['seq_lens']+1 + char_labels = [0] + sample['char_labels'] + char_heads = [0] + sample['char_heads'] + sample['chars'] = chars + sample['bigrams'] = bigrams + sample['trigrams'] = trigrams + sample['seq_lens'] = seq_lens + sample['char_labels'] = char_labels + sample['char_heads'] = char_heads + new_dataset.append(sample) + return new_dataset + +def add_pun_masks(instance): + tags = instance['pos_tags'] + pun_masks = [] + for word, tag in zip(instance['words'], tags): + if tag=='PU': + pun_masks.extend([1]*len(word)) + else: + pun_masks.extend([0]*len(word)) + return pun_masks + +def add_word_lst(instance): + words = instance['words'] + word_lst = [list(word) for word in words] + return word_lst + +def add_bigram(instance): + chars = instance['chars'] + length = len(chars) + chars = chars + [''] + bigrams = [] + for i in range(length): + bigrams.append(''.join(chars[i:i + 2])) + return bigrams + +def add_trigram(instance): + chars = instance['chars'] + length = len(chars) + chars = chars + [''] * 2 + trigrams = [] + for i in range(length): + trigrams.append(''.join(chars[i:i + 3])) + return trigrams + +def add_char_heads(instance): + words = instance['word_lst'] + heads = instance['heads'] + char_heads = [] + char_index = 1 # 因此存在root节点所以需要从1开始 + head_end_indexes = np.cumsum(list(map(len, words))).tolist() + [0] # 因为root是0,0-1=-1 + for word, head in zip(words, heads): + char_head = [] + if len(word)>1: + char_head.append(char_index+1) + char_index += 1 + for _ in range(len(word)-2): + char_index += 1 + char_head.append(char_index) + char_index += 1 + char_head.append(head_end_indexes[head-1]) + char_heads.extend(char_head) + return char_heads + +def add_char_labels(instance): + """ + 将word_lst中的数据按照下面的方式设置label + 比如"复旦大学 位于 ", 对应的分词是"B M M E B E", 则对应的dependency是"复(dep)->旦(head)", "旦(dep)->大(head)".. + 对应的label是'app', 'app', 'app', , 而学的label就是复旦大学这个词的dependency label + :param instance: + :return: + """ + words = instance['word_lst'] + labels = instance['labels'] + char_labels = [] + for word, label in zip(words, labels): + for _ in range(len(word)-1): + char_labels.append('APP') + char_labels.append(label) + return char_labels + +# add seg_targets +def add_segs(instance): + words = instance['word_lst'] + segs = [0]*len(instance['chars']) + index = 0 + for word in words: + index = index + len(word) - 1 + segs[index] = len(word)-1 + index = index + 1 + return segs + +# add target_masks +def add_mask(instance): + words = instance['word_lst'] + mask = [] + for word in words: + mask.extend([0] * (len(word) - 1)) + mask.append(1) + return mask diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py new file mode 100644 index 00000000..1ed5ea2d --- /dev/null +++ b/reproduction/joint_cws_parse/models/CharParser.py @@ -0,0 +1,311 @@ + + + +from fastNLP.models.biaffine_parser import BiaffineParser +from fastNLP.models.biaffine_parser import ArcBiaffine, LabelBilinear + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from fastNLP.modules.dropout import TimestepDropout +from fastNLP.modules.encoder.variational_rnn import VarLSTM +from fastNLP import seq_len_to_mask +from fastNLP.modules import Embedding + + +def drop_input_independent(word_embeddings, dropout_emb): + batch_size, seq_length, _ = word_embeddings.size() + word_masks = word_embeddings.new(batch_size, seq_length).fill_(1 - dropout_emb) + word_masks = torch.bernoulli(word_masks) + word_masks = word_masks.unsqueeze(dim=2) + word_embeddings = word_embeddings * word_masks + + return word_embeddings + + +class CharBiaffineParser(BiaffineParser): + def __init__(self, char_vocab_size, + emb_dim, + bigram_vocab_size, + trigram_vocab_size, + num_label, + rnn_layers=3, + rnn_hidden_size=800, #单向的数量 + arc_mlp_size=500, + label_mlp_size=100, + dropout=0.3, + encoder='lstm', + use_greedy_infer=False, + app_index = 0, + pre_chars_embed=None, + pre_bigrams_embed=None, + pre_trigrams_embed=None): + + + super(BiaffineParser, self).__init__() + rnn_out_size = 2 * rnn_hidden_size + self.char_embed = Embedding((char_vocab_size, emb_dim)) + self.bigram_embed = Embedding((bigram_vocab_size, emb_dim)) + self.trigram_embed = Embedding((trigram_vocab_size, emb_dim)) + if pre_chars_embed: + self.pre_char_embed = Embedding(pre_chars_embed) + self.pre_char_embed.requires_grad = False + if pre_bigrams_embed: + self.pre_bigram_embed = Embedding(pre_bigrams_embed) + self.pre_bigram_embed.requires_grad = False + if pre_trigrams_embed: + self.pre_trigram_embed = Embedding(pre_trigrams_embed) + self.pre_trigram_embed.requires_grad = False + self.timestep_drop = TimestepDropout(dropout) + self.encoder_name = encoder + + if encoder == 'var-lstm': + self.encoder = VarLSTM(input_size=emb_dim*3, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + input_dropout=dropout, + hidden_dropout=dropout, + bidirectional=True) + elif encoder == 'lstm': + self.encoder = nn.LSTM(input_size=emb_dim*3, + hidden_size=rnn_hidden_size, + num_layers=rnn_layers, + bias=True, + batch_first=True, + dropout=dropout, + bidirectional=True) + + else: + raise ValueError('unsupported encoder type: {}'.format(encoder)) + + self.mlp = nn.Sequential(nn.Linear(rnn_out_size, arc_mlp_size * 2 + label_mlp_size * 2), + nn.LeakyReLU(0.1), + TimestepDropout(p=dropout),) + self.arc_mlp_size = arc_mlp_size + self.label_mlp_size = label_mlp_size + self.arc_predictor = ArcBiaffine(arc_mlp_size, bias=True) + self.label_predictor = LabelBilinear(label_mlp_size, label_mlp_size, num_label, bias=True) + self.use_greedy_infer = use_greedy_infer + self.reset_parameters() + self.dropout = dropout + + self.app_index = app_index + self.num_label = num_label + if self.app_index != 0: + raise ValueError("现在app_index必须等于0") + + def reset_parameters(self): + for name, m in self.named_modules(): + if 'embed' in name: + pass + elif hasattr(m, 'reset_parameters') or hasattr(m, 'init_param'): + pass + else: + for p in m.parameters(): + if len(p.size())>1: + nn.init.xavier_normal_(p, gain=0.1) + else: + nn.init.uniform_(p, -0.1, 0.1) + + def forward(self, chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=None, pre_bigrams=None, + pre_trigrams=None): + """ + max_len是包含root的 + :param chars: batch_size x max_len + :param ngrams: batch_size x max_len*ngram_per_char + :param seq_lens: batch_size + :param gold_heads: batch_size x max_len + :param pre_chars: batch_size x max_len + :param pre_ngrams: batch_size x max_len*ngram_per_char + :return dict: parsing results + arc_pred: [batch_size, seq_len, seq_len] + label_pred: [batch_size, seq_len, seq_len] + mask: [batch_size, seq_len] + head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads + """ + # prepare embeddings + batch_size, seq_len = chars.shape + # print('forward {} {}'.format(batch_size, seq_len)) + + # get sequence mask + mask = seq_len_to_mask(seq_lens).long() + + chars = self.char_embed(chars) # [N,L] -> [N,L,C_0] + bigrams = self.bigram_embed(bigrams) # [N,L] -> [N,L,C_1] + trigrams = self.trigram_embed(trigrams) + + if pre_chars is not None: + pre_chars = self.pre_char_embed(pre_chars) + # pre_chars = self.pre_char_fc(pre_chars) + chars = pre_chars + chars + if pre_bigrams is not None: + pre_bigrams = self.pre_bigram_embed(pre_bigrams) + # pre_bigrams = self.pre_bigram_fc(pre_bigrams) + bigrams = bigrams + pre_bigrams + if pre_trigrams is not None: + pre_trigrams = self.pre_trigram_embed(pre_trigrams) + # pre_trigrams = self.pre_trigram_fc(pre_trigrams) + trigrams = trigrams + pre_trigrams + + x = torch.cat([chars, bigrams, trigrams], dim=2) # -> [N,L,C] + + # encoder, extract features + if self.training: + x = drop_input_independent(x, self.dropout) + sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) + x = x[sort_idx] + x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) + feat, _ = self.encoder(x) # -> [N,L,C] + feat, _ = nn.utils.rnn.pad_packed_sequence(feat, batch_first=True) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + feat = feat[unsort_idx] + feat = self.timestep_drop(feat) + + # for arc biaffine + # mlp, reduce dim + feat = self.mlp(feat) + arc_sz, label_sz = self.arc_mlp_size, self.label_mlp_size + arc_dep, arc_head = feat[:,:,:arc_sz], feat[:,:,arc_sz:2*arc_sz] + label_dep, label_head = feat[:,:,2*arc_sz:2*arc_sz+label_sz], feat[:,:,2*arc_sz+label_sz:] + + # biaffine arc classifier + arc_pred = self.arc_predictor(arc_head, arc_dep) # [N, L, L] + + # use gold or predicted arc to predict label + if gold_heads is None or not self.training: + # use greedy decoding in training + if self.training or self.use_greedy_infer: + heads = self.greedy_decoder(arc_pred, mask) + else: + heads = self.mst_decoder(arc_pred, mask) + head_pred = heads + else: + assert self.training # must be training mode + if gold_heads is None: + heads = self.greedy_decoder(arc_pred, mask) + head_pred = heads + else: + head_pred = None + heads = gold_heads + # heads: batch_size x max_len + + batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=chars.device).unsqueeze(1) + label_head = label_head[batch_range, heads].contiguous() + label_pred = self.label_predictor(label_head, label_dep) # [N, max_len, num_label] + # 这里限制一下,只有当head为下一个时,才能预测app这个label + arange_index = torch.arange(1, seq_len+1, dtype=torch.long, device=chars.device).unsqueeze(0)\ + .repeat(batch_size, 1) # batch_size x max_len + app_masks = heads.ne(arange_index) # batch_size x max_len, 为1的位置不可以预测app + app_masks = app_masks.unsqueeze(2).repeat(1, 1, self.num_label) + app_masks[:, :, 1:] = 0 + label_pred = label_pred.masked_fill(app_masks, -np.inf) + + res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} + if head_pred is not None: + res_dict['head_pred'] = head_pred + return res_dict + + @staticmethod + def loss(arc_pred, label_pred, arc_true, label_true, mask): + """ + Compute loss. + + :param arc_pred: [batch_size, seq_len, seq_len] + :param label_pred: [batch_size, seq_len, n_tags] + :param arc_true: [batch_size, seq_len] + :param label_true: [batch_size, seq_len] + :param mask: [batch_size, seq_len] + :return: loss value + """ + + batch_size, seq_len, _ = arc_pred.shape + flip_mask = (mask == 0) + _arc_pred = arc_pred.clone() + _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) + + arc_true[:, 0].fill_(-1) + label_true[:, 0].fill_(-1) + + arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) + label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) + + return arc_nll + label_nll + + def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars, pre_bigrams, pre_trigrams): + """ + + max_len是包含root的 + + :param chars: batch_size x max_len + :param ngrams: batch_size x max_len*ngram_per_char + :param seq_lens: batch_size + :param pre_chars: batch_size x max_len + :param pre_ngrams: batch_size x max_len*ngram_per_cha + :return: + """ + res = self(chars, bigrams, trigrams, seq_lens, pre_chars=pre_chars, pre_bigrams=pre_bigrams, + pre_trigrams=pre_trigrams, gold_heads=None) + output = {} + output['arc_pred'] = res.pop('head_pred') + _, label_pred = res.pop('label_pred').max(2) + output['label_pred'] = label_pred + return output + +class CharParser(nn.Module): + def __init__(self, char_vocab_size, + emb_dim, + bigram_vocab_size, + trigram_vocab_size, + num_label, + rnn_layers=3, + rnn_hidden_size=400, #单向的数量 + arc_mlp_size=500, + label_mlp_size=100, + dropout=0.3, + encoder='var-lstm', + use_greedy_infer=False, + app_index = 0, + pre_chars_embed=None, + pre_bigrams_embed=None, + pre_trigrams_embed=None): + super().__init__() + + self.parser = CharBiaffineParser(char_vocab_size, + emb_dim, + bigram_vocab_size, + trigram_vocab_size, + num_label, + rnn_layers, + rnn_hidden_size, #单向的数量 + arc_mlp_size, + label_mlp_size, + dropout, + encoder, + use_greedy_infer, + app_index, + pre_chars_embed=pre_chars_embed, + pre_bigrams_embed=pre_bigrams_embed, + pre_trigrams_embed=pre_trigrams_embed) + + def forward(self, chars, bigrams, trigrams, seq_lens, char_heads, char_labels, pre_chars=None, pre_bigrams=None, + pre_trigrams=None): + res_dict = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=char_heads, pre_chars=pre_chars, + pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) + arc_pred = res_dict['arc_pred'] + label_pred = res_dict['label_pred'] + masks = res_dict['mask'] + loss = self.parser.loss(arc_pred, label_pred, char_heads, char_labels, masks) + return {'loss': loss} + + def predict(self, chars, bigrams, trigrams, seq_lens, pre_chars=None, pre_bigrams=None, pre_trigrams=None): + res = self.parser(chars, bigrams, trigrams, seq_lens, gold_heads=None, pre_chars=pre_chars, + pre_bigrams=pre_bigrams, pre_trigrams=pre_trigrams) + output = {} + output['head_preds'] = res.pop('head_pred') + _, label_pred = res.pop('label_pred').max(2) + output['label_preds'] = label_pred + return output diff --git a/reproduction/joint_cws_parse/models/__init__.py b/reproduction/joint_cws_parse/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/joint_cws_parse/models/callbacks.py b/reproduction/joint_cws_parse/models/callbacks.py new file mode 100644 index 00000000..8de01109 --- /dev/null +++ b/reproduction/joint_cws_parse/models/callbacks.py @@ -0,0 +1,65 @@ + +from fastNLP.core.callback import Callback +import torch +from torch import nn + +class OptimizerCallback(Callback): + def __init__(self, optimizer, scheduler, update_every=4): + super().__init__() + + self._optimizer = optimizer + self.scheduler = scheduler + self._update_every = update_every + + def on_backward_end(self): + if self.step % self._update_every==0: + # nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), 5) + # self._optimizer.step() + self.scheduler.step() + # self.model.zero_grad() + + +class DevCallback(Callback): + def __init__(self, tester, metric_key='u_f1'): + super().__init__() + self.tester = tester + setattr(tester, 'verbose', 0) + + self.metric_key = metric_key + + self.record_best = False + self.best_eval_value = 0 + self.best_eval_res = None + + self.best_dev_res = None # 存取dev的表现 + + def on_valid_begin(self): + eval_res = self.tester.test() + metric_name = self.tester.metrics[0].__class__.__name__ + metric_value = eval_res[metric_name][self.metric_key] + if metric_value>self.best_eval_value: + self.best_eval_value = metric_value + self.best_epoch = self.trainer.epoch + self.record_best = True + self.best_eval_res = eval_res + self.test_eval_res = eval_res + eval_str = "Epoch {}/{}. \n".format(self.trainer.epoch, self.n_epochs) + \ + self.tester._format_eval_results(eval_res) + self.pbar.write(eval_str) + + def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): + if self.record_best: + self.best_dev_res = eval_result + self.record_best = False + if is_better_eval: + self.best_dev_res_on_dev = eval_result + self.best_test_res_on_dev = self.test_eval_res + self.dev_epoch = self.epoch + + def on_train_end(self): + print("Got best test performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.best_epoch, + self.tester._format_eval_results(self.best_eval_res), + self.tester._format_eval_results(self.best_dev_res))) + print("Got best dev performance in epoch:{}\n Test: {}\n Dev:{}\n".format(self.dev_epoch, + self.tester._format_eval_results(self.best_test_res_on_dev), + self.tester._format_eval_results(self.best_dev_res_on_dev))) \ No newline at end of file diff --git a/reproduction/joint_cws_parse/models/metrics.py b/reproduction/joint_cws_parse/models/metrics.py new file mode 100644 index 00000000..bf0f0622 --- /dev/null +++ b/reproduction/joint_cws_parse/models/metrics.py @@ -0,0 +1,184 @@ +from fastNLP.core.metrics import MetricBase +from fastNLP.core.utils import seq_len_to_mask +import torch + + +class SegAppCharParseF1Metric(MetricBase): + # + def __init__(self, app_index): + super().__init__() + self.app_index = app_index + + self.parse_head_tp = 0 + self.parse_label_tp = 0 + self.rec_tol = 0 + self.pre_tol = 0 + + def evaluate(self, gold_word_pairs, gold_label_word_pairs, head_preds, label_preds, seq_lens, + pun_masks): + """ + + max_len是不包含root的character的长度 + :param gold_word_pairs: List[List[((head_start, head_end), (dep_start, dep_end)), ...]], batch_size + :param gold_label_word_pairs: List[List[((head_start, head_end), label, (dep_start, dep_end)), ...]], batch_size + :param head_preds: batch_size x max_len + :param label_preds: batch_size x max_len + :param seq_lens: + :param pun_masks: batch_size x + :return: + """ + # 去掉root + head_preds = head_preds[:, 1:].tolist() + label_preds = label_preds[:, 1:].tolist() + seq_lens = (seq_lens - 1).tolist() + + # 先解码出words,POS,heads, labels, 对应的character范围 + for b in range(len(head_preds)): + seq_len = seq_lens[b] + head_pred = head_preds[b][:seq_len] + label_pred = label_preds[b][:seq_len] + + words = [] # 存放[word_start, word_end),相对起始位置,不考虑root + heads = [] + labels = [] + ranges = [] # 对应该char是第几个word,长度是seq_len+1 + word_idx = 0 + word_start_idx = 0 + for idx, (label, head) in enumerate(zip(label_pred, head_pred)): + ranges.append(word_idx) + if label == self.app_index: + pass + else: + labels.append(label) + heads.append(head) + words.append((word_start_idx, idx+1)) + word_start_idx = idx+1 + word_idx += 1 + + head_dep_tuple = [] # head在前面 + head_label_dep_tuple = [] + for idx, head in enumerate(heads): + span = words[idx] + if span[0]==span[1]-1 and pun_masks[b, span[0]]: + continue # exclude punctuations + if head == 0: + head_dep_tuple.append((('root', words[idx]))) + head_label_dep_tuple.append(('root', labels[idx], words[idx])) + else: + head_word_idx = ranges[head-1] + head_word_span = words[head_word_idx] + head_dep_tuple.append(((head_word_span, words[idx]))) + head_label_dep_tuple.append((head_word_span, labels[idx], words[idx])) + + gold_head_dep_tuple = set(gold_word_pairs[b]) + gold_head_label_dep_tuple = set(gold_label_word_pairs[b]) + + for head_dep, head_label_dep in zip(head_dep_tuple, head_label_dep_tuple): + if head_dep in gold_head_dep_tuple: + self.parse_head_tp += 1 + if head_label_dep in gold_head_label_dep_tuple: + self.parse_label_tp += 1 + self.pre_tol += len(head_dep_tuple) + self.rec_tol += len(gold_head_dep_tuple) + + def get_metric(self, reset=True): + u_p = self.parse_head_tp / self.pre_tol + u_r = self.parse_head_tp / self.rec_tol + u_f = 2*u_p*u_r/(1e-6 + u_p + u_r) + l_p = self.parse_label_tp / self.pre_tol + l_r = self.parse_label_tp / self.rec_tol + l_f = 2*l_p*l_r/(1e-6 + l_p + l_r) + + if reset: + self.parse_head_tp = 0 + self.parse_label_tp = 0 + self.rec_tol = 0 + self.pre_tol = 0 + + return {'u_f1': round(u_f, 4), 'u_p': round(u_p, 4), 'u_r/uas':round(u_r, 4), + 'l_f1': round(l_f, 4), 'l_p': round(l_p, 4), 'l_r/las': round(l_r, 4)} + + +class CWSMetric(MetricBase): + def __init__(self, app_index): + super().__init__() + self.app_index = app_index + self.pre = 0 + self.rec = 0 + self.tp = 0 + + def evaluate(self, seg_targets, seg_masks, label_preds, seq_lens): + """ + + :param seg_targets: batch_size x max_len, 每个位置预测的是该word的长度-1,在word结束的地方。 + :param seg_masks: batch_size x max_len,只有在word结束的地方为1 + :param label_preds: batch_size x max_len + :param seq_lens: batch_size + :return: + """ + + pred_masks = torch.zeros_like(seg_masks) + pred_segs = torch.zeros_like(seg_targets) + + seq_lens = (seq_lens - 1).tolist() + for idx, label_pred in enumerate(label_preds[:, 1:].tolist()): + seq_len = seq_lens[idx] + label_pred = label_pred[:seq_len] + word_len = 0 + for l_i, label in enumerate(label_pred): + if label==self.app_index and l_i!=len(label_pred)-1: + word_len += 1 + else: + pred_segs[idx, l_i] = word_len # 这个词的长度为word_len + pred_masks[idx, l_i] = 1 + word_len = 0 + + right_mask = seg_targets.eq(pred_segs) # 对长度的预测一致 + self.rec += seg_masks.sum().item() + self.pre += pred_masks.sum().item() + # 且pred和target在同一个地方有值 + self.tp += (right_mask.__and__(pred_masks.byte().__and__(seg_masks.byte()))).sum().item() + + def get_metric(self, reset=True): + res = {} + res['rec'] = round(self.tp/(self.rec+1e-6), 4) + res['pre'] = round(self.tp/(self.pre+1e-6), 4) + res['f1'] = round(2*res['rec']*res['pre']/(res['pre'] + res['rec'] + 1e-6), 4) + + if reset: + self.pre = 0 + self.rec = 0 + self.tp = 0 + + return res + + +class ParserMetric(MetricBase): + def __init__(self, ): + super().__init__() + self.num_arc = 0 + self.num_label = 0 + self.num_sample = 0 + + def get_metric(self, reset=True): + res = {'UAS': round(self.num_arc*1.0 / self.num_sample, 4), + 'LAS': round(self.num_label*1.0 / self.num_sample, 4)} + if reset: + self.num_sample = self.num_label = self.num_arc = 0 + return res + + def evaluate(self, head_preds, label_preds, heads, labels, seq_lens=None): + """Evaluate the performance of prediction. + """ + if seq_lens is None: + seq_mask = head_preds.new_ones(head_preds.size(), dtype=torch.byte) + else: + seq_mask = seq_len_to_mask(seq_lens.long(), float=False) + # mask out tag + seq_mask[:, 0] = 0 + head_pred_correct = (head_preds == heads).__and__(seq_mask) + label_pred_correct = (label_preds == labels).__and__(head_pred_correct) + self.num_arc += head_pred_correct.float().sum().item() + self.num_label += label_pred_correct.float().sum().item() + self.num_sample += seq_mask.sum().item() + diff --git a/reproduction/joint_cws_parse/readme.md b/reproduction/joint_cws_parse/readme.md new file mode 100644 index 00000000..7fe77b47 --- /dev/null +++ b/reproduction/joint_cws_parse/readme.md @@ -0,0 +1,16 @@ +Code for paper [A Unified Model for Chinese Word Segmentation and Dependency Parsing](https://arxiv.org/abs/1904.04697) + +### 准备数据 +1. 数据应该为conll格式,1, 3, 6, 7列应该对应为'words', 'pos_tags', 'heads', 'labels'. +2. 将train, dev, test放在同一个folder下,并将该folder路径填入train.py中的data_folder变量里。 +3. 从[百度云](https://pan.baidu.com/s/1uXnAZpYecYJITCiqgAjjjA)(提取:ua53)下载预训练vector,放到同一个folder下,并将train.py中vector_folder变量正确设置。 + + +### 运行代码 +``` +python train.py +``` + +### 其它 +ctb5上跑出论文中报道的结果使用以上的默认参数应该就可以了(应该会更高一些); ctb7上使用默认参数会低0.1%左右,需要调节 +learning rate scheduler. \ No newline at end of file diff --git a/reproduction/joint_cws_parse/train.py b/reproduction/joint_cws_parse/train.py new file mode 100644 index 00000000..2f8b0d04 --- /dev/null +++ b/reproduction/joint_cws_parse/train.py @@ -0,0 +1,124 @@ +import sys +sys.path.append('../..') + +from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader +from fastNLP.modules.encoder.embedding import StaticEmbedding +from torch import nn +from functools import partial +from reproduction.joint_cws_parse.models.CharParser import CharParser +from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric +from fastNLP import cache_results, BucketSampler, Trainer +from torch import optim +from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback +from torch.optim.lr_scheduler import LambdaLR, StepLR +from fastNLP import Tester +from fastNLP import GradientClipCallback, LRScheduler +import os + +def set_random_seed(random_seed=666): + import random, numpy, torch + random.seed(random_seed) + numpy.random.seed(random_seed) + torch.cuda.manual_seed(random_seed) + torch.random.manual_seed(random_seed) + +uniform_init = partial(nn.init.normal_, std=0.02) + +################################################### +# 需要变动的超参放到这里 +lr = 0.002 # 0.01~0.001 +dropout = 0.33 # 0.3~0.6 +weight_decay = 0 # 1e-5, 1e-6, 0 +arc_mlp_size = 500 # 200, 300 +rnn_hidden_size = 400 # 200, 300, 400 +rnn_layers = 3 # 2, 3 +encoder = 'var-lstm' # var-lstm, lstm +emb_size = 100 # 64 , 100 +label_mlp_size = 100 + +batch_size = 32 +update_every = 4 +n_epochs = 100 +data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 +vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt +#################################################### + +set_random_seed(1234) +device = 0 + +# @cache_results('caches/{}.pkl'.format(data_name)) +# def get_data(): +data = CTBxJointLoader().process(data_folder) + +char_labels_vocab = data.vocabs['char_labels'] + +pre_chars_vocab = data.vocabs['pre_chars'] +pre_bigrams_vocab = data.vocabs['pre_bigrams'] +pre_trigrams_vocab = data.vocabs['pre_trigrams'] + +chars_vocab = data.vocabs['chars'] +bigrams_vocab = data.vocabs['bigrams'] +trigrams_vocab = data.vocabs['trigrams'] + +pre_chars_embed = StaticEmbedding(pre_chars_vocab, + model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) +pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() +pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) +pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() +pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) +pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() + + # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data + +# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() + +print(data) +model = CharParser(char_vocab_size=len(chars_vocab), + emb_dim=emb_size, + bigram_vocab_size=len(bigrams_vocab), + trigram_vocab_size=len(trigrams_vocab), + num_label=len(char_labels_vocab), + rnn_layers=rnn_layers, + rnn_hidden_size=rnn_hidden_size, + arc_mlp_size=arc_mlp_size, + label_mlp_size=label_mlp_size, + dropout=dropout, + encoder=encoder, + use_greedy_infer=False, + app_index=char_labels_vocab['APP'], + pre_chars_embed=pre_chars_embed, + pre_bigrams_embed=pre_bigrams_embed, + pre_trigrams_embed=pre_trigrams_embed) + +metric1 = SegAppCharParseF1Metric(char_labels_vocab['APP']) +metric2 = CWSMetric(char_labels_vocab['APP']) +metrics = [metric1, metric2] + +optimizer = optim.Adam([param for param in model.parameters() if param.requires_grad], lr=lr, + weight_decay=weight_decay, betas=[0.9, 0.9]) + +sampler = BucketSampler(seq_len_field_name='seq_lens') +callbacks = [] +# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) +scheduler = StepLR(optimizer, step_size=18, gamma=0.75) +# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) +# callbacks.append(optim_callback) +scheduler_callback = LRScheduler(scheduler) +callbacks.append(scheduler_callback) +callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) + +tester = Tester(data=data.datasets['test'], model=model, metrics=metrics, + batch_size=64, device=device, verbose=0) +dev_callback = DevCallback(tester) +callbacks.append(dev_callback) + +trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, + validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, + check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, + device=device, callbacks=callbacks, update_every=update_every) +trainer.train() \ No newline at end of file