| @@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet | |||
| from fastNLP.api.utils import load_url | |||
| from fastNLP.api.processor import ModelProcessor | |||
| from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag | |||
| from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.api.pipeline import Pipeline | |||
| from fastNLP.core.metrics import SpanFPreRecMetric | |||
| @@ -77,12 +77,11 @@ class POS(API): | |||
| if not hasattr(self, "pipeline"): | |||
| raise ValueError("You have to load model first.") | |||
| sentence_list = [] | |||
| sentence_list = content | |||
| # 1. 检查sentence的类型 | |||
| if isinstance(content, str): | |||
| sentence_list.append(content) | |||
| elif isinstance(content, list): | |||
| sentence_list = content | |||
| for sentence in sentence_list: | |||
| if not all((type(obj) == str for obj in sentence)): | |||
| raise ValueError("Input must be list of list of string.") | |||
| # 2. 组建dataset | |||
| dataset = DataSet() | |||
| @@ -91,33 +90,35 @@ class POS(API): | |||
| # 3. 使用pipeline | |||
| self.pipeline(dataset) | |||
| def decode_tags(ins): | |||
| pred_tags = ins["tag"] | |||
| chars = ins["words"] | |||
| words = [] | |||
| start_idx = 0 | |||
| for idx, tag in enumerate(pred_tags): | |||
| if tag[0] == "S": | |||
| words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
| start_idx = idx + 1 | |||
| elif tag[0] == "E": | |||
| words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
| start_idx = idx + 1 | |||
| return words | |||
| dataset.apply(decode_tags, new_field_name="tag_output") | |||
| output = dataset.field_arrays["tag_output"].content | |||
| # def decode_tags(ins): | |||
| # pred_tags = ins["tag"] | |||
| # chars = ins["words"] | |||
| # words = [] | |||
| # start_idx = 0 | |||
| # for idx, tag in enumerate(pred_tags): | |||
| # if tag[0] == "S": | |||
| # words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
| # start_idx = idx + 1 | |||
| # elif tag[0] == "E": | |||
| # words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
| # start_idx = idx + 1 | |||
| # return words | |||
| # | |||
| # dataset.apply(decode_tags, new_field_name="tag_output") | |||
| output = dataset.field_arrays["tag"].content | |||
| if isinstance(content, str): | |||
| return output[0] | |||
| elif isinstance(content, list): | |||
| return output | |||
| def test(self, file_path): | |||
| test_data = ZhConllPOSReader().load(file_path) | |||
| test_data = ConllxDataLoader().load(file_path) | |||
| tag_vocab = self._dict["tag_vocab"] | |||
| pipeline = self._dict["pipeline"] | |||
| with open("model_pp_0117.pkl", "rb") as f: | |||
| save_dict = torch.load(f) | |||
| tag_vocab = save_dict["tag_vocab"] | |||
| pipeline = save_dict["pipeline"] | |||
| index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||
| pipeline.pipeline = [index_tag] + pipeline.pipeline | |||
| @@ -169,7 +169,7 @@ class CallbackManager(Callback): | |||
| pass | |||
| @transfer | |||
| def on_exception(self, exception, model, indices): | |||
| def on_exception(self, exception, model): | |||
| pass | |||
| @@ -235,7 +235,12 @@ class GradientClipCallback(Callback): | |||
| self.clip_fun(model.parameters(), self.clip_value) | |||
| class EarlyStopError(BaseException): | |||
| class CallbackException(BaseException): | |||
| def __init__(self, msg): | |||
| super(CallbackException, self).__init__(msg) | |||
| class EarlyStopError(CallbackException): | |||
| def __init__(self, msg): | |||
| super(EarlyStopError, self).__init__(msg) | |||
| @@ -266,6 +271,48 @@ class EarlyStopCallback(Callback): | |||
| def on_exception(self, exception, model): | |||
| if isinstance(exception, EarlyStopError): | |||
| print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
| else: | |||
| raise exception # 抛出陌生Error | |||
| class LRScheduler(Callback): | |||
| def __init__(self, lr_scheduler): | |||
| """对PyTorch LR Scheduler的包装 | |||
| :param lr_scheduler: PyTorch的lr_scheduler | |||
| """ | |||
| super(LRScheduler, self).__init__() | |||
| import torch.optim | |||
| if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | |||
| self.scheduler = lr_scheduler | |||
| else: | |||
| raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | |||
| def before_epoch(self, cur_epoch, total_epoch): | |||
| self.scheduler.step() | |||
| print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) | |||
| class ControlC(Callback): | |||
| def __init__(self, quit_all): | |||
| """ | |||
| :param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
| """ | |||
| super(ControlC, self).__init__() | |||
| if type(quit_all) != bool: | |||
| raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||
| self.quit_all = quit_all | |||
| def on_exception(self, exception, model): | |||
| if isinstance(exception, KeyboardInterrupt): | |||
| if self.quit_all is True: | |||
| import sys | |||
| sys.exit(0) # 直接退出程序 | |||
| else: | |||
| pass | |||
| else: | |||
| raise exception # 抛出陌生Error | |||
| if __name__ == "__main__": | |||
| @@ -14,7 +14,7 @@ except: | |||
| from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.callback import CallbackManager | |||
| from fastNLP.core.callback import CallbackManager, CallbackException | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.losses import _prepare_losser | |||
| from fastNLP.core.metrics import _prepare_metrics | |||
| @@ -122,6 +122,9 @@ class Trainer(object): | |||
| self.print_every = int(print_every) | |||
| self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||
| self.best_metric_indicator = None | |||
| self.best_dev_epoch = None | |||
| self.best_dev_step = None | |||
| self.best_dev_perf = None | |||
| self.sampler = sampler | |||
| self.num_workers = num_workers | |||
| self.pin_memory = pin_memory | |||
| @@ -212,7 +215,7 @@ class Trainer(object): | |||
| self.callback_manager.before_train() | |||
| self._train() | |||
| self.callback_manager.after_train(self.model) | |||
| except BaseException as e: | |||
| except (CallbackException, KeyboardInterrupt) as e: | |||
| self.callback_manager.on_exception(e, self.model) | |||
| if self.dev_data is not None: | |||
| @@ -876,7 +876,7 @@ class ConllPOSReader(object): | |||
| class ConllxDataLoader(object): | |||
| def load(self, path): | |||
| def load(self, path, return_dataset=False): | |||
| datalist = [] | |||
| with open(path, 'r', encoding='utf-8') as f: | |||
| sample = [] | |||
| @@ -894,10 +894,12 @@ class ConllxDataLoader(object): | |||
| data = [self.get_one(sample) for sample in datalist] | |||
| data_list = list(filter(lambda x: x is not None, data)) | |||
| ds = DataSet() | |||
| for example in data_list: | |||
| ds.append(Instance(words=example[0], tag=example[1])) | |||
| return ds | |||
| if return_dataset is True: | |||
| ds = DataSet() | |||
| for example in data_list: | |||
| ds.append(Instance(words=example[0], tag=example[1])) | |||
| data_list = ds | |||
| return data_list | |||
| def get_one(self, sample): | |||
| sample = list(map(list, zip(*sample))) | |||
| @@ -1,8 +1,9 @@ | |||
| import unittest | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback | |||
| from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.losses import BCELoss | |||
| @@ -76,3 +77,32 @@ class TestCallback(unittest.TestCase): | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[EarlyStopCallback(5)]) | |||
| trainer.train() | |||
| def test_lr_scheduler(self): | |||
| data_set, model = prepare_env() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=50, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=optimizer, | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| dev_data=data_set, | |||
| metrics=AccuracyMetric(pred="predict", target="y"), | |||
| callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||
| trainer.train() | |||
| def test_KeyBoardInterrupt(self): | |||
| data_set, model = prepare_env() | |||
| trainer = Trainer(data_set, model, | |||
| loss=BCELoss(pred="predict", target="y"), | |||
| n_epochs=50, | |||
| batch_size=32, | |||
| print_every=50, | |||
| optimizer=SGD(lr=0.1), | |||
| check_code_level=2, | |||
| use_tqdm=False, | |||
| callbacks=[ControlC(False)]) | |||
| trainer.train() | |||