Browse Source

Merge branch 'dev0.5.0' of github.com:fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
yh_cc 5 years ago
parent
commit
4f91fb1ac1
7 changed files with 233 additions and 53 deletions
  1. +12
    -10
      fastNLP/core/predictor.py
  2. +1
    -1
      reproduction/README.md
  3. +26
    -15
      reproduction/matching/README.md
  4. +9
    -2
      reproduction/matching/data/MatchingDataLoader.py
  5. +102
    -0
      reproduction/matching/matching_bert.py
  6. +79
    -23
      reproduction/matching/matching_esim.py
  7. +4
    -2
      reproduction/matching/model/esim.py

+ 12
- 10
fastNLP/core/predictor.py View File

@@ -14,12 +14,12 @@ from .utils import _build_args, _move_dict_value_to_device, _get_model_device


class Predictor(object): class Predictor(object):
""" """
An interface for predicting outputs based on trained models.
一个根据训练模型预测输出的预测器(Predictor)


It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP.
This class does not share any operations with Trainer and Tester.
Currently, Predictor does not support GPU.
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。
:param torch.nn.Module network: 用来完成预测任务的模型
""" """


def __init__(self, network): def __init__(self, network):
@@ -30,18 +30,19 @@ class Predictor(object):
self.batch_size = 1 self.batch_size = 1
self.batch_output = [] self.batch_output = []


def predict(self, data, seq_len_field_name=None):
"""Perform inference using the trained model.
def predict(self, data: DataSet, seq_len_field_name=None):
"""用已经训练好的模型进行inference.


:param data: a DataSet object.
:param str seq_len_field_name: field name indicating sequence lengths
:return: list of batch outputs
:param fastNLP.DataSet data: 待预测的数据集
:param str seq_len_field_name: 表示序列长度信息的field名字
:return: dict dict里面的内容为模型预测的结果
""" """
if not isinstance(data, DataSet): if not isinstance(data, DataSet):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))


prev_training = self.network.training
self.network.eval() self.network.eval()
network_device = _get_model_device(self.network) network_device = _get_model_device(self.network)
batch_output = defaultdict(list) batch_output = defaultdict(list)
@@ -74,4 +75,5 @@ class Predictor(object):
else: else:
batch_output[key].append(value) batch_output[key].append(value)


self.network.train(prev_training)
return batch_output return batch_output

+ 1
- 1
reproduction/README.md View File

@@ -11,7 +11,7 @@




## Matching (自然语言推理/句子匹配) ## Matching (自然语言推理/句子匹配)
- still in progress
- [Matching 任务复现](matching/)




## Sequence Labeling (序列标注) ## Sequence Labeling (序列标注)


+ 26
- 15
reproduction/matching/README.md View File

@@ -1,32 +1,34 @@
# Matching任务模型复现 # Matching任务模型复现
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。
这里使用fastNLP复现了几个著名的Matching任务的模型,旨在达到与论文中相符的性能。这几个任务的评价指标均为准确率(%).


复现的模型有(按论文发表时间顺序排序): 复现的模型有(按论文发表时间顺序排序):
- CNTN:复现代码(still in progress)[]().
- CNTN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844). 论文链接:[Convolutional Neural Tensor Network Architecture for Community-based Question Answering](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844).
- ESIM:[复现代码](model/esim.py).
- ESIM:[模型代码](model/esim.py); [训练代码](matching_esim.py).
论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf). 论文链接:[Enhanced LSTM for Natural Language Inference](https://arxiv.org/pdf/1609.06038.pdf).
- DIIN:复现代码(still in progress)[]().
- DIIN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf). 论文链接:[Natural Language Inference over Interaction Space](https://arxiv.org/pdf/1709.04348.pdf).
- MwAN:复现代码(still in progress)[]().
- MwAN:模型代码(still in progress)[](); 训练代码(still in progress)[]().
论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf). 论文链接:[Multiway Attention Networks for Modeling Sentence Pairs](https://www.ijcai.org/proceedings/2018/0613.pdf).
- BERT:[复现代码](model/bert.py).
- BERT:[模型代码](model/bert.py); [训练代码](matching_bert.py).
论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf). 论文链接:[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf).


# 数据集及复现结果汇总 # 数据集及复现结果汇总


使用fastNLP复现的结果vs论文汇报结果
使用fastNLP复现的结果vs论文汇报结果,在前面的表示使用fastNLP复现的结果


'\-'表示我们仍未复现或者论文原文没有汇报 '\-'表示我们仍未复现或者论文原文没有汇报


model name | SNLI | MNLI | RTE | QNLI | Quora model name | SNLI | MNLI | RTE | QNLI | Quora
:---: | :---: | :---: | :---: | :---: | :---: :---: | :---: | :---: | :---: | :---: | :---:
CNTN ; [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | - | - | - | - | - |
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs - | 57.04(dev) / - | 76.97(dev) / - | - |
DIIN [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 |
MwAN [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.5 vs 88.3 | - vs 78.5/77.7 | - | - | vs 89.12 |
CNTN [](); [论文](https://www.aaai.org/ocs/index.php/IJCAI/IJCAI15/paper/view/11401/10844) | 74.53 vs - | 60.84/-(dev) vs - | 57.4(dev) vs - | 62.53(dev) vs - | - |
ESIM[代码](model/bert.py); [论文](https://arxiv.org/pdf/1609.06038.pdf) | 88.13(glove) vs 88.0(glove)/88.7(elmo) | 77.78/76.49 vs 72.4/72.1* | 59.21(dev) vs - | 76.97(dev) vs - | - |
DIIN [](); [论文](https://arxiv.org/pdf/1709.04348.pdf) | - vs 88.0 | - vs 78.8/77.8 | - | - | - vs 89.06 |
MwAN [](); [论文](https://www.ijcai.org/proceedings/2018/0613.pdf) | 87.9 vs 88.3 | 77.3/76.7(dev) vs 78.5/77.7 | - | 74.6(dev) vs - | 85.6 vs 89.12 |
BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - | BERT (BASE version)[代码](model/bert.py); [论文](https://arxiv.org/pdf/1810.04805.pdf) | 90.6 vs - | - vs 84.6/83.4| 67.87(dev) vs 66.4 | 90.97(dev) vs 90.5 | - |


*ESIM模型由MNLI官方复现的结果为72.4/72.1,ESIM原论文当中没有汇报MNLI数据集的结果。

# 数据集复现结果及其他主要模型对比 # 数据集复现结果及其他主要模型对比
## SNLI ## SNLI
[Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) [Link to SNLI leaderboard](https://nlp.stanford.edu/projects/snli/)
@@ -42,7 +44,7 @@ Performance on Test set:


model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large model name | CNTN | ESIM | DIIN | MwAN | BERT-Base | BERT-Large
:---: | :---: | :---: | :---: | :---: | :---: | :---: :---: | :---: | :---: | :---: | :---: | :---: | :---:
__performance__ | - | 88.13 | - | - | 90.6 | 91.16
__performance__ | - | 88.13 | - | 87.9 | 90.6 | 91.16


## MNLI ## MNLI
[Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/) [Link to MNLI main page](https://www.nyu.edu/projects/bowman/multinli/)
@@ -58,11 +60,13 @@ Performance on Test set(matched/mismatched):


model name | CNTN | ESIM | DIIN | MwAN | BERT-Base model name | CNTN | ESIM | DIIN | MwAN | BERT-Base
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
__performance__ | - | - | - | - | - |
__performance__ | - | 77.78/76.49 | - | 77.3/76.7(dev) | - |




## RTE ## RTE


Still in progress.

## QNLI ## QNLI


### From GLUE baselines ### From GLUE baselines
@@ -73,17 +77,24 @@ Performance on Test set:
model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo model name | BiLSTM | BiLSTM + Attn | BiLSTM + ELMo | BiLSTM + Attn + ELMo
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
__performance__ | 74.6 | 74.3 | 75.5 | 79.8 | __performance__ | 74.6 | 74.3 | 75.5 | 79.8 |

*这些LSTM-based的baseline是由QNLI官方实现并测试的。

#### Transformer-based #### Transformer-based
model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN model name | GPT1.0 | BERT-Base | BERT-Large | MT-DNN
:---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
__performance__ | 87.4 | 90.5 | 92.7 | 96.0 | __performance__ | 87.4 | 90.5 | 92.7 | 96.0 |





### 基于fastNLP复现的结果 ### 基于fastNLP复现的结果
Performance on Dev set:
Performance on __Dev__ set:


model name | CNTN | ESIM | DIIN | MwAN | BERT model name | CNTN | ESIM | DIIN | MwAN | BERT
:---: | :---: | :---: | :---: | :---: | :---: :---: | :---: | :---: | :---: | :---: | :---:
__performance__ | - | 76.97 | - | - | -
__performance__ | - | 76.97 | - | 74.6 | -


## Quora ## Quora

Still in progress.


+ 9
- 2
reproduction/matching/data/MatchingDataLoader.py View File

@@ -5,8 +5,8 @@ from typing import Union, Dict


from fastNLP.core.const import Const from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataInfo
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader
from fastNLP.io.base_loader import DataInfo, DataSetLoader
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer


@@ -348,6 +348,9 @@ class MNLILoader(MatchingLoader, CSVLoader):
'dev_mismatched': 'dev_mismatched.tsv', 'dev_mismatched': 'dev_mismatched.tsv',
'test_matched': 'test_matched.tsv', 'test_matched': 'test_matched.tsv',
'test_mismatched': 'test_mismatched.tsv', 'test_mismatched': 'test_mismatched.tsv',
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
} }
MatchingLoader.__init__(self, paths=paths) MatchingLoader.__init__(self, paths=paths)
CSVLoader.__init__(self, sep='\t') CSVLoader.__init__(self, sep='\t')
@@ -364,6 +367,10 @@ class MNLILoader(MatchingLoader, CSVLoader):
if k in ds.get_field_names(): if k in ds.get_field_names():
ds.rename_field(k, v) ds.rename_field(k, v)


if Const.TARGET in ds.get_field_names():
if ds[0][Const.TARGET] == 'hidden':
ds.delete_field(Const.TARGET)

parentheses_table = str.maketrans({'(': None, ')': None}) parentheses_table = str.maketrans({'(': None, ')': None})


ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),


+ 102
- 0
reproduction/matching/matching_bert.py View File

@@ -0,0 +1,102 @@
import random
import numpy as np
import torch

from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam

from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
MNLILoader, QNLILoader, QuoraLoader
from reproduction.matching.model.bert import BertForNLI


# define hyper-parameters
class BERTConfig:

task = 'snli'
batch_size_per_gpu = 6
n_epochs = 6
lr = 2e-5
seq_len_type = 'bert'
seed = 42
train_dataset_name = 'train'
dev_dataset_name = 'dev'
test_dataset_name = 'test'
save_path = None # 模型存储的位置,None表示不存储模型。
bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹


arg = BERTConfig()

# set random seed
random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)

n_gpu = torch.cuda.device_count()
if n_gpu > 0:
torch.cuda.manual_seed_all(arg.seed)

# load data set
if arg.task == 'snli':
data_info = SNLILoader().process(
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type,
bert_tokenizer=arg.bert_dir, cut_text=512,
get_index=True, concat='bert',
)
elif arg.task == 'rte':
data_info = RTELoader().process(
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type,
bert_tokenizer=arg.bert_dir, cut_text=512,
get_index=True, concat='bert',
)
elif arg.task == 'qnli':
data_info = QNLILoader().process(
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
bert_tokenizer=arg.bert_dir, cut_text=512,
get_index=True, concat='bert',
)
elif arg.task == 'mnli':
data_info = MNLILoader().process(
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
bert_tokenizer=arg.bert_dir, cut_text=512,
get_index=True, concat='bert',
)
elif arg.task == 'quora':
data_info = QuoraLoader().process(
paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type,
bert_tokenizer=arg.bert_dir, cut_text=512,
get_index=True, concat='bert',
)
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')

# define model
model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir)

# define trainer
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_info.datasets[arg.dev_dataset_name],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
save_path=arg.save_path)

# train model
trainer.train(load_best_model=True)

# define tester
tester = Tester(
data=data_info.datasets[arg.test_dataset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
device=[i for i in range(torch.cuda.device_count())],
)

# test model
tester.test()



+ 79
- 23
reproduction/matching/matching_esim.py View File

@@ -1,47 +1,103 @@


import argparse
import random
import numpy as np
import torch import torch
from torch.optim import Adamax
from torch.optim.lr_scheduler import StepLR


from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
from fastNLP.core.callback import GradientClipCallback, LRScheduler
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding


from reproduction.matching.data.MatchingDataLoader import SNLILoader
from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
MNLILoader, QNLILoader, QuoraLoader
from reproduction.matching.model.esim import ESIMModel from reproduction.matching.model.esim import ESIMModel


argument = argparse.ArgumentParser()
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove')
argument.add_argument('--batch-size-per-gpu', type=int, default=128)
argument.add_argument('--n-epochs', type=int, default=100)
argument.add_argument('--lr', type=float, default=1e-4)
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len')
argument.add_argument('--save-dir', type=str, default=None)
arg = argument.parse_args()


bert_dirs = 'path/to/bert/dir'
# define hyper-parameters
class ESIMConfig:

task = 'snli'
embedding = 'glove'
batch_size_per_gpu = 196
n_epochs = 30
lr = 2e-3
seq_len_type = 'seq_len'
# seq_len表示在process的时候用len(words)来表示长度信息;
# mask表示用0/1掩码矩阵来表示长度信息;
seed = 42
train_dataset_name = 'train'
dev_dataset_name = 'dev'
test_dataset_name = 'test'
save_path = None # 模型存储的位置,None表示不存储模型。


arg = ESIMConfig()

# set random seed
random.seed(arg.seed)
np.random.seed(arg.seed)
torch.manual_seed(arg.seed)

n_gpu = torch.cuda.device_count()
if n_gpu > 0:
torch.cuda.manual_seed_all(arg.seed)


# load data set # load data set
data_info = SNLILoader().process(
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
get_index=True, concat=False,
)
if arg.task == 'snli':
data_info = SNLILoader().process(
paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type,
get_index=True, concat=False,
)
elif arg.task == 'rte':
data_info = RTELoader().process(
paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type,
get_index=True, concat=False,
)
elif arg.task == 'qnli':
data_info = QNLILoader().process(
paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
get_index=True, concat=False,
)
elif arg.task == 'mnli':
data_info = MNLILoader().process(
paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
get_index=True, concat=False,
)
elif arg.task == 'quora':
data_info = QuoraLoader().process(
paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type,
get_index=True, concat=False,
)
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')


# load embedding # load embedding
if arg.embedding == 'elmo': if arg.embedding == 'elmo':
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
elif arg.embedding == 'glove': elif arg.embedding == 'glove':
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False)
else: else:
raise ValueError(f'now we only support elmo or glove embedding for esim model!')
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!')


# define model # define model
model = ESIMModel(embedding)
model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET]))

# define optimizer and callback
optimizer = Adamax(lr=arg.lr, params=model.parameters())
scheduler = StepLR(optimizer, step_size=10, gamma=0.5) # 每10个epoch学习率变为原来的0.5倍

callbacks = [
GradientClipCallback(clip_value=10), # 等价于torch.nn.utils.clip_grad_norm_(10)
LRScheduler(scheduler),
]


# define trainer # define trainer
trainer = Trainer(train_data=data_info.datasets['train'], model=model,
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
optimizer=optimizer,
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1, n_epochs=arg.n_epochs, print_every=-1,
dev_data=data_info.datasets['dev'],
dev_data=data_info.datasets[arg.dev_dataset_name],
metrics=AccuracyMetric(), metric_key='acc', metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())], device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1, check_code_level=-1,
@@ -52,7 +108,7 @@ trainer.train(load_best_model=True)


# define tester # define tester
tester = Tester( tester = Tester(
data=data_info.datasets['test'],
data=data_info.datasets[arg.test_dataset_name],
model=model, model=model,
metrics=AccuracyMetric(), metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,


+ 4
- 2
reproduction/matching/model/esim.py View File

@@ -81,6 +81,7 @@ class ESIMModel(BaseModel):


out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H]
logits = torch.tanh(self.classifier(out)) logits = torch.tanh(self.classifier(out))
# logits = self.classifier(out)


if target is not None: if target is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
@@ -91,7 +92,8 @@ class ESIMModel(BaseModel):
return {Const.OUTPUT: logits} return {Const.OUTPUT: logits}


def predict(self, **kwargs): def predict(self, **kwargs):
return self.forward(**kwargs)
pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1)
return {Const.OUTPUT: pred}


# input [batch_size, len , hidden] # input [batch_size, len , hidden]
# mask [batch_size, len] (111...00) # mask [batch_size, len] (111...00)
@@ -127,7 +129,7 @@ class BiRNN(nn.Module):


def forward(self, x, x_mask): def forward(self, x, x_mask):
# Sort x # Sort x
lengths = x_mask.data.eq(1).long().sum(1).squeeze()
lengths = x_mask.data.eq(1).long().sum(1)
_, idx_sort = torch.sort(lengths, dim=0, descending=True) _, idx_sort = torch.sort(lengths, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0) _, idx_unsort = torch.sort(idx_sort, dim=0)
lengths = list(lengths[idx_sort]) lengths = list(lengths[idx_sort])


Loading…
Cancel
Save