@@ -6,7 +6,7 @@ | |||||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | ||||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | ||||
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. | |||||
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. | |||||
A deep learning NLP model is the composition of three types of modules: | A deep learning NLP model is the composition of three types of modules: | ||||
<table> | <table> | ||||
@@ -58,6 +58,13 @@ Run the following commands to install fastNLP package. | |||||
pip install fastNLP | pip install fastNLP | ||||
``` | ``` | ||||
## Models | |||||
fastNLP implements different models for variant NLP tasks. | |||||
Each model has been trained and tested carefully. | |||||
Check out models' performance, usage and source code here. | |||||
- [Documentation](reproduction/) | |||||
- [Source Code](fastNLP/models/) | |||||
## Project Structure | ## Project Structure | ||||
@@ -10,4 +10,4 @@ from .tester import Tester | |||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from ..io.dataset_loader import DataSet | from ..io.dataset_loader import DataSet | ||||
from .callback import Callback |
@@ -1,9 +1,16 @@ | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import atexit | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
_python_is_exit = False | |||||
def _set_python_is_exit(): | |||||
global _python_is_exit | |||||
_python_is_exit = True | |||||
atexit.register(_set_python_is_exit) | |||||
class Batch(object): | class Batch(object): | ||||
"""Batch is an iterable object which iterates over mini-batches. | """Batch is an iterable object which iterates over mini-batches. | ||||
@@ -14,15 +21,17 @@ class Batch(object): | |||||
:param DataSet dataset: a DataSet object | :param DataSet dataset: a DataSet object | ||||
:param int batch_size: the size of the batch | :param int batch_size: the size of the batch | ||||
:param Sampler sampler: a Sampler object | |||||
:param Sampler sampler: a Sampler object. If None, use fastNLP.sampler.RandomSampler | |||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | ||||
:param bool prefetch: If True, use multiprocessing to fetch next batch when training. | :param bool prefetch: If True, use multiprocessing to fetch next batch when training. | ||||
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler=RandomSampler(), as_numpy=False, prefetch=False): | |||||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
self.sampler = sampler | self.sampler = sampler | ||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | self.idx_list = None | ||||
@@ -95,12 +104,19 @@ def to_tensor(batch, dtype): | |||||
def run_fetch(batch, q): | def run_fetch(batch, q): | ||||
global _python_is_exit | |||||
batch.init_iter() | batch.init_iter() | ||||
# print('start fetch') | # print('start fetch') | ||||
while 1: | while 1: | ||||
res = batch.fetch_one() | res = batch.fetch_one() | ||||
# print('fetch one') | # print('fetch one') | ||||
q.put(res) | |||||
while 1: | |||||
try: | |||||
q.put(res, timeout=3) | |||||
break | |||||
except Exception as e: | |||||
if _python_is_exit: | |||||
return | |||||
if res is None: | if res is None: | ||||
# print('fetch done, waiting processing') | # print('fetch done, waiting processing') | ||||
q.join() | q.join() | ||||
@@ -15,45 +15,57 @@ class Callback(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self.trainer = None # 在Trainer内部被重新赋值 | |||||
self._trainer = None # 在Trainer内部被重新赋值 | |||||
# callback只读属性 | |||||
self._n_epochs = None | |||||
self._n_steps = None | |||||
self._batch_size = None | |||||
self._model = None | |||||
self._pbar = None | |||||
self._optimizer = None | |||||
@property | |||||
def trainer(self): | |||||
return self._trainer | |||||
@property | @property | ||||
def n_epochs(self): | |||||
return self._n_epochs | |||||
def step(self): | |||||
"""current step number, in range(1, self.n_steps+1)""" | |||||
return self._trainer.step | |||||
@property | @property | ||||
def n_steps(self): | def n_steps(self): | ||||
return self._n_steps | |||||
"""total number of steps for training""" | |||||
return self._trainer.n_steps | |||||
@property | @property | ||||
def batch_size(self): | def batch_size(self): | ||||
return self._batch_size | |||||
"""batch size for training""" | |||||
return self._trainer.batch_size | |||||
@property | @property | ||||
def model(self): | |||||
return self._model | |||||
def epoch(self): | |||||
"""current epoch number, in range(1, self.n_epochs+1)""" | |||||
return self._trainer.epoch | |||||
@property | @property | ||||
def pbar(self): | |||||
return self._pbar | |||||
def n_epochs(self): | |||||
"""total number of epochs""" | |||||
return self._trainer.n_epochs | |||||
@property | @property | ||||
def optimizer(self): | def optimizer(self): | ||||
return self._optimizer | |||||
"""torch.optim.Optimizer for current model""" | |||||
return self._trainer.optimizer | |||||
@property | |||||
def model(self): | |||||
"""training model""" | |||||
return self._trainer.model | |||||
@property | |||||
def pbar(self): | |||||
"""If use_tqdm, return trainer's tqdm print bar, else return None.""" | |||||
return self._trainer.pbar | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
# before the main training loop | # before the main training loop | ||||
pass | pass | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
# at the beginning of each epoch | # at the beginning of each epoch | ||||
pass | pass | ||||
@@ -65,14 +77,14 @@ class Callback(object): | |||||
# after data_forward, and before loss computation | # after data_forward, and before loss computation | ||||
pass | pass | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
# after loss computation, and before gradient backward | # after loss computation, and before gradient backward | ||||
pass | pass | ||||
def on_backward_end(self, model): | |||||
def on_backward_end(self): | |||||
pass | pass | ||||
def on_step_end(self, optimizer): | |||||
def on_step_end(self): | |||||
pass | pass | ||||
def on_batch_end(self, *args): | def on_batch_end(self, *args): | ||||
@@ -82,50 +94,40 @@ class Callback(object): | |||||
def on_valid_begin(self): | def on_valid_begin(self): | ||||
pass | pass | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
""" | """ | ||||
每次执行验证机的evaluation后会调用。传入eval_result | 每次执行验证机的evaluation后会调用。传入eval_result | ||||
:param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | :param eval_result: Dict[str: Dict[str: float]], evaluation的结果 | ||||
:param metric_key: str | :param metric_key: str | ||||
:param optimizer: | |||||
:param optimizer: optimizer passed to trainer | |||||
:param is_better_eval: bool, 当前dev结果是否比之前的好 | |||||
:return: | :return: | ||||
""" | """ | ||||
pass | pass | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self): | |||||
""" | """ | ||||
每个epoch结束将会调用该方法 | 每个epoch结束将会调用该方法 | ||||
:param cur_epoch: int, 当前的batch。从1开始。 | |||||
:param n_epoch: int, 总的batch数 | |||||
:param optimizer: 传入Trainer的optimizer。 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
""" | """ | ||||
训练结束,调用该方法 | 训练结束,调用该方法 | ||||
:param model: nn.Module, 传入Trainer的模型 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
""" | """ | ||||
当训练过程出现异常,会触发该方法 | 当训练过程出现异常,会触发该方法 | ||||
:param exception: 某种类型的Exception,比如KeyboardInterrupt等 | :param exception: 某种类型的Exception,比如KeyboardInterrupt等 | ||||
:param model: 传入Trainer的模型 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def transfer(func): | def transfer(func): | ||||
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. | """装饰器,将对CallbackManager的调用转发到各个Callback子类. | ||||
:param func: | :param func: | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -145,12 +147,11 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
def __init__(self, env, attr, callbacks=None): | |||||
def __init__(self, env, callbacks=None): | |||||
""" | """ | ||||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | ||||
:param dict attr: read-only attributes for all callbacks | |||||
:param Callback callbacks: | |||||
:param List[Callback] callbacks: | |||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
@@ -168,27 +169,14 @@ class CallbackManager(Callback): | |||||
for env_name, env_val in env.items(): | for env_name, env_val in env.items(): | ||||
for callback in self.callbacks: | for callback in self.callbacks: | ||||
setattr(callback, env_name, env_val) # Callback.trainer | |||||
self.set_property(**attr) | |||||
def set_property(self, **kwargs): | |||||
"""设置所有callback的只读属性 | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
for callback in self.callbacks: | |||||
for k, v in kwargs.items(): | |||||
setattr(callback, "_" + k, v) | |||||
setattr(callback, '_'+env_name, env_val) # Callback.trainer | |||||
@transfer | @transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
@@ -200,15 +188,15 @@ class CallbackManager(Callback): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_backward_end(self, model): | |||||
def on_backward_end(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_step_end(self, optimizer): | |||||
def on_step_end(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
@@ -220,19 +208,19 @@ class CallbackManager(Callback): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
pass | pass | ||||
@@ -240,15 +228,15 @@ class DummyCallback(Callback): | |||||
def on_train_begin(self, *arg): | def on_train_begin(self, *arg): | ||||
print(arg) | print(arg) | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
print(cur_epoch, n_epoch, optimizer) | |||||
def on_epoch_end(self): | |||||
print(self.epoch, self.n_epochs) | |||||
class EchoCallback(Callback): | class EchoCallback(Callback): | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
print("before_train") | print("before_train") | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
print("before_epoch") | print("before_epoch") | ||||
def on_batch_begin(self, batch_x, batch_y, indices): | def on_batch_begin(self, batch_x, batch_y, indices): | ||||
@@ -257,16 +245,16 @@ class EchoCallback(Callback): | |||||
def on_loss_begin(self, batch_y, predict_y): | def on_loss_begin(self, batch_y, predict_y): | ||||
print("before_loss") | print("before_loss") | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
print("before_backward") | print("before_backward") | ||||
def on_batch_end(self): | def on_batch_end(self): | ||||
print("after_batch") | print("after_batch") | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
def on_epoch_end(self): | |||||
print("after_epoch") | print("after_epoch") | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
print("after_train") | print("after_train") | ||||
@@ -294,9 +282,9 @@ class GradientClipCallback(Callback): | |||||
self.parameters = parameters | self.parameters = parameters | ||||
self.clip_value = clip_value | self.clip_value = clip_value | ||||
def on_backward_end(self, model): | |||||
def on_backward_end(self): | |||||
if self.parameters is None: | if self.parameters is None: | ||||
self.clip_fun(model.parameters(), self.clip_value) | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | |||||
else: | else: | ||||
self.clip_fun(self.parameters, self.clip_value) | self.clip_fun(self.parameters, self.clip_value) | ||||
@@ -318,14 +306,11 @@ class EarlyStopCallback(Callback): | |||||
:param int patience: 停止之前等待的epoch数 | :param int patience: 停止之前等待的epoch数 | ||||
""" | """ | ||||
super(EarlyStopCallback, self).__init__() | super(EarlyStopCallback, self).__init__() | ||||
self.trainer = None # override by CallbackManager | |||||
self.patience = patience | self.patience = patience | ||||
self.wait = 0 | self.wait = 0 | ||||
self.epoch = 0 | |||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
self.epoch += 1 | |||||
if not self.trainer._better_eval_result(eval_result): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if not is_better_eval: | |||||
# current result is getting worse | # current result is getting worse | ||||
if self.wait == self.patience: | if self.wait == self.patience: | ||||
raise EarlyStopError("Early stopping raised.") | raise EarlyStopError("Early stopping raised.") | ||||
@@ -334,7 +319,7 @@ class EarlyStopCallback(Callback): | |||||
else: | else: | ||||
self.wait = 0 | self.wait = 0 | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
if isinstance(exception, EarlyStopError): | if isinstance(exception, EarlyStopError): | ||||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | print("Early Stopping triggered in epoch {}!".format(self.epoch)) | ||||
else: | else: | ||||
@@ -354,7 +339,7 @@ class LRScheduler(Callback): | |||||
else: | else: | ||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
self.scheduler.step() | self.scheduler.step() | ||||
@@ -369,7 +354,7 @@ class ControlC(Callback): | |||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | ||||
self.quit_all = quit_all | self.quit_all = quit_all | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
if isinstance(exception, KeyboardInterrupt): | if isinstance(exception, KeyboardInterrupt): | ||||
if self.quit_all is True: | if self.quit_all is True: | ||||
import sys | import sys | ||||
@@ -415,15 +400,15 @@ class LRFinder(Callback): | |||||
self.find = None | self.find = None | ||||
self.loader = ModelLoader() | self.loader = ModelLoader() | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
if cur_epoch == 1: | |||||
def on_epoch_begin(self): | |||||
if self.epoch == 1: # first epoch | |||||
self.opt = self.trainer.optimizer # pytorch optimizer | self.opt = self.trainer.optimizer # pytorch optimizer | ||||
self.opt.param_groups[0]["lr"] = self.start_lr | self.opt.param_groups[0]["lr"] = self.start_lr | ||||
# save model | # save model | ||||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | ||||
self.find = True | self.find = True | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
if self.find: | if self.find: | ||||
if torch.isnan(loss) or self.stop is True: | if torch.isnan(loss) or self.stop is True: | ||||
self.stop = True | self.stop = True | ||||
@@ -444,8 +429,8 @@ class LRFinder(Callback): | |||||
self.opt.param_groups[0]["lr"] = lr | self.opt.param_groups[0]["lr"] = lr | ||||
# self.loader.load_pytorch(self.trainer.model, "tmp") | # self.loader.load_pytorch(self.trainer.model, "tmp") | ||||
def on_epoch_end(self, cur_epoch, n_epoch, optimizer): | |||||
if cur_epoch == 1: | |||||
def on_epoch_end(self): | |||||
if self.epoch == 1: # first epoch | |||||
self.opt.param_groups[0]["lr"] = self.best_lr | self.opt.param_groups[0]["lr"] = self.best_lr | ||||
self.find = False | self.find = False | ||||
# reset model | # reset model | ||||
@@ -489,7 +474,7 @@ class TensorboardCallback(Callback): | |||||
# self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) | ||||
self.graph_added = True | self.graph_added = True | ||||
def on_backward_begin(self, loss, model): | |||||
def on_backward_begin(self, loss): | |||||
if "loss" in self.options: | if "loss" in self.options: | ||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) | ||||
@@ -501,18 +486,18 @@ class TensorboardCallback(Callback): | |||||
self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_valid_end(self, eval_result, metric_key, optimizer): | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
if "metric" in self.options: | if "metric" in self.options: | ||||
for name, metric in eval_result.items(): | for name, metric in eval_result.items(): | ||||
for metric_key, metric_val in metric.items(): | for metric_key, metric_val in metric.items(): | ||||
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, | ||||
global_step=self.trainer.step) | global_step=self.trainer.step) | ||||
def on_train_end(self, model): | |||||
def on_train_end(self): | |||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
def on_exception(self, exception, model): | |||||
def on_exception(self, exception): | |||||
if hasattr(self, "_summary_writer"): | if hasattr(self, "_summary_writer"): | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
@@ -520,5 +505,5 @@ class TensorboardCallback(Callback): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) | ||||
manager.on_train_begin(10, 11, 12) | |||||
manager.on_train_begin() | |||||
# print(manager.after_epoch()) | # print(manager.after_epoch()) |
@@ -90,7 +90,7 @@ class DataSet(object): | |||||
data_set = DataSet() | data_set = DataSet() | ||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, | ||||
is_input=field.is_input, is_target=field.is_target) | |||||
is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) | |||||
return data_set | return data_set | ||||
elif isinstance(idx, str): | elif isinstance(idx, str): | ||||
if idx not in self: | if idx not in self: | ||||
@@ -313,16 +313,23 @@ class DataSet(object): | |||||
else: | else: | ||||
return results | return results | ||||
def drop(self, func): | |||||
def drop(self, func, inplace=True): | |||||
"""Drop instances if a condition holds. | """Drop instances if a condition holds. | ||||
:param func: a function that takes an Instance object as input, and returns bool. | :param func: a function that takes an Instance object as input, and returns bool. | ||||
The instance will be dropped if the function returns True. | The instance will be dropped if the function returns True. | ||||
:param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. | |||||
""" | """ | ||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
if inplace: | |||||
results = [ins for ins in self._inner_iter() if not func(ins)] | |||||
for name, old_field in self.field_arrays.items(): | |||||
self.field_arrays[name].content = [ins[name] for ins in results] | |||||
else: | |||||
results = [ins for ins in self if not func(ins)] | |||||
data = DataSet(results) | |||||
for field_name, field in self.field_arrays.items(): | |||||
data.field_arrays[field_name].to(field) | |||||
def split(self, dev_ratio): | def split(self, dev_ratio): | ||||
"""Split the dataset into training and development(validation) set. | """Split the dataset into training and development(validation) set. | ||||
@@ -346,19 +353,8 @@ class DataSet(object): | |||||
for idx in train_indices: | for idx in train_indices: | ||||
train_set.append(self[idx]) | train_set.append(self[idx]) | ||||
for field_name in self.field_arrays: | for field_name in self.field_arrays: | ||||
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||||
dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||||
dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||||
dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||||
dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||||
return train_set, dev_set | return train_set, dev_set | ||||
@@ -383,6 +383,23 @@ class FieldArray(object): | |||||
""" | """ | ||||
return len(self.content) | return len(self.content) | ||||
def to(self, other): | |||||
""" | |||||
将other的属性复制给本fieldarray(必须通过fieldarray类型). 包含 is_input, is_target, padder, dtype, pytype, content_dim | |||||
ignore_type | |||||
:param other: FieldArray | |||||
:return: | |||||
""" | |||||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | |||||
self.is_input = other.is_input | |||||
self.is_target = other.is_target | |||||
self.padder = other.padder | |||||
self.dtype = other.dtype | |||||
self.pytype = other.pytype | |||||
self.content_dim = other.content_dim | |||||
self.ignore_type = other.ignore_type | |||||
def is_iterable(content): | def is_iterable(content): | ||||
try: | try: | ||||
@@ -91,7 +91,6 @@ class MetricBase(object): | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
However, in some cases where type check is not necessary, ``_fast_param_map`` will be used. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
@@ -146,21 +145,6 @@ class MetricBase(object): | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | |||||
fast_param['target'] = list(target_dict.values())[0] | |||||
return fast_param | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict): | def __call__(self, pred_dict, target_dict): | ||||
""" | """ | ||||
@@ -172,7 +156,6 @@ class MetricBase(object): | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted.) | will be conducted.) | ||||
This function also support _fast_param_map. | |||||
:param pred_dict: usually the output of forward or prediction function | :param pred_dict: usually the output of forward or prediction function | ||||
:param target_dict: usually features set as target.. | :param target_dict: usually features set as target.. | ||||
:return: | :return: | ||||
@@ -180,11 +163,6 @@ class MetricBase(object): | |||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param: | |||||
self.evaluate(**fast_param) | |||||
return | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
@@ -262,50 +240,14 @@ class AccuracyMetric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||||
such as pred_dict has one element, target_dict has one element | |||||
:param pred_dict: | |||||
:param target_dict: | |||||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | |||||
fast_param = {} | |||||
targets = list(target_dict.values()) | |||||
if len(targets) == 1 and isinstance(targets[0], torch.Tensor): | |||||
if len(pred_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | |||||
fast_param['pred'] = pred | |||||
elif len(pred_dict) == 2: | |||||
pred1 = list(pred_dict.values())[0] | |||||
pred2 = list(pred_dict.values())[1] | |||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||||
return fast_param | |||||
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1: | |||||
seq_lens = pred1 | |||||
pred = pred2 | |||||
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1: | |||||
seq_lens = pred2 | |||||
pred = pred1 | |||||
else: | |||||
return fast_param | |||||
fast_param['pred'] = pred | |||||
fast_param['seq_lens'] = seq_lens | |||||
else: | |||||
return fast_param | |||||
fast_param['target'] = targets[0] | |||||
# TODO need to make sure they all have same batch_size | |||||
return fast_param | |||||
def evaluate(self, pred, target, seq_lens=None): | def evaluate(self, pred, target, seq_lens=None): | ||||
""" | """ | ||||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | |||||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), | |||||
torch.Size([B, max_len, n_classes]) | |||||
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), | |||||
torch.Size([B, max_len]) | |||||
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
""" | """ | ||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | ||||
@@ -321,7 +263,7 @@ class AccuracyMetric(MetricBase): | |||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
if seq_lens is not None: | if seq_lens is not None: | ||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens) | |||||
else: | else: | ||||
masks = None | masks = None | ||||
@@ -334,14 +276,12 @@ class AccuracyMetric(MetricBase): | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
pred = pred.float() | |||||
target = target.float() | |||||
target = target.to(pred) | |||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() | |||||
self.total += torch.sum(masks.float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item() | |||||
self.total += torch.sum(masks).item() | |||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target).float()).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | |||||
self.total += np.prod(list(pred.size())) | self.total += np.prod(list(pred.size())) | ||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
@@ -350,7 +290,7 @@ class AccuracyMetric(MetricBase): | |||||
:param bool reset: whether to recount next time. | :param bool reset: whether to recount next time. | ||||
:return evaluate_result: {"acc": float} | :return evaluate_result: {"acc": float} | ||||
""" | """ | ||||
evaluate_result = {'acc': round(self.acc_count / self.total, 6)} | |||||
evaluate_result = {'acc': round(float(self.acc_count) / (self.total + 1e-12), 6)} | |||||
if reset: | if reset: | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
self.total = 0 | self.total = 0 | ||||
@@ -441,8 +381,7 @@ def bio_tag_to_spans(tags, ignore_labels=None): | |||||
prev_bio_tag = bio_tag | prev_bio_tag = bio_tag | ||||
return [(span[0], (span[1][0], span[1][1]+1)) | return [(span[0], (span[1][0], span[1][1]+1)) | ||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | |||||
] | |||||
if span[0] not in ignore_labels] | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
@@ -34,7 +34,7 @@ class Trainer(object): | |||||
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | ||||
validate_every=-1, dev_data=None, save_path=None, optimizer=None, | validate_every=-1, dev_data=None, save_path=None, optimizer=None, | ||||
check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | check_code_level=0, metric_key=None, sampler=None, prefetch=False, use_tqdm=True, | ||||
use_cuda=False, callbacks=None): | |||||
use_cuda=False, callbacks=None, update_every=1): | |||||
""" | """ | ||||
:param DataSet train_data: the training data | :param DataSet train_data: the training data | ||||
:param torch.nn.modules.module model: a PyTorch model | :param torch.nn.modules.module model: a PyTorch model | ||||
@@ -62,6 +62,8 @@ class Trainer(object): | |||||
:param bool use_tqdm: whether to use tqdm to show train progress. | :param bool use_tqdm: whether to use tqdm to show train progress. | ||||
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | ||||
通过callback机制实现。 | 通过callback机制实现。 | ||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128会导致内存 | |||||
不足,通过设置batch_size=32, update_every=4达到目的 | |||||
""" | """ | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -76,6 +78,10 @@ class Trainer(object): | |||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check update every | |||||
assert update_every>=1, "update_every must be no less than 1." | |||||
self.update_every = int(update_every) | |||||
# check save_path | # check save_path | ||||
if not (save_path is None or isinstance(save_path, str)): | if not (save_path is None or isinstance(save_path, str)): | ||||
raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
@@ -121,6 +127,9 @@ class Trainer(object): | |||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler if sampler is not None else RandomSampler() | self.sampler = sampler if sampler is not None else RandomSampler() | ||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
@@ -130,6 +139,7 @@ class Trainer(object): | |||||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | |||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
@@ -144,11 +154,9 @@ class Trainer(object): | |||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
attr={"n_epochs": self.n_epochs, "n_steps": self.step, | |||||
"batch_size": self.batch_size, "model": self.model, | |||||
"optimizer": self.optimizer}, | |||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True): | def train(self, load_best_model=True): | ||||
""" | """ | ||||
@@ -205,9 +213,9 @@ class Trainer(object): | |||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end(self.model) | |||||
self.callback_manager.on_train_end() | |||||
except (CallbackException, KeyboardInterrupt) as e: | except (CallbackException, KeyboardInterrupt) as e: | ||||
self.callback_manager.on_exception(e, self.model) | |||||
self.callback_manager.on_exception(e) | |||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | ||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
@@ -234,19 +242,21 @@ class Trainer(object): | |||||
else: | else: | ||||
inner_tqdm = tqdm | inner_tqdm = tqdm | ||||
self.step = 0 | self.step = 0 | ||||
self.epoch = 0 | |||||
start = time.time() | start = time.time() | ||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
self.pbar = pbar if isinstance(pbar, tqdm) else None | |||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | ||||
prefetch=self.prefetch) | prefetch=self.prefetch) | ||||
self.callback_manager.set_property(pbar=pbar) | |||||
for epoch in range(1, self.n_epochs+1): | for epoch in range(1, self.n_epochs+1): | ||||
self.epoch = epoch | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | ||||
# early stopping | # early stopping | ||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||||
self.callback_manager.on_epoch_begin() | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
self.step += 1 | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | ||||
indices = data_iterator.get_batch_indices() | indices = data_iterator.get_batch_indices() | ||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
@@ -257,18 +267,20 @@ class Trainer(object): | |||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y) | loss = self._compute_loss(prediction, batch_y) | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
loss = loss/self.update_every | |||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss, self.model) | |||||
self.callback_manager.on_backward_begin(loss) | |||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end(self.model) | |||||
self.callback_manager.on_backward_end() | |||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end(self.optimizer) | |||||
self.callback_manager.on_step_end() | |||||
if (self.step+1) % self.print_every == 0: | if (self.step+1) % self.print_every == 0: | ||||
avg_loss = avg_loss / self.print_every | |||||
if self.use_tqdm: | if self.use_tqdm: | ||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss) | |||||
pbar.update(self.print_every) | pbar.update(self.print_every) | ||||
else: | else: | ||||
end = time.time() | end = time.time() | ||||
@@ -277,7 +289,6 @@ class Trainer(object): | |||||
epoch, self.step, avg_loss, diff) | epoch, self.step, avg_loss, diff) | ||||
pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.step += 1 | |||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
@@ -285,22 +296,24 @@ class Trainer(object): | |||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | eval_res = self._do_validation(epoch=epoch, step=self.step) | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
self.n_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str + '\n') | |||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||||
self.callback_manager.on_epoch_end() | |||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
self.pbar = None | |||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
res = self.tester.test() | res = self.tester.test() | ||||
is_better_eval = False | |||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
self._save_model(self.model, | self._save_model(self.model, | ||||
@@ -310,8 +323,9 @@ class Trainer(object): | |||||
self.best_dev_perf = res | self.best_dev_perf = res | ||||
self.best_dev_epoch = epoch | self.best_dev_epoch = epoch | ||||
self.best_dev_step = step | self.best_dev_step = step | ||||
is_better_eval = True | |||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) | |||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -330,7 +344,8 @@ class Trainer(object): | |||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
self.optimizer.step() | |||||
if (self.step+1)%self.update_every==0: | |||||
self.optimizer.step() | |||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
@@ -346,7 +361,8 @@ class Trainer(object): | |||||
For PyTorch, just do "loss.backward()" | For PyTorch, just do "loss.backward()" | ||||
""" | """ | ||||
self.model.zero_grad() | |||||
if self.step%self.update_every==0: | |||||
self.model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
import json | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -64,6 +65,53 @@ def convert_seq2seq_dataset(data): | |||||
return dataset | return dataset | ||||
def download_from_url(url, path): | |||||
from tqdm import tqdm | |||||
import requests | |||||
"""Download file""" | |||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||||
chunk_size = 16 * 1024 | |||||
total_size = int(r.headers.get('Content-length', 0)) | |||||
with open(path, "wb") as file ,\ | |||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||||
for chunk in r.iter_content(chunk_size): | |||||
if chunk: | |||||
file.write(chunk) | |||||
t.update(len(chunk)) | |||||
return | |||||
def uncompress(src, dst): | |||||
import zipfile, gzip, tarfile, os | |||||
def unzip(src, dst): | |||||
with zipfile.ZipFile(src, 'r') as f: | |||||
f.extractall(dst) | |||||
def ungz(src, dst): | |||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||||
length = 16 * 1024 # 16KB | |||||
buf = f.read(length) | |||||
while buf: | |||||
uf.write(buf) | |||||
buf = f.read(length) | |||||
def untar(src, dst): | |||||
with tarfile.open(src, 'r:gz') as f: | |||||
f.extractall(dst) | |||||
fn, ext = os.path.splitext(src) | |||||
_, ext_2 = os.path.splitext(fn) | |||||
if ext == '.zip': | |||||
unzip(src, dst) | |||||
elif ext == '.gz' and ext_2 != '.tar': | |||||
ungz(src, dst) | |||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||||
untar(src, dst) | |||||
else: | |||||
raise ValueError('unsupported file {}'.format(src)) | |||||
class DataSetLoader: | class DataSetLoader: | ||||
"""Interface for all DataSetLoaders. | """Interface for all DataSetLoaders. | ||||
@@ -290,41 +338,6 @@ class DummyClassificationReader(DataSetLoader): | |||||
return convert_seq2tag_dataset(data) | return convert_seq2tag_dataset(data) | ||||
class ConllLoader(DataSetLoader): | |||||
"""loader for conll format files""" | |||||
def __init__(self): | |||||
super(ConllLoader, self).__init__() | |||||
def load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
data = self.parse(lines) | |||||
return self.convert(data) | |||||
@staticmethod | |||||
def parse(lines): | |||||
""" | |||||
:param list lines: a list containing all lines in a conll file. | |||||
:return: a 3D list | |||||
""" | |||||
sentences = list() | |||||
tokens = list() | |||||
for line in lines: | |||||
if line[0] == "#": | |||||
# skip the comments | |||||
continue | |||||
if line == "\n": | |||||
sentences.append(tokens) | |||||
tokens = [] | |||||
continue | |||||
tokens.append(line.split()) | |||||
return sentences | |||||
def convert(self, data): | |||||
pass | |||||
class DummyLMReader(DataSetLoader): | class DummyLMReader(DataSetLoader): | ||||
"""A Dummy Language Model Dataset Reader | """A Dummy Language Model Dataset Reader | ||||
""" | """ | ||||
@@ -434,51 +447,67 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
return data_set | return data_set | ||||
class Conll2003Loader(DataSetLoader): | |||||
class ConllLoader: | |||||
def __init__(self, headers, indexs=None): | |||||
self.headers = headers | |||||
if indexs is None: | |||||
self.indexs = list(range(len(self.headers))) | |||||
else: | |||||
if len(indexs) != len(headers): | |||||
raise ValueError | |||||
self.indexs = indexs | |||||
def load(self, path): | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
start = next(f) | |||||
if '-DOCSTART-' not in start: | |||||
sample.append(start.split()) | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
if len(sample): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
data = filter(lambda x: x is not None, data) | |||||
ds = DataSet() | |||||
for sample in data: | |||||
ins = Instance() | |||||
for name, idx in zip(self.headers, self.indexs): | |||||
ins.add_field(field_name=name, field=sample[idx]) | |||||
ds.append(ins) | |||||
return ds | |||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
for field in sample: | |||||
if len(field) <= 0: | |||||
return None | |||||
return sample | |||||
class Conll2003Loader(ConllLoader): | |||||
"""Loader for conll2003 dataset | """Loader for conll2003 dataset | ||||
More information about the given dataset cound be found on | More information about the given dataset cound be found on | ||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | ||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(Conll2003Loader, self).__init__() | |||||
def load(self, dataset_path): | |||||
with open(dataset_path, "r", encoding="utf-8") as f: | |||||
lines = f.readlines() | |||||
parsed_data = [] | |||||
sentence = [] | |||||
tokens = [] | |||||
for line in lines: | |||||
if '-DOCSTART- -X- -X- O' in line or line == '\n': | |||||
if sentence != []: | |||||
parsed_data.append((sentence, tokens)) | |||||
sentence = [] | |||||
tokens = [] | |||||
continue | |||||
temp = line.strip().split(" ") | |||||
sentence.append(temp[0]) | |||||
tokens.append(temp[1:4]) | |||||
return self.convert(parsed_data) | |||||
def convert(self, parsed_data): | |||||
dataset = DataSet() | |||||
for sample in parsed_data: | |||||
label0_list = list(map( | |||||
lambda labels: labels[0], sample[1])) | |||||
label1_list = list(map( | |||||
lambda labels: labels[1], sample[1])) | |||||
label2_list = list(map( | |||||
lambda labels: labels[2], sample[1])) | |||||
dataset.append(Instance(tokens=sample[0], | |||||
pos=label0_list, | |||||
chucks=label1_list, | |||||
ner=label2_list)) | |||||
return dataset | |||||
headers = [ | |||||
'tokens', 'pos', 'chunks', 'ner', | |||||
] | |||||
super(Conll2003Loader, self).__init__(headers=headers) | |||||
class SNLIDataSetReader(DataSetLoader): | class SNLIDataSetReader(DataSetLoader): | ||||
@@ -548,6 +577,7 @@ class SNLIDataSetReader(DataSetLoader): | |||||
class ConllCWSReader(object): | class ConllCWSReader(object): | ||||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -700,6 +730,7 @@ def cut_long_sentence(sent, max_sample_length=200): | |||||
class ZhConllPOSReader(object): | class ZhConllPOSReader(object): | ||||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | """读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | ||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
pass | pass | ||||
@@ -778,10 +809,35 @@ class ZhConllPOSReader(object): | |||||
return text, pos_tags | return text, pos_tags | ||||
class ConllxDataLoader(object): | |||||
class ConllxDataLoader(ConllLoader): | |||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | ||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'words', 'pos_tags', 'heads', 'labels', | |||||
] | |||||
indexs = [ | |||||
1, 3, 6, 7, | |||||
] | |||||
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) | |||||
class SSTLoader(DataSetLoader): | |||||
"""load SST data in PTB tree format | |||||
data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
""" | """ | ||||
def __init__(self, subtree=False, fine_grained=False): | |||||
self.subtree = subtree | |||||
tag_v = {'0':'very negative', '1':'negative', '2':'neutral', | |||||
'3':'positive', '4':'very positive'} | |||||
if not fine_grained: | |||||
tag_v['0'] = tag_v['1'] | |||||
tag_v['4'] = tag_v['3'] | |||||
self.tag_v = tag_v | |||||
def load(self, path): | def load(self, path): | ||||
""" | """ | ||||
@@ -793,40 +849,47 @@ class ConllxDataLoader(object): | |||||
""" | """ | ||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
datas = [] | |||||
for l in f: | |||||
datas.extend([(s, self.tag_v[t]) | |||||
for s, t in self.get_one(l, self.subtree)]) | |||||
ds = DataSet() | |||||
for words, tag in datas: | |||||
ds.append(Instance(words=words, raw_tag=tag)) | |||||
return ds | |||||
data = [self.get_one(sample) for sample in datalist] | |||||
data_list = list(filter(lambda x: x is not None, data)) | |||||
@staticmethod | |||||
def get_one(data, subtree): | |||||
from nltk.tree import Tree | |||||
tree = Tree.fromstring(data) | |||||
if subtree: | |||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
return [(tree.leaves(), tree.label())] | |||||
class JsonLoader(DataSetLoader): | |||||
"""Load json-format data, | |||||
every line contains a json obj, like a dict | |||||
fields is the dict key that need to be load | |||||
""" | |||||
def __init__(self, **fields): | |||||
super(JsonLoader, self).__init__() | |||||
self.fields = {} | |||||
for k, v in fields.items(): | |||||
self.fields[k] = k if v is None else v | |||||
def load(self, path): | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
datas = [json.loads(l) for l in f] | |||||
ds = DataSet() | ds = DataSet() | ||||
for example in data_list: | |||||
ds.append(Instance(words=example[0], | |||||
pos_tags=example[1], | |||||
heads=example[2], | |||||
labels=example[3])) | |||||
for d in datas: | |||||
ins = Instance() | |||||
for k, v in d.items(): | |||||
if k in self.fields: | |||||
ins.add_field(self.fields[k], v) | |||||
ds.append(ins) | |||||
return ds | return ds | ||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
def add_seg_tag(data): | def add_seg_tag(data): | ||||
""" | """ | ||||
@@ -848,3 +911,4 @@ def add_seg_tag(data): | |||||
new_sample.append((word[-1], 'E-' + pos)) | new_sample.append((word[-1], 'E-' + pos)) | ||||
_processed.append(list(map(list, zip(*new_sample)))) | _processed.append(list(map(list, zip(*new_sample)))) | ||||
return _processed | return _processed | ||||
@@ -0,0 +1,223 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""A module with NAS controller-related code.""" | |||||
import collections | |||||
import os | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import fastNLP | |||||
import fastNLP.models.enas_utils as utils | |||||
from fastNLP.models.enas_utils import Node | |||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): | |||||
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and | |||||
activation functions, sampled from the controller/policy pi. | |||||
Args: | |||||
prev_nodes: Previous node actions from the policy. | |||||
activations: Activations sampled from the policy. | |||||
func_names: Mapping from activation function names to functions. | |||||
num_blocks: Number of blocks in the target RNN cell. | |||||
Returns: | |||||
A list of DAGs defined by the inputs. | |||||
RNN cell DAGs are represented in the following way: | |||||
1. Each element (node) in a DAG is a list of `Node`s. | |||||
2. The `Node`s in the list dag[i] correspond to the subsequent nodes | |||||
that take the output from node i as their own input. | |||||
3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. | |||||
dag[-1] always feeds dag[0]. | |||||
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its | |||||
weights. | |||||
4. dag[N - 1] is the node that produces the hidden state passed to | |||||
the next timestep. dag[N - 1] is also always a leaf node, and therefore | |||||
is always averaged with the other leaf nodes and fed to the output | |||||
decoder. | |||||
""" | |||||
dags = [] | |||||
for nodes, func_ids in zip(prev_nodes, activations): | |||||
dag = collections.defaultdict(list) | |||||
# add first node | |||||
dag[-1] = [Node(0, func_names[func_ids[0]])] | |||||
dag[-2] = [Node(0, func_names[func_ids[0]])] | |||||
# add following nodes | |||||
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): | |||||
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) | |||||
leaf_nodes = set(range(num_blocks)) - dag.keys() | |||||
# merge with avg | |||||
for idx in leaf_nodes: | |||||
dag[idx] = [Node(num_blocks, 'avg')] | |||||
# This is actually y^{(t)}. h^{(t)} is node N - 1 in | |||||
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes | |||||
# only one other node as its input. | |||||
# last h[t] node | |||||
last_node = Node(num_blocks + 1, 'h[t]') | |||||
dag[num_blocks] = [last_node] | |||||
dags.append(dag) | |||||
return dags | |||||
class Controller(torch.nn.Module): | |||||
"""Based on | |||||
https://github.com/pytorch/examples/blob/master/word_language_model/model.py | |||||
RL controllers do not necessarily have much to do with | |||||
language models. | |||||
Base the controller RNN on the GRU from: | |||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py | |||||
""" | |||||
def __init__(self, num_blocks=4, controller_hid=100, cuda=False): | |||||
torch.nn.Module.__init__(self) | |||||
# `num_tokens` here is just the activation function | |||||
# for every even step, | |||||
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] | |||||
self.num_tokens = [len(self.shared_rnn_activations)] | |||||
self.controller_hid = controller_hid | |||||
self.use_cuda = cuda | |||||
self.num_blocks = num_blocks | |||||
for idx in range(num_blocks): | |||||
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] | |||||
self.func_names = self.shared_rnn_activations | |||||
num_total_tokens = sum(self.num_tokens) | |||||
self.encoder = torch.nn.Embedding(num_total_tokens, | |||||
controller_hid) | |||||
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) | |||||
# Perhaps these weights in the decoder should be | |||||
# shared? At least for the activation functions, which all have the | |||||
# same size. | |||||
self.decoders = [] | |||||
for idx, size in enumerate(self.num_tokens): | |||||
decoder = torch.nn.Linear(controller_hid, size) | |||||
self.decoders.append(decoder) | |||||
self._decoders = torch.nn.ModuleList(self.decoders) | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def _get_default_hidden(key): | |||||
return utils.get_variable( | |||||
torch.zeros(key, self.controller_hid), | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
self.static_inputs = utils.keydefaultdict(_get_default_hidden) | |||||
def reset_parameters(self): | |||||
init_range = 0.1 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
for decoder in self.decoders: | |||||
decoder.bias.data.fill_(0) | |||||
def forward(self, # pylint:disable=arguments-differ | |||||
inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed): | |||||
if not is_embed: | |||||
embed = self.encoder(inputs) | |||||
else: | |||||
embed = inputs | |||||
hx, cx = self.lstm(embed, hidden) | |||||
logits = self.decoders[block_idx](hx) | |||||
logits /= 5.0 | |||||
# # exploration | |||||
# if self.args.mode == 'train': | |||||
# logits = (2.5 * F.tanh(logits)) | |||||
return logits, (hx, cx) | |||||
def sample(self, batch_size=1, with_details=False, save_dir=None): | |||||
"""Samples a set of `args.num_blocks` many computational nodes from the | |||||
controller, where each node is made up of an activation function, and | |||||
each node except the last also includes a previous node. | |||||
""" | |||||
if batch_size < 1: | |||||
raise Exception(f'Wrong batch_size: {batch_size} < 1') | |||||
# [B, L, H] | |||||
inputs = self.static_inputs[batch_size] | |||||
hidden = self.static_init_hidden[batch_size] | |||||
activations = [] | |||||
entropies = [] | |||||
log_probs = [] | |||||
prev_nodes = [] | |||||
# The RNN controller alternately outputs an activation, | |||||
# followed by a previous node, for each block except the last one, | |||||
# which only gets an activation function. The last node is the output | |||||
# node, and its previous node is the average of all leaf nodes. | |||||
for block_idx in range(2*(self.num_blocks - 1) + 1): | |||||
logits, hidden = self.forward(inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed=(block_idx == 0)) | |||||
probs = F.softmax(logits, dim=-1) | |||||
log_prob = F.log_softmax(logits, dim=-1) | |||||
# .mean() for entropy? | |||||
entropy = -(log_prob * probs).sum(1, keepdim=False) | |||||
action = probs.multinomial(num_samples=1).data | |||||
selected_log_prob = log_prob.gather( | |||||
1, utils.get_variable(action, requires_grad=False)) | |||||
# why the [:, 0] here? Should it be .squeeze(), or | |||||
# .view()? Same below with `action`. | |||||
entropies.append(entropy) | |||||
log_probs.append(selected_log_prob[:, 0]) | |||||
# 0: function, 1: previous node | |||||
mode = block_idx % 2 | |||||
inputs = utils.get_variable( | |||||
action[:, 0] + sum(self.num_tokens[:mode]), | |||||
requires_grad=False) | |||||
if mode == 0: | |||||
activations.append(action[:, 0]) | |||||
elif mode == 1: | |||||
prev_nodes.append(action[:, 0]) | |||||
prev_nodes = torch.stack(prev_nodes).transpose(0, 1) | |||||
activations = torch.stack(activations).transpose(0, 1) | |||||
dags = _construct_dags(prev_nodes, | |||||
activations, | |||||
self.func_names, | |||||
self.num_blocks) | |||||
if save_dir is not None: | |||||
for idx, dag in enumerate(dags): | |||||
utils.draw_network(dag, | |||||
os.path.join(save_dir, f'graph{idx}.png')) | |||||
if with_details: | |||||
return dags, torch.cat(log_probs), torch.cat(entropies) | |||||
return dags | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.controller_hid) | |||||
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), | |||||
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False)) |
@@ -0,0 +1,388 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""Module containing the shared RNN model.""" | |||||
import numpy as np | |||||
import collections | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from torch.autograd import Variable | |||||
import fastNLP.models.enas_utils as utils | |||||
from fastNLP.models.base_model import BaseModel | |||||
import fastNLP.modules.encoder as encoder | |||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | |||||
"""Drops out weights to implement DropConnect. | |||||
Args: | |||||
w_raw: Full, pre-dropout, weights to be dropped out. | |||||
dropout_p: Proportion of weights to drop out. | |||||
is_training: True iff _shared_ model is training. | |||||
Returns: | |||||
The dropped weights. | |||||
Why does torch.nn.functional.dropout() return: | |||||
1. `torch.autograd.Variable()` on the training loop | |||||
2. `torch.nn.Parameter()` on the controller or eval loop, when | |||||
training = False... | |||||
Even though the call to `_setweights` in the Smerity repo's | |||||
`weight_drop.py` does not have this behaviour, and `F.dropout` always | |||||
returns `torch.autograd.Variable` there, even when `training=False`? | |||||
The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | |||||
""" | |||||
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | |||||
if isinstance(dropped_w, torch.nn.Parameter): | |||||
dropped_w = dropped_w.clone() | |||||
return dropped_w | |||||
class EmbeddingDropout(torch.nn.Embedding): | |||||
"""Class for dropping out embeddings by zero'ing out parameters in the | |||||
embedding matrix. | |||||
This is equivalent to dropping out particular words, e.g., in the sentence | |||||
'the quick brown fox jumps over the lazy dog', dropping out 'the' would | |||||
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the | |||||
embedding vector space). | |||||
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | |||||
Networks', (Gal and Ghahramani, 2016). | |||||
""" | |||||
def __init__(self, | |||||
num_embeddings, | |||||
embedding_dim, | |||||
max_norm=None, | |||||
norm_type=2, | |||||
scale_grad_by_freq=False, | |||||
sparse=False, | |||||
dropout=0.1, | |||||
scale=None): | |||||
"""Embedding constructor. | |||||
Args: | |||||
dropout: Dropout probability. | |||||
scale: Used to scale parameters of embedding weight matrix that are | |||||
not dropped out. Note that this is _in addition_ to the | |||||
`1/(1 - dropout)` scaling. | |||||
See `torch.nn.Embedding` for remaining arguments. | |||||
""" | |||||
torch.nn.Embedding.__init__(self, | |||||
num_embeddings=num_embeddings, | |||||
embedding_dim=embedding_dim, | |||||
max_norm=max_norm, | |||||
norm_type=norm_type, | |||||
scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse) | |||||
self.dropout = dropout | |||||
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | |||||
'and < 1.0') | |||||
self.scale = scale | |||||
def forward(self, inputs): # pylint:disable=arguments-differ | |||||
"""Embeds `inputs` with the dropped out embedding weight matrix.""" | |||||
if self.training: | |||||
dropout = self.dropout | |||||
else: | |||||
dropout = 0 | |||||
if dropout: | |||||
mask = self.weight.data.new(self.weight.size(0), 1) | |||||
mask.bernoulli_(1 - dropout) | |||||
mask = mask.expand_as(self.weight) | |||||
mask = mask / (1 - dropout) | |||||
masked_weight = self.weight * Variable(mask) | |||||
else: | |||||
masked_weight = self.weight | |||||
if self.scale and self.scale != 1: | |||||
masked_weight = masked_weight * self.scale | |||||
return F.embedding(inputs, | |||||
masked_weight, | |||||
max_norm=self.max_norm, | |||||
norm_type=self.norm_type, | |||||
scale_grad_by_freq=self.scale_grad_by_freq, | |||||
sparse=self.sparse) | |||||
class LockedDropout(nn.Module): | |||||
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, x, dropout=0.5): | |||||
if not self.training or not dropout: | |||||
return x | |||||
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) | |||||
mask = Variable(m, requires_grad=False) / (1 - dropout) | |||||
mask = mask.expand_as(x) | |||||
return mask * x | |||||
class ENASModel(BaseModel): | |||||
"""Shared RNN model.""" | |||||
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | |||||
super(ENASModel, self).__init__() | |||||
self.use_cuda = cuda | |||||
self.shared_hid = shared_hid | |||||
self.num_blocks = num_blocks | |||||
self.decoder = nn.Linear(self.shared_hid, num_classes) | |||||
self.encoder = EmbeddingDropout(embed_num, | |||||
shared_embed, | |||||
dropout=0.1) | |||||
self.lockdrop = LockedDropout() | |||||
self.dag = None | |||||
# Tie weights | |||||
# self.decoder.weight = self.encoder.weight | |||||
# Since W^{x, c} and W^{h, c} are always summed, there | |||||
# is no point duplicating their bias offset parameter. Likewise for | |||||
# W^{x, h} and W^{h, h}. | |||||
self.w_xc = nn.Linear(shared_embed, self.shared_hid) | |||||
self.w_xh = nn.Linear(shared_embed, self.shared_hid) | |||||
# The raw weights are stored here because the hidden-to-hidden weights | |||||
# are weight dropped on the forward pass. | |||||
self.w_hc_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hh_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hc = None | |||||
self.w_hh = None | |||||
self.w_h = collections.defaultdict(dict) | |||||
self.w_c = collections.defaultdict(dict) | |||||
for idx in range(self.num_blocks): | |||||
for jdx in range(idx + 1, self.num_blocks): | |||||
self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self._w_h = nn.ModuleList([self.w_h[idx][jdx] | |||||
for idx in self.w_h | |||||
for jdx in self.w_h[idx]]) | |||||
self._w_c = nn.ModuleList([self.w_c[idx][jdx] | |||||
for idx in self.w_c | |||||
for jdx in self.w_c[idx]]) | |||||
self.batch_norm = None | |||||
# if args.mode == 'train': | |||||
# self.batch_norm = nn.BatchNorm1d(self.shared_hid) | |||||
# else: | |||||
# self.batch_norm = None | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def setDAG(self, dag): | |||||
if self.dag is None: | |||||
self.dag = dag | |||||
def forward(self, word_seq, hidden=None): | |||||
inputs = torch.transpose(word_seq, 0, 1) | |||||
time_steps = inputs.size(0) | |||||
batch_size = inputs.size(1) | |||||
self.w_hh = _get_dropped_weights(self.w_hh_raw, | |||||
0.5, | |||||
self.training) | |||||
self.w_hc = _get_dropped_weights(self.w_hc_raw, | |||||
0.5, | |||||
self.training) | |||||
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | |||||
hidden = self.static_init_hidden[batch_size] | |||||
embed = self.encoder(inputs) | |||||
embed = self.lockdrop(embed, 0.65 if self.training else 0) | |||||
# The norm of hidden states are clipped here because | |||||
# otherwise ENAS is especially prone to exploding activations on the | |||||
# forward pass. This could probably be fixed in a more elegant way, but | |||||
# it might be exposing a weakness in the ENAS algorithm as currently | |||||
# proposed. | |||||
# | |||||
# For more details, see | |||||
# https://github.com/carpedm20/ENAS-pytorch/issues/6 | |||||
clipped_num = 0 | |||||
max_clipped_norm = 0 | |||||
h1tohT = [] | |||||
logits = [] | |||||
for step in range(time_steps): | |||||
x_t = embed[step] | |||||
logit, hidden = self.cell(x_t, hidden, self.dag) | |||||
hidden_norms = hidden.norm(dim=-1) | |||||
max_norm = 25.0 | |||||
if hidden_norms.data.max() > max_norm: | |||||
# Just directly use the torch slice operations | |||||
# in PyTorch v0.4. | |||||
# | |||||
# This workaround for PyTorch v0.3.1 does everything in numpy, | |||||
# because the PyTorch slicing and slice assignment is too | |||||
# flaky. | |||||
hidden_norms = hidden_norms.data.cpu().numpy() | |||||
clipped_num += 1 | |||||
if hidden_norms.max() > max_clipped_norm: | |||||
max_clipped_norm = hidden_norms.max() | |||||
clip_select = hidden_norms > max_norm | |||||
clip_norms = hidden_norms[clip_select] | |||||
mask = np.ones(hidden.size()) | |||||
normalizer = max_norm/clip_norms | |||||
normalizer = normalizer[:, np.newaxis] | |||||
mask[clip_select] = normalizer | |||||
if self.use_cuda: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask).cuda(), requires_grad=False) | |||||
else: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask), requires_grad=False) | |||||
logits.append(logit) | |||||
h1tohT.append(hidden) | |||||
h1tohT = torch.stack(h1tohT) | |||||
output = torch.stack(logits) | |||||
raw_output = output | |||||
output = self.lockdrop(output, 0.4 if self.training else 0) | |||||
#Pooling | |||||
output = torch.mean(output, 0) | |||||
decoded = self.decoder(output) | |||||
extra_out = {'dropped': decoded, | |||||
'hiddens': h1tohT, | |||||
'raw': raw_output} | |||||
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | |||||
def cell(self, x, h_prev, dag): | |||||
"""Computes a single pass through the discovered RNN cell.""" | |||||
c = {} | |||||
h = {} | |||||
f = {} | |||||
f[0] = self.get_f(dag[-1][0].name) | |||||
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | |||||
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||||
(1 - c[0])*h_prev) | |||||
leaf_node_ids = [] | |||||
q = collections.deque() | |||||
q.append(0) | |||||
# Computes connections from the parent nodes `node_id` | |||||
# to their child nodes `next_id` recursively, skipping leaf nodes. A | |||||
# leaf node is a node whose id == `self.num_blocks`. | |||||
# | |||||
# Connections between parent i and child j should be computed as | |||||
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, | |||||
# where c_j = \sigmoid{(W^c_{ij}*h_i)} | |||||
# | |||||
# See Training details from Section 3.1 of the paper. | |||||
# | |||||
# The following algorithm does a breadth-first (since `q.popleft()` is | |||||
# used) search over the nodes and computes all the hidden states. | |||||
while True: | |||||
if len(q) == 0: | |||||
break | |||||
node_id = q.popleft() | |||||
nodes = dag[node_id] | |||||
for next_node in nodes: | |||||
next_id = next_node.id | |||||
if next_id == self.num_blocks: | |||||
leaf_node_ids.append(node_id) | |||||
assert len(nodes) == 1, ('parent of leaf node should have ' | |||||
'only one child') | |||||
continue | |||||
w_h = self.w_h[node_id][next_id] | |||||
w_c = self.w_c[node_id][next_id] | |||||
f[next_id] = self.get_f(next_node.name) | |||||
c[next_id] = torch.sigmoid(w_c(h[node_id])) | |||||
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + | |||||
(1 - c[next_id])*h[node_id]) | |||||
q.append(next_id) | |||||
# Instead of averaging loose ends, perhaps there should | |||||
# be a set of separate unshared weights for each "loose" connection | |||||
# between each node in a cell and the output. | |||||
# | |||||
# As it stands, all weights W^h_{ij} are doing double duty by | |||||
# connecting both from i to j, as well as from i to the output. | |||||
# average all the loose ends | |||||
leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | |||||
output = torch.mean(torch.stack(leaf_nodes, 2), -1) | |||||
# stabilizing the Updates of omega | |||||
if self.batch_norm is not None: | |||||
output = self.batch_norm(output) | |||||
return output, h[self.num_blocks - 1] | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.shared_hid) | |||||
return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | |||||
def get_f(self, name): | |||||
name = name.lower() | |||||
if name == 'relu': | |||||
f = torch.relu | |||||
elif name == 'tanh': | |||||
f = torch.tanh | |||||
elif name == 'identity': | |||||
f = lambda x: x | |||||
elif name == 'sigmoid': | |||||
f = torch.sigmoid | |||||
return f | |||||
@property | |||||
def num_parameters(self): | |||||
def size(p): | |||||
return np.prod(p.size()) | |||||
return sum([size(param) for param in self.parameters()]) | |||||
def reset_parameters(self): | |||||
init_range = 0.025 | |||||
# init_range = 0.025 if self.args.mode == 'train' else 0.04 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
self.decoder.bias.data.fill_(0) | |||||
def predict(self, word_seq): | |||||
""" | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
""" | |||||
output = self(word_seq) | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -0,0 +1,385 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
import os | |||||
import time | |||||
from datetime import datetime | |||||
from datetime import timedelta | |||||
import numpy as np | |||||
import torch | |||||
import math | |||||
from torch import nn | |||||
try: | |||||
from tqdm.autonotebook import tqdm | |||||
except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.callback import CallbackManager, CallbackException | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
import fastNLP | |||||
import fastNLP.models.enas_utils as utils | |||||
from fastNLP.core.utils import _build_args | |||||
from torch.optim import Adam | |||||
def _get_no_grad_ctx_mgr(): | |||||
"""Returns a the `torch.no_grad` context manager for PyTorch version >= | |||||
0.4, or a no-op context manager otherwise. | |||||
""" | |||||
return torch.no_grad() | |||||
class ENASTrainer(fastNLP.Trainer): | |||||
"""A class to wrap training code.""" | |||||
def __init__(self, train_data, model, controller, **kwargs): | |||||
"""Constructor for training algorithm. | |||||
:param DataSet train_data: the training data | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param torch.nn.modules.module controller: a PyTorch model | |||||
""" | |||||
self.final_epochs = kwargs['final_epochs'] | |||||
kwargs.pop('final_epochs') | |||||
super(ENASTrainer, self).__init__(train_data, model, **kwargs) | |||||
self.controller_step = 0 | |||||
self.shared_step = 0 | |||||
self.max_length = 35 | |||||
self.shared = model | |||||
self.controller = controller | |||||
self.shared_optim = Adam( | |||||
self.shared.parameters(), | |||||
lr=20.0, | |||||
weight_decay=1e-7) | |||||
self.controller_optim = Adam( | |||||
self.controller.parameters(), | |||||
lr=3.5e-4) | |||||
def train(self, load_best_model=True): | |||||
""" | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, 内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | |||||
results = {} | |||||
if self.n_epochs <= 0: | |||||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
results['seconds'] = 0. | |||||
return results | |||||
try: | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = self.model.cuda() | |||||
self._model_device = self.model.parameters().__next__().device | |||||
self._mode(self.model, is_test=False) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
start_time = time.time() | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end() | |||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
self.callback_manager.on_exception(e) | |||||
if self.dev_data is not None: | |||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf),) | |||||
results['best_eval'] = self.best_dev_perf | |||||
results['best_epoch'] = self.best_dev_epoch | |||||
results['best_step'] = self.best_dev_step | |||||
if load_best_model: | |||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||||
load_succeed = self._load_model(self.model, model_name) | |||||
if load_succeed: | |||||
print("Reloaded the best model.") | |||||
else: | |||||
print("Fail to reload best model.") | |||||
finally: | |||||
pass | |||||
results['seconds'] = round(time.time() - start_time, 2) | |||||
return results | |||||
def _train(self): | |||||
if not self.use_tqdm: | |||||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||||
else: | |||||
inner_tqdm = tqdm | |||||
self.step = 0 | |||||
start = time.time() | |||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
avg_loss = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for epoch in range(1, self.n_epochs+1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||||
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | |||||
if epoch == self.n_epochs + 1 - self.final_epochs: | |||||
print('Entering the final stage. (Only train the selected structure)') | |||||
# early stopping | |||||
self.callback_manager.on_epoch_begin() | |||||
# 1. Training the shared parameters omega of the child models | |||||
self.train_shared(pbar) | |||||
# 2. Training the controller parameters theta | |||||
if not last_stage: | |||||
self.train_controller() | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | |||||
if not last_stage: | |||||
self.derive() | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
# lr decay; early stopping | |||||
self.callback_manager.on_epoch_end() | |||||
# =============== epochs end =================== # | |||||
pbar.close() | |||||
# ============ tqdm end ============== # | |||||
def get_loss(self, inputs, targets, hidden, dags): | |||||
"""Computes the loss for the same batch for M models. | |||||
This amounts to an estimate of the loss, which is turned into an | |||||
estimate for the gradients of the shared model. | |||||
""" | |||||
if not isinstance(dags, list): | |||||
dags = [dags] | |||||
loss = 0 | |||||
for dag in dags: | |||||
self.shared.setDAG(dag) | |||||
inputs = _build_args(self.shared.forward, **inputs) | |||||
inputs['hidden'] = hidden | |||||
result = self.shared(**inputs) | |||||
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | |||||
self.callback_manager.on_loss_begin(targets, result) | |||||
sample_loss = self._compute_loss(result, targets) | |||||
loss += sample_loss | |||||
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | |||||
return loss, hidden, extra_out | |||||
def train_shared(self, pbar=None, max_step=None, dag=None): | |||||
"""Train the language model for 400 steps of minibatches of 64 | |||||
examples. | |||||
Args: | |||||
max_step: Used to run extra training steps as a warm-up. | |||||
dag: If not None, is used instead of calling sample(). | |||||
BPTT is truncated at 35 timesteps. | |||||
For each weight update, gradients are estimated by sampling M models | |||||
from the fixed controller policy, and averaging their gradients | |||||
computed on a batch of training data. | |||||
""" | |||||
model = self.shared | |||||
model.train() | |||||
self.controller.eval() | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
abs_max_grad = 0 | |||||
abs_max_hidden_norm = 0 | |||||
step = 0 | |||||
raw_total_loss = 0 | |||||
total_loss = 0 | |||||
train_idx = 0 | |||||
avg_loss = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
indices = data_iterator.get_batch_indices() | |||||
# negative sampling; replace unknown; re-weight batch_y | |||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||||
# prediction = self._data_forward(self.model, batch_x) | |||||
dags = self.controller.sample(1) | |||||
inputs, targets = batch_x, batch_y | |||||
# self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss, hidden, extra_out = self.get_loss(inputs, | |||||
targets, | |||||
hidden, | |||||
dags) | |||||
hidden.detach_() | |||||
avg_loss += loss.item() | |||||
# Is loss NaN or inf? requires_grad = False | |||||
self.callback_manager.on_backward_begin(loss) | |||||
self._grad_backward(loss) | |||||
self.callback_manager.on_backward_end() | |||||
self._update() | |||||
self.callback_manager.on_step_end() | |||||
if (self.step+1) % self.print_every == 0: | |||||
if self.use_tqdm: | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
pbar.update(self.print_every) | |||||
else: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, avg_loss, diff) | |||||
pbar.set_postfix_str(print_output) | |||||
avg_loss = 0 | |||||
self.step += 1 | |||||
step += 1 | |||||
self.shared_step += 1 | |||||
self.callback_manager.on_batch_end() | |||||
# ================= mini-batch end ==================== # | |||||
def get_reward(self, dag, entropies, hidden, valid_idx=0): | |||||
"""Computes the perplexity of a single sampled model on a minibatch of | |||||
validation data. | |||||
""" | |||||
if not isinstance(entropies, np.ndarray): | |||||
entropies = entropies.data.cpu().numpy() | |||||
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for inputs, targets in data_iterator: | |||||
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | |||||
valid_loss = utils.to_item(valid_loss.data) | |||||
valid_ppl = math.exp(valid_loss) | |||||
R = 80 / valid_ppl | |||||
rewards = R + 1e-4 * entropies | |||||
return rewards, hidden | |||||
def train_controller(self): | |||||
"""Fixes the shared parameters and updates the controller parameters. | |||||
The controller is updated with a score function gradient estimator | |||||
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl | |||||
is computed on a minibatch of validation data. | |||||
A moving average baseline is used. | |||||
The controller is trained for 2000 steps per epoch (i.e., | |||||
first (Train Shared) phase -> second (Train Controller) phase). | |||||
""" | |||||
model = self.controller | |||||
model.train() | |||||
# Why can't we call shared.eval() here? Leads to loss | |||||
# being uniformly zero for the controller. | |||||
# self.shared.eval() | |||||
avg_reward_base = None | |||||
baseline = None | |||||
adv_history = [] | |||||
entropy_history = [] | |||||
reward_history = [] | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
total_loss = 0 | |||||
valid_idx = 0 | |||||
for step in range(20): | |||||
# sample models | |||||
dags, log_probs, entropies = self.controller.sample( | |||||
with_details=True) | |||||
# calculate reward | |||||
np_entropies = entropies.data.cpu().numpy() | |||||
# No gradients should be backpropagated to the | |||||
# shared model during controller training, obviously. | |||||
with _get_no_grad_ctx_mgr(): | |||||
rewards, hidden = self.get_reward(dags, | |||||
np_entropies, | |||||
hidden, | |||||
valid_idx) | |||||
reward_history.extend(rewards) | |||||
entropy_history.extend(np_entropies) | |||||
# moving average baseline | |||||
if baseline is None: | |||||
baseline = rewards | |||||
else: | |||||
decay = 0.95 | |||||
baseline = decay * baseline + (1 - decay) * rewards | |||||
adv = rewards - baseline | |||||
adv_history.extend(adv) | |||||
# policy loss | |||||
loss = -log_probs*utils.get_variable(adv, | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
loss = loss.sum() # or loss.mean() | |||||
# update | |||||
self.controller_optim.zero_grad() | |||||
loss.backward() | |||||
self.controller_optim.step() | |||||
total_loss += utils.to_item(loss.data) | |||||
if ((step % 50) == 0) and (step > 0): | |||||
reward_history, adv_history, entropy_history = [], [], [] | |||||
total_loss = 0 | |||||
self.controller_step += 1 | |||||
# prev_valid_idx = valid_idx | |||||
# valid_idx = ((valid_idx + self.max_length) % | |||||
# (self.valid_data.size(0) - 1)) | |||||
# # Whenever we wrap around to the beginning of the | |||||
# # validation data, we reset the hidden states. | |||||
# if prev_valid_idx > valid_idx: | |||||
# hidden = self.shared.init_hidden(self.batch_size) | |||||
def derive(self, sample_num=10, valid_idx=0): | |||||
"""We are always deriving based on the very first batch | |||||
of validation data? This seems wrong... | |||||
""" | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
dags, _, entropies = self.controller.sample(sample_num, | |||||
with_details=True) | |||||
max_R = 0 | |||||
best_dag = None | |||||
for dag in dags: | |||||
R, _ = self.get_reward(dag, entropies, hidden, valid_idx) | |||||
if R.max() > max_R: | |||||
max_R = R.max() | |||||
best_dag = dag | |||||
self.model.setDAG(best_dag) |
@@ -0,0 +1,56 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
from __future__ import print_function | |||||
from collections import defaultdict | |||||
import collections | |||||
from datetime import datetime | |||||
import os | |||||
import json | |||||
import numpy as np | |||||
import torch | |||||
from torch.autograd import Variable | |||||
def detach(h): | |||||
if type(h) == Variable: | |||||
return Variable(h.data) | |||||
else: | |||||
return tuple(detach(v) for v in h) | |||||
def get_variable(inputs, cuda=False, **kwargs): | |||||
if type(inputs) in [list, np.ndarray]: | |||||
inputs = torch.Tensor(inputs) | |||||
if cuda: | |||||
out = Variable(inputs.cuda(), **kwargs) | |||||
else: | |||||
out = Variable(inputs, **kwargs) | |||||
return out | |||||
def update_lr(optimizer, lr): | |||||
for param_group in optimizer.param_groups: | |||||
param_group['lr'] = lr | |||||
Node = collections.namedtuple('Node', ['id', 'name']) | |||||
class keydefaultdict(defaultdict): | |||||
def __missing__(self, key): | |||||
if self.default_factory is None: | |||||
raise KeyError(key) | |||||
else: | |||||
ret = self[key] = self.default_factory(key) | |||||
return ret | |||||
def to_item(x): | |||||
"""Converts x, possibly scalar and possibly tensor, to a Python scalar.""" | |||||
if isinstance(x, (float, int)): | |||||
return x | |||||
if float(torch.__version__[0:3]) < 0.4: | |||||
assert (x.dim() == 1) and (len(x) == 1) | |||||
return x[0] | |||||
return x.item() |
@@ -0,0 +1,181 @@ | |||||
from fastNLP.modules.encoder.star_transformer import StarTransformer | |||||
from fastNLP.core.utils import seq_lens_to_masks | |||||
import torch | |||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
class StarTransEnc(nn.Module): | |||||
def __init__(self, vocab_size, emb_dim, | |||||
hidden_size, | |||||
num_layers, | |||||
num_head, | |||||
head_dim, | |||||
max_len, | |||||
emb_dropout, | |||||
dropout): | |||||
super(StarTransEnc, self).__init__() | |||||
self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||||
self.emb_drop = nn.Dropout(emb_dropout) | |||||
self.embedding = nn.Embedding(vocab_size, emb_dim) | |||||
self.encoder = StarTransformer(hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
dropout=dropout, | |||||
max_len=max_len) | |||||
def forward(self, x, mask): | |||||
x = self.embedding(x) | |||||
x = self.emb_fc(self.emb_drop(x)) | |||||
nodes, relay = self.encoder(x, mask) | |||||
return nodes, relay | |||||
class Cls(nn.Module): | |||||
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | |||||
super(Cls, self).__init__() | |||||
self.fc = nn.Sequential( | |||||
nn.Linear(in_dim, hid_dim), | |||||
nn.LeakyReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(hid_dim, num_cls), | |||||
) | |||||
def forward(self, x): | |||||
h = self.fc(x) | |||||
return h | |||||
class NLICls(nn.Module): | |||||
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | |||||
super(NLICls, self).__init__() | |||||
self.fc = nn.Sequential( | |||||
nn.Dropout(dropout), | |||||
nn.Linear(in_dim*4, hid_dim), #4 | |||||
nn.LeakyReLU(), | |||||
nn.Dropout(dropout), | |||||
nn.Linear(hid_dim, num_cls), | |||||
) | |||||
def forward(self, x1, x2): | |||||
x = torch.cat([x1, x2, torch.abs(x1-x2), x1*x2], 1) | |||||
h = self.fc(x) | |||||
return h | |||||
class STSeqLabel(nn.Module): | |||||
"""star-transformer model for sequence labeling | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
hidden_size=300, | |||||
num_layers=4, | |||||
num_head=8, | |||||
head_dim=32, | |||||
max_len=512, | |||||
cls_hidden_size=600, | |||||
emb_dropout=0.1, | |||||
dropout=0.1,): | |||||
super(STSeqLabel, self).__init__() | |||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
max_len=max_len, | |||||
emb_dropout=emb_dropout, | |||||
dropout=dropout) | |||||
self.cls = Cls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq, seq_lens): | |||||
mask = seq_lens_to_masks(seq_lens) | |||||
nodes, _ = self.enc(word_seq, mask) | |||||
output = self.cls(nodes) | |||||
output = output.transpose(1,2) # make hidden to be dim 1 | |||||
return {'output': output} # [bsz, n_cls, seq_len] | |||||
def predict(self, word_seq, seq_lens): | |||||
y = self.forward(word_seq, seq_lens) | |||||
_, pred = y['output'].max(1) | |||||
return {'output': pred, 'seq_lens': seq_lens} | |||||
class STSeqCls(nn.Module): | |||||
"""star-transformer model for sequence classification | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
hidden_size=300, | |||||
num_layers=4, | |||||
num_head=8, | |||||
head_dim=32, | |||||
max_len=512, | |||||
cls_hidden_size=600, | |||||
emb_dropout=0.1, | |||||
dropout=0.1,): | |||||
super(STSeqCls, self).__init__() | |||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
max_len=max_len, | |||||
emb_dropout=emb_dropout, | |||||
dropout=dropout) | |||||
self.cls = Cls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq, seq_lens): | |||||
mask = seq_lens_to_masks(seq_lens) | |||||
nodes, relay = self.enc(word_seq, mask) | |||||
y = 0.5 * (relay + nodes.max(1)[0]) | |||||
output = self.cls(y) # [bsz, n_cls] | |||||
return {'output': output} | |||||
def predict(self, word_seq, seq_lens): | |||||
y = self.forward(word_seq, seq_lens) | |||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} | |||||
class STNLICls(nn.Module): | |||||
"""star-transformer model for NLI | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, num_cls, | |||||
hidden_size=300, | |||||
num_layers=4, | |||||
num_head=8, | |||||
head_dim=32, | |||||
max_len=512, | |||||
cls_hidden_size=600, | |||||
emb_dropout=0.1, | |||||
dropout=0.1,): | |||||
super(STNLICls, self).__init__() | |||||
self.enc = StarTransEnc(vocab_size=vocab_size, | |||||
emb_dim=emb_dim, | |||||
hidden_size=hidden_size, | |||||
num_layers=num_layers, | |||||
num_head=num_head, | |||||
head_dim=head_dim, | |||||
max_len=max_len, | |||||
emb_dropout=emb_dropout, | |||||
dropout=dropout) | |||||
self.cls = NLICls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq1, word_seq2, seq_lens1, seq_lens2): | |||||
mask1 = seq_lens_to_masks(seq_lens1) | |||||
mask2 = seq_lens_to_masks(seq_lens2) | |||||
def enc(seq, mask): | |||||
nodes, relay = self.enc(seq, mask) | |||||
return 0.5 * (relay + nodes.max(1)[0]) | |||||
y1 = enc(word_seq1, mask1) | |||||
y2 = enc(word_seq2, mask2) | |||||
output = self.cls(y1, y2) # [bsz, n_cls] | |||||
return {'output': output} | |||||
def predict(self, word_seq1, word_seq2, seq_lens1, seq_lens2): | |||||
y = self.forward(word_seq1, word_seq2, seq_lens1, seq_lens2) | |||||
_, pred = y['output'].max(1) | |||||
return {'output': pred} |
@@ -0,0 +1,145 @@ | |||||
import torch | |||||
from torch import nn | |||||
from torch.nn import functional as F | |||||
import numpy as NP | |||||
class StarTransformer(nn.Module): | |||||
"""Star-Transformer Encoder part。 | |||||
paper: https://arxiv.org/abs/1902.09113 | |||||
:param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param num_layers: int, star-transformer的层数 | |||||
:param num_head: int,head的数量。 | |||||
:param head_dim: int, 每个head的维度大小。 | |||||
:param dropout: float dropout 概率 | |||||
:param max_len: int or None, 如果为int,输入序列的最大长度, | |||||
模型会为属于序列加上position embedding。 | |||||
若为None,忽略加上position embedding的步骤 | |||||
""" | |||||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | |||||
super(StarTransformer, self).__init__() | |||||
self.iters = num_layers | |||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | |||||
self.ring_att = nn.ModuleList( | |||||
[MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
for _ in range(self.iters)]) | |||||
self.star_att = nn.ModuleList( | |||||
[MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
for _ in range(self.iters)]) | |||||
if max_len is not None: | |||||
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | |||||
else: | |||||
self.pos_emb = None | |||||
def forward(self, data, mask): | |||||
""" | |||||
:param FloatTensor data: [batch, length, hidden] the input sequence | |||||
:param ByteTensor mask: [batch, length] the padding mask for input, in which padding pos is 0 | |||||
:return: [batch, length, hidden] the output sequence | |||||
[batch, hidden] the global relay node | |||||
""" | |||||
def norm_func(f, x): | |||||
# B, H, L, 1 | |||||
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |||||
B, L, H = data.size() | |||||
mask = (mask == 0) # flip the mask for masked_fill_ | |||||
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1) | |||||
embs = data.permute(0, 2, 1)[:,:,:,None] # B H L 1 | |||||
if self.pos_emb: | |||||
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device)\ | |||||
.view(1, L)).permute(0, 2, 1).contiguous()[:, :, :, None] # 1 H L 1 | |||||
embs = embs + P | |||||
nodes = embs | |||||
relay = embs.mean(2, keepdim=True) | |||||
ex_mask = mask[:, None, :, None].expand(B, H, L, 1) | |||||
r_embs = embs.view(B, H, 1, L) | |||||
for i in range(self.iters): | |||||
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2) | |||||
nodes = nodes + F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax)) | |||||
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask)) | |||||
nodes = nodes.masked_fill_(ex_mask, 0) | |||||
nodes = nodes.view(B, H, L).permute(0, 2, 1) | |||||
return nodes, relay.view(B, H) | |||||
class MSA1(nn.Module): | |||||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||||
super(MSA1, self).__init__() | |||||
# Multi-head Self Attention Case 1, doing self-attention for small regions | |||||
# Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small | |||||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||||
self.drop = nn.Dropout(dropout) | |||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||||
def forward(self, x, ax=None): | |||||
# x: B, H, L, 1, ax : B, H, X, L append features | |||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||||
B, H, L, _ = x.shape | |||||
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1) | |||||
if ax is not None: | |||||
aL = ax.shape[2] | |||||
ak = self.WK(ax).view(B, nhead, head_dim, aL, L) | |||||
av = self.WV(ax).view(B, nhead, head_dim, aL, L) | |||||
q = q.view(B, nhead, head_dim, 1, L) | |||||
k = F.unfold(k.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
v = F.unfold(v.view(B, nhead * head_dim, L, 1), (unfold_size, 1), padding=(unfold_size // 2, 0))\ | |||||
.view(B, nhead, head_dim, unfold_size, L) | |||||
if ax is not None: | |||||
k = torch.cat([k, ak], 3) | |||||
v = torch.cat([v, av], 3) | |||||
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U | |||||
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1) | |||||
ret = self.WO(att) | |||||
return ret | |||||
class MSA2(nn.Module): | |||||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | |||||
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value | |||||
super(MSA2, self).__init__() | |||||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | |||||
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1) | |||||
self.drop = nn.Dropout(dropout) | |||||
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim) | |||||
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3 | |||||
def forward(self, x, y, mask=None): | |||||
# x: B, H, 1, 1, 1 y: B H L 1 | |||||
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size | |||||
B, H, L, _ = y.shape | |||||
q, k, v = self.WQ(x), self.WK(y), self.WV(y) | |||||
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h | |||||
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L | |||||
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h | |||||
pre_a = torch.matmul(q, k) / NP.sqrt(head_dim) | |||||
if mask is not None: | |||||
pre_a = pre_a.masked_fill(mask[:, None, None, :], -float('inf')) | |||||
alphas = self.drop(F.softmax(pre_a, 3)) # B, N, 1, L | |||||
att = torch.matmul(alphas, v).view(B, -1, 1, 1) # B, N, 1, h -> B, N*h, 1, 1 | |||||
return self.WO(att) |
@@ -5,17 +5,18 @@ from ..dropout import TimestepDropout | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
"""transformer的encoder模块,不包含embedding层 | |||||
:param num_layers: int, transformer的层数 | |||||
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param inner_size: int, FFN层的hidden大小 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | |||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
""" | |||||
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param inner_size: int, FFN层的hidden大小 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
""" | |||||
super(TransformerEncoder.SubLayer, self).__init__() | super(TransformerEncoder.SubLayer, self).__init__() | ||||
self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) | ||||
self.norm1 = nn.LayerNorm(model_size) | self.norm1 = nn.LayerNorm(model_size) | ||||
@@ -45,6 +46,11 @@ class TransformerEncoder(nn.Module): | |||||
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)]) | ||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
""" | |||||
:param x: [batch, seq_len, model_size] 输入序列 | |||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask | |||||
:return: [batch, seq_len, model_size] 输出序列 | |||||
""" | |||||
output = x | output = x | ||||
if seq_mask is None: | if seq_mask is None: | ||||
atte_mask_out = None | atte_mask_out = None | ||||
@@ -0,0 +1,44 @@ | |||||
# 模型复现 | |||||
这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | |||||
复现的模型有: | |||||
- Star-Transformer | |||||
- ... | |||||
## Star-Transformer | |||||
[reference](https://arxiv.org/abs/1902.09113) | |||||
### Performance (still in progress) | |||||
|任务| 数据集 | SOTA | 模型表现 | | |||||
|------|------| ------| ------| | |||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |||||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |||||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |||||
|Text Classification|SST|-|49.18| | |||||
|Natural Language Inference|SNLI|-|83.76| | |||||
### Usage | |||||
``` python | |||||
# for sequence labeling(ner, pos tagging, etc) | |||||
from fastNLP.models.star_transformer import STSeqLabel | |||||
model = STSeqLabel( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for sequence classification | |||||
from fastNLP.models.star_transformer import STSeqCls | |||||
model = STSeqCls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
# for natural language inference | |||||
from fastNLP.models.star_transformer import STNLICls | |||||
model = STNLICls( | |||||
vocab_size=10000, num_cls=50, | |||||
emb_dim=300) | |||||
``` | |||||
## ... |
@@ -13,12 +13,12 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
setup( | setup( | ||||
name='FastNLP', | name='FastNLP', | ||||
version='0.1.1', | |||||
version='0.4.0', | |||||
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
long_description=readme, | long_description=readme, | ||||
license=license, | license=license, | ||||
author='FudanNLP', | author='FudanNLP', | ||||
python_requires='>=3.5', | |||||
python_requires='>=3.6', | |||||
packages=find_packages(), | packages=find_packages(), | ||||
install_requires=reqs.strip().split('\n'), | install_requires=reqs.strip().split('\n'), | ||||
) | ) |
@@ -35,7 +35,7 @@ class TestENAS(unittest.TestCase): | |||||
print(dataset[0]) | print(dataset[0]) | ||||
# DataSet.drop(func)筛除数据 | # DataSet.drop(func)筛除数据 | ||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
# 设置DataSet中,哪些field要转为tensor | # 设置DataSet中,哪些field要转为tensor | ||||
@@ -139,11 +139,14 @@ class TestCallback(unittest.TestCase): | |||||
def test_readonly_property(self): | def test_readonly_property(self): | ||||
from fastNLP.core.callback import Callback | from fastNLP.core.callback import Callback | ||||
passed_epochs = [] | |||||
total_epochs = 5 | |||||
class MyCallback(Callback): | class MyCallback(Callback): | ||||
def __init__(self): | def __init__(self): | ||||
super(MyCallback, self).__init__() | super(MyCallback, self).__init__() | ||||
def on_epoch_begin(self, cur_epoch, total_epoch): | |||||
def on_epoch_begin(self): | |||||
passed_epochs.append(self.epoch) | |||||
print(self.n_epochs, self.n_steps, self.batch_size) | print(self.n_epochs, self.n_steps, self.batch_size) | ||||
print(self.model) | print(self.model) | ||||
print(self.optimizer) | print(self.optimizer) | ||||
@@ -151,7 +154,7 @@ class TestCallback(unittest.TestCase): | |||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
n_epochs=5, | |||||
n_epochs=total_epochs, | |||||
batch_size=32, | batch_size=32, | ||||
print_every=50, | print_every=50, | ||||
optimizer=SGD(lr=0.1), | optimizer=SGD(lr=0.1), | ||||
@@ -161,3 +164,4 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[MyCallback()]) | callbacks=[MyCallback()]) | ||||
trainer.train() | trainer.train() | ||||
assert passed_epochs == list(range(1, total_epochs+1)) |
@@ -125,7 +125,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
def test_drop(self): | def test_drop(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ||||
ds.drop(lambda ins: len(ins["y"]) < 3) | |||||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | |||||
self.assertEqual(len(ds), 20) | self.assertEqual(len(ds), 20) | ||||
def test_contains(self): | def test_contains(self): | ||||
@@ -169,7 +169,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | ||||
sep='\t') | sep='\t') | ||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0) | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | |||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
# print(dataset) | # print(dataset) | ||||
@@ -217,9 +217,10 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(len(ds) > 0) | self.assertTrue(len(ds) > 0) | ||||
def test_add_null(self): | def test_add_null(self): | ||||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||||
ds = DataSet() | ds = DataSet() | ||||
ds.add_field('test', []) | |||||
ds.set_target('test') | |||||
with self.assertRaises(RuntimeError) as RE: | |||||
ds.add_field('test', []) | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
@@ -15,7 +15,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_AccuracyMetric2(self): | def test_AccuracyMetric2(self): | ||||
@@ -30,7 +30,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | |||||
print("No exception catches.") | |||||
def test_AccuracyMetric3(self): | def test_AccuracyMetric3(self): | ||||
# (3) the second batch is corrupted size | # (3) the second batch is corrupted size | ||||
@@ -95,10 +95,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | self.assertAlmostEqual(res["acc"], float(ans), places=4) | ||||
def test_AccuaryMetric8(self): | def test_AccuaryMetric8(self): | ||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||||
try: | try: | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | metric(pred_dict=pred_dict, target_dict=target_dict, ) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
@@ -141,11 +140,11 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | ||||
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | ||||
expect_bmes_res = set() | expect_bmes_res = set() | ||||
expect_bmes_res.update([('8', (0, 0)), ('2', (1, 1)), ('0', (2, 2)), ('9', (3, 3)), ('6', (4, 4)), | |||||
('5', (5, 5)), ('7', (6, 6)), ('2', (7, 7)), ('7', (8, 8)), ('8', (9, 9))]) | |||||
expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), | |||||
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||||
expect_bio_res = set() | expect_bio_res = set() | ||||
expect_bio_res.update([('7', (8, 8)), ('0', (2, 2)), ('2', (7, 7)), ('5', (5, 5)), | |||||
('6', (4, 4)), ('7', (6, 6))]) | |||||
expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), | |||||
('6', (4, 5)), ('7', (6, 7))]) | |||||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | ||||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | ||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | ||||
@@ -168,9 +167,9 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | ||||
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | ||||
expect_bmes_res = set() | expect_bmes_res = set() | ||||
expect_bmes_res.update([('', (0, 1)), ('', (2, 2)), ('', (3, 3)), ('', (4, 6)), ('', (7, 7)), ('', (8, 9))]) | |||||
expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) | |||||
expect_bio_res = set() | expect_bio_res = set() | ||||
expect_bio_res.update([('', (7, 7)), ('', (6, 6)), ('', (4, 4)), ('', (0, 0)), ('', (1, 1))]) | |||||
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) | |||||
self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | self.assertSetEqual(expect_bmes_res,set(bmes_tag_to_spans(bmes_lst))) | ||||
self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | self.assertSetEqual(expect_bio_res, set(bio_tag_to_spans(bio_lst))) | ||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | ||||
@@ -6,7 +6,7 @@ from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
config_file_dir = "test/io/" | |||||
config_file_dir = "test/io" | |||||
config_file_name = "config" | config_file_name = "config" | ||||
config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
@@ -17,11 +17,3 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_PeopleDailyCorpusLoader(self): | def test_PeopleDailyCorpusLoader(self): | ||||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | ||||
def test_ConllCWSReader(self): | |||||
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt") | |||||
def test_ZhConllPOSReader(self): | |||||
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx") | |||||
def test_ConllxDataLoader(self): | |||||
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx") |
@@ -118,7 +118,7 @@ class TestCRF(unittest.TestCase): | |||||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | ||||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | crf = ConditionalRandomField(num_tags, include_start_end_trans) | ||||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | ||||
for _ in range(10000): | |||||
for _ in range(10): | |||||
loss = crf(feats, tags, masks).mean() | loss = crf(feats, tags, masks).mean() | ||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
@@ -3,6 +3,7 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | ||||
from fastNLP.modules.encoder.star_transformer import StarTransformer | |||||
class TestGroupNorm(unittest.TestCase): | class TestGroupNorm(unittest.TestCase): | ||||
@@ -49,3 +50,12 @@ class TestBiAffine(unittest.TestCase): | |||||
encoder_input = torch.randn((batch_size, decoder_length, 10)) | encoder_input = torch.randn((batch_size, decoder_length, 10)) | ||||
y = layer(decoder_input, encoder_input) | y = layer(decoder_input, encoder_input) | ||||
self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) | self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) | ||||
class TestStarTransformer(unittest.TestCase): | |||||
def test_1(self): | |||||
model = StarTransformer(num_layers=6, hidden_size=100, num_head=8, head_dim=20, max_len=100) | |||||
x = torch.rand(16, 45, 100) | |||||
mask = torch.ones(16, 45).byte() | |||||
y, yn = model(x, mask) | |||||
self.assertEqual(tuple(y.size()), (16, 45, 100)) | |||||
self.assertEqual(tuple(yn.size()), (16, 100)) |
@@ -35,7 +35,7 @@ class TestTutorial(unittest.TestCase): | |||||
print(dataset[0]) | print(dataset[0]) | ||||
# DataSet.drop(func)筛除数据 | # DataSet.drop(func)筛除数据 | ||||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||||
dataset.drop(lambda x: x['seq_len'] <= 3, inplace=True) | |||||
print(len(dataset)) | print(len(dataset)) | ||||
# 设置DataSet中,哪些field要转为tensor | # 设置DataSet中,哪些field要转为tensor | ||||
@@ -152,7 +152,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data=train_data, | train_data=train_data, | ||||
dev_data=dev_data, | dev_data=dev_data, | ||||
loss=CrossEntropyLoss(), | loss=CrossEntropyLoss(), | ||||
metrics=AccuracyMetric() | |||||
metrics=AccuracyMetric(target='label_seq') | |||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||
@@ -296,7 +296,7 @@ class TestTutorial(unittest.TestCase): | |||||
# 筛选数据 | # 筛选数据 | ||||
origin_data_set_len = len(data_set) | origin_data_set_len = len(data_set) | ||||
data_set.drop(lambda x: len(x['premise']) <= 6) | |||||
data_set.drop(lambda x: len(x['premise']) <= 6, inplace=True) | |||||
origin_data_set_len, len(data_set) | origin_data_set_len, len(data_set) | ||||
# In[17]: | # In[17]: | ||||
@@ -353,7 +353,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data[-1], dev_data[-1], test_data[-1] | train_data[-1], dev_data[-1], test_data[-1] | ||||
# 读入vocab文件 | # 读入vocab文件 | ||||
with open('vocab.txt') as f: | |||||
with open('vocab.txt', encoding='utf-8') as f: | |||||
lines = f.readlines() | lines = f.readlines() | ||||
vocabs = [] | vocabs = [] | ||||
for line in lines: | for line in lines: | ||||
@@ -407,7 +407,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data=train_data, | train_data=train_data, | ||||
model=model, | model=model, | ||||
loss=CrossEntropyLoss(pred='pred', target='label'), | loss=CrossEntropyLoss(pred='pred', target='label'), | ||||
metrics=AccuracyMetric(), | |||||
metrics=AccuracyMetric(target='label'), | |||||
n_epochs=3, | n_epochs=3, | ||||
batch_size=16, | batch_size=16, | ||||
print_every=-1, | print_every=-1, | ||||
@@ -424,7 +424,7 @@ class TestTutorial(unittest.TestCase): | |||||
tester = Tester( | tester = Tester( | ||||
data=test_data, | data=test_data, | ||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | |||||
metrics=AccuracyMetric(target='label'), | |||||
batch_size=args["batch_size"], | batch_size=args["batch_size"], | ||||
) | ) | ||||
tester.test() | tester.test() | ||||