Browse Source

- add star transformer model

- add ConllLoader, for all kinds of conll-format files
- add JsonLoader, for json-format files
- add SSTLoader, for SST-2 & SST-5
- change Callback interface
- fix batch multi-process when killed
- add README to list models and their performance
tags/v0.4.10
yunfan 5 years ago
parent
commit
70fb4a2284
10 changed files with 530 additions and 216 deletions
  1. +8
    -1
      README.md
  2. +1
    -1
      fastNLP/core/__init__.py
  3. +15
    -1
      fastNLP/core/batch.py
  4. +85
    -86
      fastNLP/core/callback.py
  5. +20
    -14
      fastNLP/core/trainer.py
  6. +168
    -105
      fastNLP/io/dataset_loader.py
  7. +7
    -7
      fastNLP/models/enas_trainer.py
  8. +181
    -0
      fastNLP/models/star_transformer.py
  9. +44
    -0
      reproduction/README.md
  10. +1
    -1
      test/test_tutorials.py

+ 8
- 1
README.md View File

@@ -6,7 +6,7 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)


FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models.
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models.


A deep learning NLP model is the composition of three types of modules: A deep learning NLP model is the composition of three types of modules:
<table> <table>
@@ -58,6 +58,13 @@ Run the following commands to install fastNLP package.
pip install fastNLP pip install fastNLP
``` ```


## Models
fastNLP implements different models for variant NLP tasks.
Each model has been trained and tested carefully.

Check out models' performance, usage and source code here.
- [Documentation](https://github.com/fastnlp/fastNLP/tree/master/reproduction)
- [Source Code](https://github.com/fastnlp/fastNLP/tree/master/fastNLP/models)


## Project Structure ## Project Structure




+ 1
- 1
fastNLP/core/__init__.py View File

@@ -10,4 +10,4 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from ..io.dataset_loader import DataSet from ..io.dataset_loader import DataSet
from .callback import Callback

+ 15
- 1
fastNLP/core/batch.py View File

@@ -1,9 +1,16 @@
import numpy as np import numpy as np
import torch import torch
import atexit


from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
import torch.multiprocessing as mp import torch.multiprocessing as mp


_python_is_exit = False
def _set_python_is_exit():
global _python_is_exit
_python_is_exit = True
atexit.register(_set_python_is_exit)

class Batch(object): class Batch(object):
"""Batch is an iterable object which iterates over mini-batches. """Batch is an iterable object which iterates over mini-batches.


@@ -95,12 +102,19 @@ def to_tensor(batch, dtype):




def run_fetch(batch, q): def run_fetch(batch, q):
global _python_is_exit
batch.init_iter() batch.init_iter()
# print('start fetch') # print('start fetch')
while 1: while 1:
res = batch.fetch_one() res = batch.fetch_one()
# print('fetch one') # print('fetch one')
q.put(res)
while 1:
try:
q.put(res, timeout=3)
break
except Exception as e:
if _python_is_exit:
return
if res is None: if res is None:
# print('fetch done, waiting processing') # print('fetch done, waiting processing')
q.join() q.join()


+ 85
- 86
fastNLP/core/callback.py View File

@@ -15,13 +15,57 @@ class Callback(object):


def __init__(self): def __init__(self):
super(Callback, self).__init__() super(Callback, self).__init__()
self.trainer = None # 在Trainer内部被重新赋值
self._trainer = None # 在Trainer内部被重新赋值

@property
def trainer(self):
return self._trainer

@property
def step(self):
"""current step number, in range(1, self.n_steps+1)"""
return self._trainer.step

@property
def n_steps(self):
"""total number of steps for training"""
return self.n_steps

@property
def batch_size(self):
"""batch size for training"""
return self._trainer.batch_size

@property
def epoch(self):
"""current epoch number, in range(1, self.n_epochs+1)"""
return self._trainer.epoch

@property
def n_epochs(self):
"""total number of epochs"""
return self._trainer.n_epochs

@property
def optimizer(self):
"""torch.optim.Optimizer for current model"""
return self._trainer.optimizer

@property
def model(self):
"""training model"""
return self._trainer.model

@property
def pbar(self):
"""If use_tqdm, return trainer's tqdm print bar, else return None."""
return self._trainer.pbar


def on_train_begin(self): def on_train_begin(self):
# before the main training loop # before the main training loop
pass pass


def on_epoch_begin(self, cur_epoch, total_epoch):
def on_epoch_begin(self):
# at the beginning of each epoch # at the beginning of each epoch
pass pass


@@ -33,14 +77,14 @@ class Callback(object):
# after data_forward, and before loss computation # after data_forward, and before loss computation
pass pass


def on_backward_begin(self, loss, model):
def on_backward_begin(self, loss):
# after loss computation, and before gradient backward # after loss computation, and before gradient backward
pass pass


def on_backward_end(self, model):
def on_backward_end(self):
pass pass


def on_step_end(self, optimizer):
def on_step_end(self):
pass pass


def on_batch_end(self, *args): def on_batch_end(self, *args):
@@ -50,65 +94,36 @@ class Callback(object):
def on_valid_begin(self): def on_valid_begin(self):
pass pass


def on_valid_end(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key):
""" """
每次执行验证机的evaluation后会调用。传入eval_result 每次执行验证机的evaluation后会调用。传入eval_result


:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 :param eval_result: Dict[str: Dict[str: float]], evaluation的结果
:param metric_key: str :param metric_key: str
:param optimizer:
:return: :return:
""" """
pass pass


def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self):
""" """
每个epoch结束将会调用该方法 每个epoch结束将会调用该方法

:param cur_epoch: int, 当前的batch。从1开始。
:param n_epoch: int, 总的batch数
:param optimizer: 传入Trainer的optimizer。
:return:
""" """
pass pass


def on_train_end(self, model):
def on_train_end(self):
""" """
训练结束,调用该方法 训练结束,调用该方法

:param model: nn.Module, 传入Trainer的模型
:return:
""" """
pass pass


def on_exception(self, exception, model):
def on_exception(self, exception):
""" """
当训练过程出现异常,会触发该方法 当训练过程出现异常,会触发该方法
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 :param exception: 某种类型的Exception,比如KeyboardInterrupt等
:param model: 传入Trainer的模型
:return:
""" """
pass pass




def transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类.

:param func:
:return:
"""

def wrapper(manager, *arg):
returns = []
for callback in manager.callbacks:
for env_name, env_value in manager.env.items():
setattr(callback, env_name, env_value)
returns.append(getattr(callback, func.__name__)(*arg))
return returns

return wrapper


class CallbackManager(Callback): class CallbackManager(Callback):
"""A manager for all callbacks passed into Trainer. """A manager for all callbacks passed into Trainer.
It collects resources inside Trainer and raise callbacks. It collects resources inside Trainer and raise callbacks.
@@ -119,7 +134,7 @@ class CallbackManager(Callback):
""" """


:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself.
:param Callback callbacks:
:param List[Callback] callbacks:
""" """
super(CallbackManager, self).__init__() super(CallbackManager, self).__init__()
# set attribute of trainer environment # set attribute of trainer environment
@@ -136,56 +151,43 @@ class CallbackManager(Callback):
else: else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")


@transfer
def on_train_begin(self): def on_train_begin(self):
pass pass


@transfer
def on_epoch_begin(self, cur_epoch, total_epoch):
def on_epoch_begin(self):
pass pass


@transfer
def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
pass pass


@transfer
def on_loss_begin(self, batch_y, predict_y): def on_loss_begin(self, batch_y, predict_y):
pass pass


@transfer
def on_backward_begin(self, loss, model):
def on_backward_begin(self, loss):
pass pass


@transfer
def on_backward_end(self, model):
def on_backward_end(self):
pass pass


@transfer
def on_step_end(self, optimizer):
def on_step_end(self):
pass pass


@transfer
def on_batch_end(self): def on_batch_end(self):
pass pass


@transfer
def on_valid_begin(self): def on_valid_begin(self):
pass pass


@transfer
def on_valid_end(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key):
pass pass


@transfer
def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self):
pass pass


@transfer
def on_train_end(self, model):
def on_train_end(self):
pass pass


@transfer
def on_exception(self, exception, model):
def on_exception(self, exception):
pass pass




@@ -193,15 +195,15 @@ class DummyCallback(Callback):
def on_train_begin(self, *arg): def on_train_begin(self, *arg):
print(arg) print(arg)


def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
print(cur_epoch, n_epoch, optimizer)
def on_epoch_end(self):
print(self.epoch, self.n_epochs)




class EchoCallback(Callback): class EchoCallback(Callback):
def on_train_begin(self): def on_train_begin(self):
print("before_train") print("before_train")


def on_epoch_begin(self, cur_epoch, total_epoch):
def on_epoch_begin(self):
print("before_epoch") print("before_epoch")


def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
@@ -210,16 +212,16 @@ class EchoCallback(Callback):
def on_loss_begin(self, batch_y, predict_y): def on_loss_begin(self, batch_y, predict_y):
print("before_loss") print("before_loss")


def on_backward_begin(self, loss, model):
def on_backward_begin(self, loss):
print("before_backward") print("before_backward")


def on_batch_end(self): def on_batch_end(self):
print("after_batch") print("after_batch")


def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
def on_epoch_end(self):
print("after_epoch") print("after_epoch")


def on_train_end(self, model):
def on_train_end(self):
print("after_train") print("after_train")




@@ -247,8 +249,8 @@ class GradientClipCallback(Callback):
self.parameters = parameters self.parameters = parameters
self.clip_value = clip_value self.clip_value = clip_value


def on_backward_end(self, model):
self.clip_fun(model.parameters(), self.clip_value)
def on_backward_end(self):
self.clip_fun(self.model.parameters(), self.clip_value)




class CallbackException(BaseException): class CallbackException(BaseException):
@@ -268,13 +270,10 @@ class EarlyStopCallback(Callback):
:param int patience: 停止之前等待的epoch数 :param int patience: 停止之前等待的epoch数
""" """
super(EarlyStopCallback, self).__init__() super(EarlyStopCallback, self).__init__()
self.trainer = None # override by CallbackManager
self.patience = patience self.patience = patience
self.wait = 0 self.wait = 0
self.epoch = 0


def on_valid_end(self, eval_result, metric_key, optimizer):
self.epoch += 1
def on_valid_end(self, eval_result, metric_key):
if not self.trainer._better_eval_result(eval_result): if not self.trainer._better_eval_result(eval_result):
# current result is getting worse # current result is getting worse
if self.wait == self.patience: if self.wait == self.patience:
@@ -284,7 +283,7 @@ class EarlyStopCallback(Callback):
else: else:
self.wait = 0 self.wait = 0


def on_exception(self, exception, model):
def on_exception(self, exception):
if isinstance(exception, EarlyStopError): if isinstance(exception, EarlyStopError):
print("Early Stopping triggered in epoch {}!".format(self.epoch)) print("Early Stopping triggered in epoch {}!".format(self.epoch))
else: else:
@@ -304,9 +303,9 @@ class LRScheduler(Callback):
else: else:
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")


def on_epoch_begin(self, cur_epoch, total_epoch):
def on_epoch_begin(self):
self.scheduler.step() self.scheduler.step()
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"])
print("scheduler step ", "lr=", self.optimizer.param_groups[0]["lr"])




class ControlC(Callback): class ControlC(Callback):
@@ -320,7 +319,7 @@ class ControlC(Callback):
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
self.quit_all = quit_all self.quit_all = quit_all


def on_exception(self, exception, model):
def on_exception(self, exception):
if isinstance(exception, KeyboardInterrupt): if isinstance(exception, KeyboardInterrupt):
if self.quit_all is True: if self.quit_all is True:
import sys import sys
@@ -366,15 +365,15 @@ class LRFinder(Callback):
self.find = None self.find = None
self.loader = ModelLoader() self.loader = ModelLoader()


def on_epoch_begin(self, cur_epoch, total_epoch):
if cur_epoch == 1:
def on_epoch_begin(self):
if self.epoch == 1: # first epoch
self.opt = self.trainer.optimizer # pytorch optimizer self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr self.opt.param_groups[0]["lr"] = self.start_lr
# save model # save model
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
self.find = True self.find = True


def on_backward_begin(self, loss, model):
def on_backward_begin(self, loss):
if self.find: if self.find:
if torch.isnan(loss) or self.stop is True: if torch.isnan(loss) or self.stop is True:
self.stop = True self.stop = True
@@ -395,8 +394,8 @@ class LRFinder(Callback):
self.opt.param_groups[0]["lr"] = lr self.opt.param_groups[0]["lr"] = lr
# self.loader.load_pytorch(self.trainer.model, "tmp") # self.loader.load_pytorch(self.trainer.model, "tmp")


def on_epoch_end(self, cur_epoch, n_epoch, optimizer):
if cur_epoch == 1:
def on_epoch_end(self):
if self.epoch == 1: # first epoch
self.opt.param_groups[0]["lr"] = self.best_lr self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False self.find = False
# reset model # reset model
@@ -440,7 +439,7 @@ class TensorboardCallback(Callback):
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2))
self.graph_added = True self.graph_added = True


def on_backward_begin(self, loss, model):
def on_backward_begin(self, loss):
if "loss" in self.options: if "loss" in self.options:
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)


@@ -452,18 +451,18 @@ class TensorboardCallback(Callback):
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(),
global_step=self.trainer.step) global_step=self.trainer.step)


def on_valid_end(self, eval_result, metric_key, optimizer):
def on_valid_end(self, eval_result, metric_key):
if "metric" in self.options: if "metric" in self.options:
for name, metric in eval_result.items(): for name, metric in eval_result.items():
for metric_key, metric_val in metric.items(): for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.trainer.step) global_step=self.trainer.step)


def on_train_end(self, model):
def on_train_end(self):
self._summary_writer.close() self._summary_writer.close()
del self._summary_writer del self._summary_writer


def on_exception(self, exception, model):
def on_exception(self, exception):
if hasattr(self, "_summary_writer"): if hasattr(self, "_summary_writer"):
self._summary_writer.close() self._summary_writer.close()
del self._summary_writer del self._summary_writer
@@ -471,5 +470,5 @@ class TensorboardCallback(Callback):


if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.on_train_begin(10, 11, 12)
manager.on_train_begin()
# print(manager.after_epoch()) # print(manager.after_epoch())

+ 20
- 14
fastNLP/core/trainer.py View File

@@ -122,6 +122,8 @@ class Trainer(object):
self.sampler = sampler self.sampler = sampler
self.prefetch = prefetch self.prefetch = prefetch
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
self.n_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs


if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
@@ -129,6 +131,7 @@ class Trainer(object):
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())


self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
self.pbar = None
self.print_every = abs(self.print_every) self.print_every = abs(self.print_every)


if self.dev_data is not None: if self.dev_data is not None:
@@ -198,9 +201,9 @@ class Trainer(object):
try: try:
self.callback_manager.on_train_begin() self.callback_manager.on_train_begin()
self._train() self._train()
self.callback_manager.on_train_end(self.model)
self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt) as e: except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model)
self.callback_manager.on_exception(e)


if self.dev_data is not None: if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
@@ -227,18 +230,21 @@ class Trainer(object):
else: else:
inner_tqdm = tqdm inner_tqdm = tqdm
self.step = 0 self.step = 0
self.epoch = 0
start = time.time() start = time.time()
total_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar if isinstance(pbar, tqdm) else None
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch) prefetch=self.prefetch)
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
# early stopping # early stopping
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
self.callback_manager.on_epoch_begin()
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
self.step += 1
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
indices = data_iterator.get_batch_indices() indices = data_iterator.get_batch_indices()
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
@@ -251,14 +257,14 @@ class Trainer(object):
avg_loss += loss.item() avg_loss += loss.item()


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss, self.model)
self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss) self._grad_backward(loss)
self.callback_manager.on_backward_end(self.model)
self.callback_manager.on_backward_end()


self._update() self._update()
self.callback_manager.on_step_end(self.optimizer)
self.callback_manager.on_step_end()


if (self.step+1) % self.print_every == 0:
if self.step % self.print_every == 0:
if self.use_tqdm: if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
pbar.update(self.print_every) pbar.update(self.print_every)
@@ -269,7 +275,6 @@ class Trainer(object):
epoch, self.step, avg_loss, diff) epoch, self.step, avg_loss, diff)
pbar.set_postfix_str(print_output) pbar.set_postfix_str(print_output)
avg_loss = 0 avg_loss = 0
self.step += 1
self.callback_manager.on_batch_end() self.callback_manager.on_batch_end()


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
@@ -277,16 +282,17 @@ class Trainer(object):
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
total_steps) + \
self.n_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar.write(eval_str) pbar.write(eval_str)


# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #


# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
pbar.close() pbar.close()
self.pbar = None
# ============ tqdm end ============== # # ============ tqdm end ============== #


def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
@@ -303,7 +309,7 @@ class Trainer(object):
self.best_dev_epoch = epoch self.best_dev_epoch = epoch
self.best_dev_step = step self.best_dev_step = step
# get validation results; adjust optimizer # get validation results; adjust optimizer
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer)
self.callback_manager.on_valid_end(res, self.metric_key)
return res return res


def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):


+ 168
- 105
fastNLP/io/dataset_loader.py View File

@@ -1,4 +1,5 @@
import os import os
import json


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
@@ -64,6 +65,53 @@ def convert_seq2seq_dataset(data):
return dataset return dataset




def download_from_url(url, path):
from tqdm import tqdm
import requests

"""Download file"""
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
chunk_size = 16 * 1024
total_size = int(r.headers.get('Content-length', 0))
with open(path, "wb") as file ,\
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
for chunk in r.iter_content(chunk_size):
if chunk:
file.write(chunk)
t.update(len(chunk))
return

def uncompress(src, dst):
import zipfile, gzip, tarfile, os

def unzip(src, dst):
with zipfile.ZipFile(src, 'r') as f:
f.extractall(dst)

def ungz(src, dst):
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
length = 16 * 1024 # 16KB
buf = f.read(length)
while buf:
uf.write(buf)
buf = f.read(length)

def untar(src, dst):
with tarfile.open(src, 'r:gz') as f:
f.extractall(dst)

fn, ext = os.path.splitext(src)
_, ext_2 = os.path.splitext(fn)
if ext == '.zip':
unzip(src, dst)
elif ext == '.gz' and ext_2 != '.tar':
ungz(src, dst)
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
untar(src, dst)
else:
raise ValueError('unsupported file {}'.format(src))


class DataSetLoader: class DataSetLoader:
"""Interface for all DataSetLoaders. """Interface for all DataSetLoaders.


@@ -290,41 +338,6 @@ class DummyClassificationReader(DataSetLoader):
return convert_seq2tag_dataset(data) return convert_seq2tag_dataset(data)




class ConllLoader(DataSetLoader):
"""loader for conll format files"""

def __init__(self):
super(ConllLoader, self).__init__()

def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
data = self.parse(lines)
return self.convert(data)

@staticmethod
def parse(lines):
"""
:param list lines: a list containing all lines in a conll file.
:return: a 3D list
"""
sentences = list()
tokens = list()
for line in lines:
if line[0] == "#":
# skip the comments
continue
if line == "\n":
sentences.append(tokens)
tokens = []
continue
tokens.append(line.split())
return sentences

def convert(self, data):
pass


class DummyLMReader(DataSetLoader): class DummyLMReader(DataSetLoader):
"""A Dummy Language Model Dataset Reader """A Dummy Language Model Dataset Reader
""" """
@@ -434,51 +447,67 @@ class PeopleDailyCorpusLoader(DataSetLoader):
return data_set return data_set




class Conll2003Loader(DataSetLoader):
class ConllLoader:
def __init__(self, headers, indexs=None):
self.headers = headers
if indexs is None:
self.indexs = list(range(len(self.headers)))
else:
if len(indexs) != len(headers):
raise ValueError
self.indexs = indexs

def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
start = next(f)
if '-DOCSTART-' not in start:
sample.append(start.split())
for line in f:
if line.startswith('\n'):
if len(sample):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split())
if len(sample) > 0:
datalist.append(sample)

data = [self.get_one(sample) for sample in datalist]
data = filter(lambda x: x is not None, data)

ds = DataSet()
for sample in data:
ins = Instance()
for name, idx in zip(self.headers, self.indexs):
ins.add_field(field_name=name, field=sample[idx])
ds.append(ins)
return ds

def get_one(self, sample):
sample = list(map(list, zip(*sample)))
for field in sample:
if len(field) <= 0:
return None
return sample


class Conll2003Loader(ConllLoader):
"""Loader for conll2003 dataset """Loader for conll2003 dataset
More information about the given dataset cound be found on More information about the given dataset cound be found on
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data

Deprecated. Use ConllLoader for all types of conll-format files.
""" """
def __init__(self): def __init__(self):
super(Conll2003Loader, self).__init__()

def load(self, dataset_path):
with open(dataset_path, "r", encoding="utf-8") as f:
lines = f.readlines()
parsed_data = []
sentence = []
tokens = []
for line in lines:
if '-DOCSTART- -X- -X- O' in line or line == '\n':
if sentence != []:
parsed_data.append((sentence, tokens))
sentence = []
tokens = []
continue

temp = line.strip().split(" ")
sentence.append(temp[0])
tokens.append(temp[1:4])

return self.convert(parsed_data)

def convert(self, parsed_data):
dataset = DataSet()
for sample in parsed_data:
label0_list = list(map(
lambda labels: labels[0], sample[1]))
label1_list = list(map(
lambda labels: labels[1], sample[1]))
label2_list = list(map(
lambda labels: labels[2], sample[1]))
dataset.append(Instance(tokens=sample[0],
pos=label0_list,
chucks=label1_list,
ner=label2_list))

return dataset
headers = [
'tokens', 'pos', 'chunks', 'ner',
]
super(Conll2003Loader, self).__init__(headers=headers)




class SNLIDataSetReader(DataSetLoader): class SNLIDataSetReader(DataSetLoader):
@@ -548,6 +577,7 @@ class SNLIDataSetReader(DataSetLoader):




class ConllCWSReader(object): class ConllCWSReader(object):
"""Deprecated. Use ConllLoader for all types of conll-format files."""
def __init__(self): def __init__(self):
pass pass


@@ -700,6 +730,7 @@ def cut_long_sentence(sent, max_sample_length=200):
class ZhConllPOSReader(object): class ZhConllPOSReader(object):
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 """读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。


Deprecated. Use ConllLoader for all types of conll-format files.
""" """
def __init__(self): def __init__(self):
pass pass
@@ -778,47 +809,78 @@ class ZhConllPOSReader(object):
return text, pos_tags return text, pos_tags




class ConllxDataLoader(object):
class ConllxDataLoader(ConllLoader):
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。


Deprecated. Use ConllLoader for all types of conll-format files.
""" """
def __init__(self):
headers = [
'words', 'pos_tags', 'heads', 'labels',
]
indexs = [
1, 3, 6, 7,
]
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs)


class SSTLoader(DataSetLoader):
"""load SST data in PTB tree format
data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip
"""
def __init__(self, subtree=False, fine_grained=False):
self.subtree = subtree

tag_v = {'0':'very negative', '1':'negative', '2':'neutral',
'3':'positive', '4':'very positive'}
if not fine_grained:
tag_v['0'] = tag_v['1']
tag_v['4'] = tag_v['3']
self.tag_v = tag_v

def load(self, path): def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)
datas = []
for l in f:
datas.extend([(s, self.tag_v[t])
for s, t in self.get_one(l, self.subtree)])
ds = DataSet()
for words, tag in datas:
ds.append(Instance(words=words, raw_tag=tag))
return ds


data = [self.get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))
@staticmethod
def get_one(data, subtree):
from nltk.tree import Tree
tree = Tree.fromstring(data)
if subtree:
return [(t.leaves(), t.label()) for t in tree.subtrees()]
return [(tree.leaves(), tree.label())]


class JsonLoader(DataSetLoader):
"""Load json-format data,
every line contains a json obj, like a dict
fields is the dict key that need to be load
"""
def __init__(self, **fields):
super(JsonLoader, self).__init__()
self.fields = {}
for k, v in fields.items():
self.fields[k] = k if v is None else v


def load(self, path):
with open(path, 'r', encoding='utf-8') as f:
datas = [json.loads(l) for l in f]
ds = DataSet() ds = DataSet()
for example in data_list:
ds.append(Instance(words=example[0],
pos_tags=example[1],
heads=example[2],
labels=example[3]))
for d in datas:
ins = Instance()
for k, v in d.items():
if k in self.fields:
ins.add_field(self.fields[k], v)
ds.append(ins)
return ds return ds


def get_one(self, sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]



def add_seg_tag(data): def add_seg_tag(data):
""" """
@@ -840,3 +902,4 @@ def add_seg_tag(data):
new_sample.append((word[-1], 'E-' + pos)) new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample)))) _processed.append(list(map(list, zip(*new_sample))))
return _processed return _processed


+ 7
- 7
fastNLP/models/enas_trainer.py View File

@@ -92,9 +92,9 @@ class ENASTrainer(fastNLP.Trainer):
try: try:
self.callback_manager.on_train_begin() self.callback_manager.on_train_begin()
self._train() self._train()
self.callback_manager.on_train_end(self.model)
self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt) as e: except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model)
self.callback_manager.on_exception(e)


if self.dev_data is not None: if self.dev_data is not None:
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
@@ -134,7 +134,7 @@ class ENASTrainer(fastNLP.Trainer):
if epoch == self.n_epochs + 1 - self.final_epochs: if epoch == self.n_epochs + 1 - self.final_epochs:
print('Entering the final stage. (Only train the selected structure)') print('Entering the final stage. (Only train the selected structure)')
# early stopping # early stopping
self.callback_manager.on_epoch_begin(epoch, self.n_epochs)
self.callback_manager.on_epoch_begin()


# 1. Training the shared parameters omega of the child models # 1. Training the shared parameters omega of the child models
self.train_shared(pbar) self.train_shared(pbar)
@@ -155,7 +155,7 @@ class ENASTrainer(fastNLP.Trainer):
pbar.write(eval_str) pbar.write(eval_str)


# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer)
self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
pbar.close() pbar.close()
# ============ tqdm end ============== # # ============ tqdm end ============== #
@@ -234,12 +234,12 @@ class ENASTrainer(fastNLP.Trainer):
avg_loss += loss.item() avg_loss += loss.item()


# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss, self.model)
self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss) self._grad_backward(loss)
self.callback_manager.on_backward_end(self.model)
self.callback_manager.on_backward_end()


self._update() self._update()
self.callback_manager.on_step_end(self.optimizer)
self.callback_manager.on_step_end()


if (self.step+1) % self.print_every == 0: if (self.step+1) % self.print_every == 0:
if self.use_tqdm: if self.use_tqdm:


+ 181
- 0
fastNLP/models/star_transformer.py View File

@@ -0,0 +1,181 @@
from fastNLP.modules.encoder.star_transformer import StarTransformer
from fastNLP.core.utils import seq_lens_to_masks

import torch
from torch import nn
import torch.nn.functional as F


class StarTransEnc(nn.Module):
def __init__(self, vocab_size, emb_dim,
hidden_size,
num_layers,
num_head,
head_dim,
max_len,
emb_dropout,
dropout):
super(StarTransEnc, self).__init__()
self.emb_fc = nn.Linear(emb_dim, hidden_size)
self.emb_drop = nn.Dropout(emb_dropout)
self.embedding = nn.Embedding(vocab_size, emb_dim)
self.encoder = StarTransformer(hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
head_dim=head_dim,
dropout=dropout,
max_len=max_len)

def forward(self, x, mask):
x = self.embedding(x)
x = self.emb_fc(self.emb_drop(x))
nodes, relay = self.encoder(x, mask)
return nodes, relay


class Cls(nn.Module):
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1):
super(Cls, self).__init__()
self.fc = nn.Sequential(
nn.Linear(in_dim, hid_dim),
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(hid_dim, num_cls),
)

def forward(self, x):
h = self.fc(x)
return h


class NLICls(nn.Module):
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1):
super(NLICls, self).__init__()
self.fc = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(in_dim*4, hid_dim), #4
nn.LeakyReLU(),
nn.Dropout(dropout),
nn.Linear(hid_dim, num_cls),
)

def forward(self, x1, x2):
x = torch.cat([x1, x2, torch.abs(x1-x2), x1*x2], 1)
h = self.fc(x)
return h

class STSeqLabel(nn.Module):
"""star-transformer model for sequence labeling
"""
def __init__(self, vocab_size, emb_dim, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
head_dim=32,
max_len=512,
cls_hidden_size=600,
emb_dropout=0.1,
dropout=0.1,):
super(STSeqLabel, self).__init__()
self.enc = StarTransEnc(vocab_size=vocab_size,
emb_dim=emb_dim,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
head_dim=head_dim,
max_len=max_len,
emb_dropout=emb_dropout,
dropout=dropout)
self.cls = Cls(hidden_size, num_cls, cls_hidden_size)

def forward(self, word_seq, seq_lens):
mask = seq_lens_to_masks(seq_lens)
nodes, _ = self.enc(word_seq, mask)
output = self.cls(nodes)
output = output.transpose(1,2) # make hidden to be dim 1
return {'output': output} # [bsz, n_cls, seq_len]

def predict(self, word_seq, seq_lens):
y = self.forward(word_seq, seq_lens)
_, pred = y['output'].max(1)
return {'output': pred, 'seq_lens': seq_lens}


class STSeqCls(nn.Module):
"""star-transformer model for sequence classification
"""

def __init__(self, vocab_size, emb_dim, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
head_dim=32,
max_len=512,
cls_hidden_size=600,
emb_dropout=0.1,
dropout=0.1,):
super(STSeqCls, self).__init__()
self.enc = StarTransEnc(vocab_size=vocab_size,
emb_dim=emb_dim,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
head_dim=head_dim,
max_len=max_len,
emb_dropout=emb_dropout,
dropout=dropout)
self.cls = Cls(hidden_size, num_cls, cls_hidden_size)

def forward(self, word_seq, seq_lens):
mask = seq_lens_to_masks(seq_lens)
nodes, relay = self.enc(word_seq, mask)
y = 0.5 * (relay + nodes.max(1)[0])
output = self.cls(y) # [bsz, n_cls]
return {'output': output}

def predict(self, word_seq, seq_lens):
y = self.forward(word_seq, seq_lens)
_, pred = y['output'].max(1)
return {'output': pred}


class STNLICls(nn.Module):
"""star-transformer model for NLI
"""

def __init__(self, vocab_size, emb_dim, num_cls,
hidden_size=300,
num_layers=4,
num_head=8,
head_dim=32,
max_len=512,
cls_hidden_size=600,
emb_dropout=0.1,
dropout=0.1,):
super(STNLICls, self).__init__()
self.enc = StarTransEnc(vocab_size=vocab_size,
emb_dim=emb_dim,
hidden_size=hidden_size,
num_layers=num_layers,
num_head=num_head,
head_dim=head_dim,
max_len=max_len,
emb_dropout=emb_dropout,
dropout=dropout)
self.cls = NLICls(hidden_size, num_cls, cls_hidden_size)

def forward(self, word_seq1, word_seq2, seq_lens1, seq_lens2):
mask1 = seq_lens_to_masks(seq_lens1)
mask2 = seq_lens_to_masks(seq_lens2)
def enc(seq, mask):
nodes, relay = self.enc(seq, mask)
return 0.5 * (relay + nodes.max(1)[0])
y1 = enc(word_seq1, mask1)
y2 = enc(word_seq2, mask2)
output = self.cls(y1, y2) # [bsz, n_cls]
return {'output': output}

def predict(self, word_seq1, word_seq2, seq_lens1, seq_lens2):
y = self.forward(word_seq1, word_seq2, seq_lens1, seq_lens2)
_, pred = y['output'].max(1)
return {'output': pred}

+ 44
- 0
reproduction/README.md View File

@@ -0,0 +1,44 @@
# 模型复现
这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。

复现的模型有:
- Star-Transformer
- ...


## Star-Transformer
[reference](https://arxiv.org/abs/1902.09113)
### Performance
|任务| 数据集 | SOTA | 模型表现 |
|------|------| ------| ------|
|Pos Tagging|CTB 9.0|-|ACC 92.31|
|Pos Tagging|CONLL 2012|-|ACC 96.51|
|Named Entity Recognition|CONLL 2012|-|F1 85.66|
|Text Classification|SST|-|49.18|
|Natural Language Inference|SNLI|-|83.76|

### Usage
``` python
# for sequence labeling(ner, pos tagging, etc)
from fastNLP.models.star_transformer import STSeqLabel
model = STSeqLabel(
vocab_size=10000, num_cls=50,
emb_dim=300)


# for sequence classification
from fastNLP.models.star_transformer import STSeqCls
model = STSeqCls(
vocab_size=10000, num_cls=50,
emb_dim=300)


# for natural language inference
from fastNLP.models.star_transformer import STNLICls
model = STNLICls(
vocab_size=10000, num_cls=50,
emb_dim=300)

```

## ...

+ 1
- 1
test/test_tutorials.py View File

@@ -353,7 +353,7 @@ class TestTutorial(unittest.TestCase):
train_data[-1], dev_data[-1], test_data[-1] train_data[-1], dev_data[-1], test_data[-1]


# 读入vocab文件 # 读入vocab文件
with open('vocab.txt') as f:
with open('vocab.txt', encoding='utf-8') as f:
lines = f.readlines() lines = f.readlines()
vocabs = [] vocabs = []
for line in lines: for line in lines:


Loading…
Cancel
Save