Browse Source

添加了mwan模型,并稍微修改了matching dataloader

tags/v0.4.10
FFTYYY 5 years ago
parent
commit
97c7ba313d
3 changed files with 622 additions and 2 deletions
  1. +22
    -2
      fastNLP/io/data_loader/matching.py
  2. +145
    -0
      reproduction/matching/matching_mwan.py
  3. +455
    -0
      reproduction/matching/model/mwan.py

+ 22
- 2
fastNLP/io/data_loader/matching.py View File

@@ -1,6 +1,6 @@
import os import os


from typing import Union, Dict
from typing import Union, Dict , List


from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
@@ -33,7 +33,8 @@ class MatchingLoader(DataSetLoader):
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
cut_text: int = None, get_index=True, auto_pad_length: int=None, cut_text: int = None, get_index=True, auto_pad_length: int=None,
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None,
extra_split: List[str]=['-'], ) -> DataInfo:
""" """
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
@@ -56,6 +57,7 @@ class MatchingLoader(DataSetLoader):
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
:param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。
:return: :return:
""" """
if isinstance(set_input, str): if isinstance(set_input, str):
@@ -89,6 +91,24 @@ class MatchingLoader(DataSetLoader):
if Const.TARGET in data_set.get_field_names(): if Const.TARGET in data_set.get_field_names():
data_set.set_target(Const.TARGET) data_set.set_target(Const.TARGET)


if extra_split:
for data_name, data_set in data_info.datasets.items():
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))

for s in extra_split:
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '),
new_field_name=Const.INPUTS(0))
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '),
new_field_name=Const.INPUTS(0))

_filt = lambda x : x
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(0)].split(' '))),
new_field_name=Const.INPUTS(0), is_input=auto_set_input)
data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(1)].split(' '))),
new_field_name=Const.INPUTS(1), is_input=auto_set_input)
_filt = None

if to_lower: if to_lower:
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),


+ 145
- 0
reproduction/matching/matching_mwan.py View File

@@ -0,0 +1,145 @@
import sys

import os
import random

import numpy as np
import torch
from torch.optim import Adadelta, SGD
from torch.optim.lr_scheduler import StepLR

from tqdm import tqdm

from fastNLP import CrossEntropyLoss
from fastNLP import cache_results
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
from fastNLP.core.predictor import Predictor
from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallback
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding

from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader
from model.mwan import MwanModel

import fitlog
fitlog.debug()

import argparse


argument = argparse.ArgumentParser()
argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli')
argument.add_argument('--batch-size' , type = int , default = 128)
argument.add_argument('--n-epochs' , type = int , default = 50)
argument.add_argument('--lr' , type = float , default = 1)
argument.add_argument('--testset-name' , type = str , default = 'test')
argument.add_argument('--devset-name' , type = str , default = 'dev')
argument.add_argument('--seed' , type = int , default = 42)
argument.add_argument('--hidden-size' , type = int , default = 150)
argument.add_argument('--dropout' , type = float , default = 0.3)
arg = argument.parse_args()

random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)

n_gpu = torch.cuda.device_count()
if n_gpu > 0:
torch.cuda.manual_seed_all(arg.seed)
print (n_gpu)

for k in arg.__dict__:
print(k, arg.__dict__[k], type(arg.__dict__[k]))

# load data set
if arg.task == 'snli':
@cache_results(f'snli_mwan.pkl')
def read_snli():
data_info = SNLILoader().process(
paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
get_index=True, concat=False, extra_split=['/','%','-'],
)
return data_info
data_info = read_snli()
elif arg.task == 'rte':
@cache_results(f'rte_mwan.pkl')
def read_rte():
data_info = RTELoader().process(
paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
get_index=True, concat=False, extra_split=['/','%','-'],
)
return data_info
data_info = read_rte()
elif arg.task == 'qnli':
data_info = QNLILoader().process(
paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'],
)
elif arg.task == 'mnli':
@cache_results(f'mnli_v0.9_mwan.pkl')
def read_mnli():
data_info = MNLILoader().process(
paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
get_index=True, concat=False, extra_split=['/','%','-'],
)
return data_info
data_info = read_mnli()
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')

print(data_info)
print(len(data_info.vocabs['words']))


model = MwanModel(
num_class = len(data_info.vocabs[Const.TARGET]),
EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False),
ElmoLayer = None,
args_of_imm = {
"input_size" : 300 ,
"hidden_size" : arg.hidden_size ,
"dropout" : arg.dropout ,
"use_allennlp" : False ,
} ,
)


optimizer = Adadelta(lr=arg.lr, params=model.parameters())
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

callbacks = [
LRScheduler(scheduler),
]

if arg.task in ['snli']:
callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1))
elif arg.task == 'mnli':
callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'],
'dev_mismatched': data_info.datasets['dev_mismatched']},
verbose=1))

trainer = Trainer(
train_data = data_info.datasets['train'],
model = model,
optimizer = optimizer,
num_workers = 0,
batch_size = arg.batch_size,
n_epochs = arg.n_epochs,
print_every = -1,
dev_data = data_info.datasets[arg.devset_name],
metrics = AccuracyMetric(pred = "pred" , target = "target"),
metric_key = 'acc',
device = [i for i in range(torch.cuda.device_count())],
check_code_level = -1,
callbacks = callbacks,
loss = CrossEntropyLoss(pred = "pred" , target = "target")
)
trainer.train(load_best_model=True)

tester = Tester(
data=data_info.datasets[arg.testset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=arg.batch_size,
device=[i for i in range(torch.cuda.device_count())],
)
tester.test()

+ 455
- 0
reproduction/matching/model/mwan.py View File

@@ -0,0 +1,455 @@
import torch as tc
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import math
from fastNLP.core.const import Const

class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, bidrect, dropout):
super(RNNModel, self).__init__()

if num_layers <= 1:
dropout = 0.0
self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers,
batch_first=True, dropout=dropout, bidirectional=bidrect)

self.number = (2 if bidrect else 1) * num_layers

def forward(self, x, mask):
'''
mask: (batch_size, seq_len)
x: (batch_size, seq_len, input_size)
'''
lens = (mask).long().sum(dim=1)
lens, idx_sort = tc.sort(lens, descending=True)
_, idx_unsort = tc.sort(idx_sort)

x = x[idx_sort]

x = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=True)
self.rnn.flatten_parameters()
y, h = self.rnn(x)
y, lens = nn.utils.rnn.pad_packed_sequence(y, batch_first=True)

h = h.transpose(0,1).contiguous() #make batch size first

y = y[idx_unsort] #(batch_size, seq_len, bid * hid_size)
h = h[idx_unsort] #(batch_size, number, hid_size)

return y, h

class Contexualizer(nn.Module):
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.3):
super(Contexualizer, self).__init__()

self.rnn = RNNModel(input_size, hidden_size, num_layers, True, dropout)
self.output_size = hidden_size * 2

self.reset_parameters()

def reset_parameters(self):
weights = self.rnn.rnn.all_weights
for w1 in weights:
for w2 in w1:
if len(list(w2.size())) <= 1:
w2.data.fill_(0)
else: nn.init.xavier_normal_(w2.data, gain=1.414)

def forward(self, s, mask):
y = self.rnn(s, mask)[0] # (batch_size, seq_len, 2 * hidden_size)

return y

class ConcatAttention_Param(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2):
super(ConcatAttention_Param, self).__init__()
self.ln = nn.Linear(input_size + hidden_size, hidden_size)
self.v = nn.Linear(hidden_size, 1, bias=False)
self.vq = nn.Parameter(tc.rand(hidden_size))
self.drop = nn.Dropout(dropout)

self.output_size = input_size
self.reset_parameters()

def reset_parameters(self):

nn.init.xavier_uniform_(self.v.weight.data)
nn.init.xavier_uniform_(self.ln.weight.data)
self.ln.bias.data.fill_(0)

def forward(self, h, mask):
'''
h: (batch_size, len, input_size)
mask: (batch_size, len)
'''

vq = self.vq.view(1,1,-1).expand(h.size(0), h.size(1), self.vq.size(0))

s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len)
s = s - ((mask == 0).float() * 10000)
a = tc.softmax(s, dim=1)

r = a.unsqueeze(-1) * h # (batch_size, len, input_size)
r = tc.sum(r, dim=1) # (batch_size, input_size)

return self.drop(r)

def get_2dmask(mask_hq, mask_hp, siz=None):

if siz is None:
siz = (mask_hq.size(0), mask_hq.size(1), mask_hp.size(1))

mask_mat = 1
if mask_hq is not None:
mask_mat = mask_mat * mask_hq.unsqueeze(2).expand(siz)
if mask_hp is not None:
mask_mat = mask_mat * mask_hp.unsqueeze(1).expand(siz)
return mask_mat

def Attention(hq, hp, mask_hq, mask_hp, my_method):
standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])

hq_mat = hq.unsqueeze(2).expand(standard_size)
hp_mat = hp.unsqueeze(1).expand(standard_size)

s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p)

s = s - ((mask_mat == 0).float() * 10000)
a = tc.softmax(s, dim=1)

q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
q = tc.sum(q, dim=1) #(batch_size, len_p, input_size)

return q

class ConcatAttention(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1):
super(ConcatAttention, self).__init__()

if input_size_2 < 0:
input_size_2 = input_size
self.ln = nn.Linear(input_size + input_size_2, hidden_size)
self.v = nn.Linear(hidden_size, 1, bias=False)
self.drop = nn.Dropout(dropout)

self.output_size = input_size

self.reset_parameters()

def reset_parameters(self):

nn.init.xavier_uniform_(self.v.weight.data)
nn.init.xavier_uniform_(self.ln.weight.data)
self.ln.bias.data.fill_(0)

def my_method(self, hq_mat, hp_mat):
s = tc.cat([hq_mat, hp_mat], dim=-1)
s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p)
return s

def forward(self, hq, hp, mask_hq=None, mask_hp=None):
'''
hq: (batch_size, len_q, input_size)
mask_hq: (batch_size, len_q)
'''
return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))

class MinusAttention(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2):
super(MinusAttention, self).__init__()
self.ln = nn.Linear(input_size, hidden_size)
self.v = nn.Linear(hidden_size, 1, bias=False)

self.drop = nn.Dropout(dropout)
self.output_size = input_size
self.reset_parameters()

def reset_parameters(self):

nn.init.xavier_uniform_(self.v.weight.data)
nn.init.xavier_uniform_(self.ln.weight.data)
self.ln.bias.data.fill_(0)

def my_method(self, hq_mat, hp_mat):
s = hq_mat - hp_mat
s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t]
return s

def forward(self, hq, hp, mask_hq=None, mask_hp=None):
return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))

class DotProductAttention(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2):
super(DotProductAttention, self).__init__()
self.ln = nn.Linear(input_size, hidden_size)
self.v = nn.Linear(hidden_size, 1, bias=False)

self.drop = nn.Dropout(dropout)
self.output_size = input_size
self.reset_parameters()

def reset_parameters(self):

nn.init.xavier_uniform_(self.v.weight.data)
nn.init.xavier_uniform_(self.ln.weight.data)
self.ln.bias.data.fill_(0)

def my_method(self, hq_mat, hp_mat):
s = hq_mat * hp_mat
s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t]
return s

def forward(self, hq, hp, mask_hq=None, mask_hp=None):
return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method))

class BiLinearAttention(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1):
super(BiLinearAttention, self).__init__()

input_size_2 = input_size if input_size_2 < 0 else input_size_2

self.ln = nn.Linear(input_size_2, input_size)
self.drop = nn.Dropout(dropout)
self.output_size = input_size
self.reset_parameters()

def reset_parameters(self):

nn.init.xavier_uniform_(self.ln.weight.data)
self.ln.bias.data.fill_(0)
def my_method(self, hq, hp, mask_p):
# (bs, len, input_size)

hp = self.ln(hp)
hp = hp * mask_p.unsqueeze(-1)
s = tc.matmul(hq, hp.transpose(-1,-2))

return s

def forward(self, hq, hp, mask_hq=None, mask_hp=None):
standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1))
mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1])

s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p)

s = s - ((mask_mat == 0).float() * 10000)
a = tc.softmax(s, dim=1)

hq_mat = hq.unsqueeze(2).expand(standard_size)
q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size)
q = tc.sum(q, dim=1) #(batch_size, len_p, input_size)

return self.drop(q)


class AggAttention(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.2):
super(AggAttention, self).__init__()
self.ln = nn.Linear(input_size + hidden_size, hidden_size)
self.v = nn.Linear(hidden_size, 1, bias=False)
self.vq = nn.Parameter(tc.rand(hidden_size, 1))
self.drop = nn.Dropout(dropout)

self.output_size = input_size
self.reset_parameters()

def reset_parameters(self):

nn.init.xavier_uniform_(self.vq.data)
nn.init.xavier_uniform_(self.v.weight.data)
nn.init.xavier_uniform_(self.ln.weight.data)
self.ln.bias.data.fill_(0)
self.vq.data = self.vq.data[:,0]


def forward(self, hs, mask):
'''
hs: [(batch_size, len_q, input_size), ...]
mask: (batch_size, len_q)
'''
hs = tc.cat([h.unsqueeze(0) for h in hs], dim=0)# (4, batch_size, len_q, input_size)

vq = self.vq.view(1,1,1,-1).expand(hs.size(0), hs.size(1), hs.size(2), self.vq.size(0))

s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q)

s = s - ((mask.unsqueeze(0) == 0).float() * 10000)
a = tc.softmax(s, dim=0)

x = a.unsqueeze(-1) * hs
x = tc.sum(x, dim=0)#(batch_size, len_q, input_size)

return self.drop(x)

class Aggragator(nn.Module):
def __init__(self, input_size, hidden_size, dropout=0.3):
super(Aggragator, self).__init__()

now_size = input_size
self.ln = nn.Linear(2 * input_size, 2 * input_size)

now_size = 2 * input_size
self.rnn = Contexualizer(now_size, hidden_size, 2, dropout)

now_size = self.rnn.output_size
self.agg_att = AggAttention(now_size, now_size, dropout)

now_size = self.agg_att.output_size
self.agg_rnn = Contexualizer(now_size, hidden_size, 2, dropout)

self.drop = nn.Dropout(dropout)

self.output_size = self.agg_rnn.output_size

def forward(self, qs, hp, mask):
'''
qs: [ (batch_size, len_p, input_size), ...]
hp: (batch_size, len_p, input_size)
mask if the same of hp's mask
'''

hs = [0 for _ in range(len(qs))]

for i in range(len(qs)):
q = qs[i]
x = tc.cat([q, hp], dim=-1)
g = tc.sigmoid(self.ln(x))
x_star = x * g
h = self.rnn(x_star, mask)

hs[i] = h

x = self.agg_att(hs, mask) #(batch_size, len_p, output_size)
h = self.agg_rnn(x, mask) #(batch_size, len_p, output_size)
return self.drop(h)


class Mwan_Imm(nn.Module):
def __init__(self, input_size, hidden_size, num_class=3, dropout=0.2, use_allennlp=False):
super(Mwan_Imm, self).__init__()

now_size = input_size
self.enc_s1 = Contexualizer(now_size, hidden_size, 2, dropout)
self.enc_s2 = Contexualizer(now_size, hidden_size, 2, dropout)

now_size = self.enc_s1.output_size
self.att_c = ConcatAttention(now_size, hidden_size, dropout)
self.att_b = BiLinearAttention(now_size, hidden_size, dropout)
self.att_d = DotProductAttention(now_size, hidden_size, dropout)
self.att_m = MinusAttention(now_size, hidden_size, dropout)

now_size = self.att_c.output_size
self.agg = Aggragator(now_size, hidden_size, dropout)

now_size = self.enc_s1.output_size
self.pred_1 = ConcatAttention_Param(now_size, hidden_size, dropout)
now_size = self.agg.output_size
self.pred_2 = ConcatAttention(now_size, hidden_size, dropout,
input_size_2=self.pred_1.output_size)

now_size = self.pred_2.output_size
self.ln1 = nn.Linear(now_size, hidden_size)
self.ln2 = nn.Linear(hidden_size, num_class)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.ln1.weight.data)
nn.init.xavier_uniform_(self.ln2.weight.data)
self.ln1.bias.data.fill_(0)
self.ln2.bias.data.fill_(0)

def forward(self, s1, s2, mas_s1, mas_s2):
hq = self.enc_s1(s1, mas_s1) #(batch_size, len_q, output_size)
hp = self.enc_s1(s2, mas_s2)

mas_s1 = mas_s1[:,:hq.size(1)]
mas_s2 = mas_s2[:,:hp.size(1)]
mas_q, mas_p = mas_s1, mas_s2

qc = self.att_c(hq, hp, mas_s1, mas_s2) #(batch_size, len_p, output_size)
qb = self.att_b(hq, hp, mas_s1, mas_s2)
qd = self.att_d(hq, hp, mas_s1, mas_s2)
qm = self.att_m(hq, hp, mas_s1, mas_s2)

ho = self.agg([qc,qb,qd,qm], hp, mas_s2) #(batch_size, len_p, output_size)

rq = self.pred_1(hq, mas_q) #(batch_size, output_size)
rp = self.pred_2(ho, rq.unsqueeze(1), mas_p)#(batch_size, 1, output_size)
rp = rp.squeeze(1) #(batch_size, output_size)

rp = F.relu(self.ln1(rp))
rp = self.ln2(rp)

return rp

class MwanModel(nn.Module):
def __init__(self, num_class, EmbLayer, args_of_imm={}, ElmoLayer=None):
super(MwanModel, self).__init__()

self.emb = EmbLayer

if ElmoLayer is not None:
self.elmo = ElmoLayer
self.elmo_preln = nn.Linear(3 * self.elmo.emb_size, self.elmo.emb_size)
self.elmo_ln = nn.Linear(args_of_imm["input_size"] +
self.elmo.emb_size, args_of_imm["input_size"])

else:
self.elmo = None


self.imm = Mwan_Imm(num_class=num_class, **args_of_imm)
self.drop = nn.Dropout(args_of_imm["dropout"])


def forward(self, words1, words2, str_s1=None, str_s2=None, *pargs, **kwargs):
'''
str_s is for elmo use , however we don't use elmo
str_s: (batch_size, seq_len, word_len)
'''

s1, s2 = words1, words2

mas_s1 = (s1 != 0).float() # mas: (batch_size, seq_len)
mas_s2 = (s2 != 0).float() # mas: (batch_size, seq_len)

mas_s1.requires_grad = False
mas_s2.requires_grad = False

s1_emb = self.emb(s1)
s2_emb = self.emb(s2)

if self.elmo is not None:
s1_elmo = self.elmo(str_s1)
s2_elmo = self.elmo(str_s2)

s1_elmo = tc.tanh(self.elmo_preln(tc.cat(s1_elmo, dim=-1)))
s2_elmo = tc.tanh(self.elmo_preln(tc.cat(s2_elmo, dim=-1)))

s1_emb = tc.cat([s1_emb, s1_elmo], dim=-1)
s2_emb = tc.cat([s2_emb, s2_elmo], dim=-1)

s1_emb = tc.tanh(self.elmo_ln(s1_emb))
s2_emb = tc.tanh(self.elmo_ln(s2_emb))

s1_emb = self.drop(s1_emb)
s2_emb = self.drop(s2_emb)

y = self.imm(s1_emb, s2_emb, mas_s1, mas_s2)

return {
Const.OUTPUT: y,
}

Loading…
Cancel
Save