Browse Source

[add] tutorial for callback

[add] test case for logger, batch
[bugfix] batch.py
tags/v0.4.10
yunfan 5 years ago
parent
commit
29e4de36e3
6 changed files with 758 additions and 42 deletions
  1. +59
    -31
      fastNLP/core/batch.py
  2. +3
    -0
      fastNLP/core/callback.py
  3. +28
    -2
      test/core/test_batch.py
  4. +13
    -9
      test/core/test_callbacks.py
  5. +33
    -0
      test/core/test_logger.py
  6. +622
    -0
      tutorials/tutorial_callback.ipynb

+ 59
- 31
fastNLP/core/batch.py View File

@@ -111,11 +111,31 @@ class SamplerAdapter(torch.utils.data.Sampler):




class BatchIter: class BatchIter:
def __init__(self):
self.dataiter = None
self.num_batches = None
def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
if not isinstance(sampler, torch.utils.data.Sampler):
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler
if collate_fn is None:
# pytoch <= 1.1 中不能设置collate_fn=None
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
else:
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)

# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size
self.cur_batch_indices = None self.cur_batch_indices = None
self.batch_size = None


def init_iter(self): def init_iter(self):
pass pass
@@ -135,12 +155,6 @@ class BatchIter:
num_batches += 1 num_batches += 1
return num_batches return num_batches


def __iter__(self):
self.init_iter()
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y

def get_batch_indices(self): def get_batch_indices(self):
""" """
获取当前已经输出的batch的index。 获取当前已经输出的batch的index。
@@ -170,7 +184,7 @@ class DataSetIter(BatchIter):
""" """
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False, num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
timeout=0, worker_init_fn=None, collate_fn=None):
""" """
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
@@ -187,22 +201,21 @@ class DataSetIter(BatchIter):
:param timeout: :param timeout:
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
""" """
super().__init__()
assert isinstance(dataset, DataSet) assert isinstance(dataset, DataSet)
if not isinstance(sampler, torch.utils.data.Sampler):
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler
dataset = DataSetGetter(dataset, as_numpy) dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn
super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
)

def __iter__(self):
self.init_iter()
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y




class TorchLoaderIter(BatchIter): class TorchLoaderIter(BatchIter):
@@ -210,12 +223,27 @@ class TorchLoaderIter(BatchIter):
与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。


""" """
def __init__(self, dataset):
super().__init__()
assert isinstance(dataset, torch.utils.data.DataLoader)
self.dataiter = dataset
self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last)
self.batch_size = dataset.batch_size
def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
assert len(dataset) > 0
ins = dataset[0]
assert len(ins) == 2 and \
isinstance(ins[0], dict) and \
isinstance(ins[1], dict), 'DataSet should return two dict, as X and Y'

super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
)

def __iter__(self):
self.init_iter()
for batch_x, batch_y in self.dataiter:
self.cur_batch_indices = None
yield batch_x, batch_y




def _to_tensor(batch, field_dtype): def _to_tensor(batch, field_dtype):


+ 3
- 0
fastNLP/core/callback.py View File

@@ -1039,6 +1039,9 @@ class EchoCallback(Callback):
class TesterCallback(Callback): class TesterCallback(Callback):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(TesterCallback, self).__init__() super(TesterCallback, self).__init__()
if hasattr(model, 'module'):
# for data parallel model
model = model.module
self.tester = Tester(data, model, self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size, metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0) num_workers=num_workers, verbose=0)


+ 28
- 2
test/core/test_batch.py View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch


from fastNLP import DataSetIter
from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import SequentialSampler from fastNLP import SequentialSampler
@@ -149,7 +149,33 @@ class TestCase1(unittest.TestCase):
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
pass pass

def testTensorLoaderIter(self):
class FakeData:
def __init__(self, return_dict=True):
self.x = [[1,2,3], [4,5,6]]
self.return_dict = return_dict

def __len__(self):
return len(self.x)

def __getitem__(self, i):
x = self.x[i]
y = 0
if self.return_dict:
return {'x':x}, {'y':y}
return x, y

data1 = FakeData()
dataiter = TorchLoaderIter(data1, batch_size=2)
for x, y in dataiter:
print(x, y)

def func():
data2 = FakeData(return_dict=False)
dataiter = TorchLoaderIter(data2, batch_size=2)
self.assertRaises(Exception, func)

""" """
def test_multi_workers_batch(self): def test_multi_workers_batch(self):
batch_size = 32 batch_size = 32


+ 13
- 9
test/core/test_callbacks.py View File

@@ -17,6 +17,7 @@ from fastNLP.models.base_model import NaiveClassifier
from fastNLP.core.callback import EarlyStopError from fastNLP.core.callback import EarlyStopError
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback
from fastNLP.core.callback import WarmupCallback from fastNLP.core.callback import WarmupCallback
import tempfile


def prepare_env(): def prepare_env():
def prepare_fake_dataset(): def prepare_fake_dataset():
@@ -40,7 +41,13 @@ def prepare_env():




class TestCallback(unittest.TestCase): class TestCallback(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
pass
# shutil.rmtree(self.tempdir)

def test_gradient_clip(self): def test_gradient_clip(self):
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
@@ -93,7 +100,7 @@ class TestCallback(unittest.TestCase):
path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time)) path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time))
if os.path.exists(path): if os.path.exists(path):
shutil.rmtree(path) shutil.rmtree(path)
def test_readonly_property(self): def test_readonly_property(self):
from fastNLP.core.callback import Callback from fastNLP.core.callback import Callback
passed_epochs = [] passed_epochs = []
@@ -131,8 +138,7 @@ class TestCallback(unittest.TestCase):


def test_fitlog_callback(self): def test_fitlog_callback(self):
import fitlog import fitlog
os.makedirs('logs/')
fitlog.set_log_dir('logs/')
fitlog.set_log_dir(self.tempdir)
data_set, model = prepare_env() data_set, model = prepare_env()
from fastNLP import Tester from fastNLP import Tester
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y"))
@@ -143,21 +149,19 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=fitlog_callback, check_code_level=2) callbacks=fitlog_callback, check_code_level=2)
trainer.train() trainer.train()
shutil.rmtree('logs/')


def test_save_model_callback(self): def test_save_model_callback(self):
data_set, model = prepare_env() data_set, model = prepare_env()
top = 3 top = 3
save_model_callback = SaveModelCallback('save_models/', top=top)
save_model_callback = SaveModelCallback(self.tempdir, top=top)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, batch_size=32, n_epochs=5, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=save_model_callback, check_code_level=2) callbacks=save_model_callback, check_code_level=2)
trainer.train() trainer.train()


timestamp = os.listdir('save_models')[0]
self.assertEqual(len(os.listdir(os.path.join('save_models', timestamp))), top)
shutil.rmtree('save_models/')
timestamp = os.listdir(self.tempdir)[0]
self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top)


def test_warmup_callback(self): def test_warmup_callback(self):
data_set, model = prepare_env() data_set, model = prepare_env()


+ 33
- 0
test/core/test_logger.py View File

@@ -0,0 +1,33 @@
from fastNLP import logger
import unittest
from unittest.mock import patch
import os
import io
import tempfile
import shutil

class TestLogger(unittest.TestCase):
msg = 'some test logger msg'

def setUp(self):
self.tmpdir = tempfile.mkdtemp()

def tearDown(self):
pass
# shutil.rmtree(self.tmpdir)

def test_add_file(self):
fn = os.path.join(self.tmpdir, 'log.txt')
logger.add_file(fn)
logger.info(self.msg)
with open(fn, 'r') as f:
line = ''.join([l for l in f])
print(line)
self.assertTrue(self.msg in line)

@patch('sys.stdout', new_callable=io.StringIO)
def test_stdout(self, mock_out):
for i in range(3):
logger.info(self.msg)

self.assertEqual([self.msg for i in range(3)], mock_out.getvalue().strip().split('\n'))

+ 622
- 0
tutorials/tutorial_callback.ipynb View File

@@ -0,0 +1,622 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用 Callback 自定义你的训练过程"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 什么是 Callback\n",
"- 使用 Callback \n",
"- 一些常用的 Callback\n",
"- 自定义实现 Callback"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"什么是Callback\n",
"------\n",
"\n",
"Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n",
"\n",
"fastNLP 中提供了很多常用的 Callback ,开箱即用。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"使用 Callback\n",
" ------\n",
"\n",
"使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:34:46.465871Z",
"start_time": "2019-09-17T07:34:30.648758Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\ttest has 1200 instances.\n",
"\ttrain has 9600 instances.\n",
"\tdev has 1200 instances.\n",
"In total 2 vocabs:\n",
"\tchars has 4409 entries.\n",
"\ttarget has 2 entries.\n",
"\n",
"training epochs started 2019-09-17-03-34-34\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.863333\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.886667\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.890833\n",
"\n",
"\r\n",
"In Epoch:3/Step:900, got best dev performance:\n",
"AccuracyMetric: acc=0.890833\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import (Callback, EarlyStopCallback,\n",
" Trainer, CrossEntropyLoss, AccuracyMetric)\n",
"from fastNLP.models import CNNText\n",
"import torch.cuda\n",
"\n",
"# prepare data\n",
"def get_data():\n",
" from fastNLP.io import ChnSentiCorpPipe as pipe\n",
" data = pipe().process_from_file()\n",
" print(data)\n",
" data.rename_field('chars', 'words')\n",
" train_data = data.datasets['train']\n",
" dev_data = data.datasets['dev']\n",
" test_data = data.datasets['test']\n",
" vocab = data.vocabs['words']\n",
" tgt_vocab = data.vocabs['target']\n",
" return train_data, dev_data, test_data, vocab, tgt_vocab\n",
"\n",
"# prepare model\n",
"train_data, dev_data, _, vocab, tgt_vocab = get_data()\n",
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
"model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n",
"\n",
"# define callback\n",
"callbacks=[EarlyStopCallback(5)]\n",
"\n",
"# pass callbacks to Trainer\n",
"def train_with_callback(cb_list):\n",
" trainer = Trainer(\n",
" device=device,\n",
" n_epochs=3,\n",
" model=model, \n",
" train_data=train_data, \n",
" dev_data=dev_data, \n",
" loss=CrossEntropyLoss(), \n",
" metrics=AccuracyMetric(), \n",
" callbacks=cb_list, \n",
" check_code_level=-1\n",
" )\n",
" trainer.train()\n",
"\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP 中的 Callback\n",
"-------\n",
"fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:35:02.182727Z",
"start_time": "2019-09-17T07:34:49.443863Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2019-09-17-03-34-49\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.13 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.12 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.890833\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.890833\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.09 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.09 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.8875\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.8875\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.885\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.885\n",
"\n",
"\r\n",
"In Epoch:1/Step:300, got best dev performance:\n",
"AccuracyMetric: acc=0.890833\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n",
"callbacks = [\n",
" EarlyStopCallback(5),\n",
" GradientClipCallback(clip_value=5, clip_type='value'),\n",
" EvaluateCallback(dev_data)\n",
"]\n",
"\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"自定义 Callback\n",
"------\n",
"\n",
"这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n",
"\n",
"#### 创建 Callback\n",
" \n",
"要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n",
"\n",
"这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n",
"\n",
"#### 指定 Callback 调用的阶段\n",
" \n",
"Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n",
"\n",
"这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n",
"\n",
"#### 使用 Callback 的属性访问 Trainer 的内部信息\n",
" \n",
"为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n",
"\n",
"这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:43:10.907139Z",
"start_time": "2019-09-17T07:42:58.488177Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2019-09-17-03-42-58\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.883333\n",
"\n",
"Avg loss at epoch 1, 0.100254\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.8775\n",
"\n",
"Avg loss at epoch 2, 0.183511\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.13 seconds!\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.875833\n",
"\n",
"Avg loss at epoch 3, 0.257103\n",
"\r\n",
"In Epoch:1/Step:300, got best dev performance:\n",
"AccuracyMetric: acc=0.883333\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import Callback\n",
"from fastNLP import logger\n",
"\n",
"class MyCallBack(Callback):\n",
" \"\"\"Print average loss in each epoch\"\"\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.total_loss = 0\n",
" self.start_step = 0\n",
" \n",
" def on_backward_begin(self, loss):\n",
" self.total_loss += loss.item()\n",
" \n",
" def on_epoch_end(self):\n",
" n_steps = self.step - self.start_step\n",
" avg_loss = self.total_loss / n_steps\n",
" logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n",
" self.start_step = self.step\n",
"\n",
"callbacks = [MyCallBack()]\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Loading…
Cancel
Save