@@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.utils import load_url | from fastNLP.api.utils import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag | |||||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -77,12 +77,11 @@ class POS(API): | |||||
if not hasattr(self, "pipeline"): | if not hasattr(self, "pipeline"): | ||||
raise ValueError("You have to load model first.") | raise ValueError("You have to load model first.") | ||||
sentence_list = [] | |||||
sentence_list = content | |||||
# 1. 检查sentence的类型 | # 1. 检查sentence的类型 | ||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
for sentence in sentence_list: | |||||
if not all((type(obj) == str for obj in sentence)): | |||||
raise ValueError("Input must be list of list of string.") | |||||
# 2. 组建dataset | # 2. 组建dataset | ||||
dataset = DataSet() | dataset = DataSet() | ||||
@@ -91,33 +90,35 @@ class POS(API): | |||||
# 3. 使用pipeline | # 3. 使用pipeline | ||||
self.pipeline(dataset) | self.pipeline(dataset) | ||||
def decode_tags(ins): | |||||
pred_tags = ins["tag"] | |||||
chars = ins["words"] | |||||
words = [] | |||||
start_idx = 0 | |||||
for idx, tag in enumerate(pred_tags): | |||||
if tag[0] == "S": | |||||
words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
elif tag[0] == "E": | |||||
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||||
start_idx = idx + 1 | |||||
return words | |||||
dataset.apply(decode_tags, new_field_name="tag_output") | |||||
output = dataset.field_arrays["tag_output"].content | |||||
# def decode_tags(ins): | |||||
# pred_tags = ins["tag"] | |||||
# chars = ins["words"] | |||||
# words = [] | |||||
# start_idx = 0 | |||||
# for idx, tag in enumerate(pred_tags): | |||||
# if tag[0] == "S": | |||||
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:]) | |||||
# start_idx = idx + 1 | |||||
# elif tag[0] == "E": | |||||
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:]) | |||||
# start_idx = idx + 1 | |||||
# return words | |||||
# | |||||
# dataset.apply(decode_tags, new_field_name="tag_output") | |||||
output = dataset.field_arrays["tag"].content | |||||
if isinstance(content, str): | if isinstance(content, str): | ||||
return output[0] | return output[0] | ||||
elif isinstance(content, list): | elif isinstance(content, list): | ||||
return output | return output | ||||
def test(self, file_path): | def test(self, file_path): | ||||
test_data = ZhConllPOSReader().load(file_path) | |||||
test_data = ConllxDataLoader().load(file_path) | |||||
tag_vocab = self._dict["tag_vocab"] | |||||
pipeline = self._dict["pipeline"] | |||||
with open("model_pp_0117.pkl", "rb") as f: | |||||
save_dict = torch.load(f) | |||||
tag_vocab = save_dict["tag_vocab"] | |||||
pipeline = save_dict["pipeline"] | |||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | ||||
pipeline.pipeline = [index_tag] + pipeline.pipeline | pipeline.pipeline = [index_tag] + pipeline.pipeline | ||||
@@ -169,7 +169,7 @@ class CallbackManager(Callback): | |||||
pass | pass | ||||
@transfer | @transfer | ||||
def on_exception(self, exception, model, indices): | |||||
def on_exception(self, exception, model): | |||||
pass | pass | ||||
@@ -235,7 +235,12 @@ class GradientClipCallback(Callback): | |||||
self.clip_fun(model.parameters(), self.clip_value) | self.clip_fun(model.parameters(), self.clip_value) | ||||
class EarlyStopError(BaseException): | |||||
class CallbackException(BaseException): | |||||
def __init__(self, msg): | |||||
super(CallbackException, self).__init__(msg) | |||||
class EarlyStopError(CallbackException): | |||||
def __init__(self, msg): | def __init__(self, msg): | ||||
super(EarlyStopError, self).__init__(msg) | super(EarlyStopError, self).__init__(msg) | ||||
@@ -266,6 +271,48 @@ class EarlyStopCallback(Callback): | |||||
def on_exception(self, exception, model): | def on_exception(self, exception, model): | ||||
if isinstance(exception, EarlyStopError): | if isinstance(exception, EarlyStopError): | ||||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | print("Early Stopping triggered in epoch {}!".format(self.epoch)) | ||||
else: | |||||
raise exception # 抛出陌生Error | |||||
class LRScheduler(Callback): | |||||
def __init__(self, lr_scheduler): | |||||
"""对PyTorch LR Scheduler的包装 | |||||
:param lr_scheduler: PyTorch的lr_scheduler | |||||
""" | |||||
super(LRScheduler, self).__init__() | |||||
import torch.optim | |||||
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | |||||
self.scheduler = lr_scheduler | |||||
else: | |||||
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.") | |||||
def before_epoch(self, cur_epoch, total_epoch): | |||||
self.scheduler.step() | |||||
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"]) | |||||
class ControlC(Callback): | |||||
def __init__(self, quit_all): | |||||
""" | |||||
:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||||
""" | |||||
super(ControlC, self).__init__() | |||||
if type(quit_all) != bool: | |||||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||||
self.quit_all = quit_all | |||||
def on_exception(self, exception, model): | |||||
if isinstance(exception, KeyboardInterrupt): | |||||
if self.quit_all is True: | |||||
import sys | |||||
sys.exit(0) # 直接退出程序 | |||||
else: | |||||
pass | |||||
else: | |||||
raise exception # 抛出陌生Error | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -14,7 +14,7 @@ except: | |||||
from fastNLP.core.utils import pseudo_tqdm as tqdm | from fastNLP.core.utils import pseudo_tqdm as tqdm | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.callback import CallbackManager | |||||
from fastNLP.core.callback import CallbackManager, CallbackException | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
@@ -122,6 +122,9 @@ class Trainer(object): | |||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) if validate_every!=0 else -1 | self.validate_every = int(validate_every) if validate_every!=0 else -1 | ||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.best_dev_epoch = None | |||||
self.best_dev_step = None | |||||
self.best_dev_perf = None | |||||
self.sampler = sampler | self.sampler = sampler | ||||
self.num_workers = num_workers | self.num_workers = num_workers | ||||
self.pin_memory = pin_memory | self.pin_memory = pin_memory | ||||
@@ -212,7 +215,7 @@ class Trainer(object): | |||||
self.callback_manager.before_train() | self.callback_manager.before_train() | ||||
self._train() | self._train() | ||||
self.callback_manager.after_train(self.model) | self.callback_manager.after_train(self.model) | ||||
except BaseException as e: | |||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
self.callback_manager.on_exception(e, self.model) | self.callback_manager.on_exception(e, self.model) | ||||
if self.dev_data is not None: | if self.dev_data is not None: | ||||
@@ -876,7 +876,7 @@ class ConllPOSReader(object): | |||||
class ConllxDataLoader(object): | class ConllxDataLoader(object): | ||||
def load(self, path): | |||||
def load(self, path, return_dataset=False): | |||||
datalist = [] | datalist = [] | ||||
with open(path, 'r', encoding='utf-8') as f: | with open(path, 'r', encoding='utf-8') as f: | ||||
sample = [] | sample = [] | ||||
@@ -894,10 +894,12 @@ class ConllxDataLoader(object): | |||||
data = [self.get_one(sample) for sample in datalist] | data = [self.get_one(sample) for sample in datalist] | ||||
data_list = list(filter(lambda x: x is not None, data)) | data_list = list(filter(lambda x: x is not None, data)) | ||||
ds = DataSet() | |||||
for example in data_list: | |||||
ds.append(Instance(words=example[0], tag=example[1])) | |||||
return ds | |||||
if return_dataset is True: | |||||
ds = DataSet() | |||||
for example in data_list: | |||||
ds.append(Instance(words=example[0], tag=example[1])) | |||||
data_list = ds | |||||
return data_list | |||||
def get_one(self, sample): | def get_one(self, sample): | ||||
sample = list(map(list, zip(*sample))) | sample = list(map(list, zip(*sample))) | ||||
@@ -1,8 +1,9 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | |||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback | |||||
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
@@ -76,3 +77,32 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[EarlyStopCallback(5)]) | callbacks=[EarlyStopCallback(5)]) | ||||
trainer.train() | trainer.train() | ||||
def test_lr_scheduler(self): | |||||
data_set, model = prepare_env() | |||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=50, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=optimizer, | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||||
trainer.train() | |||||
def test_KeyBoardInterrupt(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=50, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[ControlC(False)]) | |||||
trainer.train() |