Browse Source

* 重构POS API,改成接受word作为输入

* 添加两类Callback
* 完善Trainer对error的捕捉
tags/v0.3.1^2
FengZiYjun 6 years ago
parent
commit
ab953b43ab
5 changed files with 119 additions and 36 deletions
  1. +27
    -26
      fastNLP/api/api.py
  2. +49
    -2
      fastNLP/core/callback.py
  3. +5
    -2
      fastNLP/core/trainer.py
  4. +7
    -5
      fastNLP/io/dataset_loader.py
  5. +31
    -1
      test/core/test_callbacks.py

+ 27
- 26
fastNLP/api/api.py View File

@@ -9,7 +9,7 @@ from fastNLP.core.dataset import DataSet

from fastNLP.api.utils import load_url
from fastNLP.api.processor import ModelProcessor
from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric
@@ -77,12 +77,11 @@ class POS(API):
if not hasattr(self, "pipeline"):
raise ValueError("You have to load model first.")

sentence_list = []
sentence_list = content
# 1. 检查sentence的类型
if isinstance(content, str):
sentence_list.append(content)
elif isinstance(content, list):
sentence_list = content
for sentence in sentence_list:
if not all((type(obj) == str for obj in sentence)):
raise ValueError("Input must be list of list of string.")

# 2. 组建dataset
dataset = DataSet()
@@ -91,33 +90,35 @@ class POS(API):
# 3. 使用pipeline
self.pipeline(dataset)

def decode_tags(ins):
pred_tags = ins["tag"]
chars = ins["words"]
words = []
start_idx = 0
for idx, tag in enumerate(pred_tags):
if tag[0] == "S":
words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
start_idx = idx + 1
elif tag[0] == "E":
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
start_idx = idx + 1
return words
dataset.apply(decode_tags, new_field_name="tag_output")
output = dataset.field_arrays["tag_output"].content
# def decode_tags(ins):
# pred_tags = ins["tag"]
# chars = ins["words"]
# words = []
# start_idx = 0
# for idx, tag in enumerate(pred_tags):
# if tag[0] == "S":
# words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
# start_idx = idx + 1
# elif tag[0] == "E":
# words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
# start_idx = idx + 1
# return words
#
# dataset.apply(decode_tags, new_field_name="tag_output")
output = dataset.field_arrays["tag"].content
if isinstance(content, str):
return output[0]
elif isinstance(content, list):
return output

def test(self, file_path):
test_data = ZhConllPOSReader().load(file_path)
test_data = ConllxDataLoader().load(file_path)

tag_vocab = self._dict["tag_vocab"]
pipeline = self._dict["pipeline"]
with open("model_pp_0117.pkl", "rb") as f:
save_dict = torch.load(f)
tag_vocab = save_dict["tag_vocab"]
pipeline = save_dict["pipeline"]
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
pipeline.pipeline = [index_tag] + pipeline.pipeline



+ 49
- 2
fastNLP/core/callback.py View File

@@ -169,7 +169,7 @@ class CallbackManager(Callback):
pass

@transfer
def on_exception(self, exception, model, indices):
def on_exception(self, exception, model):
pass


@@ -235,7 +235,12 @@ class GradientClipCallback(Callback):
self.clip_fun(model.parameters(), self.clip_value)


class EarlyStopError(BaseException):
class CallbackException(BaseException):
def __init__(self, msg):
super(CallbackException, self).__init__(msg)


class EarlyStopError(CallbackException):
def __init__(self, msg):
super(EarlyStopError, self).__init__(msg)

@@ -266,6 +271,48 @@ class EarlyStopCallback(Callback):
def on_exception(self, exception, model):
if isinstance(exception, EarlyStopError):
print("Early Stopping triggered in epoch {}!".format(self.epoch))
else:
raise exception # 抛出陌生Error


class LRScheduler(Callback):
def __init__(self, lr_scheduler):
"""对PyTorch LR Scheduler的包装

:param lr_scheduler: PyTorch的lr_scheduler
"""
super(LRScheduler, self).__init__()
import torch.optim
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
self.scheduler = lr_scheduler
else:
raise ValueError(f"Expect torch.optim.lr_scheduler for LRScheduler. Got {type(lr_scheduler)}.")

def before_epoch(self, cur_epoch, total_epoch):
self.scheduler.step()
print("scheduler step ", "lr=", self.trainer.optimizer.param_groups[0]["lr"])


class ControlC(Callback):
def __init__(self, quit_all):
"""

:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
"""
super(ControlC, self).__init__()
if type(quit_all) != bool:
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.")
self.quit_all = quit_all

def on_exception(self, exception, model):
if isinstance(exception, KeyboardInterrupt):
if self.quit_all is True:
import sys
sys.exit(0) # 直接退出程序
else:
pass
else:
raise exception # 抛出陌生Error


if __name__ == "__main__":


+ 5
- 2
fastNLP/core/trainer.py View File

@@ -14,7 +14,7 @@ except:
from fastNLP.core.utils import pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager
from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
@@ -122,6 +122,9 @@ class Trainer(object):
self.print_every = int(print_every)
self.validate_every = int(validate_every) if validate_every!=0 else -1
self.best_metric_indicator = None
self.best_dev_epoch = None
self.best_dev_step = None
self.best_dev_perf = None
self.sampler = sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
@@ -212,7 +215,7 @@ class Trainer(object):
self.callback_manager.before_train()
self._train()
self.callback_manager.after_train(self.model)
except BaseException as e:
except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model)

if self.dev_data is not None:


+ 7
- 5
fastNLP/io/dataset_loader.py View File

@@ -876,7 +876,7 @@ class ConllPOSReader(object):


class ConllxDataLoader(object):
def load(self, path):
def load(self, path, return_dataset=False):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
@@ -894,10 +894,12 @@ class ConllxDataLoader(object):
data = [self.get_one(sample) for sample in datalist]
data_list = list(filter(lambda x: x is not None, data))

ds = DataSet()
for example in data_list:
ds.append(Instance(words=example[0], tag=example[1]))
return ds
if return_dataset is True:
ds = DataSet()
for example in data_list:
ds.append(Instance(words=example[0], tag=example[1]))
data_list = ds
return data_list

def get_one(self, sample):
sample = list(map(list, zip(*sample)))


+ 31
- 1
test/core/test_callbacks.py View File

@@ -1,8 +1,9 @@
import unittest

import numpy as np
import torch

from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback
from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
@@ -76,3 +77,32 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[EarlyStopCallback(5)])
trainer.train()

def test_lr_scheduler(self):
data_set, model = prepare_env()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=50,
batch_size=32,
print_every=50,
optimizer=optimizer,
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))])
trainer.train()

def test_KeyBoardInterrupt(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=50,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
callbacks=[ControlC(False)])
trainer.train()

Loading…
Cancel
Save