Browse Source

Add cntn model for matching.

tags/v0.4.10
lionel 6 years ago
parent
commit
2a1d5dc2a4
2 changed files with 225 additions and 0 deletions
  1. +105
    -0
      reproduction/matching/matching_cntn.py
  2. +120
    -0
      reproduction/matching/model/cntn.py

+ 105
- 0
reproduction/matching/matching_cntn.py View File

@@ -0,0 +1,105 @@
import argparse
import torch
import os

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
from fastNLP.modules.encoder.embedding import StaticEmbedding

from reproduction.matching.data.MatchingDataLoader import QNLILoader, RTELoader, SNLILoader, MNLILoader
from reproduction.matching.model.cntn import CNTNModel

# define hyper-parameters
argument = argparse.ArgumentParser()
argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove')
argument.add_argument('--batch-size-per-gpu', type=int, default=256)
argument.add_argument('--n-epochs', type=int, default=200)
argument.add_argument('--lr', type=float, default=1e-5)
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask')
argument.add_argument('--save-dir', type=str, default=None)
argument.add_argument('--cntn-depth', type=int, default=1)
argument.add_argument('--cntn-ns', type=int, default=200)
argument.add_argument('--cntn-k-top', type=int, default=10)
argument.add_argument('--cntn-r', type=int, default=5)
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli')
argument.add_argument('--max-len', type=int, default=50)
arg = argument.parse_args()

# dataset dict
dev_dict = {
'qnli': 'dev',
'rte': 'dev',
'snli': 'dev',
'mnli': 'dev_matched',
}

test_dict = {
'qnli': 'dev',
'rte': 'dev',
'snli': 'test',
'mnli': 'dev_matched',
}

# set num_labels
if arg.dataset == 'qnli' or arg.dataset == 'rte':
num_labels = 2
else:
num_labels = 3

# load data set
if arg.dataset == 'qnli':
data_info = QNLILoader().process(
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
get_index=True, concat=False, auto_pad_length=arg.max_len)
elif arg.dataset == 'rte':
data_info = RTELoader().process(
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
get_index=True, concat=False, auto_pad_length=arg.max_len)
elif arg.dataset == 'snli':
data_info = SNLILoader().process(
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
get_index=True, concat=False, auto_pad_length=arg.max_len)
elif arg.dataset == 'mnli':
data_info = MNLILoader().process(
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
get_index=True, concat=False, auto_pad_length=arg.max_len)
else:
raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!')

# load embedding
if arg.embedding == 'word2vec':
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True)
elif arg.embedding == 'glove':
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300',
requires_grad=True)
else:
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!')

# define model
model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth,
r=arg.cntn_r)
print(model)

# define trainer
trainer = Trainer(train_data=data_info.datasets['train'], model=model,
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_info.datasets[dev_dict[arg.dataset]],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)

# train model
trainer.train(load_best_model=True)

# define tester
tester = Tester(
data=data_info.datasets[test_dict[arg.dataset]],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
device=[i for i in range(torch.cuda.device_count())]
)

# test model
tester.test()

+ 120
- 0
reproduction/matching/model/cntn.py View File

@@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch.nn import CrossEntropyLoss

from fastNLP.models import BaseModel
from fastNLP.modules.encoder.embedding import TokenEmbedding
from fastNLP.core.const import Const


class DynamicKMaxPooling(nn.Module):
"""
:param k_top: Fixed number of pooling output features for the topmost convolutional layer.
:param l: Number of convolutional layers.
"""

def __init__(self, k_top, l):
super(DynamicKMaxPooling, self).__init__()
self.k_top = k_top
self.L = l

def forward(self, x, l):
"""
:param x: Input sequence.
:param l: Current convolutional layers.
"""
s = x.size()[3]
k_ll = ((self.L - l) / self.L) * s
k_l = int(round(max(self.k_top, np.ceil(k_ll))))
out = F.adaptive_max_pool2d(x, (x.size()[2], k_l))
return out


class CNTNModel(BaseModel):
"""
使用CNN进行问答匹配的模型
'Qiu, Xipeng, and Xuanjing Huang.
Convolutional neural tensor network architecture for community-based question answering.
Twenty-Fourth International Joint Conference on Artificial Intelligence. 2015.'

:param init_embedding: Embedding.
:param ns: Sentence embedding size.
:param k_top: Fixed number of pooling output features for the topmost convolutional layer.
:param num_labels: Number of labels.
:param depth: Number of convolutional layers.
:param r: Number of weight tensor slices.
:param drop_rate: Dropout rate.
"""

def __init__(self, init_embedding: TokenEmbedding, ns=200, k_top=10, num_labels=2, depth=2, r=5,
dropout_rate=0.3):
super(CNTNModel, self).__init__()
self.embedding = init_embedding
self.depth = depth
self.kmaxpooling = DynamicKMaxPooling(k_top, depth)
self.conv_q = nn.ModuleList()
self.conv_a = nn.ModuleList()
width = self.embedding.embed_size
for i in range(depth):
self.conv_q.append(nn.Sequential(
nn.Dropout(p=dropout_rate),
nn.Conv2d(
in_channels=1,
out_channels=width // 2,
kernel_size=(width, 3),
padding=(0, 2))
))
self.conv_a.append(nn.Sequential(
nn.Dropout(p=dropout_rate),
nn.Conv2d(
in_channels=1,
out_channels=width // 2,
kernel_size=(width, 3),
padding=(0, 2))
))
width = width // 2

self.fc_q = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns))
self.fc_a = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns))
self.weight_M = nn.Bilinear(ns, ns, r)
self.weight_V = nn.Linear(2 * ns, r)
self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels))

def forward(self, words1, words2, seq_len1, seq_len2, target=None):
"""
:param words1: [batch, seq_len, emb_size] Question.
:param words2: [batch, seq_len, emb_size] Answer.
:param seq_len1: [batch]
:param seq_len2: [batch]
:param target: [batch] Glod labels.
:return:
"""
in_q = self.embedding(words1)
in_a = self.embedding(words2)
in_q = in_q.permute(0, 2, 1).unsqueeze(1)
in_a = in_a.permute(0, 2, 1).unsqueeze(1)

for i in range(self.depth):
in_q = F.relu(self.conv_q[i](in_q))
in_q = in_q.squeeze().unsqueeze(1)
in_q = self.kmaxpooling(in_q, i + 1)
in_a = F.relu(self.conv_a[i](in_a))
in_a = in_a.squeeze().unsqueeze(1)
in_a = self.kmaxpooling(in_a, i + 1)

in_q = self.fc_q(in_q.view(in_q.size(0), -1))
in_a = self.fc_q(in_a.view(in_a.size(0), -1))
score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1))))

if target is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(score, target)
return {Const.LOSS: loss, Const.OUTPUT: score}
else:
return {Const.OUTPUT: score}

def predict(self, **kwargs):
return self.forward(**kwargs)

Loading…
Cancel
Save