From c07cbd72f696564a77eb3e888b4d14618fe5b526 Mon Sep 17 00:00:00 2001 From: xxliu Date: Sat, 6 Jul 2019 18:28:46 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8C=87=E4=BB=A3=E6=B6=88=E8=A7=A3=E6=BA=90?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../coreference_resolution/__init__.py | 0 .../data_load/__init__.py | 0 .../data_load/cr_loader.py | 68 +++ .../coreference_resolution/model/__init__.py | 0 .../coreference_resolution/model/config.py | 54 ++ .../coreference_resolution/model/metric.py | 163 +++++ .../coreference_resolution/model/model_re.py | 576 ++++++++++++++++++ .../model/preprocess.py | 225 +++++++ .../model/softmax_loss.py | 32 + .../coreference_resolution/model/util.py | 101 +++ reproduction/coreference_resolution/readme.md | 49 ++ .../coreference_resolution/test/__init__.py | 0 .../test/test_dataloader.py | 14 + reproduction/coreference_resolution/train.py | 69 +++ reproduction/coreference_resolution/valid.py | 24 + 15 files changed, 1375 insertions(+) create mode 100644 reproduction/coreference_resolution/__init__.py create mode 100644 reproduction/coreference_resolution/data_load/__init__.py create mode 100644 reproduction/coreference_resolution/data_load/cr_loader.py create mode 100644 reproduction/coreference_resolution/model/__init__.py create mode 100644 reproduction/coreference_resolution/model/config.py create mode 100644 reproduction/coreference_resolution/model/metric.py create mode 100644 reproduction/coreference_resolution/model/model_re.py create mode 100644 reproduction/coreference_resolution/model/preprocess.py create mode 100644 reproduction/coreference_resolution/model/softmax_loss.py create mode 100644 reproduction/coreference_resolution/model/util.py create mode 100644 reproduction/coreference_resolution/readme.md create mode 100644 reproduction/coreference_resolution/test/__init__.py create mode 100644 reproduction/coreference_resolution/test/test_dataloader.py create mode 100644 reproduction/coreference_resolution/train.py create mode 100644 reproduction/coreference_resolution/valid.py diff --git a/reproduction/coreference_resolution/__init__.py b/reproduction/coreference_resolution/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/data_load/__init__.py b/reproduction/coreference_resolution/data_load/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py new file mode 100644 index 00000000..986afcd5 --- /dev/null +++ b/reproduction/coreference_resolution/data_load/cr_loader.py @@ -0,0 +1,68 @@ +from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance +from fastNLP.io.file_reader import _read_json +from fastNLP.core.vocabulary import Vocabulary +from fastNLP.io.base_loader import DataInfo +from reproduction.coreference_resolution.model.config import Config +import reproduction.coreference_resolution.model.preprocess as preprocess + + +class CRLoader(JsonLoader): + def __init__(self, fields=None, dropna=False): + super().__init__(fields, dropna) + + def _load(self, path): + """ + 加载数据 + :param path: + :return: + """ + dataset = DataSet() + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): + if self.fields: + ins = {self.fields[k]: v for k, v in d.items()} + else: + ins = d + dataset.append(Instance(**ins)) + return dataset + + def process(self, paths, **kwargs): + data_info = DataInfo() + for name in ['train', 'test', 'dev']: + data_info.datasets[name] = self.load(paths[name]) + + config = Config() + vocab = Vocabulary().from_dataset(*data_info.datasets.values(), field_name='sentences') + vocab.build_vocab() + word2id = vocab.word2idx + + char_dict = preprocess.get_char_dict(config.char_path) + data_info.vocabs = vocab + + genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} + + for name, ds in data_info.datasets.items(): + ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), + config.max_sentences, is_train=name=='train')[0], + new_field_name='doc_np') + ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), + config.max_sentences, is_train=name=='train')[1], + new_field_name='char_index') + ds.apply(lambda x: preprocess.doc2numpy(x['sentences'], word2id, char_dict, max(config.filter), + config.max_sentences, is_train=name=='train')[2], + new_field_name='seq_len') + ds.apply(lambda x: preprocess.speaker2numpy(x["speakers"], config.max_sentences, is_train=name=='train'), + new_field_name='speaker_ids_np') + ds.apply(lambda x: genres[x["doc_key"][:2]], new_field_name='genre') + + ds.set_ignore_type('clusters') + ds.set_padder('clusters', None) + ds.set_input("sentences", "doc_np", "speaker_ids_np", "genre", "char_index", "seq_len") + ds.set_target("clusters") + + # train_dev, test = self.ds.split(348 / (2802 + 343 + 348), shuffle=False) + # train, dev = train_dev.split(343 / (2802 + 343), shuffle=False) + + return data_info + + + diff --git a/reproduction/coreference_resolution/model/__init__.py b/reproduction/coreference_resolution/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/model/config.py b/reproduction/coreference_resolution/model/config.py new file mode 100644 index 00000000..6011257b --- /dev/null +++ b/reproduction/coreference_resolution/model/config.py @@ -0,0 +1,54 @@ +class Config(): + def __init__(self): + self.is_training = True + # path + self.glove = 'data/glove.840B.300d.txt.filtered' + self.turian = 'data/turian.50d.txt' + self.train_path = "data/train.english.jsonlines" + self.dev_path = "data/dev.english.jsonlines" + self.test_path = "data/test.english.jsonlines" + self.char_path = "data/char_vocab.english.txt" + + self.cuda = "0" + self.max_word = 1500 + self.epoch = 200 + + # config + # self.use_glove = True + # self.use_turian = True #No + self.use_elmo = False + self.use_CNN = True + self.model_heads = True #Yes + self.use_width = True # Yes + self.use_distance = True #Yes + self.use_metadata = True #Yes + + self.mention_ratio = 0.4 + self.max_sentences = 50 + self.span_width = 10 + self.feature_size = 20 #宽度信息emb的size + self.lr = 0.001 + self.lr_decay = 1e-3 + self.max_antecedents = 100 # 这个参数在mention detection中没有用 + self.atten_hidden_size = 150 + self.mention_hidden_size = 150 + self.sa_hidden_size = 150 + + self.char_emb_size = 8 + self.filter = [3,4,5] + + + # decay = 1e-5 + + def __str__(self): + d = self.__dict__ + out = 'config==============\n' + for i in list(d): + out += i+":" + out += str(d[i])+"\n" + out+="config==============\n" + return out + +if __name__=="__main__": + config = Config() + print(config) diff --git a/reproduction/coreference_resolution/model/metric.py b/reproduction/coreference_resolution/model/metric.py new file mode 100644 index 00000000..2c924660 --- /dev/null +++ b/reproduction/coreference_resolution/model/metric.py @@ -0,0 +1,163 @@ +from fastNLP.core.metrics import MetricBase + +import numpy as np + +from collections import Counter +from sklearn.utils.linear_assignment_ import linear_assignment + +""" +Mostly borrowed from https://github.com/clarkkev/deep-coref/blob/master/evaluation.py +""" + + + +class CRMetric(MetricBase): + def __init__(self): + super().__init__() + self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] + + # TODO 改名为evaluate,输入也 + def evaluate(self, predicted, mention_to_predicted,clusters): + for e in self.evaluators: + e.update(predicted,mention_to_predicted, clusters) + + def get_f1(self): + return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) + + def get_recall(self): + return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) + + def get_precision(self): + return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) + + # TODO 原本的getprf + def get_metric(self,reset=False): + res = {"pre":self.get_precision(), "rec":self.get_recall(), "f":self.get_f1()} + self.evaluators = [Evaluator(m) for m in (muc, b_cubed, ceafe)] + return res + + + + + + +class Evaluator(): + def __init__(self, metric, beta=1): + self.p_num = 0 + self.p_den = 0 + self.r_num = 0 + self.r_den = 0 + self.metric = metric + self.beta = beta + + def update(self, predicted,mention_to_predicted,gold): + gold = gold[0].tolist() + gold = [tuple(tuple(m) for m in gc) for gc in gold] + mention_to_gold = {} + for gc in gold: + for mention in gc: + mention_to_gold[mention] = gc + + if self.metric == ceafe: + pn, pd, rn, rd = self.metric(predicted, gold) + else: + pn, pd = self.metric(predicted, mention_to_gold) + rn, rd = self.metric(gold, mention_to_predicted) + self.p_num += pn + self.p_den += pd + self.r_num += rn + self.r_den += rd + + def get_f1(self): + return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) + + def get_recall(self): + return 0 if self.r_num == 0 else self.r_num / float(self.r_den) + + def get_precision(self): + return 0 if self.p_num == 0 else self.p_num / float(self.p_den) + + def get_prf(self): + return self.get_precision(), self.get_recall(), self.get_f1() + + def get_counts(self): + return self.p_num, self.p_den, self.r_num, self.r_den + + + +def b_cubed(clusters, mention_to_gold): + num, dem = 0, 0 + + for c in clusters: + if len(c) == 1: + continue + + gold_counts = Counter() + correct = 0 + for m in c: + if m in mention_to_gold: + gold_counts[tuple(mention_to_gold[m])] += 1 + for c2, count in gold_counts.items(): + if len(c2) != 1: + correct += count * count + + num += correct / float(len(c)) + dem += len(c) + + return num, dem + + +def muc(clusters, mention_to_gold): + tp, p = 0, 0 + for c in clusters: + p += len(c) - 1 + tp += len(c) + linked = set() + for m in c: + if m in mention_to_gold: + linked.add(mention_to_gold[m]) + else: + tp -= 1 + tp -= len(linked) + return tp, p + + +def phi4(c1, c2): + return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) + + +def ceafe(clusters, gold_clusters): + clusters = [c for c in clusters if len(c) != 1] + scores = np.zeros((len(gold_clusters), len(clusters))) + for i in range(len(gold_clusters)): + for j in range(len(clusters)): + scores[i, j] = phi4(gold_clusters[i], clusters[j]) + matching = linear_assignment(-scores) + similarity = sum(scores[matching[:, 0], matching[:, 1]]) + return similarity, len(clusters), similarity, len(gold_clusters) + + +def lea(clusters, mention_to_gold): + num, dem = 0, 0 + + for c in clusters: + if len(c) == 1: + continue + + common_links = 0 + all_links = len(c) * (len(c) - 1) / 2.0 + for i, m in enumerate(c): + if m in mention_to_gold: + for m2 in c[i + 1:]: + if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: + common_links += 1 + + num += len(c) * common_links / float(all_links) + dem += len(c) + + return num, dem + +def f1(p_num, p_den, r_num, r_den, beta=1): + p = 0 if p_den == 0 else p_num / float(p_den) + r = 0 if r_den == 0 else r_num / float(r_den) + return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) diff --git a/reproduction/coreference_resolution/model/model_re.py b/reproduction/coreference_resolution/model/model_re.py new file mode 100644 index 00000000..9dd90ec4 --- /dev/null +++ b/reproduction/coreference_resolution/model/model_re.py @@ -0,0 +1,576 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from allennlp.commands.elmo import ElmoEmbedder +from fastNLP.models.base_model import BaseModel +from fastNLP.modules.encoder.variational_rnn import VarLSTM +from reproduction.coreference_resolution.model import preprocess +from fastNLP.io.embed_loader import EmbedLoader +import random + +# 设置seed +torch.manual_seed(0) # cpu +torch.cuda.manual_seed(0) # gpu +np.random.seed(0) # numpy +random.seed(0) + + +class ffnn(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(ffnn, self).__init__() + + self.f = nn.Sequential( + # 多少层数 + nn.Linear(input_size, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.2), + nn.Linear(hidden_size, hidden_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.2), + nn.Linear(hidden_size, output_size) + ) + self.reset_param() + + def reset_param(self): + for name, param in self.named_parameters(): + if param.dim() > 1: + nn.init.xavier_normal_(param) + # param.data = torch.tensor(np.random.randn(*param.shape)).float() + else: + nn.init.zeros_(param) + + def forward(self, input): + return self.f(input).squeeze() + + +class Model(BaseModel): + def __init__(self, vocab, config): + word2id = vocab.word2idx + super(Model, self).__init__() + vocab_num = len(word2id) + self.word2id = word2id + self.config = config + self.char_dict = preprocess.get_char_dict('data/char_vocab.english.txt') + self.genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} + self.device = torch.device("cuda:" + config.cuda) + + self.emb = nn.Embedding(vocab_num, 350) + + emb1 = EmbedLoader().load_with_vocab(config.glove, vocab,normalize=False) + emb2 = EmbedLoader().load_with_vocab(config.turian, vocab ,normalize=False) + pre_emb = np.concatenate((emb1, emb2), axis=1) + pre_emb /= (np.linalg.norm(pre_emb, axis=1, keepdims=True) + 1e-12) + + if pre_emb is not None: + self.emb.weight = nn.Parameter(torch.from_numpy(pre_emb).float()) + for param in self.emb.parameters(): + param.requires_grad = False + self.emb_dropout = nn.Dropout(inplace=True) + + + if config.use_elmo: + self.elmo = ElmoEmbedder(options_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_options.json', + weight_file='data/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', + cuda_device=int(config.cuda)) + print("elmo load over.") + self.elmo_args = torch.randn((3), requires_grad=True).to(self.device) + + self.char_emb = nn.Embedding(len(self.char_dict), config.char_emb_size) + self.conv1 = nn.Conv1d(config.char_emb_size, 50, 3) + self.conv2 = nn.Conv1d(config.char_emb_size, 50, 4) + self.conv3 = nn.Conv1d(config.char_emb_size, 50, 5) + + self.feature_emb = nn.Embedding(config.span_width, config.feature_size) + self.feature_emb_dropout = nn.Dropout(p=0.2, inplace=True) + + self.mention_distance_emb = nn.Embedding(10, config.feature_size) + self.distance_drop = nn.Dropout(p=0.2, inplace=True) + + self.genre_emb = nn.Embedding(7, config.feature_size) + self.speaker_emb = nn.Embedding(2, config.feature_size) + + self.bilstm = VarLSTM(input_size=350+150*config.use_CNN+config.use_elmo*1024,hidden_size=200,bidirectional=True,batch_first=True,hidden_dropout=0.2) + # self.bilstm = nn.LSTM(input_size=500, hidden_size=200, bidirectional=True, batch_first=True) + self.h0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) + self.c0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device) + self.bilstm_drop = nn.Dropout(p=0.2, inplace=True) + + self.atten = ffnn(input_size=400, hidden_size=config.atten_hidden_size, output_size=1) + self.mention_score = ffnn(input_size=1320, hidden_size=config.mention_hidden_size, output_size=1) + self.sa = ffnn(input_size=3980+40*config.use_metadata, hidden_size=config.sa_hidden_size, output_size=1) + self.mention_start_np = None + self.mention_end_np = None + + def _reorder_lstm(self, word_emb, seq_lens): + sort_ind = sorted(range(len(seq_lens)), key=lambda i: seq_lens[i], reverse=True) + seq_lens_re = [seq_lens[i] for i in sort_ind] + emb_seq = self.reorder_sequence(word_emb, sort_ind, batch_first=True) + packed_seq = nn.utils.rnn.pack_padded_sequence(emb_seq, seq_lens_re, batch_first=True) + + h0 = self.h0.repeat(1, len(seq_lens), 1) + c0 = self.c0.repeat(1, len(seq_lens), 1) + packed_out, final_states = self.bilstm(packed_seq, (h0, c0)) + + lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True) + back_map = {ind: i for i, ind in enumerate(sort_ind)} + reorder_ind = [back_map[i] for i in range(len(seq_lens_re))] + lstm_out = self.reorder_sequence(lstm_out, reorder_ind, batch_first=True) + return lstm_out + + def reorder_sequence(self, sequence_emb, order, batch_first=True): + """ + sequence_emb: [T, B, D] if not batch_first + order: list of sequence length + """ + batch_dim = 0 if batch_first else 1 + assert len(order) == sequence_emb.size()[batch_dim] + + order = torch.LongTensor(order) + order = order.to(sequence_emb).long() + + sorted_ = sequence_emb.index_select(index=order, dim=batch_dim) + + del order + return sorted_ + + def flat_lstm(self, lstm_out, seq_lens): + batch = lstm_out.shape[0] + seq = lstm_out.shape[1] + dim = lstm_out.shape[2] + l = [j + i * seq for i, seq_len in enumerate(seq_lens) for j in range(seq_len)] + flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0, torch.LongTensor(l).to(self.device)) + return flatted + + def potential_mention_index(self, word_index, max_sent_len): + # get mention index [3,2]:the first sentence is 3 and secend 2 + # [0,0,0,1,1] --> [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [3, 3], [3, 4], [4, 4]] (max =2) + potential_mention = [] + for i in range(len(word_index)): + for j in range(i, i + max_sent_len): + if (j < len(word_index) and word_index[i] == word_index[j]): + potential_mention.append([i, j]) + return potential_mention + + def get_mention_start_end(self, seq_lens): + # 序列长度转换成mention + # [3,2] --> [0,0,0,1,1] + word_index = [0] * sum(seq_lens) + sent_index = 0 + index = 0 + for length in seq_lens: + for l in range(length): + word_index[index] = sent_index + index += 1 + sent_index += 1 + + # [0,0,0,1,1]-->[[0,0],[0,1],[0,2]....] + mention_id = self.potential_mention_index(word_index, self.config.span_width) + mention_start = np.array(mention_id, dtype=int)[:, 0] + mention_end = np.array(mention_id, dtype=int)[:, 1] + return mention_start, mention_end + + def get_mention_emb(self, flatten_lstm, mention_start, mention_end): + mention_start_tensor = torch.from_numpy(mention_start).to(self.device) + mention_end_tensor = torch.from_numpy(mention_end).to(self.device) + emb_start = flatten_lstm.index_select(dim=0, index=mention_start_tensor) # [mention_num,embed] + emb_end = flatten_lstm.index_select(dim=0, index=mention_end_tensor) # [mention_num,embed] + return emb_start, emb_end + + def get_mask(self, mention_start, mention_end): + # big mask for attention + mention_num = mention_start.shape[0] + mask = np.zeros((mention_num, self.config.span_width)) # [mention_num,span_width] + for i in range(mention_num): + start = mention_start[i] + end = mention_end[i] + # 实际上是宽度 + for j in range(end - start + 1): + mask[i][j] = 1 + mask = torch.from_numpy(mask) # [mention_num,max_mention] + # 0-->-inf 1-->0 + log_mask = torch.log(mask) + return log_mask + + def get_mention_index(self, mention_start, max_mention): + # TODO 后面可能要改 + assert len(mention_start.shape) == 1 + mention_start_tensor = torch.from_numpy(mention_start) + num_mention = mention_start_tensor.shape[0] + mention_index = mention_start_tensor.expand(max_mention, num_mention).transpose(0, + 1) # [num_mention,max_mention] + assert mention_index.shape[0] == num_mention + assert mention_index.shape[1] == max_mention + range_add = torch.arange(0, max_mention).expand(num_mention, max_mention).long() # [num_mention,max_mention] + mention_index = mention_index + range_add + mention_index = torch.min(mention_index, torch.LongTensor([mention_start[-1]]).expand(num_mention, max_mention)) + return mention_index.to(self.device) + + def sort_mention(self, mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_lens): + # 排序记录,高分段在前面 + mention_score, mention_ids = torch.sort(candidate_mention_score, descending=True) + preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens)) + mention_ids = mention_ids[0:preserve_mention_num] + mention_score = mention_score[0:preserve_mention_num] + + mention_start_tensor = torch.from_numpy(mention_start).to(self.device).index_select(dim=0, + index=mention_ids) # [lamda*word_num] + mention_end_tensor = torch.from_numpy(mention_end).to(self.device).index_select(dim=0, + index=mention_ids) # [lamda*word_num] + mention_emb = candidate_mention_emb.index_select(index=mention_ids, dim=0) # [lamda*word_num,emb] + assert mention_score.shape[0] == preserve_mention_num + assert mention_start_tensor.shape[0] == preserve_mention_num + assert mention_end_tensor.shape[0] == preserve_mention_num + assert mention_emb.shape[0] == preserve_mention_num + # TODO 不交叉没做处理 + + # 对start进行再排序,实际位置在前面 + # TODO 这里只考虑了start没有考虑end + mention_start_tensor, temp_index = torch.sort(mention_start_tensor) + mention_end_tensor = mention_end_tensor.index_select(dim=0, index=temp_index) + mention_emb = mention_emb.index_select(dim=0, index=temp_index) + mention_score = mention_score.index_select(dim=0, index=temp_index) + return mention_start_tensor, mention_end_tensor, mention_score, mention_emb + + def get_antecedents(self, mention_starts, max_antecedents): + num_mention = mention_starts.shape[0] + max_antecedents = min(max_antecedents, num_mention) + # mention和它是第几个mention之间的对应关系 + antecedents = np.zeros((num_mention, max_antecedents), dtype=int) # [num_mention,max_an] + # 记录长度 + antecedents_len = [0] * num_mention + for i in range(num_mention): + ante_count = 0 + for j in range(max(0, i - max_antecedents), i): + antecedents[i, ante_count] = j + ante_count += 1 + # 补位操作 + for j in range(ante_count, max_antecedents): + antecedents[i, j] = 0 + antecedents_len[i] = ante_count + assert antecedents.shape[1] == max_antecedents + return antecedents, antecedents_len + + def get_antecedents_score(self, span_represent, mention_score, antecedents, antecedents_len, mention_speakers_ids, + genre): + num_mention = mention_score.shape[0] + max_antecedent = antecedents.shape[1] + + pair_emb = self.get_pair_emb(span_represent, antecedents, mention_speakers_ids, genre) # [span_num,max_ant,emb] + antecedent_scores = self.sa(pair_emb) + mask01 = self.sequence_mask(antecedents_len, max_antecedent) + maskinf = torch.log(mask01).to(self.device) + assert maskinf.shape[1] <= max_antecedent + assert antecedent_scores.shape[0] == num_mention + antecedent_scores = antecedent_scores + maskinf + antecedents = torch.from_numpy(antecedents).to(self.device) + mention_scoreij = mention_score.unsqueeze(1) + torch.gather( + mention_score.unsqueeze(0).expand(num_mention, num_mention), dim=1, index=antecedents) + antecedent_scores += mention_scoreij + + antecedent_scores = torch.cat([torch.zeros([mention_score.shape[0], 1]).to(self.device), antecedent_scores], + 1) # [num_mentions, max_ant + 1] + return antecedent_scores + + ############################## + def distance_bin(self, mention_distance): + bins = torch.zeros(mention_distance.size()).byte().to(self.device) + rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31], [32, 63], [64, 300]] + for t, k in enumerate(rg): + i, j = k[0], k[1] + b = torch.LongTensor([i]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) + m1 = torch.ge(mention_distance, b) + e = torch.LongTensor([j]).unsqueeze(-1).expand(mention_distance.size()).to(self.device) + m2 = torch.le(mention_distance, e) + bins = bins + (t + 1) * (m1 & m2) + return bins.long() + + def get_distance_emb(self, antecedents_tensor): + num_mention = antecedents_tensor.shape[0] + max_ant = antecedents_tensor.shape[1] + + assert max_ant <= self.config.max_antecedents + source = torch.arange(0, num_mention).expand(max_ant, num_mention).transpose(0,1).to(self.device) # [num_mention,max_ant] + mention_distance = source - antecedents_tensor + mention_distance_bin = self.distance_bin(mention_distance) + distance_emb = self.mention_distance_emb(mention_distance_bin) + distance_emb = self.distance_drop(distance_emb) + return distance_emb + + def get_pair_emb(self, span_emb, antecedents, mention_speakers_ids, genre): + emb_dim = span_emb.shape[1] + num_span = span_emb.shape[0] + max_ant = antecedents.shape[1] + assert span_emb.shape[0] == antecedents.shape[0] + antecedents = torch.from_numpy(antecedents).to(self.device) + + # [num_span,max_ant,emb] + antecedent_emb = torch.gather(span_emb.unsqueeze(0).expand(num_span, num_span, emb_dim), dim=1, + index=antecedents.unsqueeze(2).expand(num_span, max_ant, emb_dim)) + # [num_span,max_ant,emb] + target_emb_tiled = span_emb.expand((max_ant, num_span, emb_dim)) + target_emb_tiled = target_emb_tiled.transpose(0, 1) + + similarity_emb = antecedent_emb * target_emb_tiled + + pair_emb_list = [target_emb_tiled, antecedent_emb, similarity_emb] + + # get speakers and genre + if self.config.use_metadata: + antecedent_speaker_ids = mention_speakers_ids.unsqueeze(0).expand(num_span, num_span).gather(dim=1, + index=antecedents) + same_speaker = torch.eq(mention_speakers_ids.unsqueeze(1).expand(num_span, max_ant), + antecedent_speaker_ids) # [num_mention,max_ant] + speaker_embedding = self.speaker_emb(same_speaker.long().to(self.device)) # [mention_num.max_ant,emb] + genre_embedding = self.genre_emb( + torch.LongTensor([genre]).expand(num_span, max_ant).to(self.device)) # [mention_num,max_ant,emb] + pair_emb_list.append(speaker_embedding) + pair_emb_list.append(genre_embedding) + + # get distance emb + if self.config.use_distance: + distance_emb = self.get_distance_emb(antecedents) + pair_emb_list.append(distance_emb) + + pair_emb = torch.cat(pair_emb_list, 2) + return pair_emb + + def sequence_mask(self, len_list, max_len): + x = np.zeros((len(len_list), max_len)) + for i in range(len(len_list)): + l = len_list[i] + for j in range(l): + x[i][j] = 1 + return torch.from_numpy(x).float() + + def logsumexp(self, value, dim=None, keepdim=False): + """Numerically stable implementation of the operation + + value.exp().sum(dim, keepdim).log() + """ + # TODO: torch.max(value, dim=None) threw an error at time of writing + if dim is not None: + m, _ = torch.max(value, dim=dim, keepdim=True) + value0 = value - m + if keepdim is False: + m = m.squeeze(dim) + return m + torch.log(torch.sum(torch.exp(value0), + dim=dim, keepdim=keepdim)) + else: + m = torch.max(value) + sum_exp = torch.sum(torch.exp(value - m)) + + return m + torch.log(sum_exp) + + def softmax_loss(self, antecedent_scores, antecedent_labels): + antecedent_labels = torch.from_numpy(antecedent_labels * 1).to(self.device) + gold_scores = antecedent_scores + torch.log(antecedent_labels.float()) # [num_mentions, max_ant + 1] + marginalized_gold_scores = self.logsumexp(gold_scores, 1) # [num_mentions] + log_norm = self.logsumexp(antecedent_scores, 1) # [num_mentions] + return torch.sum(log_norm - marginalized_gold_scores) # [num_mentions]reduce_logsumexp + + def get_predicted_antecedents(self, antecedents, antecedent_scores): + predicted_antecedents = [] + for i, index in enumerate(np.argmax(antecedent_scores.detach(), axis=1) - 1): + if index < 0: + predicted_antecedents.append(-1) + else: + predicted_antecedents.append(antecedents[i, index]) + return predicted_antecedents + + def get_predicted_clusters(self, mention_starts, mention_ends, predicted_antecedents): + mention_to_predicted = {} + predicted_clusters = [] + for i, predicted_index in enumerate(predicted_antecedents): + if predicted_index < 0: + continue + assert i > predicted_index + predicted_antecedent = (int(mention_starts[predicted_index]), int(mention_ends[predicted_index])) + if predicted_antecedent in mention_to_predicted: + predicted_cluster = mention_to_predicted[predicted_antecedent] + else: + predicted_cluster = len(predicted_clusters) + predicted_clusters.append([predicted_antecedent]) + mention_to_predicted[predicted_antecedent] = predicted_cluster + + mention = (int(mention_starts[i]), int(mention_ends[i])) + predicted_clusters[predicted_cluster].append(mention) + mention_to_predicted[mention] = predicted_cluster + + predicted_clusters = [tuple(pc) for pc in predicted_clusters] + mention_to_predicted = {m: predicted_clusters[i] for m, i in mention_to_predicted.items()} + + return predicted_clusters, mention_to_predicted + + def evaluate_coref(self, mention_starts, mention_ends, predicted_antecedents, gold_clusters, evaluator): + gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters] + mention_to_gold = {} + for gc in gold_clusters: + for mention in gc: + mention_to_gold[mention] = gc + predicted_clusters, mention_to_predicted = self.get_predicted_clusters(mention_starts, mention_ends, + predicted_antecedents) + evaluator.update(predicted_clusters, gold_clusters, mention_to_predicted, mention_to_gold) + return predicted_clusters + + + def forward(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): + """ + 实际输入都是tensor + :param sentences: 句子,被fastNLP转化成了numpy, + :param doc_np: 被fastNLP转化成了Tensor + :param speaker_ids_np: 被fastNLP转化成了Tensor + :param genre: 被fastNLP转化成了Tensor + :param char_index: 被fastNLP转化成了Tensor + :param seq_len: 被fastNLP转化成了Tensor + :return: + """ + # change for fastNLP + sentences = sentences[0].tolist() + doc_tensor = doc_np[0] + speakers_tensor = speaker_ids_np[0] + genre = genre[0].item() + char_index = char_index[0] + seq_len = seq_len[0].cpu().numpy() + + # 类型 + + # doc_tensor = torch.from_numpy(doc_np).to(self.device) + # speakers_tensor = torch.from_numpy(speaker_ids_np).to(self.device) + mention_emb_list = [] + + word_emb = self.emb(doc_tensor) + word_emb_list = [word_emb] + if self.config.use_CNN: + # [batch, length, char_length, char_dim] + char = self.char_emb(char_index) + char_size = char.size() + # first transform to [batch *length, char_length, char_dim] + # then transpose to [batch * length, char_dim, char_length] + char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2) + + # put into cnn [batch*length, char_filters, char_length] + # then put into maxpooling [batch * length, char_filters] + char_over_cnn, _ = self.conv1(char).max(dim=2) + # reshape to [batch, length, char_filters] + char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) + word_emb_list.append(char_over_cnn) + + char_over_cnn, _ = self.conv2(char).max(dim=2) + char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) + word_emb_list.append(char_over_cnn) + + char_over_cnn, _ = self.conv3(char).max(dim=2) + char_over_cnn = torch.tanh(char_over_cnn).view(char_size[0], char_size[1], -1) + word_emb_list.append(char_over_cnn) + + # word_emb = torch.cat(word_emb_list, dim=2) + + # use elmo or not + if self.config.use_elmo: + # 如果确实被截断了 + if doc_tensor.shape[0] == 50 and len(sentences) > 50: + sentences = sentences[0:50] + elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(sentences) + elmo_embedding = elmo_embedding.to( + self.device) # [sentence_num,max_sent_len,3,1024]--[sentence_num,max_sent,1024] + elmo_embedding = elmo_embedding[:, 0, :, :] * self.elmo_args[0] + elmo_embedding[:, 1, :, :] * \ + self.elmo_args[1] + elmo_embedding[:, 2, :, :] * self.elmo_args[2] + word_emb_list.append(elmo_embedding) + # print(word_emb_list[0].shape) + # print(word_emb_list[1].shape) + # print(word_emb_list[2].shape) + # print(word_emb_list[3].shape) + # print(word_emb_list[4].shape) + + word_emb = torch.cat(word_emb_list, dim=2) + + word_emb = self.emb_dropout(word_emb) + # word_emb_elmo = self.emb_dropout(word_emb_elmo) + lstm_out = self._reorder_lstm(word_emb, seq_len) + flatten_lstm = self.flat_lstm(lstm_out, seq_len) # [word_num,emb] + flatten_lstm = self.bilstm_drop(flatten_lstm) + # TODO 没有按照论文写 + flatten_word_emb = self.flat_lstm(word_emb, seq_len) # [word_num,emb] + + mention_start, mention_end = self.get_mention_start_end(seq_len) # [mention_num] + self.mention_start_np = mention_start # [mention_num] np + self.mention_end_np = mention_end + mention_num = mention_start.shape[0] + emb_start, emb_end = self.get_mention_emb(flatten_lstm, mention_start, mention_end) # [mention_num,emb] + + # list + mention_emb_list.append(emb_start) + mention_emb_list.append(emb_end) + + if self.config.use_width: + mention_width_index = mention_end - mention_start + mention_width_tensor = torch.from_numpy(mention_width_index).to(self.device) # [mention_num] + mention_width_emb = self.feature_emb(mention_width_tensor) + mention_width_emb = self.feature_emb_dropout(mention_width_emb) + mention_emb_list.append(mention_width_emb) + + if self.config.model_heads: + mention_index = self.get_mention_index(mention_start, self.config.span_width) # [mention_num,max_mention] + log_mask_tensor = self.get_mask(mention_start, mention_end).float().to( + self.device) # [mention_num,max_mention] + alpha = self.atten(flatten_lstm).to(self.device) # [word_num] + + # 得到attention + mention_head_score = torch.gather(alpha.expand(mention_num, -1), 1, + mention_index).float().to(self.device) # [mention_num,max_mention] + mention_attention = F.softmax(mention_head_score + log_mask_tensor, dim=1) # [mention_num,max_mention] + + # TODO flatte lstm + word_num = flatten_lstm.shape[0] + lstm_emb = flatten_lstm.shape[1] + emb_num = flatten_word_emb.shape[1] + + # [num_mentions, max_mention_width, emb] + mention_text_emb = torch.gather( + flatten_word_emb.unsqueeze(1).expand(word_num, self.config.span_width, emb_num), + 0, mention_index.unsqueeze(2).expand(mention_num, self.config.span_width, + emb_num)) + # [mention_num,emb] + mention_head_emb = torch.sum( + mention_attention.unsqueeze(2).expand(mention_num, self.config.span_width, emb_num) * mention_text_emb, + dim=1) + mention_emb_list.append(mention_head_emb) + + candidate_mention_emb = torch.cat(mention_emb_list, 1) # [candidate_mention_num,emb] + candidate_mention_score = self.mention_score(candidate_mention_emb) # [candidate_mention_num] + + antecedent_scores, antecedents, mention_start_tensor, mention_end_tensor = (None, None, None, None) + mention_start_tensor, mention_end_tensor, mention_score, mention_emb = \ + self.sort_mention(mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_len) + mention_speakers_ids = speakers_tensor.index_select(dim=0, index=mention_start_tensor) # num_mention + + antecedents, antecedents_len = self.get_antecedents(mention_start_tensor, self.config.max_antecedents) + antecedent_scores = self.get_antecedents_score(mention_emb, mention_score, antecedents, antecedents_len, + mention_speakers_ids, genre) + + ans = {"candidate_mention_score": candidate_mention_score, "antecedent_scores": antecedent_scores, + "antecedents": antecedents, "mention_start_tensor": mention_start_tensor, + "mention_end_tensor": mention_end_tensor} + + return ans + + def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len): + ans = self(sentences, + doc_np, + speaker_ids_np, + genre, + char_index, + seq_len) + + predicted_antecedents = self.get_predicted_antecedents(ans["antecedents"], ans["antecedent_scores"]) + predicted_clusters, mention_to_predicted = self.get_predicted_clusters(ans["mention_start_tensor"], + ans["mention_end_tensor"], + predicted_antecedents) + + return {'predicted':predicted_clusters,"mention_to_predicted":mention_to_predicted} + + +if __name__ == '__main__': + pass diff --git a/reproduction/coreference_resolution/model/preprocess.py b/reproduction/coreference_resolution/model/preprocess.py new file mode 100644 index 00000000..d97fcb4d --- /dev/null +++ b/reproduction/coreference_resolution/model/preprocess.py @@ -0,0 +1,225 @@ +import json +import numpy as np +from . import util +import collections + +def load(path): + """ + load the file from jsonline + :param path: + :return: examples with many example(dict): {"clusters":[[[mention],[mention]],[another cluster]], + "doc_key":"str","speakers":[[,,,],[]...],"sentence":[[][]]} + """ + with open(path) as f: + train_examples = [json.loads(jsonline) for jsonline in f.readlines()] + return train_examples + +def get_vocab(): + """ + 从所有的句子中得到最终的字典,被main调用,不止是train,还有dev和test + :param examples: + :return: word2id & id2word + """ + word2id = {'PAD':0,'UNK':1} + id2word = {0:'PAD',1:'UNK'} + index = 2 + data = [load("../data/train.english.jsonlines"),load("../data/dev.english.jsonlines"),load("../data/test.english.jsonlines")] + for examples in data: + for example in examples: + for sent in example["sentences"]: + for word in sent: + if(word not in word2id): + word2id[word]=index + id2word[index] = word + index += 1 + return word2id,id2word + +def normalize(v): + norm = np.linalg.norm(v) + if norm > 0: + return v / norm + else: + return v + +# 加载glove得到embedding +def get_emb(id2word,embedding_size): + glove_oov = 0 + turian_oov = 0 + both = 0 + glove_emb_path = "../data/glove.840B.300d.txt.filtered" + turian_emb_path = "../data/turian.50d.txt" + word_num = len(id2word) + emb = np.zeros((word_num,embedding_size)) + glove_emb_dict = util.load_embedding_dict(glove_emb_path,300,"txt") + turian_emb_dict = util.load_embedding_dict(turian_emb_path,50,"txt") + for i in range(word_num): + if id2word[i] in glove_emb_dict: + word_embedding = glove_emb_dict.get(id2word[i]) + emb[i][0:300] = np.array(word_embedding) + else: + # print(id2word[i]) + glove_oov += 1 + if id2word[i] in turian_emb_dict: + word_embedding = turian_emb_dict.get(id2word[i]) + emb[i][300:350] = np.array(word_embedding) + else: + # print(id2word[i]) + turian_oov += 1 + if id2word[i] not in glove_emb_dict and id2word[i] not in turian_emb_dict: + both += 1 + emb[i] = normalize(emb[i]) + print("embedding num:"+str(word_num)) + print("glove num:"+str(glove_oov)) + print("glove oov rate:"+str(glove_oov/word_num)) + print("turian num:"+str(turian_oov)) + print("turian oov rate:"+str(turian_oov/word_num)) + print("both num:"+str(both)) + return emb + + +def _doc2vec(doc,word2id,char_dict,max_filter,max_sentences,is_train): + max_len = 0 + max_word_length = 0 + docvex = [] + length = [] + if is_train: + sent_num = min(max_sentences,len(doc)) + else: + sent_num = len(doc) + + for i in range(sent_num): + sent = doc[i] + length.append(len(sent)) + if (len(sent) > max_len): + max_len = len(sent) + sent_vec =[] + for j,word in enumerate(sent): + if len(word)>max_word_length: + max_word_length = len(word) + if word in word2id: + sent_vec.append(word2id[word]) + else: + sent_vec.append(word2id["UNK"]) + docvex.append(sent_vec) + + char_index = np.zeros((sent_num, max_len, max_word_length),dtype=int) + for i in range(sent_num): + sent = doc[i] + for j,word in enumerate(sent): + char_index[i, j, :len(word)] = [char_dict[c] for c in word] + + return docvex,char_index,length,max_len + +# TODO 修改了接口,确认所有该修改的地方都修改好 +def doc2numpy(doc,word2id,chardict,max_filter,max_sentences,is_train): + docvec, char_index, length, max_len = _doc2vec(doc,word2id,chardict,max_filter,max_sentences,is_train) + assert max(length) == max_len + assert char_index.shape[0]==len(length) + assert char_index.shape[1]==max_len + doc_np = np.zeros((len(docvec), max_len), int) + for i in range(len(docvec)): + for j in range(len(docvec[i])): + doc_np[i][j] = docvec[i][j] + return doc_np,char_index,length + +# TODO 没有测试 +def speaker2numpy(speakers_raw,max_sentences,is_train): + if is_train and len(speakers_raw)> max_sentences: + speakers_raw = speakers_raw[0:max_sentences] + speakers = flatten(speakers_raw) + speaker_dict = {s: i for i, s in enumerate(set(speakers))} + speaker_ids = np.array([speaker_dict[s] for s in speakers]) + return speaker_ids + + +def flat_cluster(clusters): + flatted = [] + for cluster in clusters: + for item in cluster: + flatted.append(item) + return flatted + +def get_right_mention(clusters,mention_start_np,mention_end_np): + flatted = flat_cluster(clusters) + cluster_num = len(flatted) + mention_num = mention_start_np.shape[0] + right_mention = np.zeros(mention_num,dtype=int) + for i in range(mention_num): + if [mention_start_np[i],mention_end_np[i]] in flatted: + right_mention[i]=1 + return right_mention,cluster_num + +def handle_cluster(clusters): + gold_mentions = sorted(tuple(m) for m in flatten(clusters)) + gold_mention_map = {m: i for i, m in enumerate(gold_mentions)} + cluster_ids = np.zeros(len(gold_mentions), dtype=int) + for cluster_id, cluster in enumerate(clusters): + for mention in cluster: + cluster_ids[gold_mention_map[tuple(mention)]] = cluster_id + gold_starts, gold_ends = tensorize_mentions(gold_mentions) + return cluster_ids, gold_starts, gold_ends + +# 展平 +def flatten(l): + return [item for sublist in l for item in sublist] + +# 把mention分成start end +def tensorize_mentions(mentions): + if len(mentions) > 0: + starts, ends = zip(*mentions) + else: + starts, ends = [], [] + return np.array(starts), np.array(ends) + +def get_char_dict(path): + vocab = [""] + with open(path) as f: + vocab.extend(c.strip() for c in f.readlines()) + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(vocab)}) + return char_dict + +def get_labels(clusters,mention_starts,mention_ends,max_antecedents): + cluster_ids, gold_starts, gold_ends = handle_cluster(clusters) + num_mention = mention_starts.shape[0] + num_gold = gold_starts.shape[0] + max_antecedents = min(max_antecedents, num_mention) + mention_indices = {} + + for i in range(num_mention): + mention_indices[(mention_starts[i].detach().item(), mention_ends[i].detach().item())] = i + # 用来记录哪些mention是对的,-1表示错误,正数代表这个mention实际上对应哪个gold cluster的id + mention_cluster_ids = [-1] * num_mention + # test + right_mention_count = 0 + for i in range(num_gold): + right_mention = mention_indices.get((gold_starts[i], gold_ends[i])) + if (right_mention != None): + right_mention_count += 1 + mention_cluster_ids[right_mention] = cluster_ids[i] + + # i j 是否属于同一个cluster + labels = np.zeros((num_mention, max_antecedents + 1), dtype=bool) # [num_mention,max_an+1] + for i in range(num_mention): + ante_count = 0 + null_label = True + for j in range(max(0, i - max_antecedents), i): + if (mention_cluster_ids[i] >= 0 and mention_cluster_ids[i] == mention_cluster_ids[j]): + labels[i, ante_count + 1] = True + null_label = False + else: + labels[i, ante_count + 1] = False + ante_count += 1 + for j in range(ante_count, max_antecedents): + labels[i, j + 1] = False + labels[i, 0] = null_label + return labels + +# test=========================== + + +if __name__=="__main__": + word2id,id2word = get_vocab() + get_emb(id2word,350) + + diff --git a/reproduction/coreference_resolution/model/softmax_loss.py b/reproduction/coreference_resolution/model/softmax_loss.py new file mode 100644 index 00000000..c75a31d6 --- /dev/null +++ b/reproduction/coreference_resolution/model/softmax_loss.py @@ -0,0 +1,32 @@ +from fastNLP.core.losses import LossBase + +from reproduction.coreference_resolution.model.preprocess import get_labels +from reproduction.coreference_resolution.model.config import Config +import torch + + +class SoftmaxLoss(LossBase): + """ + 交叉熵loss + 允许多标签分类 + """ + + def __init__(self, antecedent_scores=None, clusters=None, mention_start_tensor=None, mention_end_tensor=None): + """ + + :param pred: + :param target: + """ + super().__init__() + self._init_param_map(antecedent_scores=antecedent_scores, clusters=clusters, + mention_start_tensor=mention_start_tensor, mention_end_tensor=mention_end_tensor) + + def get_loss(self, antecedent_scores, clusters, mention_start_tensor, mention_end_tensor): + antecedent_labels = get_labels(clusters[0], mention_start_tensor, mention_end_tensor, + Config().max_antecedents) + + antecedent_labels = torch.from_numpy(antecedent_labels*1).to(torch.device("cuda:" + Config().cuda)) + gold_scores = antecedent_scores + torch.log(antecedent_labels.float()).to(torch.device("cuda:" + Config().cuda)) # [num_mentions, max_ant + 1] + marginalized_gold_scores = gold_scores.logsumexp(dim=1) # [num_mentions] + log_norm = antecedent_scores.logsumexp(dim=1) # [num_mentions] + return torch.sum(log_norm - marginalized_gold_scores) diff --git a/reproduction/coreference_resolution/model/util.py b/reproduction/coreference_resolution/model/util.py new file mode 100644 index 00000000..42cd09fe --- /dev/null +++ b/reproduction/coreference_resolution/model/util.py @@ -0,0 +1,101 @@ +import os +import errno +import collections +import torch +import numpy as np +import pyhocon + + + +# flatten the list +def flatten(l): + return [item for sublist in l for item in sublist] + + +def get_config(filename): + return pyhocon.ConfigFactory.parse_file(filename) + + +# safe make directions +def mkdirs(path): + try: + os.makedirs(path) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + return path + + +def load_char_dict(char_vocab_path): + vocab = [""] + with open(char_vocab_path) as f: + vocab.extend(c.strip() for c in f.readlines()) + char_dict = collections.defaultdict(int) + char_dict.update({c: i for i, c in enumerate(vocab)}) + return char_dict + +# 加载embedding +def load_embedding_dict(embedding_path, embedding_size, embedding_format): + print("Loading word embeddings from {}...".format(embedding_path)) + default_embedding = np.zeros(embedding_size) + embedding_dict = collections.defaultdict(lambda: default_embedding) + skip_first = embedding_format == "vec" + with open(embedding_path) as f: + for i, line in enumerate(f.readlines()): + if skip_first and i == 0: + continue + splits = line.split() + assert len(splits) == embedding_size + 1 + word = splits[0] + embedding = np.array([float(s) for s in splits[1:]]) + embedding_dict[word] = embedding + print("Done loading word embeddings.") + return embedding_dict + + +# safe devide +def maybe_divide(x, y): + return 0 if y == 0 else x / float(y) + + +def shape(x, dim): + return x.get_shape()[dim].value or torch.shape(x)[dim] + + +def normalize(v): + norm = np.linalg.norm(v) + if norm > 0: + return v / norm + else: + return v + + +class RetrievalEvaluator(object): + def __init__(self): + self._num_correct = 0 + self._num_gold = 0 + self._num_predicted = 0 + + def update(self, gold_set, predicted_set): + self._num_correct += len(gold_set & predicted_set) + self._num_gold += len(gold_set) + self._num_predicted += len(predicted_set) + + def recall(self): + return maybe_divide(self._num_correct, self._num_gold) + + def precision(self): + return maybe_divide(self._num_correct, self._num_predicted) + + def metrics(self): + recall = self.recall() + precision = self.precision() + f1 = maybe_divide(2 * recall * precision, precision + recall) + return recall, precision, f1 + + + +if __name__=="__main__": + print(load_char_dict("../data/char_vocab.english.txt")) + embedding_dict = load_embedding_dict("../data/glove.840B.300d.txt.filtered",300,"txt") + print("hello") diff --git a/reproduction/coreference_resolution/readme.md b/reproduction/coreference_resolution/readme.md new file mode 100644 index 00000000..67d8cdc7 --- /dev/null +++ b/reproduction/coreference_resolution/readme.md @@ -0,0 +1,49 @@ +# 共指消解复现 +## 介绍 +Coreference resolution是查找文本中指向同一现实实体的所有表达式的任务。 +对于涉及自然语言理解的许多更高级别的NLP任务来说, +这是一个重要的步骤,例如文档摘要,问题回答和信息提取。 +代码的实现主要基于[ End-to-End Coreference Resolution (Lee et al, 2017)](https://arxiv.org/pdf/1707.07045). + + +## 数据获取与预处理 +论文在[OntoNote5.0](https://allennlp.org/models)数据集上取得了当时的sota结果。 +由于版权问题,本文无法提供数据集的下载,请自行下载。 +原始数据集的格式为conll格式,详细介绍参考数据集给出的官方介绍页面。 + +代码实现采用了论文作者Lee的预处理方法,具体细节参加[链接](https://github.com/kentonl/e2e-coref/blob/e2e/setup_training.sh)。 +处理之后的数据集为json格式,例子: +``` +{ + "clusters": [], + "doc_key": "nw", + "sentences": [["This", "is", "the", "first", "sentence", "."], ["This", "is", "the", "second", "."]], + "speakers": [["spk1", "spk1", "spk1", "spk1", "spk1", "spk1"], ["spk2", "spk2", "spk2", "spk2", "spk2"]] +} +``` + +### embedding 数据集下载 +[turian emdedding](https://lil.cs.washington.edu/coref/turian.50d.txt) + +[glove embedding]( https://nlp.stanford.edu/data/glove.840B.300d.zip) + + + +## 运行 +```python +# 训练代码 +CUDA_VISIBLE_DEVICES=0 python train.py +# 测试代码 +CUDA_VISIBLE_DEVICES=0 python valid.py +``` + +## 结果 +原论文作者在测试集上取得了67.2%的结果,AllenNLP复现的结果为 [63.0%](https://allennlp.org/models)。 +其中allenNLP训练时没有加入speaker信息,没有variational dropout以及只使用了100的antecedents而不是250。 + +在与allenNLP使用同样的超参和配置时,本代码复现取得了63.6%的F1值。 + + +## 问题 +如果您有什么问题或者反馈,请提issue或者邮件联系我: +yexu_i@qq.com diff --git a/reproduction/coreference_resolution/test/__init__.py b/reproduction/coreference_resolution/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/coreference_resolution/test/test_dataloader.py b/reproduction/coreference_resolution/test/test_dataloader.py new file mode 100644 index 00000000..0d9dae52 --- /dev/null +++ b/reproduction/coreference_resolution/test/test_dataloader.py @@ -0,0 +1,14 @@ +import unittest +from ..data_load.cr_loader import CRLoader + +class Test_CRLoader(unittest.TestCase): + def test_cr_loader(self): + train_path = 'data/train.english.jsonlines.mini' + dev_path = 'data/dev.english.jsonlines.minid' + test_path = 'data/test.english.jsonlines' + cr = CRLoader() + data_info = cr.process({'train':train_path,'dev':dev_path,'test':test_path}) + + print(data_info.datasets['train'][0]) + print(data_info.datasets['dev'][0]) + print(data_info.datasets['test'][0]) diff --git a/reproduction/coreference_resolution/train.py b/reproduction/coreference_resolution/train.py new file mode 100644 index 00000000..a231a575 --- /dev/null +++ b/reproduction/coreference_resolution/train.py @@ -0,0 +1,69 @@ +import sys +sys.path.append('../..') + +import torch +from torch.optim import Adam + +from fastNLP.core.callback import Callback, GradientClipCallback +from fastNLP.core.trainer import Trainer + +from reproduction.coreference_resolution.data_load.cr_loader import CRLoader +from reproduction.coreference_resolution.model.config import Config +from reproduction.coreference_resolution.model.model_re import Model +from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss +from reproduction.coreference_resolution.model.metric import CRMetric +from fastNLP import SequentialSampler +from fastNLP import cache_results + + +# torch.backends.cudnn.benchmark = False +# torch.backends.cudnn.deterministic = True + +class LRCallback(Callback): + def __init__(self, parameters, decay_rate=1e-3): + super().__init__() + self.paras = parameters + self.decay_rate = decay_rate + + def on_step_end(self): + if self.step % 100 == 0: + for para in self.paras: + para['lr'] = para['lr'] * (1 - self.decay_rate) + + +if __name__ == "__main__": + config = Config() + + print(config) + + @cache_results('cache.pkl') + def cache(): + cr_train_dev_test = CRLoader() + + data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path, + 'test': config.test_path}) + return data_info + data_info = cache() + print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])), + "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"]))) + # print(data_info) + model = Model(data_info.vocabs, config) + print(model) + + loss = SoftmaxLoss() + + metric = CRMetric() + + optim = Adam(model.parameters(), lr=config.lr) + + lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay) + + trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"], + loss=loss, metrics=metric, check_code_level=-1,sampler=None, + batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch, + optimizer=optim, + save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save', + callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)]) + print() + + trainer.train() diff --git a/reproduction/coreference_resolution/valid.py b/reproduction/coreference_resolution/valid.py new file mode 100644 index 00000000..826332c6 --- /dev/null +++ b/reproduction/coreference_resolution/valid.py @@ -0,0 +1,24 @@ +import torch +from reproduction.coreference_resolution.model.config import Config +from reproduction.coreference_resolution.model.metric import CRMetric +from reproduction.coreference_resolution.data_load.cr_loader import CRLoader +from fastNLP import Tester +import argparse + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--path') + args = parser.parse_args() + + cr_loader = CRLoader() + config = Config() + data_info = cr_loader.process({'train': config.train_path, 'dev': config.dev_path, + 'test': config.test_path}) + metirc = CRMetric() + model = torch.load(args.path) + tester = Tester(data_info.datasets['test'],model,metirc,batch_size=1,device="cuda:0") + tester.test() + print('test over') + +