@@ -548,7 +548,7 @@ class LRScheduler(Callback): | |||
else: | |||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | |||
def on_epoch_begin(self): | |||
def on_epoch_end(self): | |||
self.scheduler.step(self.epoch) | |||
@@ -801,17 +801,19 @@ class DataSet(object): | |||
else: | |||
return DataSet() | |||
def split(self, ratio): | |||
def split(self, ratio, shuffle=True): | |||
""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据 | |||
:param bool shuffle: 在split前是否shuffle一下 | |||
:return: [DataSet, DataSet] | |||
""" | |||
assert isinstance(ratio, float) | |||
assert 0 < ratio < 1 | |||
all_indices = [_ for _ in range(len(self))] | |||
np.random.shuffle(all_indices) | |||
if shuffle: | |||
np.random.shuffle(all_indices) | |||
split = int(ratio * len(self)) | |||
dev_indices = all_indices[:split] | |||
train_indices = all_indices[split:] | |||
@@ -26,7 +26,7 @@ from .utils import _build_args | |||
from .utils import _check_arg_dict_list | |||
from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
class LossBase(object): | |||
""" | |||
@@ -223,7 +223,9 @@ class CrossEntropyLoss(LossBase): | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容 | |||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | |||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
Example:: | |||
@@ -231,16 +233,18 @@ class CrossEntropyLoss(LossBase): | |||
""" | |||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100): | |||
super(CrossEntropyLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
self.padding_idx = padding_idx | |||
def get_loss(self, pred, target): | |||
def get_loss(self, pred, target, seq_len=None): | |||
if pred.dim()>2: | |||
if pred.size()[:2]==target.size(): | |||
# F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 | |||
pred = pred.transpose(1, 2) | |||
pred = pred.view(-1, pred.size(-1)) | |||
target = target.view(-1) | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len).view(-1).eq(0) | |||
target = target.masked_fill(mask, self.padding_idx) | |||
return F.cross_entropy(input=pred, target=target, | |||
ignore_index=self.padding_idx) | |||
@@ -36,6 +36,23 @@ class Optimizer(object): | |||
""" | |||
return [param for param in params if param.requires_grad] | |||
class NullOptimizer(Optimizer): | |||
""" | |||
当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | |||
""" | |||
def __init__(self): | |||
super().__init__(None) | |||
def construct_from_pytorch(self, model_params): | |||
pass | |||
def __getattr__(self, item): | |||
def pass_func(*args, **kwargs): | |||
pass | |||
return pass_func | |||
class SGD(Optimizer): | |||
""" | |||
@@ -9,7 +9,7 @@ import torch | |||
from . import DataSetIter | |||
from . import DataSet | |||
from . import SequentialSampler | |||
from .utils import _build_args | |||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
class Predictor(object): | |||
@@ -43,6 +43,7 @@ class Predictor(object): | |||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
self.network.eval() | |||
network_device = _get_model_device(self.network) | |||
batch_output = defaultdict(list) | |||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
@@ -53,6 +54,7 @@ class Predictor(object): | |||
with torch.no_grad(): | |||
for batch_x, _ in data_iterator: | |||
_move_dict_value_to_device(batch_x, _, device=network_device) | |||
refined_batch_x = _build_args(predict_func, **batch_x) | |||
prediction = predict_func(**refined_batch_x) | |||
@@ -48,6 +48,7 @@ from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from .utils import _data_parallel_wrapper | |||
__all__ = [ | |||
"Tester" | |||
@@ -104,26 +105,25 @@ class Tester(object): | |||
self.data_iterator = data | |||
else: | |||
raise TypeError("data type {} not support".format(type(data))) | |||
# 如果是DataParallel将没有办法使用predict方法 | |||
if isinstance(self._model, nn.DataParallel): | |||
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): | |||
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," | |||
" while DataParallel has no predict() function.") | |||
self._model = self._model.module | |||
# check predict | |||
if hasattr(self._model, 'predict'): | |||
self._predict_func = self._model.predict | |||
if not callable(self._predict_func): | |||
_model_name = model.__class__.__name__ | |||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||
f"for evaluation, not `{type(self._predict_func)}`.") | |||
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ | |||
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | |||
callable(self._model.module.predict)): | |||
if isinstance(self._model, nn.DataParallel): | |||
self._predict_func_wrapper = _data_parallel_wrapper(self._model.module.predict, self._model.device_ids, | |||
self._model.output_device) | |||
self._predict_func = self._model.module.predict | |||
else: | |||
self._predict_func = self._model.predict | |||
self._predict_func_wrapper = self._model.predict | |||
else: | |||
if isinstance(model, nn.DataParallel): | |||
if isinstance(self._model, nn.DataParallel): | |||
self._predict_func_wrapper = self._model.forward | |||
self._predict_func = self._model.module.forward | |||
else: | |||
self._predict_func = self._model.forward | |||
self._predict_func_wrapper = self._model.forward | |||
def test(self): | |||
"""开始进行验证,并返回验证结果。 | |||
@@ -180,7 +180,7 @@ class Tester(object): | |||
def _data_forward(self, func, x): | |||
"""A forward pass of the model. """ | |||
x = _build_args(func, **x) | |||
y = func(**x) | |||
y = self._predict_func_wrapper(**x) | |||
return y | |||
def _format_eval_results(self, results): | |||
@@ -452,17 +452,15 @@ class Trainer(object): | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
self.model = _move_model_to_device(model, device=device) | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
_check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
self.model = _move_model_to_device(model, device=device) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
self.model = model | |||
self.losser = losser | |||
self.metrics = metrics | |||
self.n_epochs = int(n_epochs) | |||
@@ -480,16 +478,16 @@ class Trainer(object): | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
self.optimizer = optimizer.construct_from_pytorch(model.parameters()) | |||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||
elif optimizer is None: | |||
self.optimizer = torch.optim.Adam(model.parameters(), lr=4e-3) | |||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||
@@ -617,7 +615,7 @@ class Trainer(object): | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
if self.use_tqdm: | |||
print_output = "loss:{0:<6.5f}".format(avg_loss) | |||
print_output = "loss:{:<6.5f}".format(avg_loss) | |||
pbar.update(self.print_every) | |||
else: | |||
end = time.time() | |||
@@ -681,7 +679,7 @@ class Trainer(object): | |||
"""Perform weight update on a model. | |||
""" | |||
if self.optimizer is not None and (self.step + 1) % self.update_every == 0: | |||
if self.step % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
@@ -699,7 +697,7 @@ class Trainer(object): | |||
For PyTorch, just do "loss.backward()" | |||
""" | |||
if self.step % self.update_every == 0: | |||
if (self.step-1) % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
@@ -16,7 +16,9 @@ from collections import Counter, namedtuple | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
from torch.nn.parallel.replicate import replicate | |||
from torch.nn.parallel.parallel_apply import parallel_apply | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -277,6 +279,25 @@ def _move_model_to_device(model, device): | |||
model = model.to(device) | |||
return model | |||
def _data_parallel_wrapper(func, device_ids, output_device): | |||
""" | |||
这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 | |||
:param func: callable | |||
:param device_ids: nn.DataParallel中的device_ids | |||
:param inputs: | |||
:param kwargs: | |||
:return: | |||
""" | |||
def wrapper(*inputs, **kwargs): | |||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||
if len(device_ids) == 1: | |||
return func(*inputs[0], **kwargs[0]) | |||
replicas = replicate(func, device_ids[:len(inputs)]) | |||
outputs = parallel_apply(replicas, inputs, kwargs) | |||
return gather(outputs, output_device) | |||
return wrapper | |||
def _get_model_device(model): | |||
""" | |||
@@ -38,7 +38,8 @@ class EmbedLoader(BaseLoader): | |||
super(EmbedLoader, self).__init__() | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, error='ignore'): | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore', init_method=None): | |||
""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
@@ -52,6 +53,7 @@ class EmbedLoader(BaseLoader): | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||
:param callable init_method: 传入numpy.ndarray, 返回numpy.ndarray, 用以初始化embedding | |||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
@@ -69,6 +71,8 @@ class EmbedLoader(BaseLoader): | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
if init_method: | |||
matrix = init_method(matrix) | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
@@ -91,14 +95,15 @@ class EmbedLoader(BaseLoader): | |||
raise e | |||
total_hits = sum(hit_flags) | |||
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||
found_vectors = matrix[hit_flags] | |||
if len(found_vectors) != 0: | |||
mean = np.mean(found_vectors, axis=0, keepdims=True) | |||
std = np.std(found_vectors, axis=0, keepdims=True) | |||
unfound_vec_num = len(vocab) - total_hits | |||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||
matrix[hit_flags == False] = r_vecs | |||
if init_method is None: | |||
found_vectors = matrix[hit_flags] | |||
if len(found_vectors) != 0: | |||
mean = np.mean(found_vectors, axis=0, keepdims=True) | |||
std = np.std(found_vectors, axis=0, keepdims=True) | |||
unfound_vec_num = len(vocab) - total_hits | |||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||
matrix[hit_flags == False] = r_vecs | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
@@ -157,13 +162,17 @@ class EmbedLoader(BaseLoader): | |||
if dim == -1: | |||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
for key, vec in vec_dict.items(): | |||
index = vocab.to_index(key) | |||
matrix[index] = vec | |||
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): | |||
start_idx = 0 | |||
if padding is not None: | |||
start_idx += 1 | |||
if unknown is not None: | |||
start_idx += 1 | |||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | |||
std = np.std(matrix[start_idx:], axis=0, keepdims=True) | |||
if (unknown is not None and not found_unknown): | |||
@@ -171,10 +180,6 @@ class EmbedLoader(BaseLoader): | |||
if (padding is not None and not found_pad): | |||
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean | |||
for key, vec in vec_dict.items(): | |||
index = vocab.to_index(key) | |||
matrix[index] = vec | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
@@ -10,6 +10,37 @@ import shutil | |||
import hashlib | |||
PRETRAINED_BERT_MODEL_DIR = { | |||
'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
PRETRAINED_ELMO_MODEL_DIR = { | |||
'en': 'elmo_en-d39843fe.tar.gz', | |||
'cn': 'elmo_cn-5e9b34e2.tar.gz' | |||
} | |||
PRETRAIN_STATIC_FILES = { | |||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", | |||
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", | |||
'en-fasttext': "cc.en.300.vec-d53187b2.gz", | |||
'cn': "tencent_cn-dab24577.tar.gz", | |||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | |||
} | |||
def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||
""" | |||
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||
@@ -26,6 +26,7 @@ from ...core.dataset import DataSet | |||
from ...core.batch import DataSetIter | |||
from ...core.sampler import SequentialSampler | |||
from ...core.utils import _move_model_to_device, _get_model_device | |||
from ...io.file_utils import PRETRAINED_BERT_MODEL_DIR, PRETRAINED_ELMO_MODEL_DIR, PRETRAIN_STATIC_FILES | |||
class Embedding(nn.Module): | |||
@@ -179,23 +180,14 @@ class StaticEmbedding(TokenEmbedding): | |||
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, | |||
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | |||
:param requires_grad: 是否需要gradient. 默认为True | |||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.xavier_uniform_ | |||
。调用该方法时传入一个tensor对象。 | |||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 | |||
:param normailize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None, | |||
normalize=False): | |||
super(StaticEmbedding, self).__init__(vocab) | |||
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | |||
PRETRAIN_STATIC_FILES = { | |||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", | |||
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", | |||
'en-fasttext': "cc.en.300.vec-d53187b2.gz", | |||
'cn': "tencent_cn-dab24577.tar.gz", | |||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | |||
} | |||
# 得到cache_path | |||
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||
@@ -210,7 +202,8 @@ class StaticEmbedding(TokenEmbedding): | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
# 读取embedding | |||
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method, | |||
normalize=normalize) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
@@ -231,7 +224,7 @@ class StaticEmbedding(TokenEmbedding): | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'words_to_words' not in name]) | |||
if 'words_to_words' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
@@ -244,8 +237,8 @@ class StaticEmbedding(TokenEmbedding): | |||
continue | |||
param.requires_grad = value | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore', init_method=None): | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | |||
normalize=True, error='ignore', init_method=None): | |||
""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
@@ -265,10 +258,7 @@ class StaticEmbedding(TokenEmbedding): | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||
if init_method is None: | |||
init_method = nn.init.xavier_uniform_ | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
found_count = 0 | |||
line = f.readline().strip() | |||
parts = line.split() | |||
start_idx = 0 | |||
@@ -279,7 +269,8 @@ class StaticEmbedding(TokenEmbedding): | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = torch.zeros(len(vocab), dim) | |||
init_method(matrix) | |||
if init_method is not None: | |||
init_method(matrix) | |||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
@@ -294,7 +285,6 @@ class StaticEmbedding(TokenEmbedding): | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||
found_count += 1 | |||
hit_flags[index] = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
@@ -302,7 +292,16 @@ class StaticEmbedding(TokenEmbedding): | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
found_count = sum(hit_flags) | |||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
if init_method is None: | |||
if len(vocab)-found_count>0 and found_count>0: # 有的没找到 | |||
found_vecs = matrix[torch.LongTensor(hit_flags.astype(int)).byte()] | |||
mean = found_vecs.mean(dim=0, keepdim=True) | |||
std = found_vecs.std(dim=0, keepdim=True) | |||
unfound_vec_num = np.sum(hit_flags==False) | |||
unfound_vecs = torch.randn(unfound_vec_num, dim)*std + mean | |||
matrix[torch.LongTensor(hit_flags.astype(int)).eq(0)] = unfound_vecs | |||
if normalize: | |||
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) | |||
@@ -329,11 +328,6 @@ class ContextualEmbedding(TokenEmbedding): | |||
""" | |||
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 | |||
Example:: | |||
>>> | |||
:param datasets: DataSet对象 | |||
:param batch_size: int, 生成cache的sentence表示时使用的batch的大小 | |||
:param device: 参考 :class::fastNLP.Trainer 的device | |||
@@ -363,7 +357,7 @@ class ContextualEmbedding(TokenEmbedding): | |||
seq_len = words.ne(pad_index).sum(dim=-1) | |||
max_len = words.size(1) | |||
# 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。 | |||
seq_len_from_behind =(max_len - seq_len).tolist() | |||
seq_len_from_behind = (max_len - seq_len).tolist() | |||
word_embeds = self(words).detach().cpu().numpy() | |||
for b in range(words.size(0)): | |||
length = seq_len_from_behind[b] | |||
@@ -446,9 +440,6 @@ class ElmoEmbedding(ContextualEmbedding): | |||
self.layers = layers | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | |||
'cn': 'elmo_cn-5e9b34e2.tar.gz'} | |||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('elmo') | |||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | |||
@@ -532,21 +523,8 @@ class BertEmbedding(ContextualEmbedding): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | |||
super(BertEmbedding, self).__init__(vocab) | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
@@ -73,21 +73,12 @@ class LSTM(nn.Module): | |||
x = x[:, sort_idx] | |||
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) | |||
output, hx = self.lstm(x, hx) # -> [N,L,C] | |||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) | |||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
if self.batch_first: | |||
output = output[unsort_idx] | |||
else: | |||
output = output[:, unsort_idx] | |||
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 | |||
if self.batch_first: | |||
if output.size(1) < max_len: | |||
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) | |||
output = torch.cat([output, dummy_tensor], 1) | |||
else: | |||
if output.size(0) < max_len: | |||
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) | |||
output = torch.cat([output, dummy_tensor], 0) | |||
else: | |||
output, hx = self.lstm(x, hx) | |||
return output, hx |
@@ -82,6 +82,8 @@ def get_embeddings(init_embed): | |||
if isinstance(init_embed, tuple): | |||
res = nn.Embedding( | |||
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)), | |||
b=np.sqrt(3/res.weight.data.size(1))) | |||
elif isinstance(init_embed, nn.Module): | |||
res = init_embed | |||
elif isinstance(init_embed, torch.Tensor): | |||
@@ -1,36 +1,63 @@ | |||
import os | |||
from nltk import Tree | |||
from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path | |||
from fastNLP.io.dataset_loader import JsonLoader, DataSetLoader, CSVLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
class MatchingLoader(JsonLoader): | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
""" | |||
def __init__(self, fields=None, paths: dict=None): | |||
super(MatchingLoader, self).__init__(fields=fields) | |||
def __init__(self, paths: dict=None): | |||
""" | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
self.paths = paths | |||
def _load(self, path): | |||
return super(MatchingLoader, self)._load(path) | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, | |||
to_lower=False, char_information=False, seq_len_type: str=None, | |||
bert_tokenizer: str=None, get_index=True, set_input: Union[list, str, bool]=True, | |||
""" | |||
:param str path: 待读取数据集的路径名 | |||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||
的原始字符串文本,第三个为标签 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
对应的全路径文件名。 | |||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||
这个数据集的名字,如果不定义则默认为train。 | |||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
:param bool get_index: 是否需要根据词表将文本转为index | |||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
于此同时其他field不会被设置为input。默认值为True。 | |||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||
:return: | |||
""" | |||
if isinstance(set_input, str): | |||
set_input = [set_input] | |||
if isinstance(set_target, str): | |||
@@ -59,7 +86,8 @@ class MatchingLoader(JsonLoader): | |||
if auto_set_input: | |||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||
if auto_set_target: | |||
data_set.set_target(Const.TARGET) | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.set_target(Const.TARGET) | |||
if to_lower: | |||
for data_name, data_set in data_info.datasets.items(): | |||
@@ -69,19 +97,6 @@ class MatchingLoader(JsonLoader): | |||
is_input=auto_set_input) | |||
if bert_tokenizer is not None: | |||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
@@ -93,6 +108,13 @@ class MatchingLoader(JsonLoader): | |||
else: | |||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||
lines = f.readlines() | |||
lines = [line.strip() for line in lines] | |||
words_vocab.add_word_lst(lines) | |||
words_vocab.build_vocab() | |||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
for data_name, data_set in data_info.datasets.items(): | |||
@@ -128,14 +150,14 @@ class MatchingLoader(JsonLoader): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.TARGET), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'mask': | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [1] * len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.TARGET), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'bert': | |||
for data_name, data_set in data_info.datasets.items(): | |||
@@ -147,18 +169,26 @@ class MatchingLoader(JsonLoader): | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
if cut_text is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set_list = [d for n, d in data_info.datasets.items()] | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is not None: | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
else: | |||
if bert_tokenizer is None: | |||
words_vocab = Vocabulary() | |||
words_vocab = words_vocab.from_dataset(*data_set_list, | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)]) | |||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)], | |||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||
if 'train' not in n]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) | |||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=Const.TARGET) | |||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||
if get_index: | |||
@@ -168,19 +198,20 @@ class MatchingLoader(JsonLoader): | |||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(set_input) | |||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||
if isinstance(set_target, list): | |||
data_set.set_target(set_target) | |||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
return data_info | |||
class SNLILoader(MatchingLoader): | |||
class SNLILoader(MatchingLoader, JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
@@ -195,29 +226,166 @@ class SNLILoader(MatchingLoader): | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_parse': Const.INPUTS(0), | |||
'sentence2_parse': Const.INPUTS(1), | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
super(SNLILoader, self).__init__(fields=fields, paths=paths) | |||
MatchingLoader.__init__(self, paths=paths) | |||
JsonLoader.__init__(self, fields=fields) | |||
def _load(self, path): | |||
ds = super(SNLILoader, self)._load(path) | |||
ds = JsonLoader._load(self, path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class RTELoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader` | |||
读取RTE数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
# 'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'sentence1': Const.INPUTS(0), | |||
'sentence2': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class QNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader` | |||
读取QNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
# 'test': 'test.tsv' # test set has not label | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
self.fields = { | |||
'question': Const.INPUTS(0), | |||
'sentence': Const.INPUTS(1), | |||
'label': Const.TARGET, | |||
} | |||
CSVLoader.__init__(self, sep='\t') | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
ds.rename_field(k, v) | |||
for fields in ds.get_all_fields(): | |||
if Const.INPUT in fields: | |||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||
return ds | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
class QuoraLoader(MatchingLoader, CSVLoader): | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev': 'dev.tsv', | |||
'test': 'test.tsv', | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
return ds |
@@ -1,44 +0,0 @@ | |||
import os | |||
import torch | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.io.dataset_loader import MatchingLoader | |||
from reproduction.matching.model.bert import BertForNLI | |||
from reproduction.matching.model.esim import ESIMModel | |||
bert_dirs = 'path/to/bert/dir' | |||
# load data set | |||
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... | |||
data_info = MatchingLoader(data_format='snli', for_model='esim').process( | |||
{'train': './data/snli/snli_1.0_train.jsonl', | |||
'dev': './data/snli/snli_1.0_dev.jsonl', | |||
'test': './data/snli/snli_1.0_test.jsonl'}, | |||
input_field=[Const.TARGET] | |||
) | |||
# model = BertForNLI(bert_dir=bert_dirs) | |||
model = ESIMModel(data_info.embeddings['elmo'],) | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
optimizer=Adam(lr=1e-4, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, | |||
dev_data=data_info.datasets['dev'], | |||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1) | |||
trainer.train(load_best_model=True) | |||
tester = Tester( | |||
data=data_info.datasets['test'], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * 12, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
tester.test() | |||
@@ -0,0 +1,65 @@ | |||
import argparse | |||
import torch | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding | |||
from reproduction.matching.data.MatchingDataLoader import SNLILoader | |||
from reproduction.matching.model.esim import ESIMModel | |||
argument = argparse.ArgumentParser() | |||
argument.add_argument('--embedding', choices=['glove', 'elmo'], default='glove') | |||
argument.add_argument('--batch-size-per-gpu', type=int, default=128) | |||
argument.add_argument('--n-epochs', type=int, default=100) | |||
argument.add_argument('--lr', type=float, default=1e-4) | |||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='seq_len') | |||
argument.add_argument('--save-dir', type=str, default=None) | |||
arg = argument.parse_args() | |||
bert_dirs = 'path/to/bert/dir' | |||
# load data set | |||
data_info = SNLILoader().process( | |||
paths='path/to/snli/data/dir', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||
get_index=True, concat=False, | |||
) | |||
# load embedding | |||
if arg.embedding == 'elmo': | |||
embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
elif arg.embedding == 'glove': | |||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True) | |||
else: | |||
raise ValueError(f'now we only support elmo or glove embedding for esim model!') | |||
# define model | |||
model = ESIMModel(embedding) | |||
# define trainer | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
n_epochs=arg.n_epochs, print_every=-1, | |||
dev_data=data_info.datasets['dev'], | |||
metrics=AccuracyMetric(), metric_key='acc', | |||
device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1, | |||
save_path=arg.save_path) | |||
# train model | |||
trainer.train(load_best_model=True) | |||
# define tester | |||
tester = Tester( | |||
data=data_info.datasets['test'], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
# test model | |||
tester.test() | |||
@@ -30,24 +30,37 @@ class ESIMModel(BaseModel): | |||
self.bi_attention = SoftmaxAttention() | |||
self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True,) | |||
self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), | |||
nn.Linear(8 * hidden_size, hidden_size), | |||
nn.Tanh(), | |||
nn.Dropout(p=dropout_rate), | |||
nn.Linear(hidden_size, num_labels)) | |||
self.dropout_rnn = nn.Dropout(p=dropout_rate) | |||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | |||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||
mask1 = seq_len_to_mask(seq_len1) | |||
mask2 = seq_len_to_mask(seq_len2) | |||
""" | |||
:param words1: [batch, seq_len] | |||
:param words2: [batch, seq_len] | |||
:param seq_len1: [batch] | |||
:param seq_len2: [batch] | |||
:param target: | |||
:return: | |||
""" | |||
mask1 = seq_len_to_mask(seq_len1, words1.size(1)) | |||
mask2 = seq_len_to_mask(seq_len2, words2.size(1)) | |||
a0 = self.embedding(words1) # B * len * emb_dim | |||
b0 = self.embedding(words2) | |||
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) | |||
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] | |||
b = self.rnn(b0, mask2.byte()) | |||
# a = self.dropout_rnn(self.rnn(a0, seq_len1)[0]) # a: [B, PL, 2 * H] | |||
# b = self.dropout_rnn(self.rnn(b0, seq_len2)[0]) | |||
ai, bi = self.bi_attention(a, mask1, b, mask2) | |||
@@ -58,6 +71,8 @@ class ESIMModel(BaseModel): | |||
a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] | |||
b_h = self.rnn_high(b_f, mask2.byte()) | |||
# a_h = self.dropout_rnn(self.rnn_high(a_f, seq_len1)[0]) # ma: [B, PL, 2 * H] | |||
# b_h = self.dropout_rnn(self.rnn_high(b_f, seq_len2)[0]) | |||
a_avg = self.mean_pooling(a_h, mask1, dim=1) | |||
a_max, _ = self.max_pooling(a_h, mask1, dim=1) | |||
@@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: | |||
setup( | |||
name='FastNLP', | |||
version='0.4.0', | |||
version='dev0.5.0', | |||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | |||
long_description=readme, | |||
long_description_content_type='text/markdown', | |||