From 29e4de36e3e95374ab6388a1e7b6d7cff6a53d66 Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 17 Sep 2019 21:38:52 +0800 Subject: [PATCH] [add] tutorial for callback [add] test case for logger, batch [bugfix] batch.py --- fastNLP/core/batch.py | 90 +++-- fastNLP/core/callback.py | 3 + test/core/test_batch.py | 30 +- test/core/test_callbacks.py | 22 +- test/core/test_logger.py | 33 ++ tutorials/tutorial_callback.ipynb | 622 ++++++++++++++++++++++++++++++ 6 files changed, 758 insertions(+), 42 deletions(-) create mode 100644 test/core/test_logger.py create mode 100644 tutorials/tutorial_callback.ipynb diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 76c14005..4ee1916a 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -111,11 +111,31 @@ class SamplerAdapter(torch.utils.data.Sampler): class BatchIter: - def __init__(self): - self.dataiter = None - self.num_batches = None + def __init__(self, dataset, batch_size=1, sampler=None, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None, collate_fn=None): + if not isinstance(sampler, torch.utils.data.Sampler): + self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) + else: + self.sampler = sampler + if collate_fn is None: + # pytoch <= 1.1 中不能设置collate_fn=None + self.dataiter = torch.utils.data.DataLoader( + dataset=dataset, batch_size=batch_size, sampler=self.sampler, + num_workers=num_workers, + pin_memory=pin_memory, drop_last=drop_last, + timeout=timeout, worker_init_fn=worker_init_fn) + else: + self.dataiter = torch.utils.data.DataLoader( + dataset=dataset, batch_size=batch_size, sampler=self.sampler, + collate_fn=collate_fn, num_workers=num_workers, + pin_memory=pin_memory, drop_last=drop_last, + timeout=timeout, worker_init_fn=worker_init_fn) + + # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 + self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) + self.batch_size = batch_size self.cur_batch_indices = None - self.batch_size = None def init_iter(self): pass @@ -135,12 +155,6 @@ class BatchIter: num_batches += 1 return num_batches - def __iter__(self): - self.init_iter() - for indices, batch_x, batch_y in self.dataiter: - self.cur_batch_indices = indices - yield batch_x, batch_y - def get_batch_indices(self): """ 获取当前已经输出的batch的index。 @@ -170,7 +184,7 @@ class DataSetIter(BatchIter): """ def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, num_workers=0, pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): + timeout=0, worker_init_fn=None, collate_fn=None): """ :param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 @@ -187,22 +201,21 @@ class DataSetIter(BatchIter): :param timeout: :param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 """ - super().__init__() assert isinstance(dataset, DataSet) - if not isinstance(sampler, torch.utils.data.Sampler): - self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) - else: - self.sampler = sampler dataset = DataSetGetter(dataset, as_numpy) - collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None - self.dataiter = torch.utils.data.DataLoader( - dataset=dataset, batch_size=batch_size, sampler=self.sampler, - collate_fn=collate_fn, num_workers=num_workers, - pin_memory=pin_memory, drop_last=drop_last, - timeout=timeout, worker_init_fn=worker_init_fn) - # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 - self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) - self.batch_size = batch_size + collate_fn = dataset.collate_fn if collate_fn is None else collate_fn + super().__init__( + dataset=dataset, batch_size=batch_size, sampler=sampler, + num_workers=num_workers, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + collate_fn=collate_fn + ) + + def __iter__(self): + self.init_iter() + for indices, batch_x, batch_y in self.dataiter: + self.cur_batch_indices = indices + yield batch_x, batch_y class TorchLoaderIter(BatchIter): @@ -210,12 +223,27 @@ class TorchLoaderIter(BatchIter): 与DataSetIter类似,但用于pytorch的DataSet对象。通过使用TorchLoaderIter封装pytorch的DataSet,然后将其传入到Trainer中。 """ - def __init__(self, dataset): - super().__init__() - assert isinstance(dataset, torch.utils.data.DataLoader) - self.dataiter = dataset - self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last) - self.batch_size = dataset.batch_size + def __init__(self, dataset, batch_size=1, sampler=None, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None, collate_fn=None): + assert len(dataset) > 0 + ins = dataset[0] + assert len(ins) == 2 and \ + isinstance(ins[0], dict) and \ + isinstance(ins[1], dict), 'DataSet should return two dict, as X and Y' + + super().__init__( + dataset=dataset, batch_size=batch_size, sampler=sampler, + num_workers=num_workers, pin_memory=pin_memory, + drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + collate_fn=collate_fn + ) + + def __iter__(self): + self.init_iter() + for batch_x, batch_y in self.dataiter: + self.cur_batch_indices = None + yield batch_x, batch_y def _to_tensor(batch, field_dtype): diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 985431bc..6ad98b0b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -1039,6 +1039,9 @@ class EchoCallback(Callback): class TesterCallback(Callback): def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): super(TesterCallback, self).__init__() + if hasattr(model, 'module'): + # for data parallel model + model = model.module self.tester = Tester(data, model, metrics=metrics, batch_size=batch_size, num_workers=num_workers, verbose=0) diff --git a/test/core/test_batch.py b/test/core/test_batch.py index aa9808ee..d9898bc7 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -3,7 +3,7 @@ import unittest import numpy as np import torch -from fastNLP import DataSetIter +from fastNLP import DataSetIter, TorchLoaderIter from fastNLP import DataSet from fastNLP import Instance from fastNLP import SequentialSampler @@ -149,7 +149,33 @@ class TestCase1(unittest.TestCase): batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) for batch_x, batch_y in batch: pass - + + def testTensorLoaderIter(self): + class FakeData: + def __init__(self, return_dict=True): + self.x = [[1,2,3], [4,5,6]] + self.return_dict = return_dict + + def __len__(self): + return len(self.x) + + def __getitem__(self, i): + x = self.x[i] + y = 0 + if self.return_dict: + return {'x':x}, {'y':y} + return x, y + + data1 = FakeData() + dataiter = TorchLoaderIter(data1, batch_size=2) + for x, y in dataiter: + print(x, y) + + def func(): + data2 = FakeData(return_dict=False) + dataiter = TorchLoaderIter(data2, batch_size=2) + self.assertRaises(Exception, func) + """ def test_multi_workers_batch(self): batch_size = 32 diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index b36beb06..78f76b65 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -17,6 +17,7 @@ from fastNLP.models.base_model import NaiveClassifier from fastNLP.core.callback import EarlyStopError from fastNLP.core.callback import EvaluateCallback, FitlogCallback, SaveModelCallback from fastNLP.core.callback import WarmupCallback +import tempfile def prepare_env(): def prepare_fake_dataset(): @@ -40,7 +41,13 @@ def prepare_env(): class TestCallback(unittest.TestCase): - + def setUp(self): + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + pass + # shutil.rmtree(self.tempdir) + def test_gradient_clip(self): data_set, model = prepare_env() trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), @@ -93,7 +100,7 @@ class TestCallback(unittest.TestCase): path = os.path.join("./", 'tensorboard_logs_{}'.format(trainer.start_time)) if os.path.exists(path): shutil.rmtree(path) - + def test_readonly_property(self): from fastNLP.core.callback import Callback passed_epochs = [] @@ -131,8 +138,7 @@ class TestCallback(unittest.TestCase): def test_fitlog_callback(self): import fitlog - os.makedirs('logs/') - fitlog.set_log_dir('logs/') + fitlog.set_log_dir(self.tempdir) data_set, model = prepare_env() from fastNLP import Tester tester = Tester(data=data_set, model=model, metrics=AccuracyMetric(pred="predict", target="y")) @@ -143,21 +149,19 @@ class TestCallback(unittest.TestCase): metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=fitlog_callback, check_code_level=2) trainer.train() - shutil.rmtree('logs/') def test_save_model_callback(self): data_set, model = prepare_env() top = 3 - save_model_callback = SaveModelCallback('save_models/', top=top) + save_model_callback = SaveModelCallback(self.tempdir, top=top) trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=save_model_callback, check_code_level=2) trainer.train() - timestamp = os.listdir('save_models')[0] - self.assertEqual(len(os.listdir(os.path.join('save_models', timestamp))), top) - shutil.rmtree('save_models/') + timestamp = os.listdir(self.tempdir)[0] + self.assertEqual(len(os.listdir(os.path.join(self.tempdir, timestamp))), top) def test_warmup_callback(self): data_set, model = prepare_env() diff --git a/test/core/test_logger.py b/test/core/test_logger.py new file mode 100644 index 00000000..610f42bd --- /dev/null +++ b/test/core/test_logger.py @@ -0,0 +1,33 @@ +from fastNLP import logger +import unittest +from unittest.mock import patch +import os +import io +import tempfile +import shutil + +class TestLogger(unittest.TestCase): + msg = 'some test logger msg' + + def setUp(self): + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self): + pass + # shutil.rmtree(self.tmpdir) + + def test_add_file(self): + fn = os.path.join(self.tmpdir, 'log.txt') + logger.add_file(fn) + logger.info(self.msg) + with open(fn, 'r') as f: + line = ''.join([l for l in f]) + print(line) + self.assertTrue(self.msg in line) + + @patch('sys.stdout', new_callable=io.StringIO) + def test_stdout(self, mock_out): + for i in range(3): + logger.info(self.msg) + + self.assertEqual([self.msg for i in range(3)], mock_out.getvalue().strip().split('\n')) diff --git a/tutorials/tutorial_callback.ipynb b/tutorials/tutorial_callback.ipynb new file mode 100644 index 00000000..ed71a9b0 --- /dev/null +++ b/tutorials/tutorial_callback.ipynb @@ -0,0 +1,622 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 使用 Callback 自定义你的训练过程" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- 什么是 Callback\n", + "- 使用 Callback \n", + "- 一些常用的 Callback\n", + "- 自定义实现 Callback" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "什么是Callback\n", + "------\n", + "\n", + "Callback 是与 Trainer 紧密结合的模块,利用 Callback 可以在 Trainer 训练时,加入自定义的操作,比如梯度裁剪,学习率调节,测试模型的性能等。定义的 Callback 会在训练的特定阶段被调用。\n", + "\n", + "fastNLP 中提供了很多常用的 Callback ,开箱即用。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "使用 Callback\n", + " ------\n", + "\n", + "使用 Callback 很简单,将需要的 callback 按 list 存储,以对应参数 ``callbacks`` 传入对应的 Trainer。Trainer 在训练时就会自动执行这些 Callback 指定的操作了。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2019-09-17T07:34:46.465871Z", + "start_time": "2019-09-17T07:34:30.648758Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 3 datasets:\n", + "\ttest has 1200 instances.\n", + "\ttrain has 9600 instances.\n", + "\tdev has 1200 instances.\n", + "In total 2 vocabs:\n", + "\tchars has 4409 entries.\n", + "\ttarget has 2 entries.\n", + "\n", + "training epochs started 2019-09-17-03-34-34\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.1 seconds!\n", + "Evaluation on dev at Epoch 1/3. Step:300/900: \n", + "AccuracyMetric: acc=0.863333\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.11 seconds!\n", + "Evaluation on dev at Epoch 2/3. Step:600/900: \n", + "AccuracyMetric: acc=0.886667\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.1 seconds!\n", + "Evaluation on dev at Epoch 3/3. Step:900/900: \n", + "AccuracyMetric: acc=0.890833\n", + "\n", + "\r\n", + "In Epoch:3/Step:900, got best dev performance:\n", + "AccuracyMetric: acc=0.890833\n", + "Reloaded the best model.\n" + ] + } + ], + "source": [ + "from fastNLP import (Callback, EarlyStopCallback,\n", + " Trainer, CrossEntropyLoss, AccuracyMetric)\n", + "from fastNLP.models import CNNText\n", + "import torch.cuda\n", + "\n", + "# prepare data\n", + "def get_data():\n", + " from fastNLP.io import ChnSentiCorpPipe as pipe\n", + " data = pipe().process_from_file()\n", + " print(data)\n", + " data.rename_field('chars', 'words')\n", + " train_data = data.datasets['train']\n", + " dev_data = data.datasets['dev']\n", + " test_data = data.datasets['test']\n", + " vocab = data.vocabs['words']\n", + " tgt_vocab = data.vocabs['target']\n", + " return train_data, dev_data, test_data, vocab, tgt_vocab\n", + "\n", + "# prepare model\n", + "train_data, dev_data, _, vocab, tgt_vocab = get_data()\n", + "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", + "model = CNNText((len(vocab),50), num_classes=len(tgt_vocab))\n", + "\n", + "# define callback\n", + "callbacks=[EarlyStopCallback(5)]\n", + "\n", + "# pass callbacks to Trainer\n", + "def train_with_callback(cb_list):\n", + " trainer = Trainer(\n", + " device=device,\n", + " n_epochs=3,\n", + " model=model, \n", + " train_data=train_data, \n", + " dev_data=dev_data, \n", + " loss=CrossEntropyLoss(), \n", + " metrics=AccuracyMetric(), \n", + " callbacks=cb_list, \n", + " check_code_level=-1\n", + " )\n", + " trainer.train()\n", + "\n", + "train_with_callback(callbacks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "fastNLP 中的 Callback\n", + "-------\n", + "fastNLP 中提供了很多常用的 Callback,如梯度裁剪,训练时早停和测试验证集,fitlog 等等。具体 Callback 请参考 fastNLP.core.callbacks" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2019-09-17T07:35:02.182727Z", + "start_time": "2019-09-17T07:34:49.443863Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training epochs started 2019-09-17-03-34-49\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.13 seconds!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.12 seconds!\n", + "Evaluation on data-test:\n", + "AccuracyMetric: acc=0.890833\n", + "Evaluation on dev at Epoch 1/3. Step:300/900: \n", + "AccuracyMetric: acc=0.890833\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.09 seconds!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.09 seconds!\n", + "Evaluation on data-test:\n", + "AccuracyMetric: acc=0.8875\n", + "Evaluation on dev at Epoch 2/3. Step:600/900: \n", + "AccuracyMetric: acc=0.8875\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.11 seconds!\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.1 seconds!\n", + "Evaluation on data-test:\n", + "AccuracyMetric: acc=0.885\n", + "Evaluation on dev at Epoch 3/3. Step:900/900: \n", + "AccuracyMetric: acc=0.885\n", + "\n", + "\r\n", + "In Epoch:1/Step:300, got best dev performance:\n", + "AccuracyMetric: acc=0.890833\n", + "Reloaded the best model.\n" + ] + } + ], + "source": [ + "from fastNLP import EarlyStopCallback, GradientClipCallback, EvaluateCallback\n", + "callbacks = [\n", + " EarlyStopCallback(5),\n", + " GradientClipCallback(clip_value=5, clip_type='value'),\n", + " EvaluateCallback(dev_data)\n", + "]\n", + "\n", + "train_with_callback(callbacks)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "自定义 Callback\n", + "------\n", + "\n", + "这里我们以一个简单的 Callback作为例子,它的作用是打印每一个 Epoch 平均训练 loss。\n", + "\n", + "#### 创建 Callback\n", + " \n", + "要自定义 Callback,我们要实现一个类,继承 fastNLP.Callback。\n", + "\n", + "这里我们定义 MyCallBack ,继承 fastNLP.Callback 。\n", + "\n", + "#### 指定 Callback 调用的阶段\n", + " \n", + "Callback 中所有以 on_ 开头的类方法会在 Trainer 的训练中在特定阶段调用。 如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。 具体有哪些类方法,参见 Callback 文档。\n", + "\n", + "这里, MyCallBack 在求得loss时调用 on_backward_begin() 记录当前 loss ,在每一个 epoch 结束时调用 on_epoch_end() ,求当前 epoch 平均loss并输出。\n", + "\n", + "#### 使用 Callback 的属性访问 Trainer 的内部信息\n", + " \n", + "为了方便使用,可以使用 Callback 的属性,访问 Trainer 中的对应信息,如 optimizer, epoch, n_epochs,分别对应训练时的优化器,当前 epoch 数,和总 epoch 数。 具体可访问的属性,参见文档 Callback 。\n", + "\n", + "这里, MyCallBack 为了求平均 loss ,需要知道当前 epoch 的总步数,可以通过 self.step 属性得到当前训练了多少步。\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2019-09-17T07:43:10.907139Z", + "start_time": "2019-09-17T07:42:58.488177Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training epochs started 2019-09-17-03-42-58\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=900), HTML(value='')), layout=Layout(display=…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.11 seconds!\n", + "Evaluation on dev at Epoch 1/3. Step:300/900: \n", + "AccuracyMetric: acc=0.883333\n", + "\n", + "Avg loss at epoch 1, 0.100254\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.1 seconds!\n", + "Evaluation on dev at Epoch 2/3. Step:600/900: \n", + "AccuracyMetric: acc=0.8775\n", + "\n", + "Avg loss at epoch 2, 0.183511\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Evaluate data in 0.13 seconds!\n", + "Evaluation on dev at Epoch 3/3. Step:900/900: \n", + "AccuracyMetric: acc=0.875833\n", + "\n", + "Avg loss at epoch 3, 0.257103\n", + "\r\n", + "In Epoch:1/Step:300, got best dev performance:\n", + "AccuracyMetric: acc=0.883333\n", + "Reloaded the best model.\n" + ] + } + ], + "source": [ + "from fastNLP import Callback\n", + "from fastNLP import logger\n", + "\n", + "class MyCallBack(Callback):\n", + " \"\"\"Print average loss in each epoch\"\"\"\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.total_loss = 0\n", + " self.start_step = 0\n", + " \n", + " def on_backward_begin(self, loss):\n", + " self.total_loss += loss.item()\n", + " \n", + " def on_epoch_end(self):\n", + " n_steps = self.step - self.start_step\n", + " avg_loss = self.total_loss / n_steps\n", + " logger.info('Avg loss at epoch %d, %.6f', self.epoch, avg_loss)\n", + " self.start_step = self.step\n", + "\n", + "callbacks = [MyCallBack()]\n", + "train_with_callback(callbacks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}