@@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.utils import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag | |||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
@@ -77,12 +77,11 @@ class POS(API): | |||
if not hasattr(self, "pipeline"): | |||
raise ValueError("You have to load model first.") | |||
sentence_list = [] | |||
sentence_list = content | |||
# 1. 检查sentence的类型 | |||
if isinstance(content, str): | |||
sentence_list.append(content) | |||
elif isinstance(content, list): | |||
sentence_list = content | |||
for sentence in sentence_list: | |||
if not all((type(obj) == str for obj in sentence)): | |||
raise ValueError("Input must be list of list of string.") | |||
# 2. 组建dataset | |||
dataset = DataSet() | |||
@@ -91,33 +90,35 @@ class POS(API): | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
def decode_tags(ins): | |||
pred_tags = ins["tag"] | |||
chars = ins["words"] | |||
words = [] | |||
start_idx = 0 | |||
for idx, tag in enumerate(pred_tags): | |||
if tag[0] == "S": | |||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
start_idx = idx + 1 | |||
elif tag[0] == "E": | |||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
start_idx = idx + 1 | |||
return words | |||
dataset.apply(decode_tags, new_field_name="tag_output") | |||
output = dataset.field_arrays["tag_output"].content | |||
# def decode_tags(ins): | |||
# pred_tags = ins["tag"] | |||
# chars = ins["words"] | |||
# words = [] | |||
# start_idx = 0 | |||
# for idx, tag in enumerate(pred_tags): | |||
# if tag[0] == "S": | |||
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||
# start_idx = idx + 1 | |||
# elif tag[0] == "E": | |||
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||
# start_idx = idx + 1 | |||
# return words | |||
# | |||
# dataset.apply(decode_tags, new_field_name="tag_output") | |||
output = dataset.field_arrays["tag"].content | |||
if isinstance(content, str): | |||
return output[0] | |||
elif isinstance(content, list): | |||
return output | |||
def test(self, file_path): | |||
test_data = ZhConllPOSReader().load(file_path) | |||
test_data = ConllxDataLoader().load(file_path) | |||
tag_vocab = self._dict["tag_vocab"] | |||
pipeline = self._dict["pipeline"] | |||
with open("model_pp_0117.pkl", "rb") as f: | |||
save_dict = torch.load(f) | |||
tag_vocab = save_dict["tag_vocab"] | |||
pipeline = save_dict["pipeline"] | |||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||
@@ -169,7 +169,7 @@ class CallbackManager(Callback): | |||
pass | |||
@transfer | |||
def on_exception(self, exception, model, indices): | |||
def on_exception(self, exception, model): | |||
pass | |||
@@ -235,7 +235,12 @@ class GradientClipCallback(Callback): | |||
self.clip_fun(model.parameters(), self.clip_value) | |||
class EarlyStopError(BaseException): | |||
class CallbackException(BaseException): | |||
def __init__(self, msg): | |||
super(CallbackException, self).__init__(msg) | |||
class EarlyStopError(CallbackException): | |||
def __init__(self, msg): | |||
super(EarlyStopError, self).__init__(msg) | |||
@@ -266,6 +271,48 @@ class EarlyStopCallback(Callback): | |||
def on_exception(self, exception, model): | |||
if isinstance(exception, EarlyStopError): | |||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
else: | |||
raise exception # 抛出陌生Error | |||
class LRScheduler(Callback): | |||
def __init__(self, lr_scheduler): | |||
"""对PyTorch LR Scheduler的包装 | |||
:param lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
super(LRScheduler, self).__init__() | |||
import torch.optim | |||
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | |||
self.scheduler = lr_scheduler | |||
else: | |||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | |||
def before_epoch(self, cur_epoch, total_epoch): | |||
self.scheduler.step() | |||
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) | |||
class ControlC(Callback): | |||
def __init__(self, quit_all): | |||
""" | |||
:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
""" | |||
super(ControlC, self).__init__() | |||
if type(quit_all) != bool: | |||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||
self.quit_all = quit_all | |||
def on_exception(self, exception, model): | |||
if isinstance(exception, KeyboardInterrupt): | |||
if self.quit_all is True: | |||
import sys | |||
sys.exit(0) # 直接退出程序 | |||
else: | |||
pass | |||
else: | |||
raise exception # 抛出陌生Error | |||
if __name__ == "__main__": | |||
@@ -14,7 +14,7 @@ except: | |||
from fastNLP.core.utils import pseudo_tqdm as tqdm | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.callback import CallbackManager | |||
from fastNLP.core.callback import CallbackManager, CallbackException | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.losses import _prepare_losser | |||
from fastNLP.core.metrics import _prepare_metrics | |||
@@ -122,6 +122,9 @@ class Trainer(object): | |||
self.print_every = int(print_every) | |||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | |||
self.best_metric_indicator = None | |||
self.best_dev_epoch = None | |||
self.best_dev_step = None | |||
self.best_dev_perf = None | |||
self.sampler = sampler | |||
self.num_workers = num_workers | |||
self.pin_memory = pin_memory | |||
@@ -212,7 +215,7 @@ class Trainer(object): | |||
self.callback_manager.before_train() | |||
self._train() | |||
self.callback_manager.after_train(self.model) | |||
except BaseException as e: | |||
except (CallbackException, KeyboardInterrupt) as e: | |||
self.callback_manager.on_exception(e, self.model) | |||
if self.dev_data is not None: | |||
@@ -876,7 +876,7 @@ class ConllPOSReader(object): | |||
class ConllxDataLoader(object): | |||
def load(self, path): | |||
def load(self, path, return_dataset=False): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
@@ -894,10 +894,12 @@ class ConllxDataLoader(object): | |||
data = [self.get_one(sample) for sample in datalist] | |||
data_list = list(filter(lambda x: x is not None, data)) | |||
ds = DataSet() | |||
for example in data_list: | |||
ds.append(Instance(words=example[0], tag=example[1])) | |||
return ds | |||
if return_dataset is True: | |||
ds = DataSet() | |||
for example in data_list: | |||
ds.append(Instance(words=example[0], tag=example[1])) | |||
data_list = ds | |||
return data_list | |||
def get_one(self, sample): | |||
sample = list(map(list, zip(*sample))) | |||
@@ -1,8 +1,9 @@ | |||
import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback | |||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
@@ -76,3 +77,32 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[EarlyStopCallback(5)]) | |||
trainer.train() | |||
def test_lr_scheduler(self): | |||
data_set, model = prepare_env() | |||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=50, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=optimizer, | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||
trainer.train() | |||
def test_KeyBoardInterrupt(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=50, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
callbacks=[ControlC(False)]) | |||
trainer.train() |