Browse Source

[add] tutorial for callback

[add] test case for logger, batch
[bugfix] batch.py
tags/v0.4.10
yunfan 5 years ago
parent
commit
29e4de36e3
6 changed files with 758 additions and 42 deletions
  1. +59
    -31
      fastNLP/core/batch.py
  2. +3
    -0
      fastNLP/core/callback.py
  3. +28
    -2
      test/core/test_batch.py
  4. +13
    -9
      test/core/test_callbacks.py
  5. +33
    -0
      test/core/test_logger.py
  6. +622
    -0
      tutorials/tutorial_callback.ipynb

+ 59
- 31
fastNLP/core/batch.py View File

@@ -111,11 +111,31 @@ class SamplerAdapter(torch.utils.data.Sampler):


class BatchIter:
def __init__(self):
self.dataiter = None
self.num_batches = None
def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
if not isinstance(sampler, torch.utils.data.Sampler):
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler
if collate_fn is None:
# pytoch <= 1.1 中不能设置collate_fn=None
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
else:
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)

# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size
self.cur_batch_indices = None
self.batch_size = None

def init_iter(self):
pass
@@ -135,12 +155,6 @@ class BatchIter:
num_batches += 1
return num_batches

def __iter__(self):
self.init_iter()
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y

def get_batch_indices(self):
"""
获取当前已经输出的batch的index。
@@ -170,7 +184,7 @@ class DataSetIter(BatchIter):
"""
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
timeout=0, worker_init_fn=None, collate_fn=None):
"""
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
@@ -187,22 +201,21 @@ class DataSetIter(BatchIter):
:param timeout:
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。
"""
super().__init__()
assert isinstance(dataset, DataSet)
if not isinstance(sampler, torch.utils.data.Sampler):
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
else:
self.sampler = sampler
dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size
collate_fn = dataset.collate_fn if collate_fn is None else collate_fn
super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
)

def __iter__(self):
self.init_iter()
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y


class TorchLoaderIter(BatchIter):
@@ -210,12 +223,27 @@ class TorchLoaderIter(BatchIter):
与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。

"""
def __init__(self, dataset):
super().__init__()
assert isinstance(dataset, torch.utils.data.DataLoader)
self.dataiter = dataset
self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last)
self.batch_size = dataset.batch_size
def __init__(self, dataset, batch_size=1, sampler=None,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, collate_fn=None):
assert len(dataset) > 0
ins = dataset[0]
assert len(ins) == 2 and \
isinstance(ins[0], dict) and \
isinstance(ins[1], dict), 'DataSet should return two dict, as X and Y'

super().__init__(
dataset=dataset, batch_size=batch_size, sampler=sampler,
num_workers=num_workers, pin_memory=pin_memory,
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
collate_fn=collate_fn
)

def __iter__(self):
self.init_iter()
for batch_x, batch_y in self.dataiter:
self.cur_batch_indices = None
yield batch_x, batch_y


def _to_tensor(batch, field_dtype):


+ 3
- 0
fastNLP/core/callback.py View File

@@ -1039,6 +1039,9 @@ class EchoCallback(Callback):
class TesterCallback(Callback):
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(TesterCallback, self).__init__()
if hasattr(model, 'module'):
# for data parallel model
model = model.module
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)


+ 28
- 2
test/core/test_batch.py View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch

from fastNLP import DataSetIter
from fastNLP import DataSetIter, TorchLoaderIter
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import SequentialSampler
@@ -149,7 +149,33 @@ class TestCase1(unittest.TestCase):
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch:
pass

def testTensorLoaderIter(self):
class FakeData:
def __init__(self, return_dict=True):
self.x = [[1,2,3], [4,5,6]]
self.return_dict = return_dict

def __len__(self):
return len(self.x)

def __getitem__(self, i):
x = self.x[i]
y = 0
if self.return_dict:
return {'x':x}, {'y':y}
return x, y

data1 = FakeData()
dataiter = TorchLoaderIter(data1, batch_size=2)
for x, y in dataiter:
print(x, y)

def func():
data2 = FakeData(return_dict=False)
dataiter = TorchLoaderIter(data2, batch_size=2)
self.assertRaises(Exception, func)

"""
def test_multi_workers_batch(self):
batch_size = 32


+ 13
- 9
test/core/test_callbacks.py View File

@@ -17,6 +17,7 @@ from fastNLP.models.base_model import NaiveClassifier
from fastNLP.core.callback import EarlyStopError
from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback
from fastNLP.core.callback import WarmupCallback
import tempfile

def prepare_env():
def prepare_fake_dataset():
@@ -40,7 +41,13 @@ def prepare_env():


class TestCallback(unittest.TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
pass
# shutil.rmtree(self.tempdir)

def test_gradient_clip(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
@@ -93,7 +100,7 @@ class TestCallback(unittest.TestCase):
path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time))
if os.path.exists(path):
shutil.rmtree(path)
def test_readonly_property(self):
from fastNLP.core.callback import Callback
passed_epochs = []
@@ -131,8 +138,7 @@ class TestCallback(unittest.TestCase):

def test_fitlog_callback(self):
import fitlog
os.makedirs('logs/')
fitlog.set_log_dir('logs/')
fitlog.set_log_dir(self.tempdir)
data_set, model = prepare_env()
from fastNLP import Tester
tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y"))
@@ -143,21 +149,19 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=fitlog_callback, check_code_level=2)
trainer.train()
shutil.rmtree('logs/')

def test_save_model_callback(self):
data_set, model = prepare_env()
top = 3
save_model_callback = SaveModelCallback('save_models/', top=top)
save_model_callback = SaveModelCallback(self.tempdir, top=top)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=save_model_callback, check_code_level=2)
trainer.train()

timestamp = os.listdir('save_models')[0]
self.assertEqual(len(os.listdir(os.path.join('save_models', timestamp))), top)
shutil.rmtree('save_models/')
timestamp = os.listdir(self.tempdir)[0]
self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top)

def test_warmup_callback(self):
data_set, model = prepare_env()


+ 33
- 0
test/core/test_logger.py View File

@@ -0,0 +1,33 @@
from fastNLP import logger
import unittest
from unittest.mock import patch
import os
import io
import tempfile
import shutil

class TestLogger(unittest.TestCase):
msg = 'some test logger msg'

def setUp(self):
self.tmpdir = tempfile.mkdtemp()

def tearDown(self):
pass
# shutil.rmtree(self.tmpdir)

def test_add_file(self):
fn = os.path.join(self.tmpdir, 'log.txt')
logger.add_file(fn)
logger.info(self.msg)
with open(fn, 'r') as f:
line = ''.join([l for l in f])
print(line)
self.assertTrue(self.msg in line)

@patch('sys.stdout', new_callable=io.StringIO)
def test_stdout(self, mock_out):
for i in range(3):
logger.info(self.msg)

self.assertEqual([self.msg for i in range(3)], mock_out.getvalue().strip().split('\n'))

+ 622
- 0
tutorials/tutorial_callback.ipynb View File

@@ -0,0 +1,622 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 使用 Callback 自定义你的训练过程"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 什么是 Callback\n",
"- 使用 Callback \n",
"- 一些常用的 Callback\n",
"- 自定义实现 Callback"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"什么是Callback\n",
"------\n",
"\n",
"Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n",
"\n",
"fastNLP 中提供了很多常用的 Callback ,开箱即用。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"使用 Callback\n",
" ------\n",
"\n",
"使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:34:46.465871Z",
"start_time": "2019-09-17T07:34:30.648758Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In total 3 datasets:\n",
"\ttest has 1200 instances.\n",
"\ttrain has 9600 instances.\n",
"\tdev has 1200 instances.\n",
"In total 2 vocabs:\n",
"\tchars has 4409 entries.\n",
"\ttarget has 2 entries.\n",
"\n",
"training epochs started 2019-09-17-03-34-34\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.863333\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.886667\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.890833\n",
"\n",
"\r\n",
"In Epoch:3/Step:900, got best dev performance:\n",
"AccuracyMetric: acc=0.890833\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import (Callback, EarlyStopCallback,\n",
" Trainer, CrossEntropyLoss, AccuracyMetric)\n",
"from fastNLP.models import CNNText\n",
"import torch.cuda\n",
"\n",
"# prepare data\n",
"def get_data():\n",
" from fastNLP.io import ChnSentiCorpPipe as pipe\n",
" data = pipe().process_from_file()\n",
" print(data)\n",
" data.rename_field('chars', 'words')\n",
" train_data = data.datasets['train']\n",
" dev_data = data.datasets['dev']\n",
" test_data = data.datasets['test']\n",
" vocab = data.vocabs['words']\n",
" tgt_vocab = data.vocabs['target']\n",
" return train_data, dev_data, test_data, vocab, tgt_vocab\n",
"\n",
"# prepare model\n",
"train_data, dev_data, _, vocab, tgt_vocab = get_data()\n",
"device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
"model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n",
"\n",
"# define callback\n",
"callbacks=[EarlyStopCallback(5)]\n",
"\n",
"# pass callbacks to Trainer\n",
"def train_with_callback(cb_list):\n",
" trainer = Trainer(\n",
" device=device,\n",
" n_epochs=3,\n",
" model=model, \n",
" train_data=train_data, \n",
" dev_data=dev_data, \n",
" loss=CrossEntropyLoss(), \n",
" metrics=AccuracyMetric(), \n",
" callbacks=cb_list, \n",
" check_code_level=-1\n",
" )\n",
" trainer.train()\n",
"\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP 中的 Callback\n",
"-------\n",
"fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:35:02.182727Z",
"start_time": "2019-09-17T07:34:49.443863Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2019-09-17-03-34-49\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.13 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.12 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.890833\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.890833\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.09 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.09 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.8875\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.8875\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on data-test:\n",
"AccuracyMetric: acc=0.885\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.885\n",
"\n",
"\r\n",
"In Epoch:1/Step:300, got best dev performance:\n",
"AccuracyMetric: acc=0.890833\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n",
"callbacks = [\n",
" EarlyStopCallback(5),\n",
" GradientClipCallback(clip_value=5, clip_type='value'),\n",
" EvaluateCallback(dev_data)\n",
"]\n",
"\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"自定义 Callback\n",
"------\n",
"\n",
"这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n",
"\n",
"#### 创建 Callback\n",
" \n",
"要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n",
"\n",
"这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n",
"\n",
"#### 指定 Callback 调用的阶段\n",
" \n",
"Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n",
"\n",
"这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n",
"\n",
"#### 使用 Callback 的属性访问 Trainer 的内部信息\n",
" \n",
"为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n",
"\n",
"这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-09-17T07:43:10.907139Z",
"start_time": "2019-09-17T07:42:58.488177Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2019-09-17-03-42-58\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.11 seconds!\n",
"Evaluation on dev at Epoch 1/3. Step:300/900: \n",
"AccuracyMetric: acc=0.883333\n",
"\n",
"Avg loss at epoch 1, 0.100254\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.1 seconds!\n",
"Evaluation on dev at Epoch 2/3. Step:600/900: \n",
"AccuracyMetric: acc=0.8775\n",
"\n",
"Avg loss at epoch 2, 0.183511\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Evaluate data in 0.13 seconds!\n",
"Evaluation on dev at Epoch 3/3. Step:900/900: \n",
"AccuracyMetric: acc=0.875833\n",
"\n",
"Avg loss at epoch 3, 0.257103\n",
"\r\n",
"In Epoch:1/Step:300, got best dev performance:\n",
"AccuracyMetric: acc=0.883333\n",
"Reloaded the best model.\n"
]
}
],
"source": [
"from fastNLP import Callback\n",
"from fastNLP import logger\n",
"\n",
"class MyCallBack(Callback):\n",
" \"\"\"Print average loss in each epoch\"\"\"\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.total_loss = 0\n",
" self.start_step = 0\n",
" \n",
" def on_backward_begin(self, loss):\n",
" self.total_loss += loss.item()\n",
" \n",
" def on_epoch_end(self):\n",
" n_steps = self.step - self.start_step\n",
" avg_loss = self.total_loss / n_steps\n",
" logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n",
" self.start_step = self.step\n",
"\n",
"callbacks = [MyCallBack()]\n",
"train_with_callback(callbacks)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}

Loading…
Cancel
Save