Browse Source

update data loader of matching

tags/v0.4.10
xuyige 5 years ago
parent
commit
e0b23b16db
5 changed files with 178 additions and 72 deletions
  1. +31
    -0
      fastNLP/io/file_utils.py
  2. +6
    -35
      fastNLP/modules/encoder/embedding.py
  3. +58
    -34
      reproduction/matching/data/MatchingDataLoader.py
  4. +65
    -0
      reproduction/matching/matching_esim.py
  5. +18
    -3
      reproduction/matching/model/esim.py

+ 31
- 0
fastNLP/io/file_utils.py View File

@@ -10,6 +10,37 @@ import shutil
import hashlib import hashlib




PRETRAINED_BERT_MODEL_DIR = {
'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}

PRETRAINED_ELMO_MODEL_DIR = {
'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn-5e9b34e2.tar.gz'
}

PRETRAIN_STATIC_FILES = {
'en': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz",
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz",
'en-fasttext': "cc.en.300.vec-d53187b2.gz",
'cn': "tencent_cn-dab24577.tar.gz",
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz",
}


def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
""" """
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并


+ 6
- 35
fastNLP/modules/encoder/embedding.py View File

@@ -26,6 +26,7 @@ from ...core.dataset import DataSet
from ...core.batch import DataSetIter from ...core.batch import DataSetIter
from ...core.sampler import SequentialSampler from ...core.sampler import SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device from ...core.utils import _move_model_to_device, _get_model_device
from ...io.file_utils import PRETRAINED_BERT_MODEL_DIR, PRETRAINED_ELMO_MODEL_DIR, PRETRAIN_STATIC_FILES




class Embedding(nn.Module): class Embedding(nn.Module):
@@ -187,15 +188,6 @@ class StaticEmbedding(TokenEmbedding):
super(StaticEmbedding, self).__init__(vocab) super(StaticEmbedding, self).__init__(vocab)


# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
PRETRAIN_STATIC_FILES = {
'en': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz",
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz",
'en-fasttext': "cc.en.300.vec-d53187b2.gz",
'cn': "tencent_cn-dab24577.tar.gz",
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz",
}


# 得到cache_path # 得到cache_path
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
@@ -231,7 +223,7 @@ class StaticEmbedding(TokenEmbedding):
:return: :return:
""" """
requires_grads = set([param.requires_grad for name, param in self.named_parameters() requires_grads = set([param.requires_grad for name, param in self.named_parameters()
if 'words_to_words' not in name])
if 'words_to_words' not in name])
if len(requires_grads) == 1: if len(requires_grads) == 1:
return requires_grads.pop() return requires_grads.pop()
else: else:
@@ -244,8 +236,8 @@ class StaticEmbedding(TokenEmbedding):
continue continue
param.requires_grad = value param.requires_grad = value


def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True,
error='ignore', init_method=None):
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>',
normalize=True, error='ignore', init_method=None):
""" """
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
word2vec(第一行只有两个元素)还是glove格式的数据。 word2vec(第一行只有两个元素)还是glove格式的数据。
@@ -329,11 +321,6 @@ class ContextualEmbedding(TokenEmbedding):
""" """
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。


Example::

>>>


:param datasets: DataSet对象 :param datasets: DataSet对象
:param batch_size: int, 生成cache的sentence表示时使用的batch的大小 :param batch_size: int, 生成cache的sentence表示时使用的batch的大小
:param device: 参考 :class::fastNLP.Trainer 的device :param device: 参考 :class::fastNLP.Trainer 的device
@@ -363,7 +350,7 @@ class ContextualEmbedding(TokenEmbedding):
seq_len = words.ne(pad_index).sum(dim=-1) seq_len = words.ne(pad_index).sum(dim=-1)
max_len = words.size(1) max_len = words.size(1)
# 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。 # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。
seq_len_from_behind =(max_len - seq_len).tolist()
seq_len_from_behind = (max_len - seq_len).tolist()
word_embeds = self(words).detach().cpu().numpy() word_embeds = self(words).detach().cpu().numpy()
for b in range(words.size(0)): for b in range(words.size(0)):
length = seq_len_from_behind[b] length = seq_len_from_behind[b]
@@ -446,9 +433,6 @@ class ElmoEmbedding(ContextualEmbedding):
self.layers = layers self.layers = layers


# 根据model_dir_or_name检查是否存在并下载 # 根据model_dir_or_name检查是否存在并下载
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn-5e9b34e2.tar.gz'}

if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
PRETRAIN_URL = _get_base_url('elmo') PRETRAIN_URL = _get_base_url('elmo')
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name]
@@ -532,21 +516,8 @@ class BertEmbedding(ContextualEmbedding):
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False):
super(BertEmbedding, self).__init__(vocab) super(BertEmbedding, self).__init__(vocab)
# 根据model_dir_or_name检查是否存在并下载
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}


# 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert') PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]


+ 58
- 34
reproduction/matching/data/MatchingDataLoader.py View File

@@ -6,31 +6,58 @@ from typing import Union, Dict


from fastNLP.core.const import Const from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.io.base_loader import DataInfo from fastNLP.io.base_loader import DataInfo
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.io.file_utils import _get_base_url, cached_path
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer




class MatchingLoader(JsonLoader):
class MatchingLoader(DataSetLoader):
""" """
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`


读取Matching任务的数据集 读取Matching任务的数据集
""" """


def __init__(self, fields=None, paths: dict=None):
super(MatchingLoader, self).__init__(fields=fields)
def __init__(self, paths: dict=None):
"""
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
"""
self.paths = paths self.paths = paths


def _load(self, path): def _load(self, path):
return super(MatchingLoader, self)._load(path)

def process(self, paths: Union[str, Dict[str, str]], dataset_name=None,
to_lower=False, char_information=False, seq_len_type: str=None,
bert_tokenizer: str=None, get_index=True, set_input: Union[list, str, bool]=True,
"""
:param str path: 待读取数据集的路径名
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子
的原始字符串文本,第三个为标签
"""
raise NotImplementedError

def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
get_index=True, set_input: Union[list, str, bool]=True,
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo:
"""
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
对应的全路径文件名。
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义
这个数据集的名字,如果不定义则默认为train。
:param bool to_lower: 是否将文本自动转为小写。默认值为False。
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` :
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
:param bool get_index: 是否需要根据词表将文本转为index
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
于此同时其他field不会被设置为input。默认值为True。
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
:return:
"""
if isinstance(set_input, str): if isinstance(set_input, str):
set_input = [set_input] set_input = [set_input]
if isinstance(set_target, str): if isinstance(set_target, str):
@@ -69,19 +96,6 @@ class MatchingLoader(JsonLoader):
is_input=auto_set_input) is_input=auto_set_input)


if bert_tokenizer is not None: if bert_tokenizer is not None:
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert') PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
@@ -128,14 +142,14 @@ class MatchingLoader(JsonLoader):
for fields in data_set.get_field_names(): for fields in data_set.get_field_names():
if Const.INPUT in fields: if Const.INPUT in fields:
data_set.apply(lambda x: len(x[fields]), data_set.apply(lambda x: len(x[fields]),
new_field_name=fields.replace(Const.INPUT, Const.TARGET),
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
is_input=auto_set_input) is_input=auto_set_input)
elif seq_len_type == 'mask': elif seq_len_type == 'mask':
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
for fields in data_set.get_field_names(): for fields in data_set.get_field_names():
if Const.INPUT in fields: if Const.INPUT in fields:
data_set.apply(lambda x: [1] * len(x[fields]), data_set.apply(lambda x: [1] * len(x[fields]),
new_field_name=fields.replace(Const.INPUT, Const.TARGET),
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
is_input=auto_set_input) is_input=auto_set_input)
elif seq_len_type == 'bert': elif seq_len_type == 'bert':
for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
@@ -152,11 +166,18 @@ class MatchingLoader(JsonLoader):


if bert_tokenizer is not None: if bert_tokenizer is not None:
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
lines = f.readlines()
lines = [line.strip() for line in lines]
words_vocab.add_word_lst(lines)
words_vocab.build_vocab()
else: else:
words_vocab = Vocabulary() words_vocab = Vocabulary()
words_vocab = words_vocab.from_dataset(*data_set_list,
field_name=[n for n in data_set_list[0].get_field_names()
if (Const.INPUT in n)])
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
field_name=[n for n in data_set_list[0].get_field_names()
if (Const.INPUT in n)],
no_create_entry_dataset=[d for n, d in data_info.datasets.items()
if 'train' not in n])
target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET)
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}
@@ -173,14 +194,14 @@ class MatchingLoader(JsonLoader):


for data_name, data_set in data_info.datasets.items(): for data_name, data_set in data_info.datasets.items():
if isinstance(set_input, list): if isinstance(set_input, list):
data_set.set_input(set_input)
data_set.set_input(*set_input)
if isinstance(set_target, list): if isinstance(set_target, list):
data_set.set_target(set_target)
data_set.set_target(*set_target)


return data_info return data_info




class SNLILoader(MatchingLoader):
class SNLILoader(MatchingLoader, JsonLoader):
""" """
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`


@@ -203,10 +224,13 @@ class SNLILoader(MatchingLoader):
'train': 'snli_1.0_train.jsonl', 'train': 'snli_1.0_train.jsonl',
'dev': 'snli_1.0_dev.jsonl', 'dev': 'snli_1.0_dev.jsonl',
'test': 'snli_1.0_test.jsonl'} 'test': 'snli_1.0_test.jsonl'}
super(SNLILoader, self).__init__(fields=fields, paths=paths)
# super(SNLILoader, self).__init__(fields=fields, paths=paths)
MatchingLoader.__init__(self, paths=paths)
JsonLoader.__init__(self, fields=fields)


def _load(self, path): def _load(self, path):
ds = super(SNLILoader, self)._load(path)
# ds = super(SNLILoader, self)._load(path)
ds = JsonLoader._load(self, path)


def parse_tree(x): def parse_tree(x):
t = Tree.fromstring(x) t = Tree.fromstring(x)


+ 65
- 0
reproduction/matching/matching_esim.py View File

@@ -0,0 +1,65 @@

import argparse
import torch

from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding

from reproduction.matching.data.MatchingDataLoader import SNLILoader
from reproduction.matching.model.esim import ESIMModel

argument = argparse.ArgumentParser()
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove')
argument.add_argument('--batch-size-per-gpu', type=int, default=128)
argument.add_argument('--n-epochs', type=int, default=100)
argument.add_argument('--lr', type=float, default=1e-4)
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len')
argument.add_argument('--save-dir', type=str, default=None)
arg = argument.parse_args()

bert_dirs = 'path/to/bert/dir'

# load data set
data_info = SNLILoader().process(
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
get_index=True, concat=False,
)

# load embedding
if arg.embedding == 'elmo':
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
elif arg.embedding == 'glove':
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
else:
raise ValueError(f'now we only support elmo or glove embedding for esim model!')

# define model
model = ESIMModel(embedding)

# define trainer
trainer = Trainer(train_data=data_info.datasets['train'], model=model,
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_info.datasets['dev'],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
save_path=arg.save_path)

# train model
trainer.train(load_best_model=True)

# define tester
tester = Tester(
data=data_info.datasets['test'],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
device=[i for i in range(torch.cuda.device_count())],
)

# test model
tester.test()



+ 18
- 3
reproduction/matching/model/esim.py View File

@@ -30,24 +30,37 @@ class ESIMModel(BaseModel):
self.bi_attention = SoftmaxAttention() self.bi_attention = SoftmaxAttention()


self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate)
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True)
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,)


self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate),
nn.Linear(8 * hidden_size, hidden_size), nn.Linear(8 * hidden_size, hidden_size),
nn.Tanh(), nn.Tanh(),
nn.Dropout(p=dropout_rate), nn.Dropout(p=dropout_rate),
nn.Linear(hidden_size, num_labels)) nn.Linear(hidden_size, num_labels))

self.dropout_rnn = nn.Dropout(p=dropout_rate)

nn.init.xavier_uniform_(self.classifier[1].weight.data) nn.init.xavier_uniform_(self.classifier[1].weight.data)
nn.init.xavier_uniform_(self.classifier[4].weight.data) nn.init.xavier_uniform_(self.classifier[4].weight.data)


def forward(self, words1, words2, seq_len1, seq_len2, target=None): def forward(self, words1, words2, seq_len1, seq_len2, target=None):
mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2)
"""
:param words1: [batch, seq_len]
:param words2: [batch, seq_len]
:param seq_len1: [batch]
:param seq_len2: [batch]
:param target:
:return:
"""
mask1 = seq_len_to_mask(seq_len1, words1.size(1))
mask2 = seq_len_to_mask(seq_len2, words2.size(1))
a0 = self.embedding(words1) # B * len * emb_dim a0 = self.embedding(words1) # B * len * emb_dim
b0 = self.embedding(words2) b0 = self.embedding(words2)
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0)
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H]
b = self.rnn(b0, mask2.byte()) b = self.rnn(b0, mask2.byte())
# a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H]
# b = self.dropout_rnn(self.rnn(b0, seq_len2)[0])


ai, bi = self.bi_attention(a, mask1, b, mask2) ai, bi = self.bi_attention(a, mask1, b, mask2)


@@ -58,6 +71,8 @@ class ESIMModel(BaseModel):


a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H]
b_h = self.rnn_high(b_f, mask2.byte()) b_h = self.rnn_high(b_f, mask2.byte())
# a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H]
# b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0])


a_avg = self.mean_pooling(a_h, mask1, dim=1) a_avg = self.mean_pooling(a_h, mask1, dim=1)
a_max, _ = self.max_pooling(a_h, mask1, dim=1) a_max, _ = self.max_pooling(a_h, mask1, dim=1)


Loading…
Cancel
Save