From 2a1d5dc2a422b0066851d104b666adf0bcad00cc Mon Sep 17 00:00:00 2001 From: lionel <13307130318@fudan.edu.cn> Date: Sun, 7 Jul 2019 15:24:48 +0800 Subject: [PATCH] Add cntn model for matching. --- reproduction/matching/matching_cntn.py | 105 ++++++++++++++++++++++ reproduction/matching/model/cntn.py | 120 +++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 reproduction/matching/matching_cntn.py create mode 100644 reproduction/matching/model/cntn.py diff --git a/reproduction/matching/matching_cntn.py b/reproduction/matching/matching_cntn.py new file mode 100644 index 00000000..d813164d --- /dev/null +++ b/reproduction/matching/matching_cntn.py @@ -0,0 +1,105 @@ +import argparse +import torch +import os + +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const +from fastNLP.modules.encoder.embedding import StaticEmbedding + +from reproduction.matching.data.MatchingDataLoader import QNLILoader, RTELoader, SNLILoader, MNLILoader +from reproduction.matching.model.cntn import CNTNModel + +# define hyper-parameters +argument = argparse.ArgumentParser() +argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove') +argument.add_argument('--batch-size-per-gpu', type=int, default=256) +argument.add_argument('--n-epochs', type=int, default=200) +argument.add_argument('--lr', type=float, default=1e-5) +argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask') +argument.add_argument('--save-dir', type=str, default=None) +argument.add_argument('--cntn-depth', type=int, default=1) +argument.add_argument('--cntn-ns', type=int, default=200) +argument.add_argument('--cntn-k-top', type=int, default=10) +argument.add_argument('--cntn-r', type=int, default=5) +argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') +argument.add_argument('--max-len', type=int, default=50) +arg = argument.parse_args() + +# dataset dict +dev_dict = { + 'qnli': 'dev', + 'rte': 'dev', + 'snli': 'dev', + 'mnli': 'dev_matched', +} + +test_dict = { + 'qnli': 'dev', + 'rte': 'dev', + 'snli': 'test', + 'mnli': 'dev_matched', +} + +# set num_labels +if arg.dataset == 'qnli' or arg.dataset == 'rte': + num_labels = 2 +else: + num_labels = 3 + +# load data set +if arg.dataset == 'qnli': + data_info = QNLILoader().process( + paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, + get_index=True, concat=False, auto_pad_length=arg.max_len) +elif arg.dataset == 'rte': + data_info = RTELoader().process( + paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, + get_index=True, concat=False, auto_pad_length=arg.max_len) +elif arg.dataset == 'snli': + data_info = SNLILoader().process( + paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, + get_index=True, concat=False, auto_pad_length=arg.max_len) +elif arg.dataset == 'mnli': + data_info = MNLILoader().process( + paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, + get_index=True, concat=False, auto_pad_length=arg.max_len) +else: + raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!') + +# load embedding +if arg.embedding == 'word2vec': + embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True) +elif arg.embedding == 'glove': + embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300', + requires_grad=True) +else: + raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') + +# define model +model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth, + r=arg.cntn_r) +print(model) + +# define trainer +trainer = Trainer(train_data=data_info.datasets['train'], model=model, + optimizer=Adam(lr=arg.lr, model_params=model.parameters()), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + n_epochs=arg.n_epochs, print_every=-1, + dev_data=data_info.datasets[dev_dict[arg.dataset]], + metrics=AccuracyMetric(), metric_key='acc', + device=[i for i in range(torch.cuda.device_count())], + check_code_level=-1) + +# train model +trainer.train(load_best_model=True) + +# define tester +tester = Tester( + data=data_info.datasets[test_dict[arg.dataset]], + model=model, + metrics=AccuracyMetric(), + batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, + device=[i for i in range(torch.cuda.device_count())] +) + +# test model +tester.test() diff --git a/reproduction/matching/model/cntn.py b/reproduction/matching/model/cntn.py new file mode 100644 index 00000000..0b4803fa --- /dev/null +++ b/reproduction/matching/model/cntn.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torch.nn import CrossEntropyLoss + +from fastNLP.models import BaseModel +from fastNLP.modules.encoder.embedding import TokenEmbedding +from fastNLP.core.const import Const + + +class DynamicKMaxPooling(nn.Module): + """ + :param k_top: Fixed number of pooling output features for the topmost convolutional layer. + :param l: Number of convolutional layers. + """ + + def __init__(self, k_top, l): + super(DynamicKMaxPooling, self).__init__() + self.k_top = k_top + self.L = l + + def forward(self, x, l): + """ + :param x: Input sequence. + :param l: Current convolutional layers. + """ + s = x.size()[3] + k_ll = ((self.L - l) / self.L) * s + k_l = int(round(max(self.k_top, np.ceil(k_ll)))) + out = F.adaptive_max_pool2d(x, (x.size()[2], k_l)) + return out + + +class CNTNModel(BaseModel): + """ + 使用CNN进行问答匹配的模型 + 'Qiu, Xipeng, and Xuanjing Huang. + Convolutional neural tensor network architecture for community-based question answering. + Twenty-Fourth International Joint Conference on Artificial Intelligence. 2015.' + + :param init_embedding: Embedding. + :param ns: Sentence embedding size. + :param k_top: Fixed number of pooling output features for the topmost convolutional layer. + :param num_labels: Number of labels. + :param depth: Number of convolutional layers. + :param r: Number of weight tensor slices. + :param drop_rate: Dropout rate. + """ + + def __init__(self, init_embedding: TokenEmbedding, ns=200, k_top=10, num_labels=2, depth=2, r=5, + dropout_rate=0.3): + super(CNTNModel, self).__init__() + self.embedding = init_embedding + self.depth = depth + self.kmaxpooling = DynamicKMaxPooling(k_top, depth) + self.conv_q = nn.ModuleList() + self.conv_a = nn.ModuleList() + width = self.embedding.embed_size + for i in range(depth): + self.conv_q.append(nn.Sequential( + nn.Dropout(p=dropout_rate), + nn.Conv2d( + in_channels=1, + out_channels=width // 2, + kernel_size=(width, 3), + padding=(0, 2)) + )) + self.conv_a.append(nn.Sequential( + nn.Dropout(p=dropout_rate), + nn.Conv2d( + in_channels=1, + out_channels=width // 2, + kernel_size=(width, 3), + padding=(0, 2)) + )) + width = width // 2 + + self.fc_q = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns)) + self.fc_a = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns)) + self.weight_M = nn.Bilinear(ns, ns, r) + self.weight_V = nn.Linear(2 * ns, r) + self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels)) + + def forward(self, words1, words2, seq_len1, seq_len2, target=None): + """ + :param words1: [batch, seq_len, emb_size] Question. + :param words2: [batch, seq_len, emb_size] Answer. + :param seq_len1: [batch] + :param seq_len2: [batch] + :param target: [batch] Glod labels. + :return: + """ + in_q = self.embedding(words1) + in_a = self.embedding(words2) + in_q = in_q.permute(0, 2, 1).unsqueeze(1) + in_a = in_a.permute(0, 2, 1).unsqueeze(1) + + for i in range(self.depth): + in_q = F.relu(self.conv_q[i](in_q)) + in_q = in_q.squeeze().unsqueeze(1) + in_q = self.kmaxpooling(in_q, i + 1) + in_a = F.relu(self.conv_a[i](in_a)) + in_a = in_a.squeeze().unsqueeze(1) + in_a = self.kmaxpooling(in_a, i + 1) + + in_q = self.fc_q(in_q.view(in_q.size(0), -1)) + in_a = self.fc_q(in_a.view(in_a.size(0), -1)) + score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1)))) + + if target is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(score, target) + return {Const.LOSS: loss, Const.OUTPUT: score} + else: + return {Const.OUTPUT: score} + + def predict(self, **kwargs): + return self.forward(**kwargs)