From 70fb4a2284625d539c8fccec378ed0e56e8386d8 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 12 Apr 2019 15:35:22 +0800 Subject: [PATCH] - add star transformer model - add ConllLoader, for all kinds of conll-format files - add JsonLoader, for json-format files - add SSTLoader, for SST-2 & SST-5 - change Callback interface - fix batch multi-process when killed - add README to list models and their performance --- README.md | 9 +- fastNLP/core/__init__.py | 2 +- fastNLP/core/batch.py | 16 +- fastNLP/core/callback.py | 171 +++++++++--------- fastNLP/core/trainer.py | 34 ++-- fastNLP/io/dataset_loader.py | 273 ++++++++++++++++++----------- fastNLP/models/enas_trainer.py | 14 +- fastNLP/models/star_transformer.py | 181 +++++++++++++++++++ reproduction/README.md | 44 +++++ test/test_tutorials.py | 2 +- 10 files changed, 530 insertions(+), 216 deletions(-) create mode 100644 fastNLP/models/star_transformer.py create mode 100644 reproduction/README.md diff --git a/README.md b/README.md index 5346fbd7..5e51cf62 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) -FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. +FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models. A deep learning NLP model is the composition of three types of modules: @@ -58,6 +58,13 @@ Run the following commands to install fastNLP package. pip install fastNLP ``` +## Models +fastNLP implements different models for variant NLP tasks. +Each model has been trained and tested carefully. + +Check out models' performance, usage and source code here. +- [Documentation](https://github.com/fastnlp/fastNLP/tree/master/reproduction) +- [Source Code](https://github.com/fastnlp/fastNLP/tree/master/fastNLP/models) ## Project Structure diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 038ca12f..0bb6a2dd 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -10,4 +10,4 @@ from .tester import Tester from .trainer import Trainer from .vocabulary import Vocabulary from ..io.dataset_loader import DataSet - +from .callback import Callback diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 88d9185d..bddecab3 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -1,9 +1,16 @@ import numpy as np import torch +import atexit from fastNLP.core.sampler import RandomSampler import torch.multiprocessing as mp +_python_is_exit = False +def _set_python_is_exit(): + global _python_is_exit + _python_is_exit = True +atexit.register(_set_python_is_exit) + class Batch(object): """Batch is an iterable object which iterates over mini-batches. @@ -95,12 +102,19 @@ def to_tensor(batch, dtype): def run_fetch(batch, q): + global _python_is_exit batch.init_iter() # print('start fetch') while 1: res = batch.fetch_one() # print('fetch one') - q.put(res) + while 1: + try: + q.put(res, timeout=3) + break + except Exception as e: + if _python_is_exit: + return if res is None: # print('fetch done, waiting processing') q.join() diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index b1a480cc..9cabba15 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -15,13 +15,57 @@ class Callback(object): def __init__(self): super(Callback, self).__init__() - self.trainer = None # 在Trainer内部被重新赋值 + self._trainer = None # 在Trainer内部被重新赋值 + + @property + def trainer(self): + return self._trainer + + @property + def step(self): + """current step number, in range(1, self.n_steps+1)""" + return self._trainer.step + + @property + def n_steps(self): + """total number of steps for training""" + return self.n_steps + + @property + def batch_size(self): + """batch size for training""" + return self._trainer.batch_size + + @property + def epoch(self): + """current epoch number, in range(1, self.n_epochs+1)""" + return self._trainer.epoch + + @property + def n_epochs(self): + """total number of epochs""" + return self._trainer.n_epochs + + @property + def optimizer(self): + """torch.optim.Optimizer for current model""" + return self._trainer.optimizer + + @property + def model(self): + """training model""" + return self._trainer.model + + @property + def pbar(self): + """If use_tqdm, return trainer's tqdm print bar, else return None.""" + return self._trainer.pbar def on_train_begin(self): # before the main training loop pass - def on_epoch_begin(self, cur_epoch, total_epoch): + def on_epoch_begin(self): # at the beginning of each epoch pass @@ -33,14 +77,14 @@ class Callback(object): # after data_forward, and before loss computation pass - def on_backward_begin(self, loss, model): + def on_backward_begin(self, loss): # after loss computation, and before gradient backward pass - def on_backward_end(self, model): + def on_backward_end(self): pass - def on_step_end(self, optimizer): + def on_step_end(self): pass def on_batch_end(self, *args): @@ -50,65 +94,36 @@ class Callback(object): def on_valid_begin(self): pass - def on_valid_end(self, eval_result, metric_key, optimizer): + def on_valid_end(self, eval_result, metric_key): """ 每次执行验证机的evaluation后会调用。传入eval_result :param eval_result: Dict[str: Dict[str: float]], evaluation的结果 :param metric_key: str - :param optimizer: :return: """ pass - def on_epoch_end(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self): """ 每个epoch结束将会调用该方法 - - :param cur_epoch: int, 当前的batch。从1开始。 - :param n_epoch: int, 总的batch数 - :param optimizer: 传入Trainer的optimizer。 - :return: """ pass - def on_train_end(self, model): + def on_train_end(self): """ 训练结束,调用该方法 - - :param model: nn.Module, 传入Trainer的模型 - :return: """ pass - def on_exception(self, exception, model): + def on_exception(self, exception): """ 当训练过程出现异常,会触发该方法 :param exception: 某种类型的Exception,比如KeyboardInterrupt等 - :param model: 传入Trainer的模型 - :return: """ pass -def transfer(func): - """装饰器,将对CallbackManager的调用转发到各个Callback子类. - - :param func: - :return: - """ - - def wrapper(manager, *arg): - returns = [] - for callback in manager.callbacks: - for env_name, env_value in manager.env.items(): - setattr(callback, env_name, env_value) - returns.append(getattr(callback, func.__name__)(*arg)) - return returns - - return wrapper - - class CallbackManager(Callback): """A manager for all callbacks passed into Trainer. It collects resources inside Trainer and raise callbacks. @@ -119,7 +134,7 @@ class CallbackManager(Callback): """ :param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. - :param Callback callbacks: + :param List[Callback] callbacks: """ super(CallbackManager, self).__init__() # set attribute of trainer environment @@ -136,56 +151,43 @@ class CallbackManager(Callback): else: raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") - @transfer def on_train_begin(self): pass - @transfer - def on_epoch_begin(self, cur_epoch, total_epoch): + def on_epoch_begin(self): pass - @transfer def on_batch_begin(self, batch_x, batch_y, indices): pass - @transfer def on_loss_begin(self, batch_y, predict_y): pass - @transfer - def on_backward_begin(self, loss, model): + def on_backward_begin(self, loss): pass - @transfer - def on_backward_end(self, model): + def on_backward_end(self): pass - @transfer - def on_step_end(self, optimizer): + def on_step_end(self): pass - @transfer def on_batch_end(self): pass - @transfer def on_valid_begin(self): pass - @transfer - def on_valid_end(self, eval_result, metric_key, optimizer): + def on_valid_end(self, eval_result, metric_key): pass - @transfer - def on_epoch_end(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self): pass - @transfer - def on_train_end(self, model): + def on_train_end(self): pass - @transfer - def on_exception(self, exception, model): + def on_exception(self, exception): pass @@ -193,15 +195,15 @@ class DummyCallback(Callback): def on_train_begin(self, *arg): print(arg) - def on_epoch_end(self, cur_epoch, n_epoch, optimizer): - print(cur_epoch, n_epoch, optimizer) + def on_epoch_end(self): + print(self.epoch, self.n_epochs) class EchoCallback(Callback): def on_train_begin(self): print("before_train") - def on_epoch_begin(self, cur_epoch, total_epoch): + def on_epoch_begin(self): print("before_epoch") def on_batch_begin(self, batch_x, batch_y, indices): @@ -210,16 +212,16 @@ class EchoCallback(Callback): def on_loss_begin(self, batch_y, predict_y): print("before_loss") - def on_backward_begin(self, loss, model): + def on_backward_begin(self, loss): print("before_backward") def on_batch_end(self): print("after_batch") - def on_epoch_end(self, cur_epoch, n_epoch, optimizer): + def on_epoch_end(self): print("after_epoch") - def on_train_end(self, model): + def on_train_end(self): print("after_train") @@ -247,8 +249,8 @@ class GradientClipCallback(Callback): self.parameters = parameters self.clip_value = clip_value - def on_backward_end(self, model): - self.clip_fun(model.parameters(), self.clip_value) + def on_backward_end(self): + self.clip_fun(self.model.parameters(), self.clip_value) class CallbackException(BaseException): @@ -268,13 +270,10 @@ class EarlyStopCallback(Callback): :param int patience: 停止之前等待的epoch数 """ super(EarlyStopCallback, self).__init__() - self.trainer = None # override by CallbackManager self.patience = patience self.wait = 0 - self.epoch = 0 - def on_valid_end(self, eval_result, metric_key, optimizer): - self.epoch += 1 + def on_valid_end(self, eval_result, metric_key): if not self.trainer._better_eval_result(eval_result): # current result is getting worse if self.wait == self.patience: @@ -284,7 +283,7 @@ class EarlyStopCallback(Callback): else: self.wait = 0 - def on_exception(self, exception, model): + def on_exception(self, exception): if isinstance(exception, EarlyStopError): print("Early Stopping triggered in epoch {}!".format(self.epoch)) else: @@ -304,9 +303,9 @@ class LRScheduler(Callback): else: raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") - def on_epoch_begin(self, cur_epoch, total_epoch): + def on_epoch_begin(self): self.scheduler.step() - print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) + print("scheduler step ", "lr=", self.optimizer.param_groups[0]["lr"]) class ControlC(Callback): @@ -320,7 +319,7 @@ class ControlC(Callback): raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") self.quit_all = quit_all - def on_exception(self, exception, model): + def on_exception(self, exception): if isinstance(exception, KeyboardInterrupt): if self.quit_all is True: import sys @@ -366,15 +365,15 @@ class LRFinder(Callback): self.find = None self.loader = ModelLoader() - def on_epoch_begin(self, cur_epoch, total_epoch): - if cur_epoch == 1: + def on_epoch_begin(self): + if self.epoch == 1: # first epoch self.opt = self.trainer.optimizer # pytorch optimizer self.opt.param_groups[0]["lr"] = self.start_lr # save model ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) self.find = True - def on_backward_begin(self, loss, model): + def on_backward_begin(self, loss): if self.find: if torch.isnan(loss) or self.stop is True: self.stop = True @@ -395,8 +394,8 @@ class LRFinder(Callback): self.opt.param_groups[0]["lr"] = lr # self.loader.load_pytorch(self.trainer.model, "tmp") - def on_epoch_end(self, cur_epoch, n_epoch, optimizer): - if cur_epoch == 1: + def on_epoch_end(self): + if self.epoch == 1: # first epoch self.opt.param_groups[0]["lr"] = self.best_lr self.find = False # reset model @@ -440,7 +439,7 @@ class TensorboardCallback(Callback): # self._summary_writer.add_graph(self.trainer.model, torch.zeros(32, 2)) self.graph_added = True - def on_backward_begin(self, loss, model): + def on_backward_begin(self, loss): if "loss" in self.options: self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) @@ -452,18 +451,18 @@ class TensorboardCallback(Callback): self._summary_writer.add_scalar(name + "_grad_mean", param.grad.mean(), global_step=self.trainer.step) - def on_valid_end(self, eval_result, metric_key, optimizer): + def on_valid_end(self, eval_result, metric_key): if "metric" in self.options: for name, metric in eval_result.items(): for metric_key, metric_val in metric.items(): self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, global_step=self.trainer.step) - def on_train_end(self, model): + def on_train_end(self): self._summary_writer.close() del self._summary_writer - def on_exception(self, exception, model): + def on_exception(self, exception): if hasattr(self, "_summary_writer"): self._summary_writer.close() del self._summary_writer @@ -471,5 +470,5 @@ class TensorboardCallback(Callback): if __name__ == "__main__": manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) - manager.on_train_begin(10, 11, 12) + manager.on_train_begin() # print(manager.after_epoch()) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index ddd35b28..25a32787 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -122,6 +122,8 @@ class Trainer(object): self.sampler = sampler self.prefetch = prefetch self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) + self.n_steps = (len(self.train_data) // self.batch_size + int( + len(self.train_data) % self.batch_size != 0)) * self.n_epochs if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -129,6 +131,7 @@ class Trainer(object): self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.use_tqdm = use_tqdm + self.pbar = None self.print_every = abs(self.print_every) if self.dev_data is not None: @@ -198,9 +201,9 @@ class Trainer(object): try: self.callback_manager.on_train_begin() self._train() - self.callback_manager.on_train_end(self.model) + self.callback_manager.on_train_end() except (CallbackException, KeyboardInterrupt) as e: - self.callback_manager.on_exception(e, self.model) + self.callback_manager.on_exception(e) if self.dev_data is not None: print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + @@ -227,18 +230,21 @@ class Trainer(object): else: inner_tqdm = tqdm self.step = 0 + self.epoch = 0 start = time.time() - total_steps = (len(self.train_data) // self.batch_size + int( - len(self.train_data) % self.batch_size != 0)) * self.n_epochs - with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: + + with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: + self.pbar = pbar if isinstance(pbar, tqdm) else None avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) for epoch in range(1, self.n_epochs+1): + self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping - self.callback_manager.on_epoch_begin(epoch, self.n_epochs) + self.callback_manager.on_epoch_begin() for batch_x, batch_y in data_iterator: + self.step += 1 _move_dict_value_to_device(batch_x, batch_y, device=self._model_device) indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y @@ -251,14 +257,14 @@ class Trainer(object): avg_loss += loss.item() # Is loss NaN or inf? requires_grad = False - self.callback_manager.on_backward_begin(loss, self.model) + self.callback_manager.on_backward_begin(loss) self._grad_backward(loss) - self.callback_manager.on_backward_end(self.model) + self.callback_manager.on_backward_end() self._update() - self.callback_manager.on_step_end(self.optimizer) + self.callback_manager.on_step_end() - if (self.step+1) % self.print_every == 0: + if self.step % self.print_every == 0: if self.use_tqdm: print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) pbar.update(self.print_every) @@ -269,7 +275,6 @@ class Trainer(object): epoch, self.step, avg_loss, diff) pbar.set_postfix_str(print_output) avg_loss = 0 - self.step += 1 self.callback_manager.on_batch_end() if ((self.validate_every > 0 and self.step % self.validate_every == 0) or @@ -277,16 +282,17 @@ class Trainer(object): and self.dev_data is not None: eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, - total_steps) + \ + self.n_steps) + \ self.tester._format_eval_results(eval_res) pbar.write(eval_str) # ================= mini-batch end ==================== # # lr decay; early stopping - self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) + self.callback_manager.on_epoch_end() # =============== epochs end =================== # pbar.close() + self.pbar = None # ============ tqdm end ============== # def _do_validation(self, epoch, step): @@ -303,7 +309,7 @@ class Trainer(object): self.best_dev_epoch = epoch self.best_dev_step = step # get validation results; adjust optimizer - self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer) + self.callback_manager.on_valid_end(res, self.metric_key) return res def _mode(self, model, is_test=False): diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 07b721c5..93e95033 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -1,4 +1,5 @@ import os +import json from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance @@ -64,6 +65,53 @@ def convert_seq2seq_dataset(data): return dataset +def download_from_url(url, path): + from tqdm import tqdm + import requests + + """Download file""" + r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) + chunk_size = 16 * 1024 + total_size = int(r.headers.get('Content-length', 0)) + with open(path, "wb") as file ,\ + tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: + for chunk in r.iter_content(chunk_size): + if chunk: + file.write(chunk) + t.update(len(chunk)) + return + +def uncompress(src, dst): + import zipfile, gzip, tarfile, os + + def unzip(src, dst): + with zipfile.ZipFile(src, 'r') as f: + f.extractall(dst) + + def ungz(src, dst): + with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: + length = 16 * 1024 # 16KB + buf = f.read(length) + while buf: + uf.write(buf) + buf = f.read(length) + + def untar(src, dst): + with tarfile.open(src, 'r:gz') as f: + f.extractall(dst) + + fn, ext = os.path.splitext(src) + _, ext_2 = os.path.splitext(fn) + if ext == '.zip': + unzip(src, dst) + elif ext == '.gz' and ext_2 != '.tar': + ungz(src, dst) + elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': + untar(src, dst) + else: + raise ValueError('unsupported file {}'.format(src)) + + class DataSetLoader: """Interface for all DataSetLoaders. @@ -290,41 +338,6 @@ class DummyClassificationReader(DataSetLoader): return convert_seq2tag_dataset(data) -class ConllLoader(DataSetLoader): - """loader for conll format files""" - - def __init__(self): - super(ConllLoader, self).__init__() - - def load(self, data_path): - with open(data_path, "r", encoding="utf-8") as f: - lines = f.readlines() - data = self.parse(lines) - return self.convert(data) - - @staticmethod - def parse(lines): - """ - :param list lines: a list containing all lines in a conll file. - :return: a 3D list - """ - sentences = list() - tokens = list() - for line in lines: - if line[0] == "#": - # skip the comments - continue - if line == "\n": - sentences.append(tokens) - tokens = [] - continue - tokens.append(line.split()) - return sentences - - def convert(self, data): - pass - - class DummyLMReader(DataSetLoader): """A Dummy Language Model Dataset Reader """ @@ -434,51 +447,67 @@ class PeopleDailyCorpusLoader(DataSetLoader): return data_set -class Conll2003Loader(DataSetLoader): +class ConllLoader: + def __init__(self, headers, indexs=None): + self.headers = headers + if indexs is None: + self.indexs = list(range(len(self.headers))) + else: + if len(indexs) != len(headers): + raise ValueError + self.indexs = indexs + + def load(self, path): + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + start = next(f) + if '-DOCSTART-' not in start: + sample.append(start.split()) + for line in f: + if line.startswith('\n'): + if len(sample): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.split()) + if len(sample) > 0: + datalist.append(sample) + + data = [self.get_one(sample) for sample in datalist] + data = filter(lambda x: x is not None, data) + + ds = DataSet() + for sample in data: + ins = Instance() + for name, idx in zip(self.headers, self.indexs): + ins.add_field(field_name=name, field=sample[idx]) + ds.append(ins) + return ds + + def get_one(self, sample): + sample = list(map(list, zip(*sample))) + for field in sample: + if len(field) <= 0: + return None + return sample + + +class Conll2003Loader(ConllLoader): """Loader for conll2003 dataset More information about the given dataset cound be found on https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data - + + Deprecated. Use ConllLoader for all types of conll-format files. """ def __init__(self): - super(Conll2003Loader, self).__init__() - - def load(self, dataset_path): - with open(dataset_path, "r", encoding="utf-8") as f: - lines = f.readlines() - parsed_data = [] - sentence = [] - tokens = [] - for line in lines: - if '-DOCSTART- -X- -X- O' in line or line == '\n': - if sentence != []: - parsed_data.append((sentence, tokens)) - sentence = [] - tokens = [] - continue - - temp = line.strip().split(" ") - sentence.append(temp[0]) - tokens.append(temp[1:4]) - - return self.convert(parsed_data) - - def convert(self, parsed_data): - dataset = DataSet() - for sample in parsed_data: - label0_list = list(map( - lambda labels: labels[0], sample[1])) - label1_list = list(map( - lambda labels: labels[1], sample[1])) - label2_list = list(map( - lambda labels: labels[2], sample[1])) - dataset.append(Instance(tokens=sample[0], - pos=label0_list, - chucks=label1_list, - ner=label2_list)) - - return dataset + headers = [ + 'tokens', 'pos', 'chunks', 'ner', + ] + super(Conll2003Loader, self).__init__(headers=headers) class SNLIDataSetReader(DataSetLoader): @@ -548,6 +577,7 @@ class SNLIDataSetReader(DataSetLoader): class ConllCWSReader(object): + """Deprecated. Use ConllLoader for all types of conll-format files.""" def __init__(self): pass @@ -700,6 +730,7 @@ def cut_long_sentence(sent, max_sample_length=200): class ZhConllPOSReader(object): """读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 + Deprecated. Use ConllLoader for all types of conll-format files. """ def __init__(self): pass @@ -778,47 +809,78 @@ class ZhConllPOSReader(object): return text, pos_tags -class ConllxDataLoader(object): +class ConllxDataLoader(ConllLoader): """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 + Deprecated. Use ConllLoader for all types of conll-format files. """ + def __init__(self): + headers = [ + 'words', 'pos_tags', 'heads', 'labels', + ] + indexs = [ + 1, 3, 6, 7, + ] + super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) + + +class SSTLoader(DataSetLoader): + """load SST data in PTB tree format + data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + """ + def __init__(self, subtree=False, fine_grained=False): + self.subtree = subtree + + tag_v = {'0':'very negative', '1':'negative', '2':'neutral', + '3':'positive', '4':'very positive'} + if not fine_grained: + tag_v['0'] = tag_v['1'] + tag_v['4'] = tag_v['3'] + self.tag_v = tag_v + def load(self, path): - datalist = [] with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split('\t')) - if len(sample) > 0: - datalist.append(sample) + datas = [] + for l in f: + datas.extend([(s, self.tag_v[t]) + for s, t in self.get_one(l, self.subtree)]) + ds = DataSet() + for words, tag in datas: + ds.append(Instance(words=words, raw_tag=tag)) + return ds - data = [self.get_one(sample) for sample in datalist] - data_list = list(filter(lambda x: x is not None, data)) + @staticmethod + def get_one(data, subtree): + from nltk.tree import Tree + tree = Tree.fromstring(data) + if subtree: + return [(t.leaves(), t.label()) for t in tree.subtrees()] + return [(tree.leaves(), tree.label())] + + +class JsonLoader(DataSetLoader): + """Load json-format data, + every line contains a json obj, like a dict + fields is the dict key that need to be load + """ + def __init__(self, **fields): + super(JsonLoader, self).__init__() + self.fields = {} + for k, v in fields.items(): + self.fields[k] = k if v is None else v + def load(self, path): + with open(path, 'r', encoding='utf-8') as f: + datas = [json.loads(l) for l in f] ds = DataSet() - for example in data_list: - ds.append(Instance(words=example[0], - pos_tags=example[1], - heads=example[2], - labels=example[3])) + for d in datas: + ins = Instance() + for k, v in d.items(): + if k in self.fields: + ins.add_field(self.fields[k], v) + ds.append(ins) return ds - def get_one(self, sample): - sample = list(map(list, zip(*sample))) - if len(sample) == 0: - return None - for w in sample[7]: - if w == '_': - print('Error Sample {}'.format(sample)) - return None - # return word_seq, pos_seq, head_seq, head_tag_seq - return sample[1], sample[3], list(map(int, sample[6])), sample[7] - def add_seg_tag(data): """ @@ -840,3 +902,4 @@ def add_seg_tag(data): new_sample.append((word[-1], 'E-' + pos)) _processed.append(list(map(list, zip(*new_sample)))) return _processed + diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/models/enas_trainer.py index 22e323ce..6b51c897 100644 --- a/fastNLP/models/enas_trainer.py +++ b/fastNLP/models/enas_trainer.py @@ -92,9 +92,9 @@ class ENASTrainer(fastNLP.Trainer): try: self.callback_manager.on_train_begin() self._train() - self.callback_manager.on_train_end(self.model) + self.callback_manager.on_train_end() except (CallbackException, KeyboardInterrupt) as e: - self.callback_manager.on_exception(e, self.model) + self.callback_manager.on_exception(e) if self.dev_data is not None: print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + @@ -134,7 +134,7 @@ class ENASTrainer(fastNLP.Trainer): if epoch == self.n_epochs + 1 - self.final_epochs: print('Entering the final stage. (Only train the selected structure)') # early stopping - self.callback_manager.on_epoch_begin(epoch, self.n_epochs) + self.callback_manager.on_epoch_begin() # 1. Training the shared parameters omega of the child models self.train_shared(pbar) @@ -155,7 +155,7 @@ class ENASTrainer(fastNLP.Trainer): pbar.write(eval_str) # lr decay; early stopping - self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) + self.callback_manager.on_epoch_end() # =============== epochs end =================== # pbar.close() # ============ tqdm end ============== # @@ -234,12 +234,12 @@ class ENASTrainer(fastNLP.Trainer): avg_loss += loss.item() # Is loss NaN or inf? requires_grad = False - self.callback_manager.on_backward_begin(loss, self.model) + self.callback_manager.on_backward_begin(loss) self._grad_backward(loss) - self.callback_manager.on_backward_end(self.model) + self.callback_manager.on_backward_end() self._update() - self.callback_manager.on_step_end(self.optimizer) + self.callback_manager.on_step_end() if (self.step+1) % self.print_every == 0: if self.use_tqdm: diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py new file mode 100644 index 00000000..3af3fe19 --- /dev/null +++ b/fastNLP/models/star_transformer.py @@ -0,0 +1,181 @@ +from fastNLP.modules.encoder.star_transformer import StarTransformer +from fastNLP.core.utils import seq_lens_to_masks + +import torch +from torch import nn +import torch.nn.functional as F + + +class StarTransEnc(nn.Module): + def __init__(self, vocab_size, emb_dim, + hidden_size, + num_layers, + num_head, + head_dim, + max_len, + emb_dropout, + dropout): + super(StarTransEnc, self).__init__() + self.emb_fc = nn.Linear(emb_dim, hidden_size) + self.emb_drop = nn.Dropout(emb_dropout) + self.embedding = nn.Embedding(vocab_size, emb_dim) + self.encoder = StarTransformer(hidden_size=hidden_size, + num_layers=num_layers, + num_head=num_head, + head_dim=head_dim, + dropout=dropout, + max_len=max_len) + + def forward(self, x, mask): + x = self.embedding(x) + x = self.emb_fc(self.emb_drop(x)) + nodes, relay = self.encoder(x, mask) + return nodes, relay + + +class Cls(nn.Module): + def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): + super(Cls, self).__init__() + self.fc = nn.Sequential( + nn.Linear(in_dim, hid_dim), + nn.LeakyReLU(), + nn.Dropout(dropout), + nn.Linear(hid_dim, num_cls), + ) + + def forward(self, x): + h = self.fc(x) + return h + + +class NLICls(nn.Module): + def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): + super(NLICls, self).__init__() + self.fc = nn.Sequential( + nn.Dropout(dropout), + nn.Linear(in_dim*4, hid_dim), #4 + nn.LeakyReLU(), + nn.Dropout(dropout), + nn.Linear(hid_dim, num_cls), + ) + + def forward(self, x1, x2): + x = torch.cat([x1, x2, torch.abs(x1-x2), x1*x2], 1) + h = self.fc(x) + return h + +class STSeqLabel(nn.Module): + """star-transformer model for sequence labeling + """ + def __init__(self, vocab_size, emb_dim, num_cls, + hidden_size=300, + num_layers=4, + num_head=8, + head_dim=32, + max_len=512, + cls_hidden_size=600, + emb_dropout=0.1, + dropout=0.1,): + super(STSeqLabel, self).__init__() + self.enc = StarTransEnc(vocab_size=vocab_size, + emb_dim=emb_dim, + hidden_size=hidden_size, + num_layers=num_layers, + num_head=num_head, + head_dim=head_dim, + max_len=max_len, + emb_dropout=emb_dropout, + dropout=dropout) + self.cls = Cls(hidden_size, num_cls, cls_hidden_size) + + def forward(self, word_seq, seq_lens): + mask = seq_lens_to_masks(seq_lens) + nodes, _ = self.enc(word_seq, mask) + output = self.cls(nodes) + output = output.transpose(1,2) # make hidden to be dim 1 + return {'output': output} # [bsz, n_cls, seq_len] + + def predict(self, word_seq, seq_lens): + y = self.forward(word_seq, seq_lens) + _, pred = y['output'].max(1) + return {'output': pred, 'seq_lens': seq_lens} + + +class STSeqCls(nn.Module): + """star-transformer model for sequence classification + """ + + def __init__(self, vocab_size, emb_dim, num_cls, + hidden_size=300, + num_layers=4, + num_head=8, + head_dim=32, + max_len=512, + cls_hidden_size=600, + emb_dropout=0.1, + dropout=0.1,): + super(STSeqCls, self).__init__() + self.enc = StarTransEnc(vocab_size=vocab_size, + emb_dim=emb_dim, + hidden_size=hidden_size, + num_layers=num_layers, + num_head=num_head, + head_dim=head_dim, + max_len=max_len, + emb_dropout=emb_dropout, + dropout=dropout) + self.cls = Cls(hidden_size, num_cls, cls_hidden_size) + + def forward(self, word_seq, seq_lens): + mask = seq_lens_to_masks(seq_lens) + nodes, relay = self.enc(word_seq, mask) + y = 0.5 * (relay + nodes.max(1)[0]) + output = self.cls(y) # [bsz, n_cls] + return {'output': output} + + def predict(self, word_seq, seq_lens): + y = self.forward(word_seq, seq_lens) + _, pred = y['output'].max(1) + return {'output': pred} + + +class STNLICls(nn.Module): + """star-transformer model for NLI + """ + + def __init__(self, vocab_size, emb_dim, num_cls, + hidden_size=300, + num_layers=4, + num_head=8, + head_dim=32, + max_len=512, + cls_hidden_size=600, + emb_dropout=0.1, + dropout=0.1,): + super(STNLICls, self).__init__() + self.enc = StarTransEnc(vocab_size=vocab_size, + emb_dim=emb_dim, + hidden_size=hidden_size, + num_layers=num_layers, + num_head=num_head, + head_dim=head_dim, + max_len=max_len, + emb_dropout=emb_dropout, + dropout=dropout) + self.cls = NLICls(hidden_size, num_cls, cls_hidden_size) + + def forward(self, word_seq1, word_seq2, seq_lens1, seq_lens2): + mask1 = seq_lens_to_masks(seq_lens1) + mask2 = seq_lens_to_masks(seq_lens2) + def enc(seq, mask): + nodes, relay = self.enc(seq, mask) + return 0.5 * (relay + nodes.max(1)[0]) + y1 = enc(word_seq1, mask1) + y2 = enc(word_seq2, mask2) + output = self.cls(y1, y2) # [bsz, n_cls] + return {'output': output} + + def predict(self, word_seq1, word_seq2, seq_lens1, seq_lens2): + y = self.forward(word_seq1, word_seq2, seq_lens1, seq_lens2) + _, pred = y['output'].max(1) + return {'output': pred} diff --git a/reproduction/README.md b/reproduction/README.md new file mode 100644 index 00000000..1c93c6bc --- /dev/null +++ b/reproduction/README.md @@ -0,0 +1,44 @@ +# 模型复现 +这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 + +复现的模型有: +- Star-Transformer +- ... + + +## Star-Transformer +[reference](https://arxiv.org/abs/1902.09113) +### Performance +|任务| 数据集 | SOTA | 模型表现 | +|------|------| ------| ------| +|Pos Tagging|CTB 9.0|-|ACC 92.31| +|Pos Tagging|CONLL 2012|-|ACC 96.51| +|Named Entity Recognition|CONLL 2012|-|F1 85.66| +|Text Classification|SST|-|49.18| +|Natural Language Inference|SNLI|-|83.76| + +### Usage +``` python +# for sequence labeling(ner, pos tagging, etc) +from fastNLP.models.star_transformer import STSeqLabel +model = STSeqLabel( + vocab_size=10000, num_cls=50, + emb_dim=300) + + +# for sequence classification +from fastNLP.models.star_transformer import STSeqCls +model = STSeqCls( + vocab_size=10000, num_cls=50, + emb_dim=300) + + +# for natural language inference +from fastNLP.models.star_transformer import STNLICls +model = STNLICls( + vocab_size=10000, num_cls=50, + emb_dim=300) + +``` + +## ... diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 68c874fa..ee5b0e58 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -353,7 +353,7 @@ class TestTutorial(unittest.TestCase): train_data[-1], dev_data[-1], test_data[-1] # 读入vocab文件 - with open('vocab.txt') as f: + with open('vocab.txt', encoding='utf-8') as f: lines = f.readlines() vocabs = [] for line in lines: