diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py index 771f2748..ce9c280b 100644 --- a/fastNLP/io/data_loader/matching.py +++ b/fastNLP/io/data_loader/matching.py @@ -1,6 +1,6 @@ import os -from typing import Union, Dict +from typing import Union, Dict , List from ...core.const import Const from ...core.vocabulary import Vocabulary @@ -33,7 +33,8 @@ class MatchingLoader(DataSetLoader): to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, cut_text: int = None, get_index=True, auto_pad_length: int=None, auto_pad_token: str='', set_input: Union[list, str, bool]=True, - set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: + set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, + extra_split: List[str]=['-'], ) -> DataInfo: """ :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 @@ -56,6 +57,7 @@ class MatchingLoader(DataSetLoader): :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. + :param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。 :return: """ if isinstance(set_input, str): @@ -89,6 +91,24 @@ class MatchingLoader(DataSetLoader): if Const.TARGET in data_set.get_field_names(): data_set.set_target(Const.TARGET) + if extra_split: + for data_name, data_set in data_info.datasets.items(): + data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) + data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) + + for s in extra_split: + data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '), + new_field_name=Const.INPUTS(0)) + data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s , ' ' + s + ' '), + new_field_name=Const.INPUTS(0)) + + _filt = lambda x : x + data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(0)].split(' '))), + new_field_name=Const.INPUTS(0), is_input=auto_set_input) + data_set.apply(lambda x: list(filter(_filt , x[Const.INPUTS(1)].split(' '))), + new_field_name=Const.INPUTS(1), is_input=auto_set_input) + _filt = None + if to_lower: for data_name, data_set in data_info.datasets.items(): data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), diff --git a/reproduction/matching/matching_mwan.py b/reproduction/matching/matching_mwan.py new file mode 100644 index 00000000..d2d3033f --- /dev/null +++ b/reproduction/matching/matching_mwan.py @@ -0,0 +1,145 @@ +import sys + +import os +import random + +import numpy as np +import torch +from torch.optim import Adadelta, SGD +from torch.optim.lr_scheduler import StepLR + +from tqdm import tqdm + +from fastNLP import CrossEntropyLoss +from fastNLP import cache_results +from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const +from fastNLP.core.predictor import Predictor +from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallback +from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding + +from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader +from model.mwan import MwanModel + +import fitlog +fitlog.debug() + +import argparse + + +argument = argparse.ArgumentParser() +argument.add_argument('--task' , choices = ['snli', 'rte', 'qnli', 'mnli'],default = 'snli') +argument.add_argument('--batch-size' , type = int , default = 128) +argument.add_argument('--n-epochs' , type = int , default = 50) +argument.add_argument('--lr' , type = float , default = 1) +argument.add_argument('--testset-name' , type = str , default = 'test') +argument.add_argument('--devset-name' , type = str , default = 'dev') +argument.add_argument('--seed' , type = int , default = 42) +argument.add_argument('--hidden-size' , type = int , default = 150) +argument.add_argument('--dropout' , type = float , default = 0.3) +arg = argument.parse_args() + +random.seed(arg.seed) +np.random.seed(arg.seed) +torch.manual_seed(arg.seed) + +n_gpu = torch.cuda.device_count() +if n_gpu > 0: + torch.cuda.manual_seed_all(arg.seed) +print (n_gpu) + +for k in arg.__dict__: + print(k, arg.__dict__[k], type(arg.__dict__[k])) + +# load data set +if arg.task == 'snli': + @cache_results(f'snli_mwan.pkl') + def read_snli(): + data_info = SNLILoader().process( + paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, + get_index=True, concat=False, extra_split=['/','%','-'], + ) + return data_info + data_info = read_snli() +elif arg.task == 'rte': + @cache_results(f'rte_mwan.pkl') + def read_rte(): + data_info = RTELoader().process( + paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, + get_index=True, concat=False, extra_split=['/','%','-'], + ) + return data_info + data_info = read_rte() +elif arg.task == 'qnli': + data_info = QNLILoader().process( + paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, + get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'], + ) +elif arg.task == 'mnli': + @cache_results(f'mnli_v0.9_mwan.pkl') + def read_mnli(): + data_info = MNLILoader().process( + paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None, + get_index=True, concat=False, extra_split=['/','%','-'], + ) + return data_info + data_info = read_mnli() +else: + raise RuntimeError(f'NOT support {arg.task} task yet!') + +print(data_info) +print(len(data_info.vocabs['words'])) + + +model = MwanModel( + num_class = len(data_info.vocabs[Const.TARGET]), + EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False), + ElmoLayer = None, + args_of_imm = { + "input_size" : 300 , + "hidden_size" : arg.hidden_size , + "dropout" : arg.dropout , + "use_allennlp" : False , + } , +) + + +optimizer = Adadelta(lr=arg.lr, params=model.parameters()) +scheduler = StepLR(optimizer, step_size=10, gamma=0.5) + +callbacks = [ + LRScheduler(scheduler), +] + +if arg.task in ['snli']: + callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1)) +elif arg.task == 'mnli': + callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'], + 'dev_mismatched': data_info.datasets['dev_mismatched']}, + verbose=1)) + +trainer = Trainer( + train_data = data_info.datasets['train'], + model = model, + optimizer = optimizer, + num_workers = 0, + batch_size = arg.batch_size, + n_epochs = arg.n_epochs, + print_every = -1, + dev_data = data_info.datasets[arg.devset_name], + metrics = AccuracyMetric(pred = "pred" , target = "target"), + metric_key = 'acc', + device = [i for i in range(torch.cuda.device_count())], + check_code_level = -1, + callbacks = callbacks, + loss = CrossEntropyLoss(pred = "pred" , target = "target") +) +trainer.train(load_best_model=True) + +tester = Tester( + data=data_info.datasets[arg.testset_name], + model=model, + metrics=AccuracyMetric(), + batch_size=arg.batch_size, + device=[i for i in range(torch.cuda.device_count())], +) +tester.test() diff --git a/reproduction/matching/model/mwan.py b/reproduction/matching/model/mwan.py new file mode 100644 index 00000000..7ca6df3b --- /dev/null +++ b/reproduction/matching/model/mwan.py @@ -0,0 +1,455 @@ +import torch as tc +import torch.nn as nn +import torch.nn.functional as F +import sys +import os +import math +from fastNLP.core.const import Const + +class RNNModel(nn.Module): + def __init__(self, input_size, hidden_size, num_layers, bidrect, dropout): + super(RNNModel, self).__init__() + + if num_layers <= 1: + dropout = 0.0 + + self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, + batch_first=True, dropout=dropout, bidirectional=bidrect) + + self.number = (2 if bidrect else 1) * num_layers + + def forward(self, x, mask): + ''' + mask: (batch_size, seq_len) + x: (batch_size, seq_len, input_size) + ''' + lens = (mask).long().sum(dim=1) + lens, idx_sort = tc.sort(lens, descending=True) + _, idx_unsort = tc.sort(idx_sort) + + x = x[idx_sort] + + x = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=True) + self.rnn.flatten_parameters() + y, h = self.rnn(x) + y, lens = nn.utils.rnn.pad_packed_sequence(y, batch_first=True) + + h = h.transpose(0,1).contiguous() #make batch size first + + y = y[idx_unsort] #(batch_size, seq_len, bid * hid_size) + h = h[idx_unsort] #(batch_size, number, hid_size) + + return y, h + +class Contexualizer(nn.Module): + def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.3): + super(Contexualizer, self).__init__() + + self.rnn = RNNModel(input_size, hidden_size, num_layers, True, dropout) + self.output_size = hidden_size * 2 + + self.reset_parameters() + + def reset_parameters(self): + weights = self.rnn.rnn.all_weights + for w1 in weights: + for w2 in w1: + if len(list(w2.size())) <= 1: + w2.data.fill_(0) + else: nn.init.xavier_normal_(w2.data, gain=1.414) + + def forward(self, s, mask): + y = self.rnn(s, mask)[0] # (batch_size, seq_len, 2 * hidden_size) + + return y + +class ConcatAttention_Param(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2): + super(ConcatAttention_Param, self).__init__() + self.ln = nn.Linear(input_size + hidden_size, hidden_size) + self.v = nn.Linear(hidden_size, 1, bias=False) + self.vq = nn.Parameter(tc.rand(hidden_size)) + self.drop = nn.Dropout(dropout) + + self.output_size = input_size + + self.reset_parameters() + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.v.weight.data) + nn.init.xavier_uniform_(self.ln.weight.data) + self.ln.bias.data.fill_(0) + + def forward(self, h, mask): + ''' + h: (batch_size, len, input_size) + mask: (batch_size, len) + ''' + + vq = self.vq.view(1,1,-1).expand(h.size(0), h.size(1), self.vq.size(0)) + + s = self.v(tc.tanh(self.ln(tc.cat([h,vq],-1)))).squeeze(-1) # (batch_size, len) + + s = s - ((mask == 0).float() * 10000) + a = tc.softmax(s, dim=1) + + r = a.unsqueeze(-1) * h # (batch_size, len, input_size) + r = tc.sum(r, dim=1) # (batch_size, input_size) + + return self.drop(r) + + +def get_2dmask(mask_hq, mask_hp, siz=None): + + if siz is None: + siz = (mask_hq.size(0), mask_hq.size(1), mask_hp.size(1)) + + mask_mat = 1 + if mask_hq is not None: + mask_mat = mask_mat * mask_hq.unsqueeze(2).expand(siz) + if mask_hp is not None: + mask_mat = mask_mat * mask_hp.unsqueeze(1).expand(siz) + return mask_mat + +def Attention(hq, hp, mask_hq, mask_hp, my_method): + standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1)) + mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1]) + + hq_mat = hq.unsqueeze(2).expand(standard_size) + hp_mat = hp.unsqueeze(1).expand(standard_size) + + s = my_method(hq_mat, hp_mat) # (batch_size, len_q, len_p) + + s = s - ((mask_mat == 0).float() * 10000) + a = tc.softmax(s, dim=1) + + q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size) + q = tc.sum(q, dim=1) #(batch_size, len_p, input_size) + + return q + +class ConcatAttention(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1): + super(ConcatAttention, self).__init__() + + if input_size_2 < 0: + input_size_2 = input_size + self.ln = nn.Linear(input_size + input_size_2, hidden_size) + self.v = nn.Linear(hidden_size, 1, bias=False) + self.drop = nn.Dropout(dropout) + + self.output_size = input_size + + + self.reset_parameters() + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.v.weight.data) + nn.init.xavier_uniform_(self.ln.weight.data) + self.ln.bias.data.fill_(0) + + def my_method(self, hq_mat, hp_mat): + s = tc.cat([hq_mat, hp_mat], dim=-1) + s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) + return s + + def forward(self, hq, hp, mask_hq=None, mask_hp=None): + ''' + hq: (batch_size, len_q, input_size) + mask_hq: (batch_size, len_q) + ''' + return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method)) + +class MinusAttention(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2): + super(MinusAttention, self).__init__() + self.ln = nn.Linear(input_size, hidden_size) + self.v = nn.Linear(hidden_size, 1, bias=False) + + self.drop = nn.Dropout(dropout) + self.output_size = input_size + self.reset_parameters() + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.v.weight.data) + nn.init.xavier_uniform_(self.ln.weight.data) + self.ln.bias.data.fill_(0) + + def my_method(self, hq_mat, hp_mat): + s = hq_mat - hp_mat + s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t] + return s + + def forward(self, hq, hp, mask_hq=None, mask_hp=None): + return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method)) + +class DotProductAttention(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2): + super(DotProductAttention, self).__init__() + self.ln = nn.Linear(input_size, hidden_size) + self.v = nn.Linear(hidden_size, 1, bias=False) + + self.drop = nn.Dropout(dropout) + self.output_size = input_size + self.reset_parameters() + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.v.weight.data) + nn.init.xavier_uniform_(self.ln.weight.data) + self.ln.bias.data.fill_(0) + + def my_method(self, hq_mat, hp_mat): + s = hq_mat * hp_mat + s = self.v(tc.tanh(self.ln(s))).squeeze(-1) #(batch_size, len_q, len_p) s[j,t] + return s + + def forward(self, hq, hp, mask_hq=None, mask_hp=None): + return self.drop(Attention(hq, hp, mask_hq, mask_hp, self.my_method)) + +class BiLinearAttention(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2, input_size_2=-1): + super(BiLinearAttention, self).__init__() + + input_size_2 = input_size if input_size_2 < 0 else input_size_2 + + self.ln = nn.Linear(input_size_2, input_size) + self.drop = nn.Dropout(dropout) + self.output_size = input_size + + self.reset_parameters() + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.ln.weight.data) + self.ln.bias.data.fill_(0) + + def my_method(self, hq, hp, mask_p): + # (bs, len, input_size) + + hp = self.ln(hp) + hp = hp * mask_p.unsqueeze(-1) + s = tc.matmul(hq, hp.transpose(-1,-2)) + + return s + + def forward(self, hq, hp, mask_hq=None, mask_hp=None): + standard_size = (hq.size(0), hq.size(1), hp.size(1), hq.size(-1)) + mask_mat = get_2dmask(mask_hq, mask_hp, standard_size[:-1]) + + s = self.my_method(hq, hp, mask_hp) # (batch_size, len_q, len_p) + + s = s - ((mask_mat == 0).float() * 10000) + a = tc.softmax(s, dim=1) + + hq_mat = hq.unsqueeze(2).expand(standard_size) + q = a.unsqueeze(-1) * hq_mat #(batch_size, len_q, len_p, input_size) + q = tc.sum(q, dim=1) #(batch_size, len_p, input_size) + + return self.drop(q) + + +class AggAttention(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.2): + super(AggAttention, self).__init__() + self.ln = nn.Linear(input_size + hidden_size, hidden_size) + self.v = nn.Linear(hidden_size, 1, bias=False) + self.vq = nn.Parameter(tc.rand(hidden_size, 1)) + self.drop = nn.Dropout(dropout) + + self.output_size = input_size + + self.reset_parameters() + + def reset_parameters(self): + + nn.init.xavier_uniform_(self.vq.data) + nn.init.xavier_uniform_(self.v.weight.data) + nn.init.xavier_uniform_(self.ln.weight.data) + self.ln.bias.data.fill_(0) + self.vq.data = self.vq.data[:,0] + + + def forward(self, hs, mask): + ''' + hs: [(batch_size, len_q, input_size), ...] + mask: (batch_size, len_q) + ''' + + hs = tc.cat([h.unsqueeze(0) for h in hs], dim=0)# (4, batch_size, len_q, input_size) + + vq = self.vq.view(1,1,1,-1).expand(hs.size(0), hs.size(1), hs.size(2), self.vq.size(0)) + + s = self.v(tc.tanh(self.ln(tc.cat([hs,vq],-1)))).squeeze(-1)# (4, batch_size, len_q) + + s = s - ((mask.unsqueeze(0) == 0).float() * 10000) + a = tc.softmax(s, dim=0) + + x = a.unsqueeze(-1) * hs + x = tc.sum(x, dim=0)#(batch_size, len_q, input_size) + + return self.drop(x) + +class Aggragator(nn.Module): + def __init__(self, input_size, hidden_size, dropout=0.3): + super(Aggragator, self).__init__() + + now_size = input_size + self.ln = nn.Linear(2 * input_size, 2 * input_size) + + now_size = 2 * input_size + self.rnn = Contexualizer(now_size, hidden_size, 2, dropout) + + now_size = self.rnn.output_size + self.agg_att = AggAttention(now_size, now_size, dropout) + + now_size = self.agg_att.output_size + self.agg_rnn = Contexualizer(now_size, hidden_size, 2, dropout) + + self.drop = nn.Dropout(dropout) + + self.output_size = self.agg_rnn.output_size + + def forward(self, qs, hp, mask): + ''' + qs: [ (batch_size, len_p, input_size), ...] + hp: (batch_size, len_p, input_size) + mask if the same of hp's mask + ''' + + hs = [0 for _ in range(len(qs))] + + for i in range(len(qs)): + q = qs[i] + x = tc.cat([q, hp], dim=-1) + g = tc.sigmoid(self.ln(x)) + x_star = x * g + h = self.rnn(x_star, mask) + + hs[i] = h + + x = self.agg_att(hs, mask) #(batch_size, len_p, output_size) + h = self.agg_rnn(x, mask) #(batch_size, len_p, output_size) + return self.drop(h) + + +class Mwan_Imm(nn.Module): + def __init__(self, input_size, hidden_size, num_class=3, dropout=0.2, use_allennlp=False): + super(Mwan_Imm, self).__init__() + + now_size = input_size + self.enc_s1 = Contexualizer(now_size, hidden_size, 2, dropout) + self.enc_s2 = Contexualizer(now_size, hidden_size, 2, dropout) + + now_size = self.enc_s1.output_size + self.att_c = ConcatAttention(now_size, hidden_size, dropout) + self.att_b = BiLinearAttention(now_size, hidden_size, dropout) + self.att_d = DotProductAttention(now_size, hidden_size, dropout) + self.att_m = MinusAttention(now_size, hidden_size, dropout) + + now_size = self.att_c.output_size + self.agg = Aggragator(now_size, hidden_size, dropout) + + now_size = self.enc_s1.output_size + self.pred_1 = ConcatAttention_Param(now_size, hidden_size, dropout) + now_size = self.agg.output_size + self.pred_2 = ConcatAttention(now_size, hidden_size, dropout, + input_size_2=self.pred_1.output_size) + + now_size = self.pred_2.output_size + self.ln1 = nn.Linear(now_size, hidden_size) + self.ln2 = nn.Linear(hidden_size, num_class) + + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.ln1.weight.data) + nn.init.xavier_uniform_(self.ln2.weight.data) + self.ln1.bias.data.fill_(0) + self.ln2.bias.data.fill_(0) + + def forward(self, s1, s2, mas_s1, mas_s2): + hq = self.enc_s1(s1, mas_s1) #(batch_size, len_q, output_size) + hp = self.enc_s1(s2, mas_s2) + + mas_s1 = mas_s1[:,:hq.size(1)] + mas_s2 = mas_s2[:,:hp.size(1)] + mas_q, mas_p = mas_s1, mas_s2 + + qc = self.att_c(hq, hp, mas_s1, mas_s2) #(batch_size, len_p, output_size) + qb = self.att_b(hq, hp, mas_s1, mas_s2) + qd = self.att_d(hq, hp, mas_s1, mas_s2) + qm = self.att_m(hq, hp, mas_s1, mas_s2) + + ho = self.agg([qc,qb,qd,qm], hp, mas_s2) #(batch_size, len_p, output_size) + + rq = self.pred_1(hq, mas_q) #(batch_size, output_size) + rp = self.pred_2(ho, rq.unsqueeze(1), mas_p)#(batch_size, 1, output_size) + rp = rp.squeeze(1) #(batch_size, output_size) + + rp = F.relu(self.ln1(rp)) + rp = self.ln2(rp) + + return rp + +class MwanModel(nn.Module): + def __init__(self, num_class, EmbLayer, args_of_imm={}, ElmoLayer=None): + super(MwanModel, self).__init__() + + self.emb = EmbLayer + + if ElmoLayer is not None: + self.elmo = ElmoLayer + self.elmo_preln = nn.Linear(3 * self.elmo.emb_size, self.elmo.emb_size) + self.elmo_ln = nn.Linear(args_of_imm["input_size"] + + self.elmo.emb_size, args_of_imm["input_size"]) + + else: + self.elmo = None + + + self.imm = Mwan_Imm(num_class=num_class, **args_of_imm) + self.drop = nn.Dropout(args_of_imm["dropout"]) + + + def forward(self, words1, words2, str_s1=None, str_s2=None, *pargs, **kwargs): + ''' + str_s is for elmo use , however we don't use elmo + str_s: (batch_size, seq_len, word_len) + ''' + + s1, s2 = words1, words2 + + mas_s1 = (s1 != 0).float() # mas: (batch_size, seq_len) + mas_s2 = (s2 != 0).float() # mas: (batch_size, seq_len) + + mas_s1.requires_grad = False + mas_s2.requires_grad = False + + s1_emb = self.emb(s1) + s2_emb = self.emb(s2) + + if self.elmo is not None: + s1_elmo = self.elmo(str_s1) + s2_elmo = self.elmo(str_s2) + + s1_elmo = tc.tanh(self.elmo_preln(tc.cat(s1_elmo, dim=-1))) + s2_elmo = tc.tanh(self.elmo_preln(tc.cat(s2_elmo, dim=-1))) + + s1_emb = tc.cat([s1_emb, s1_elmo], dim=-1) + s2_emb = tc.cat([s2_emb, s2_elmo], dim=-1) + + s1_emb = tc.tanh(self.elmo_ln(s1_emb)) + s2_emb = tc.tanh(self.elmo_ln(s2_emb)) + + s1_emb = self.drop(s1_emb) + s2_emb = self.drop(s2_emb) + + y = self.imm(s1_emb, s2_emb, mas_s1, mas_s2) + + return { + Const.OUTPUT: y, + }