- add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performancetags/v0.4.10
@@ -6,7 +6,7 @@ | |||||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ||||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | ||||
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. | |||||
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. | |||||
A deep learning NLP model is the composition of three types of modules: | A deep learning NLP model is the composition of three types of modules: | ||||
<table> | <table> | ||||
@@ -58,6 +58,13 @@ Run the following commands to install fastNLP package. | |||||
pip install fastNLP | pip install fastNLP | ||||
``` | ``` | ||||
## Models | |||||
fastNLP implements different models for variant NLP tasks. | |||||
Each model has been trained and tested carefully. | |||||
Check out models' performance, usage and source code here. | |||||
- [Documentation](https://github.com/fastnlp/fastNLP/tree/master/reproduction) | |||||
- [Source Code](https://github.com/fastnlp/fastNLP/tree/master/fastNLP/models) | |||||
## Project Structure | ## Project Structure | ||||
@@ -10,4 +10,4 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from ..io.dataset_loader import DataSet | from ..io.dataset_loader import DataSet | ||||
from .callback import Callback |
@@ -1,9 +1,16 @@ | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import atexit | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
_python_is_exit = False | |||||
def _set_python_is_exit(): | |||||
global _python_is_exit | |||||
_python_is_exit = True | |||||
atexit.register(_set_python_is_exit) | |||||
class Batch(object): | class Batch(object): | ||||
"""Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
@@ -95,12 +102,19 @@ def to_tensor(batch, dtype): | |||||
def run_fetch(batch, q): | def run_fetch(batch, q): | ||||
global _python_is_exit | |||||
batch.init_iter() | batch.init_iter() | ||||
# print('start fetch') | # print('start fetch') | ||||
while 1: | while 1: | ||||
res = batch.fetch_one() | res = batch.fetch_one() | ||||
# print('fetch one') | # print('fetch one') | ||||
q.put(res) | |||||
while 1: | |||||
try: | |||||
q.put(res, timeout=3) | |||||
break | |||||
except Exception as e: | |||||
if _python_is_exit: | |||||
return | |||||
if res is None: | if res is None: | ||||
# print('fetch done, waiting processing') | # print('fetch done, waiting processing') | ||||
q.join() | q.join() | ||||
@@ -15,13 +15,57 @@ class Callback(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | |||||
self._trainer = None # 在Trainer内部被重新赋值 | |||||
@property | |||||
def trainer(self): | |||||
return self._trainer | |||||
@property | |||||
def step(self): | |||||
"""current step number, in range(1, self.n_steps+1)""" | |||||
return self._trainer.step | |||||
@property | |||||
def n_steps(self): | |||||
"""total number of steps for training""" | |||||
return self.n_steps | |||||
@property | |||||
def batch_size(self): | |||||
"""batch size for training""" | |||||
return self._trainer.batch_size | |||||
@property | |||||
def epoch(self): | |||||
"""current epoch number, in range(1, self.n_epochs+1)""" | |||||
return self._trainer.epoch | |||||
@property | |||||
def n_epochs(self): | |||||
"""total number of epochs""" | |||||
return self._trainer.n_epochs | |||||
@property | |||||
def optimizer(self): | |||||
"""torch.optim.Optimizer for current model""" | |||||
return self._trainer.optimizer | |||||
@property | |||||
def model(self): | |||||
"""training model""" | |||||
return self._trainer.model | |||||
@property | |||||
def pbar(self): | |||||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | |||||
return self._trainer.pbar | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
# at the beginning of each epoch | # at the beginning of each epoch | ||||
pass | pass | ||||
@@ -33,14 +77,14 @@ class Callback(object): | |||||
# after data_forward, and before loss computation | # after data_forward, and before loss computation | ||||
pass | pass | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
# after loss computation, and before gradient backward | # after loss computation, and before gradient backward | ||||
pass | pass | ||||
def on_backward_end(self, model): | |||||
def on_backward_end(self): | |||||
pass | pass | ||||
def on_step_end(self, optimizer): | |||||
def on_step_end(self): | |||||
pass | pass | ||||
def on_batch_end(self, *args): | def on_batch_end(self, *args): | ||||
@@ -50,65 +94,36 @@ class Callback(object): | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key): | |||||
""" | """ | ||||
每次执行验证机的evaluation后会调用。传入eval_result | 每次执行验证机的evaluation后会调用。传入eval_result | ||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | :param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | ||||
:param metric_key: str | :param metric_key: str | ||||
:param optimizer: | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self): | |||||
""" | """ | ||||
每个epoch结束将会调用该方法 | 每个epoch结束将会调用该方法 | ||||
:param cur_epoch: int, 当前的batch。从1开始。 | |||||
:param n_epoch: int, 总的batch数 | |||||
:param optimizer: 传入Trainer的optimizer。 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
""" | """ | ||||
训练结束,调用该方法 | 训练结束,调用该方法 | ||||
:param model: nn.Module, 传入Trainer的模型 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
""" | """ | ||||
当训练过程出现异常,会触发该方法 | 当训练过程出现异常,会触发该方法 | ||||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | :param exception: 某种类型的Exception,比如KeyboardInterrupt等 | ||||
:param model: 传入Trainer的模型 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def transfer(func): | |||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | |||||
:param func: | |||||
:return: | |||||
""" | |||||
def wrapper(manager, *arg): | |||||
returns = [] | |||||
for callback in manager.callbacks: | |||||
for env_name, env_value in manager.env.items(): | |||||
setattr(callback, env_name, env_value) | |||||
returns.append(getattr(callback, func.__name__)(*arg)) | |||||
return returns | |||||
return wrapper | |||||
class CallbackManager(Callback): | class CallbackManager(Callback): | ||||
"""A manager for all callbacks passed into Trainer. | """A manager for all callbacks passed into Trainer. | ||||
It collects resources inside Trainer and raise callbacks. | It collects resources inside Trainer and raise callbacks. | ||||
@@ -119,7 +134,7 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param Callback callbacks: | |||||
:param List[Callback] callbacks: | |||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
@@ -136,56 +151,43 @@ class CallbackManager(Callback): | |||||
else: | else: | ||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | ||||
@transfer | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@transfer | |||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
pass | pass | ||||
@transfer | |||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
pass | pass | ||||
@transfer | |||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
pass | pass | ||||
@transfer | |||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
pass | pass | ||||
@transfer | |||||
def on_backward_end(self, model): | |||||
def on_backward_end(self): | |||||
pass | pass | ||||
@transfer | |||||
def on_step_end(self, optimizer): | |||||
def on_step_end(self): | |||||
pass | pass | ||||
@transfer | |||||
def on_batch_end(self): | def on_batch_end(self): | ||||
pass | pass | ||||
@transfer | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
@transfer | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key): | |||||
pass | pass | ||||
@transfer | |||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self): | |||||
pass | pass | ||||
@transfer | |||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
pass | pass | ||||
@transfer | |||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
pass | pass | ||||
@@ -193,15 +195,15 @@ class DummyCallback(Callback): | |||||
def on_train_begin(self, *arg): | def on_train_begin(self, *arg): | ||||
print(arg) | print(arg) | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
print(cur_epoch, n_epoch, optimizer) | |||||
def on_epoch_end(self): | |||||
print(self.epoch, self.n_epochs) | |||||
class EchoCallback(Callback): | class EchoCallback(Callback): | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
print("before_train") | print("before_train") | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
print("before_epoch") | print("before_epoch") | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
@@ -210,16 +212,16 @@ class EchoCallback(Callback): | |||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
print("before_loss") | print("before_loss") | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
print("before_backward") | print("before_backward") | ||||
def on_batch_end(self): | def on_batch_end(self): | ||||
print("after_batch") | print("after_batch") | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self): | |||||
print("after_epoch") | print("after_epoch") | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
print("after_train") | print("after_train") | ||||
@@ -247,8 +249,8 @@ class GradientClipCallback(Callback): | |||||
self.parameters = parameters | self.parameters = parameters | ||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def on_backward_end(self, model): | |||||
self.clip_fun(model.parameters(), self.clip_value) | |||||
def on_backward_end(self): | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
class CallbackException(BaseException): | class CallbackException(BaseException): | ||||
@@ -268,13 +270,10 @@ class EarlyStopCallback(Callback): | |||||
:param int patience: 停止之前等待的epoch数 | :param int patience: 停止之前等待的epoch数 | ||||
""" | """ | ||||
super(EarlyStopCallback, self).__init__() | super(EarlyStopCallback, self).__init__() | ||||
self.trainer = None # override by CallbackManager | |||||
self.patience = patience | self.patience = patience | ||||
self.wait = 0 | self.wait = 0 | ||||
self.epoch = 0 | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
self.epoch += 1 | |||||
def on_valid_end(self, eval_result, metric_key): | |||||
if not self.trainer._better_eval_result(eval_result): | if not self.trainer._better_eval_result(eval_result): | ||||
# current result is getting worse | # current result is getting worse | ||||
if self.wait == self.patience: | if self.wait == self.patience: | ||||
@@ -284,7 +283,7 @@ class EarlyStopCallback(Callback): | |||||
else: | else: | ||||
self.wait = 0 | self.wait = 0 | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
if isinstance(exception, EarlyStopError): | if isinstance(exception, EarlyStopError): | ||||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | print("Early Stopping triggered in epoch {}!".format(self.epoch)) | ||||
else: | else: | ||||
@@ -304,9 +303,9 @@ class LRScheduler(Callback): | |||||
else: | else: | ||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
self.scheduler.step() | self.scheduler.step() | ||||
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) | |||||
print("scheduler step ", "lr=", self.optimizer.param_groups[0]["lr"]) | |||||
class ControlC(Callback): | class ControlC(Callback): | ||||
@@ -320,7 +319,7 @@ class ControlC(Callback): | |||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | ||||
self.quit_all = quit_all | self.quit_all = quit_all | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
if isinstance(exception, KeyboardInterrupt): | if isinstance(exception, KeyboardInterrupt): | ||||
if self.quit_all is True: | if self.quit_all is True: | ||||
import sys | import sys | ||||
@@ -366,15 +365,15 @@ class LRFinder(Callback): | |||||
self.find = None | self.find = None | ||||
self.loader = ModelLoader() | self.loader = ModelLoader() | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
if cur_epoch == 1: | |||||
def on_epoch_begin(self): | |||||
if self.epoch == 1: # first epoch | |||||
self.opt = self.trainer.optimizer # pytorch optimizer | self.opt = self.trainer.optimizer # pytorch optimizer | ||||
self.opt.param_groups[0]["lr"] = self.start_lr | self.opt.param_groups[0]["lr"] = self.start_lr | ||||
# save model | # save model | ||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | ||||
self.find = True | self.find = True | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
if self.find: | if self.find: | ||||
if torch.isnan(loss) or self.stop is True: | if torch.isnan(loss) or self.stop is True: | ||||
self.stop = True | self.stop = True | ||||
@@ -395,8 +394,8 @@ class LRFinder(Callback): | |||||
self.opt.param_groups[0]["lr"] = lr | self.opt.param_groups[0]["lr"] = lr | ||||
# self.loader.load_pytorch(self.trainer.model, "tmp") | # self.loader.load_pytorch(self.trainer.model, "tmp") | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
if cur_epoch == 1: | |||||
def on_epoch_end(self): | |||||
if self.epoch == 1: # first epoch | |||||
self.opt.param_groups[0]["lr"] = self.best_lr | self.opt.param_groups[0]["lr"] = self.best_lr | ||||
self.find = False | self.find = False | ||||
# reset model | # reset model | ||||
@@ -440,7 +439,7 @@ class TensorboardCallback(Callback): | |||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | ||||
self.graph_added = True | self.graph_added = True | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
if "loss" in self.options: | if "loss" in self.options: | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | ||||
@@ -452,18 +451,18 @@ class TensorboardCallback(Callback): | |||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key): | |||||
if "metric" in self.options: | if "metric" in self.options: | ||||
for name, metric in eval_result.items(): | for name, metric in eval_result.items(): | ||||
for metric_key, metric_val in metric.items(): | for metric_key, metric_val in metric.items(): | ||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
if hasattr(self, "_summary_writer"): | if hasattr(self, "_summary_writer"): | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
@@ -471,5 +470,5 @@ class TensorboardCallback(Callback): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | ||||
manager.on_train_begin(10, 11, 12) | |||||
manager.on_train_begin() | |||||
# print(manager.after_epoch()) | # print(manager.after_epoch()) |
@@ -122,6 +122,8 @@ class Trainer(object): | |||||
self.sampler = sampler | self.sampler = sampler | ||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | ||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -129,6 +131,7 @@ class Trainer(object): | |||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | |||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
@@ -198,9 +201,9 @@ class Trainer(object): | |||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end(self.model) | |||||
self.callback_manager.on_train_end() | |||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e, self.model) | |||||
self.callback_manager.on_exception(e) | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
@@ -227,18 +230,21 @@ class Trainer(object): | |||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | self.step = 0 | ||||
self.epoch = 0 | |||||
start = time.time() | start = time.time() | ||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
self.pbar = pbar if isinstance(pbar, tqdm) else None | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
self.epoch = epoch | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||||
self.callback_manager.on_epoch_begin() | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
self.step += 1 | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
indices = data_iterator.get_batch_indices() | indices = data_iterator.get_batch_indices() | ||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
@@ -251,14 +257,14 @@ class Trainer(object): | |||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss, self.model) | |||||
self.callback_manager.on_backward_begin(loss) | |||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end(self.model) | |||||
self.callback_manager.on_backward_end() | |||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end(self.optimizer) | |||||
self.callback_manager.on_step_end() | |||||
if (self.step+1) % self.print_every == 0: | |||||
if self.step % self.print_every == 0: | |||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | ||||
pbar.update(self.print_every) | pbar.update(self.print_every) | ||||
@@ -269,7 +275,6 @@ class Trainer(object): | |||||
epoch, self.step, avg_loss, diff) | epoch, self.step, avg_loss, diff) | ||||
pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.step += 1 | |||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
@@ -277,16 +282,17 @@ class Trainer(object): | |||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
total_steps) + \ | |||||
self.n_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str) | pbar.write(eval_str) | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||||
self.callback_manager.on_epoch_end() | |||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
self.pbar = None | |||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
@@ -303,7 +309,7 @@ class Trainer(object): | |||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) | |||||
self.callback_manager.on_valid_end(res, self.metric_key) | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
import json | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -64,6 +65,53 @@ def convert_seq2seq_dataset(data): | |||||
return dataset | return dataset | ||||
def download_from_url(url, path): | |||||
from tqdm import tqdm | |||||
import requests | |||||
"""Download file""" | |||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||||
chunk_size = 16 * 1024 | |||||
total_size = int(r.headers.get('Content-length', 0)) | |||||
with open(path, "wb") as file ,\ | |||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||||
for chunk in r.iter_content(chunk_size): | |||||
if chunk: | |||||
file.write(chunk) | |||||
t.update(len(chunk)) | |||||
return | |||||
def uncompress(src, dst): | |||||
import zipfile, gzip, tarfile, os | |||||
def unzip(src, dst): | |||||
with zipfile.ZipFile(src, 'r') as f: | |||||
f.extractall(dst) | |||||
def ungz(src, dst): | |||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||||
length = 16 * 1024 # 16KB | |||||
buf = f.read(length) | |||||
while buf: | |||||
uf.write(buf) | |||||
buf = f.read(length) | |||||
def untar(src, dst): | |||||
with tarfile.open(src, 'r:gz') as f: | |||||
f.extractall(dst) | |||||
fn, ext = os.path.splitext(src) | |||||
_, ext_2 = os.path.splitext(fn) | |||||
if ext == '.zip': | |||||
unzip(src, dst) | |||||
elif ext == '.gz' and ext_2 != '.tar': | |||||
ungz(src, dst) | |||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||||
untar(src, dst) | |||||
else: | |||||
raise ValueError('unsupported file {}'.format(src)) | |||||
class DataSetLoader: | class DataSetLoader: | ||||
"""Interface for all DataSetLoaders. | """Interface for all DataSetLoaders. | ||||
@@ -290,41 +338,6 @@ class DummyClassificationReader(DataSetLoader): | |||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
class ConllLoader(DataSetLoader): | |||||
"""loader for conll format files""" | |||||
def __init__(self): | |||||
super(ConllLoader, self).__init__() | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
@staticmethod | |||||
def parse(lines): | |||||
""" | |||||
:param list lines: a list containing all lines in a conll file. | |||||
:return: a 3D list | |||||
""" | |||||
sentences = list() | |||||
tokens = list() | |||||
for line in lines: | |||||
if line[0] == "#": | |||||
# skip the comments | |||||
continue | |||||
if line == "\n": | |||||
sentences.append(tokens) | |||||
tokens = [] | |||||
continue | |||||
tokens.append(line.split()) | |||||
return sentences | |||||
def convert(self, data): | |||||
pass | |||||
class DummyLMReader(DataSetLoader): | class DummyLMReader(DataSetLoader): | ||||
"""A Dummy Language Model Dataset Reader | """A Dummy Language Model Dataset Reader | ||||
""" | """ | ||||
@@ -434,51 +447,67 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
return data_set | return data_set | ||||
class Conll2003Loader(DataSetLoader): | |||||
class ConllLoader: | |||||
def __init__(self, headers, indexs=None): | |||||
self.headers = headers | |||||
if indexs is None: | |||||
self.indexs = list(range(len(self.headers))) | |||||
else: | |||||
if len(indexs) != len(headers): | |||||
raise ValueError | |||||
self.indexs = indexs | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
start = next(f) | |||||
if '-DOCSTART-' not in start: | |||||
sample.append(start.split()) | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
if len(sample): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
data = filter(lambda x: x is not None, data) | |||||
ds = DataSet() | |||||
for sample in data: | |||||
ins = Instance() | |||||
for name, idx in zip(self.headers, self.indexs): | |||||
ins.add_field(field_name=name, field=sample[idx]) | |||||
ds.append(ins) | |||||
return ds | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
for field in sample: | |||||
if len(field) <= 0: | |||||
return None | |||||
return sample | |||||
class Conll2003Loader(ConllLoader): | |||||
"""Loader for conll2003 dataset | """Loader for conll2003 dataset | ||||
More information about the given dataset cound be found on | More information about the given dataset cound be found on | ||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | ||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(Conll2003Loader, self).__init__() | |||||
def load(self, dataset_path): | |||||
with open(dataset_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
parsed_data = [] | |||||
sentence = [] | |||||
tokens = [] | |||||
for line in lines: | |||||
if '-DOCSTART- -X- -X- O' in line or line == '\n': | |||||
if sentence != []: | |||||
parsed_data.append((sentence, tokens)) | |||||
sentence = [] | |||||
tokens = [] | |||||
continue | |||||
temp = line.strip().split(" ") | |||||
sentence.append(temp[0]) | |||||
tokens.append(temp[1:4]) | |||||
return self.convert(parsed_data) | |||||
def convert(self, parsed_data): | |||||
dataset = DataSet() | |||||
for sample in parsed_data: | |||||
label0_list = list(map( | |||||
lambda labels: labels[0], sample[1])) | |||||
label1_list = list(map( | |||||
lambda labels: labels[1], sample[1])) | |||||
label2_list = list(map( | |||||
lambda labels: labels[2], sample[1])) | |||||
dataset.append(Instance(tokens=sample[0], | |||||
pos=label0_list, | |||||
chucks=label1_list, | |||||
ner=label2_list)) | |||||
return dataset | |||||
headers = [ | |||||
'tokens', 'pos', 'chunks', 'ner', | |||||
] | |||||
super(Conll2003Loader, self).__init__(headers=headers) | |||||
class SNLIDataSetReader(DataSetLoader): | class SNLIDataSetReader(DataSetLoader): | ||||
@@ -548,6 +577,7 @@ class SNLIDataSetReader(DataSetLoader): | |||||
class ConllCWSReader(object): | class ConllCWSReader(object): | ||||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -700,6 +730,7 @@ def cut_long_sentence(sent, max_sample_length=200): | |||||
class ZhConllPOSReader(object): | class ZhConllPOSReader(object): | ||||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | """读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | ||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -778,47 +809,78 @@ class ZhConllPOSReader(object): | |||||
return text, pos_tags | return text, pos_tags | ||||
class ConllxDataLoader(object): | |||||
class ConllxDataLoader(ConllLoader): | |||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | ||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | """ | ||||
def __init__(self): | |||||
headers = [ | |||||
'words', 'pos_tags', 'heads', 'labels', | |||||
] | |||||
indexs = [ | |||||
1, 3, 6, 7, | |||||
] | |||||
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) | |||||
class SSTLoader(DataSetLoader): | |||||
"""load SST data in PTB tree format | |||||
data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
""" | |||||
def __init__(self, subtree=False, fine_grained=False): | |||||
self.subtree = subtree | |||||
tag_v = {'0':'very negative', '1':'negative', '2':'neutral', | |||||
'3':'positive', '4':'very positive'} | |||||
if not fine_grained: | |||||
tag_v['0'] = tag_v['1'] | |||||
tag_v['4'] = tag_v['3'] | |||||
self.tag_v = tag_v | |||||
def load(self, path): | def load(self, path): | ||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
datas = [] | |||||
for l in f: | |||||
datas.extend([(s, self.tag_v[t]) | |||||
for s, t in self.get_one(l, self.subtree)]) | |||||
ds = DataSet() | |||||
for words, tag in datas: | |||||
ds.append(Instance(words=words, raw_tag=tag)) | |||||
return ds | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
data_list = list(filter(lambda x: x is not None, data)) | |||||
@staticmethod | |||||
def get_one(data, subtree): | |||||
from nltk.tree import Tree | |||||
tree = Tree.fromstring(data) | |||||
if subtree: | |||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
return [(tree.leaves(), tree.label())] | |||||
class JsonLoader(DataSetLoader): | |||||
"""Load json-format data, | |||||
every line contains a json obj, like a dict | |||||
fields is the dict key that need to be load | |||||
""" | |||||
def __init__(self, **fields): | |||||
super(JsonLoader, self).__init__() | |||||
self.fields = {} | |||||
for k, v in fields.items(): | |||||
self.fields[k] = k if v is None else v | |||||
def load(self, path): | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
datas = [json.loads(l) for l in f] | |||||
ds = DataSet() | ds = DataSet() | ||||
for example in data_list: | |||||
ds.append(Instance(words=example[0], | |||||
pos_tags=example[1], | |||||
heads=example[2], | |||||
labels=example[3])) | |||||
for d in datas: | |||||
ins = Instance() | |||||
for k, v in d.items(): | |||||
if k in self.fields: | |||||
ins.add_field(self.fields[k], v) | |||||
ds.append(ins) | |||||
return ds | return ds | ||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
def add_seg_tag(data): | def add_seg_tag(data): | ||||
""" | """ | ||||
@@ -840,3 +902,4 @@ def add_seg_tag(data): | |||||
new_sample.append((word[-1], 'E-' + pos)) | new_sample.append((word[-1], 'E-' + pos)) | ||||
_processed.append(list(map(list, zip(*new_sample)))) | _processed.append(list(map(list, zip(*new_sample)))) | ||||
return _processed | return _processed | ||||
@@ -92,9 +92,9 @@ class ENASTrainer(fastNLP.Trainer): | |||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end(self.model) | |||||
self.callback_manager.on_train_end() | |||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e, self.model) | |||||
self.callback_manager.on_exception(e) | |||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
@@ -134,7 +134,7 @@ class ENASTrainer(fastNLP.Trainer): | |||||
if epoch == self.n_epochs + 1 - self.final_epochs: | if epoch == self.n_epochs + 1 - self.final_epochs: | ||||
print('Entering the final stage. (Only train the selected structure)') | print('Entering the final stage. (Only train the selected structure)') | ||||
# early stopping | # early stopping | ||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||||
self.callback_manager.on_epoch_begin() | |||||
# 1. Training the shared parameters omega of the child models | # 1. Training the shared parameters omega of the child models | ||||
self.train_shared(pbar) | self.train_shared(pbar) | ||||
@@ -155,7 +155,7 @@ class ENASTrainer(fastNLP.Trainer): | |||||
pbar.write(eval_str) | pbar.write(eval_str) | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||||
self.callback_manager.on_epoch_end() | |||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
@@ -234,12 +234,12 @@ class ENASTrainer(fastNLP.Trainer): | |||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss, self.model) | |||||
self.callback_manager.on_backward_begin(loss) | |||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end(self.model) | |||||
self.callback_manager.on_backward_end() | |||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end(self.optimizer) | |||||
self.callback_manager.on_step_end() | |||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
if self.use_tqdm: | if self.use_tqdm: | ||||
@@ -0,0 +1,181 @@ | |||||
from fastNLP.modules.encoder.star_transformer import StarTransformer | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
class StarTransEnc(nn.Module): | |||||
def __init__(self, vocab_size, emb_dim, | |||||
hidden_size, | |||||
num_layers, | |||||
num_head, | |||||
head_dim, | |||||
max_len, | |||||
emb_dropout, | |||||
dropout): | |||||
super(StarTransEnc, self).__init__() | |||||
self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||||
self.emb_drop = nn.Dropout(emb_dropout) | |||||
self.embedding = nn.Embedding(vocab_size, emb_dim) | |||||
self.encoder = StarTransformer(hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
dropout=dropout, | |||||
max_len=max_len) | |||||
def forward(self, x, mask): | |||||
x = self.embedding(x) | |||||
x = self.emb_fc(self.emb_drop(x)) | |||||
nodes, relay = self.encoder(x, mask) | |||||
return nodes, relay | |||||
class Cls(nn.Module): | |||||
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | |||||
super(Cls, self).__init__() | |||||
self.fc = nn.Sequential( | |||||
nn.Linear(in_dim, hid_dim), | |||||
nn.LeakyReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(hid_dim, num_cls), | |||||
) | |||||
def forward(self, x): | |||||
h = self.fc(x) | |||||
return h | |||||
class NLICls(nn.Module): | |||||
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | |||||
super(NLICls, self).__init__() | |||||
self.fc = nn.Sequential( | |||||
nn.Dropout(dropout), | |||||
nn.Linear(in_dim*4, hid_dim), #4 | |||||
nn.LeakyReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(hid_dim, num_cls), | |||||
) | |||||
def forward(self, x1, x2): | |||||
x = torch.cat([x1, x2, torch.abs(x1-x2), x1*x2], 1) | |||||
h = self.fc(x) | |||||
return h | |||||
class STSeqLabel(nn.Module): | |||||
"""star-transformer model for sequence labeling | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
hidden_size=300, | |||||
num_layers=4, | |||||
num_head=8, | |||||
head_dim=32, | |||||
max_len=512, | |||||
cls_hidden_size=600, | |||||
emb_dropout=0.1, | |||||
dropout=0.1,): | |||||
super(STSeqLabel, self).__init__() | |||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
max_len=max_len, | |||||
emb_dropout=emb_dropout, | |||||
dropout=dropout) | |||||
self.cls = Cls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq, seq_lens): | |||||
mask = seq_lens_to_masks(seq_lens) | |||||
nodes, _ = self.enc(word_seq, mask) | |||||
output = self.cls(nodes) | |||||
output = output.transpose(1,2) # make hidden to be dim 1 | |||||
return {'output': output} # [bsz, n_cls, seq_len] | |||||
def predict(self, word_seq, seq_lens): | |||||
y = self.forward(word_seq, seq_lens) | |||||
_, pred = y['output'].max(1) | |||||
return {'output': pred, 'seq_lens': seq_lens} | |||||
class STSeqCls(nn.Module): | |||||
"""star-transformer model for sequence classification | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
hidden_size=300, | |||||
num_layers=4, | |||||
num_head=8, | |||||
head_dim=32, | |||||
max_len=512, | |||||
cls_hidden_size=600, | |||||
emb_dropout=0.1, | |||||
dropout=0.1,): | |||||
super(STSeqCls, self).__init__() | |||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
max_len=max_len, | |||||
emb_dropout=emb_dropout, | |||||
dropout=dropout) | |||||
self.cls = Cls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq, seq_lens): | |||||
mask = seq_lens_to_masks(seq_lens) | |||||
nodes, relay = self.enc(word_seq, mask) | |||||
y = 0.5 * (relay + nodes.max(1)[0]) | |||||
output = self.cls(y) # [bsz, n_cls] | |||||
return {'output': output} | |||||
def predict(self, word_seq, seq_lens): | |||||
y = self.forward(word_seq, seq_lens) | |||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} | |||||
class STNLICls(nn.Module): | |||||
"""star-transformer model for NLI | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
hidden_size=300, | |||||
num_layers=4, | |||||
num_head=8, | |||||
head_dim=32, | |||||
max_len=512, | |||||
cls_hidden_size=600, | |||||
emb_dropout=0.1, | |||||
dropout=0.1,): | |||||
super(STNLICls, self).__init__() | |||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
max_len=max_len, | |||||
emb_dropout=emb_dropout, | |||||
dropout=dropout) | |||||
self.cls = NLICls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq1, word_seq2, seq_lens1, seq_lens2): | |||||
mask1 = seq_lens_to_masks(seq_lens1) | |||||
mask2 = seq_lens_to_masks(seq_lens2) | |||||
def enc(seq, mask): | |||||
nodes, relay = self.enc(seq, mask) | |||||
return 0.5 * (relay + nodes.max(1)[0]) | |||||
y1 = enc(word_seq1, mask1) | |||||
y2 = enc(word_seq2, mask2) | |||||
output = self.cls(y1, y2) # [bsz, n_cls] | |||||
return {'output': output} | |||||
def predict(self, word_seq1, word_seq2, seq_lens1, seq_lens2): | |||||
y = self.forward(word_seq1, word_seq2, seq_lens1, seq_lens2) | |||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} |
@@ -0,0 +1,44 @@ | |||||
# 模型复现 | |||||
这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | |||||
复现的模型有: | |||||
- Star-Transformer | |||||
- ... | |||||
## Star-Transformer | |||||
[reference](https://arxiv.org/abs/1902.09113) | |||||
### Performance | |||||
|任务| 数据集 | SOTA | 模型表现 | | |||||
|------|------| ------| ------| | |||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |||||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |||||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |||||
|Text Classification|SST|-|49.18| | |||||
|Natural Language Inference|SNLI|-|83.76| | |||||
### Usage | |||||
``` python | |||||
# for sequence labeling(ner, pos tagging, etc) | |||||
from fastNLP.models.star_transformer import STSeqLabel | |||||
model = STSeqLabel( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for sequence classification | |||||
from fastNLP.models.star_transformer import STSeqCls | |||||
model = STSeqCls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for natural language inference | |||||
from fastNLP.models.star_transformer import STNLICls | |||||
model = STNLICls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
``` | |||||
## ... |
@@ -353,7 +353,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data[-1], dev_data[-1], test_data[-1] | train_data[-1], dev_data[-1], test_data[-1] | ||||
# 读入vocab文件 | # 读入vocab文件 | ||||
with open('vocab.txt') as f: | |||||
with open('vocab.txt', encoding='utf-8') as f: | |||||
lines = f.readlines() | lines = f.readlines() | ||||
vocabs = [] | vocabs = [] | ||||
for line in lines: | for line in lines: | ||||