diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index b9bc7b70..38af57b3 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.utils import load_url from fastNLP.api.processor import ModelProcessor -from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag +from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance from fastNLP.api.pipeline import Pipeline from fastNLP.core.metrics import SpanFPreRecMetric @@ -77,12 +77,11 @@ class POS(API): if not hasattr(self, "pipeline"): raise ValueError("You have to load model first.") - sentence_list = [] + sentence_list = content # 1. 检查sentence的类型 - if isinstance(content, str): - sentence_list.append(content) - elif isinstance(content, list): - sentence_list = content + for sentence in sentence_list: + if not all((type(obj) == str for obj in sentence)): + raise ValueError("Input must be list of list of string.") # 2. 组建dataset dataset = DataSet() @@ -91,33 +90,35 @@ class POS(API): # 3. 使用pipeline self.pipeline(dataset) - def decode_tags(ins): - pred_tags = ins["tag"] - chars = ins["words"] - words = [] - start_idx = 0 - for idx, tag in enumerate(pred_tags): - if tag[0] == "S": - words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) - start_idx = idx + 1 - elif tag[0] == "E": - words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) - start_idx = idx + 1 - return words - - dataset.apply(decode_tags, new_field_name="tag_output") - - output = dataset.field_arrays["tag_output"].content + # def decode_tags(ins): + # pred_tags = ins["tag"] + # chars = ins["words"] + # words = [] + # start_idx = 0 + # for idx, tag in enumerate(pred_tags): + # if tag[0] == "S": + # words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) + # start_idx = idx + 1 + # elif tag[0] == "E": + # words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) + # start_idx = idx + 1 + # return words + # + # dataset.apply(decode_tags, new_field_name="tag_output") + + output = dataset.field_arrays["tag"].content if isinstance(content, str): return output[0] elif isinstance(content, list): return output def test(self, file_path): - test_data = ZhConllPOSReader().load(file_path) + test_data = ConllxDataLoader().load(file_path) - tag_vocab = self._dict["tag_vocab"] - pipeline = self._dict["pipeline"] + with open("model_pp_0117.pkl", "rb") as f: + save_dict = torch.load(f) + tag_vocab = save_dict["tag_vocab"] + pipeline = save_dict["pipeline"] index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) pipeline.pipeline = [index_tag] + pipeline.pipeline diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index e6760a28..f354ffc6 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -169,7 +169,7 @@ class CallbackManager(Callback): pass @transfer - def on_exception(self, exception, model, indices): + def on_exception(self, exception, model): pass @@ -235,7 +235,12 @@ class GradientClipCallback(Callback): self.clip_fun(model.parameters(), self.clip_value) -class EarlyStopError(BaseException): +class CallbackException(BaseException): + def __init__(self, msg): + super(CallbackException, self).__init__(msg) + + +class EarlyStopError(CallbackException): def __init__(self, msg): super(EarlyStopError, self).__init__(msg) @@ -266,6 +271,48 @@ class EarlyStopCallback(Callback): def on_exception(self, exception, model): if isinstance(exception, EarlyStopError): print("Early Stopping triggered in epoch {}!".format(self.epoch)) + else: + raise exception # 抛出陌生Error + + +class LRScheduler(Callback): + def __init__(self, lr_scheduler): + """对PyTorch LR Scheduler的包装 + + :param lr_scheduler: PyTorch的lr_scheduler + """ + super(LRScheduler, self).__init__() + import torch.optim + if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): + self.scheduler = lr_scheduler + else: + raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") + + def before_epoch(self, cur_epoch, total_epoch): + self.scheduler.step() + print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) + + +class ControlC(Callback): + def __init__(self, quit_all): + """ + + :param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer + """ + super(ControlC, self).__init__() + if type(quit_all) != bool: + raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") + self.quit_all = quit_all + + def on_exception(self, exception, model): + if isinstance(exception, KeyboardInterrupt): + if self.quit_all is True: + import sys + sys.exit(0) # 直接退出程序 + else: + pass + else: + raise exception # 抛出陌生Error if __name__ == "__main__": diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 07d94d11..a5861091 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -14,7 +14,7 @@ except: from fastNLP.core.utils import pseudo_tqdm as tqdm from fastNLP.core.batch import Batch -from fastNLP.core.callback import CallbackManager +from fastNLP.core.callback import CallbackManager, CallbackException from fastNLP.core.dataset import DataSet from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics @@ -122,6 +122,9 @@ class Trainer(object): self.print_every = int(print_every) self.validate_every = int(validate_every) if validate_every!=0 else -1 self.best_metric_indicator = None + self.best_dev_epoch = None + self.best_dev_step = None + self.best_dev_perf = None self.sampler = sampler self.num_workers = num_workers self.pin_memory = pin_memory @@ -212,7 +215,7 @@ class Trainer(object): self.callback_manager.before_train() self._train() self.callback_manager.after_train(self.model) - except BaseException as e: + except (CallbackException, KeyboardInterrupt) as e: self.callback_manager.on_exception(e, self.model) if self.dev_data is not None: diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index fb781c3e..c1092e53 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -876,7 +876,7 @@ class ConllPOSReader(object): class ConllxDataLoader(object): - def load(self, path): + def load(self, path, return_dataset=False): datalist = [] with open(path, 'r', encoding='utf-8') as f: sample = [] @@ -894,10 +894,12 @@ class ConllxDataLoader(object): data = [self.get_one(sample) for sample in datalist] data_list = list(filter(lambda x: x is not None, data)) - ds = DataSet() - for example in data_list: - ds.append(Instance(words=example[0], tag=example[1])) - return ds + if return_dataset is True: + ds = DataSet() + for example in data_list: + ds.append(Instance(words=example[0], tag=example[1])) + data_list = ds + return data_list def get_one(self, sample): sample = list(map(list, zip(*sample))) diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index e5c4dc6b..59f2be1b 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -1,8 +1,9 @@ import unittest import numpy as np +import torch -from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback +from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance from fastNLP.core.losses import BCELoss @@ -76,3 +77,32 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[EarlyStopCallback(5)]) trainer.train() + + def test_lr_scheduler(self): + data_set, model = prepare_env() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=50, + batch_size=32, + print_every=50, + optimizer=optimizer, + check_code_level=2, + use_tqdm=False, + dev_data=data_set, + metrics=AccuracyMetric(pred="predict", target="y"), + callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) + trainer.train() + + def test_KeyBoardInterrupt(self): + data_set, model = prepare_env() + trainer = Trainer(data_set, model, + loss=BCELoss(pred="predict", target="y"), + n_epochs=50, + batch_size=32, + print_every=50, + optimizer=SGD(lr=0.1), + check_code_level=2, + use_tqdm=False, + callbacks=[ControlC(False)]) + trainer.train()