diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index e11bacde..690146a2 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -81,7 +81,8 @@ class MoreEvaluateCallback(HasMonitorCallback): **kwargs): super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, must_have_monitor=False) - + if watch_monitor is not None and evaluate_every == -1: # 将evaluate_every 弄掉。 + evaluate_every = None if watch_monitor is None and evaluate_every is None: raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") if watch_monitor is not None and evaluate_every is not None: diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 7fc11ec8..dab5028c 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -176,8 +176,8 @@ class Collator: self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 try: for key, padder in self.padders.items(): - batch = unpack_batch.get(key) - pad_batch[key] = padder(batch) + batch = unpack_batch.get(key) + pad_batch[key] = padder(batch) except BaseException as e: try: logger.error(f"The following exception happens when try to pad the `{key}` field with padder:{padder}:") diff --git a/fastNLP/core/collators/padders/exceptions.py b/fastNLP/core/collators/padders/exceptions.py index 8b08683d..a2b97cbf 100644 --- a/fastNLP/core/collators/padders/exceptions.py +++ b/fastNLP/core/collators/padders/exceptions.py @@ -3,7 +3,8 @@ __all__ = [ 'EleDtypeUnsupportedError', 'EleDtypeDtypeConversionError', 'DtypeUnsupportedError', - "DtypeError" + "DtypeError", + "NoProperPadderError" ] @@ -22,6 +23,12 @@ class DtypeError(BaseException): self.msg = msg +class NoProperPadderError(BaseException): + def __init__(self, msg, *args): + super(NoProperPadderError, self).__init__(msg, *args) + self.msg = msg + + class EleDtypeUnsupportedError(DtypeError): """ 当 batch 中的 element 的类别本身无法 pad 的时候报错。 diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 66d2eee2..dfc228a3 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -49,8 +49,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> f"information please set logger's level to DEBUG." if must_pad: raise InconsistencyError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) # 再检查所有的元素 shape 是否一致? shape_lens = set([len(v[0]) for v in catalog.values()]) @@ -60,8 +59,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> f"information please set logger's level to DEBUG." if must_pad: raise InconsistencyError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) # 再检查所有的元素 type 是否一致 try: @@ -74,8 +72,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> f"information please set logger's level to DEBUG." if must_pad: raise InconsistencyError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) depth = depths.pop() shape_len = shape_lens.pop() @@ -131,8 +128,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> msg = "Does not support pad tensor under nested list. If you need this, please report." if must_pad: raise RuntimeError(msg) - logger.debug(msg) - return NullPadder() + raise NoProperPadderError(msg) except DtypeError as e: msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ @@ -140,8 +136,9 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> if must_pad: logger.error(msg) raise type(e)(msg=msg) - logger.debug(msg) - return NullPadder() + + except NoProperPadderError as e: + logger.debug(f"{e.msg}") except BaseException as e: raise e diff --git a/fastNLP/core/collators/padders/jittor_padder.py b/fastNLP/core/collators/padders/jittor_padder.py index d85893f1..c9b36b89 100644 --- a/fastNLP/core/collators/padders/jittor_padder.py +++ b/fastNLP/core/collators/padders/jittor_padder.py @@ -188,7 +188,7 @@ def fill_tensor(batch_field, padded_batch, dtype): padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype)) elif padded_batch.ndim == 4: try: # 应该是图像,所以直接应该就 ok 了。 - padded_batch = np.array(batch_field) + padded_batch = jittor.Var(batch_field) except: for i, content_i in enumerate(batch_field): for j, content_ii in enumerate(content_i): diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index c208adca..911c7d8c 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -175,7 +175,7 @@ def fill_tensor(batch_field, padded_batch, dtype): padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) elif padded_batch.ndim == 4: try: # 应该是图像,所以直接应该就 ok 了。 - padded_batch = np.array(batch_field) + padded_batch = torch.tensor(batch_field) except: for i, content_i in enumerate(batch_field): for j, content_ii in enumerate(content_i): diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index bc25c756..75e3ad79 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -203,7 +203,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto", - non_train_batch_size: int = 16) \ + non_train_batch_size: int = None) \ -> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: """ ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 :class:`JittorDataLoader` 对象, 详见 :class:`~fastNLP.core.dataloaders.JittorDataLoader`。 diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index a489e5ed..b89ee40e 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -255,7 +255,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, worker_init_fn: Callable = None, persistent_workers=False, - non_train_batch_size: int = 16) \ + non_train_batch_size: int = None) \ -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: """ ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.PaddleDataLoader`。 diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py index f9004f76..5f469f2b 100644 --- a/fastNLP/core/dataloaders/prepare_dataloader.py +++ b/fastNLP/core/dataloaders/prepare_dataloader.py @@ -6,7 +6,6 @@ from typing import Union, Callable import os import sys -from ..samplers import RandomBatchSampler, RandomSampler from .torch_dataloader import prepare_torch_dataloader from .paddle_dataloader import prepare_paddle_dataloader from .jittor_dataloader import prepare_jittor_dataloader @@ -16,7 +15,7 @@ from ..log import logger def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, - seed: int = 0, backend: str = 'auto'): + backend: str = 'auto'): """ 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 @@ -37,7 +36,6 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro * 为 ``Callable`` 时,应当接受一个 ``batch`` 的数据作为参数,同时输出一个对象 。 * 为 ``None`` 时,使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 :param num_workers: 使用多少进程进行数据的 fetch 。 - :param seed: 使用的随机数种子。 :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 * 为 ``auto`` 时,首先(1) 根据环境变量 "FASTNLP_BACKEND" 进行判断;如果没有设置则,(2)通过当前 @@ -52,18 +50,14 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro if backend == 'auto': backend = _get_backend() if backend == 'torch': - batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, seed=seed) - return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, - num_workers=num_workers, shuffle=False, sampler=None) + return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, + num_workers=num_workers, shuffle=shuffle, sampler=None, + batch_size=batch_size) elif backend == 'paddle': - batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last, seed=seed) - return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, - num_workers=num_workers) + return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, + num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) elif backend == 'jittor': - sampler = RandomSampler(dataset=dataset, shuffle=shuffle, seed=seed) - prepare_jittor_dataloader(ds_or_db=dataset, sampler=sampler, collate_fn=collate_fn, + prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last) else: diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 476eb7f6..d18dbc84 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -222,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, persistent_workers: bool = False, non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, - non_train_batch_size: int = 16) \ + non_train_batch_size: int = None) \ -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class:`~fastNLP.core.dataloaders.TorchDataLoader`。 @@ -254,13 +254,13 @@ def prepare_torch_dataloader(ds_or_db, :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 :param shuffle: 是否打乱数据集, 默认为 ``False``。 :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , - 默认为None, 当其不为 None 时, shuffle 参数无效。 + 默认为None, 当其不为 None 时, shuffle 参数无效。 :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , - 默认为None, 当其不为 None 时, shuffle 参数无效。 + 默认为None, 当其不为 None 时, shuffle 参数无效。 :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 - dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 + dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 - 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 + 数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. * callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, @@ -273,7 +273,7 @@ def prepare_torch_dataloader(ds_or_db, :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; - 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 + 若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 :param timeout: 子进程的输出队列获取数据的超时值 :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 :param multiprocessing_context: 多进程的上下文环境 diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index d8e9d45b..b28bb3f7 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -46,7 +46,6 @@ class DummyFRichProgress: class FRichProgress(Progress, metaclass=Singleton): def new_progess(self, *columns: Union[str, ProgressColumn], - console: Optional[Console] = None, # 这里将 auto_refresh 关掉是想要避免单独开启线程,同时也是为了避免pdb的时候会持续刷新 auto_refresh: bool = False, refresh_per_second: float = 10, @@ -81,7 +80,7 @@ class FRichProgress(Progress, metaclass=Singleton): self.expand = expand self.live = Live( - console=console or get_console(), + console=get_console(), auto_refresh=auto_refresh, refresh_per_second=refresh_per_second, transient=transient, @@ -92,6 +91,12 @@ class FRichProgress(Progress, metaclass=Singleton): self.get_time = get_time or self.console.get_time self.print = self.console.print self.log = self.console.log + self.auto_refresh = auto_refresh + self.transient = transient + self.redirect_stdout = redirect_stdout + self.redirect_stderr = redirect_stderr + self.refresh_per_second = refresh_per_second + self._need_renew_live = False return self @@ -125,7 +130,19 @@ class FRichProgress(Progress, metaclass=Singleton): from .tqdm_progress import f_tqdm_progress assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." - if self.live._started is False: + # 如果需要替换,应该是由于destroy的时候给换掉了 + if self._need_renew_live: + self.live = Live( + console=get_console(), + auto_refresh=self.auto_refresh, + refresh_per_second=self.refresh_per_second, + transient=self.transient, + redirect_stdout=self.redirect_stdout, + redirect_stderr=self.redirect_stderr, + get_renderable=self.get_renderable, + ) + self._need_renew_live = False + if not self.live.is_started: self.start() post_desc = fields.pop('post_desc', '') return super().add_task(description=description, @@ -155,6 +172,8 @@ class FRichProgress(Progress, metaclass=Singleton): setattr(self.live.console, 'line', lambda *args,**kwargs:...) self.live.stop() setattr(self.live.console, 'line', old_line) + # 在 jupyter 的情况下需要替换一下,不然会出不打印的问题。 + self._need_renew_live = True if is_notebook() else False def start(self) -> None: super().start() diff --git a/fastNLP/models/torch/seq2seq_generator.py b/fastNLP/models/torch/seq2seq_generator.py index 9ee723e5..68b405ba 100755 --- a/fastNLP/models/torch/seq2seq_generator.py +++ b/fastNLP/models/torch/seq2seq_generator.py @@ -65,7 +65,7 @@ class SequenceGeneratorModel(nn.Module): if tgt_seq_len is not None: mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) - loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) + loss = F.cross_entropy(pred[:, :-1].transpose(1, 2), tgt_tokens[:, 1:]) return {'loss': loss} def evaluate_step(self, src_tokens, src_seq_len=None): diff --git a/fastNLP/models/torch/seq2seq_model.py b/fastNLP/models/torch/seq2seq_model.py index 057fb93b..1420375e 100755 --- a/fastNLP/models/torch/seq2seq_model.py +++ b/fastNLP/models/torch/seq2seq_model.py @@ -59,7 +59,7 @@ class Seq2SeqModel(nn.Module): if tgt_seq_len is not None: mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) - loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) + loss = F.cross_entropy(pred[:, :-1].transpose(1, 2), tgt_tokens[:, 1:]) return {'loss': loss} def prepare_state(self, src_tokens, src_seq_len=None): diff --git a/fastNLP/modules/torch/generator/seq2seq_generator.py b/fastNLP/modules/torch/generator/seq2seq_generator.py index cf9c5306..b54eea28 100755 --- a/fastNLP/modules/torch/generator/seq2seq_generator.py +++ b/fastNLP/modules/torch/generator/seq2seq_generator.py @@ -368,13 +368,13 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) - from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids / (num_beams + 1)).long() # (batch_size, 2*num_beams) else: scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) - from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids / vocab_size).long() # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams) # 接下来需要组装下一个batch的结果。 diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 09ec4af8..8443ef92 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -318,6 +318,13 @@ class TestCollator: with pytest.raises(KeyError): collator.set_pad((1, 2)) + @pytest.mark.torch + def test_torch_4d(self): + collator = Collator(backend='torch') + data = [{'x': [[[0,1], [2,3]]]}, {'x': [[[0,1]]]}] + output = collator(data) + assert output['x'].size() == (2, 1, 2, 2) + @pytest.mark.torch def test_torch_dl(): diff --git a/tests/core/dataloaders/test_prepare_dataloader.py b/tests/core/dataloaders/test_prepare_dataloader.py index 223b7880..be6d5feb 100644 --- a/tests/core/dataloaders/test_prepare_dataloader.py +++ b/tests/core/dataloaders/test_prepare_dataloader.py @@ -2,6 +2,7 @@ import pytest from fastNLP import prepare_dataloader from fastNLP import DataSet +from fastNLP.io import DataBundle @pytest.mark.torch @@ -10,4 +11,18 @@ def test_torch(): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) dl = prepare_dataloader(ds, batch_size=2, shuffle=True) for batch in dl: - assert isinstance(batch['x'], torch.Tensor) \ No newline at end of file + assert isinstance(batch['x'], torch.Tensor) + + +@pytest.mark.torch +def test_torch_data_bundle(): + import torch + ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) + dl = DataBundle() + dl.set_dataset(dataset=ds, name='train') + dl.set_dataset(dataset=ds, name='test') + dls = prepare_dataloader(dl, batch_size=2, shuffle=True) + for dl in dls.values(): + for batch in dl: + assert isinstance(batch['x'], torch.Tensor) + assert batch['x'].size(0) == 2 diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb index 245eaf91..8312353b 100644 --- a/tutorials/fastnlp_tutorial_0.ipynb +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -15,15 +15,15 @@ "\n", "    1.3   trainer 内部初始化 evaluater\n", "\n", - "  2   使用 fastNLP 0.8 搭建 argmax 模型\n", + "  2   使用 fastNLP 搭建 argmax 模型\n", "\n", "    2.1   trainer_step 和 evaluator_step\n", "\n", "    2.2   trainer 和 evaluator 的参数匹配\n", "\n", - "    2.3   一个实际案例:argmax 模型\n", + "    2.3   示例:argmax 模型的搭建\n", "\n", - "  3   使用 fastNLP 0.8 训练 argmax 模型\n", + "  3   使用 fastNLP 训练 argmax 模型\n", " \n", "    3.1   trainer 外部初始化的 evaluator\n", "\n", @@ -248,7 +248,7 @@ "id": "f62b7bb1", "metadata": {}, "source": [ - "### 2.3 一个实际案例:argmax 模型\n", + "### 2.3 示例:argmax 模型的搭建\n", "\n", "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", "\n", @@ -271,7 +271,7 @@ "\n", "class ArgMaxModel(nn.Module):\n", " def __init__(self, num_labels, feature_dimension):\n", - " super(ArgMaxModel, self).__init__()\n", + " nn.Module.__init__(self)\n", " self.num_labels = num_labels\n", "\n", " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", @@ -434,7 +434,7 @@ "\n", "  通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", "\n", - "    但对于`\"auto\"`和`\"rich\"`格式,在notebook中,进度条在训练结束后会被丢弃\n", + "    但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n", "\n", "  通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" ] @@ -489,7 +489,7 @@ "\n", "  其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", "\n", - "  此外,可以通过`inspect.getfullargspec(trainer.run)`查询`run`函数的全部参数列表" + "  `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容" ] }, { @@ -590,7 +590,7 @@ "\n", "  其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", "\n", - "  最终,输出形如`{'acc#acc': acc}`的字典,在notebook中,进度条在评测结束后会被丢弃" + "  最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃" ] }, { @@ -603,6 +603,20 @@ } }, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -616,11 +630,11 @@ { "data": { "text/html": [ - "
{'acc#acc': 0.37, 'total#acc': 100.0, 'correct#acc': 37.0}\n",
+       "
{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}\n",
        "
\n" ], "text/plain": [ - "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.37\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m37.0\u001b[0m\u001b[1m}\u001b[0m\n" + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, @@ -629,7 +643,7 @@ { "data": { "text/plain": [ - "{'acc#acc': 0.37, 'total#acc': 100.0, 'correct#acc': 37.0}" + "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}" ] }, "execution_count": 9, @@ -650,9 +664,9 @@ "\n", "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", "\n", - "  通过`progress_bar`同时设定训练和评估进度条格式,在notebook中,在进度条训练结束后会被丢弃\n", + "  通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n", "\n", - "  **通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n", + "  但是中间的评估结果仍会保留;**通过`evaluate_every`设定评估频率**,可以为负数、正数或者函数:\n", "\n", "    **为负数时**,**表示每隔几个`epoch`评估一次**;**为正数时**,**则表示每隔几个`batch`评估一次**" ] @@ -687,9 +701,9 @@ "source": [ "通过使用`Trainer`类的`run`函数,进行训练\n", "\n", - "  还可以通过参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测,默认为2\n", + "  还可以通过**参数`num_eval_sanity_batch`决定每次训练前运行多少个`evaluate_batch`进行评测**,**默认为`2`**\n", "\n", - "  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此试探性评测" + "  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**" ] }, { @@ -702,6 +716,33 @@ } }, "outputs": [ + { + "data": { + "text/html": [ + "
[18:28:25] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -712,6 +753,490 @@ "metadata": {}, "output_type": "display_data" }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.31,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 31.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.33,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 33.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.34,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 34.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.37,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 37.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.4,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 40.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -746,6 +1271,20 @@ "id": "c4e9c619", "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -759,7 +1298,7 @@ { "data": { "text/plain": [ - "{'acc#acc': 0.47, 'total#acc': 100.0, 'correct#acc': 47.0}" + "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}" ] }, "execution_count": 12, @@ -773,9 +1312,222 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "db784d5b", "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['__annotations__',\n", + " '__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " '_check_callback_called_legality',\n", + " '_check_train_batch_loop_legality',\n", + " '_custom_callbacks',\n", + " '_driver',\n", + " '_evaluate_dataloaders',\n", + " '_fetch_matched_fn_callbacks',\n", + " '_set_num_eval_batch_per_dl',\n", + " '_train_batch_loop',\n", + " '_train_dataloader',\n", + " '_train_step',\n", + " '_train_step_signature_fn',\n", + " 'accumulation_steps',\n", + " 'add_callback_fn',\n", + " 'backward',\n", + " 'batch_idx_in_epoch',\n", + " 'batch_step_fn',\n", + " 'callback_manager',\n", + " 'check_batch_step_fn',\n", + " 'cur_epoch_idx',\n", + " 'data_device',\n", + " 'dataloader',\n", + " 'device',\n", + " 'driver',\n", + " 'driver_name',\n", + " 'epoch_evaluate',\n", + " 'evaluate_batch_step_fn',\n", + " 'evaluate_dataloaders',\n", + " 'evaluate_every',\n", + " 'evaluate_fn',\n", + " 'evaluator',\n", + " 'extract_loss_from_outputs',\n", + " 'fp16',\n", + " 'get_no_sync_context',\n", + " 'global_forward_batches',\n", + " 'has_checked_train_batch_loop',\n", + " 'input_mapping',\n", + " 'kwargs',\n", + " 'larger_better',\n", + " 'load_checkpoint',\n", + " 'load_model',\n", + " 'marker',\n", + " 'metrics',\n", + " 'model',\n", + " 'model_device',\n", + " 'monitor',\n", + " 'move_data_to_device',\n", + " 'n_epochs',\n", + " 'num_batches_per_epoch',\n", + " 'on',\n", + " 'on_after_backward',\n", + " 'on_after_optimizers_step',\n", + " 'on_after_trainer_initialized',\n", + " 'on_after_zero_grad',\n", + " 'on_before_backward',\n", + " 'on_before_optimizers_step',\n", + " 'on_before_zero_grad',\n", + " 'on_evaluate_begin',\n", + " 'on_evaluate_end',\n", + " 'on_exception',\n", + " 'on_fetch_data_begin',\n", + " 'on_fetch_data_end',\n", + " 'on_load_checkpoint',\n", + " 'on_load_model',\n", + " 'on_sanity_check_begin',\n", + " 'on_sanity_check_end',\n", + " 'on_save_checkpoint',\n", + " 'on_save_model',\n", + " 'on_train_batch_begin',\n", + " 'on_train_batch_end',\n", + " 'on_train_begin',\n", + " 'on_train_end',\n", + " 'on_train_epoch_begin',\n", + " 'on_train_epoch_end',\n", + " 'optimizers',\n", + " 'output_mapping',\n", + " 'progress_bar',\n", + " 'run',\n", + " 'run_evaluate',\n", + " 'save_checkpoint',\n", + " 'save_model',\n", + " 'start_batch_idx_in_epoch',\n", + " 'state',\n", + " 'step',\n", + " 'step_evaluate',\n", + " 'total_batches',\n", + " 'train_batch_loop',\n", + " 'train_dataloader',\n", + " 'train_fn',\n", + " 'train_step',\n", + " 'trainer_state',\n", + " 'zero_grad']" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dir(trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "953533c4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help on method run in module fastNLP.core.controllers.trainer:\n", + "\n", + "run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None) method of fastNLP.core.controllers.trainer.Trainer instance\n", + " 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;\n", + " \n", + " 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback``\n", + " 去保存断点重训的文件;\n", + " \n", + " :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度;\n", + " :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度;\n", + " :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;\n", + " :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;\n", + " :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,\n", + " 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的\n", + " 其余状态都是保持初始化时的状态不会改变;\n", + " :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,\n", + " ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``\n", + " 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;\n", + " \n", + " .. warning::\n", + " \n", + " 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时\n", + " ``trainer.cur_epoch_idx == trainer.n_epochs``;\n", + " \n", + " 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``;\n", + " \n", + " .. note::\n", + " \n", + " 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,\n", + " 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将\n", + " 该值设定为 **-1** 开始真正的训练;\n", + " \n", + " ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证\n", + " 整体的验证流程是否正确;\n", + " \n", + " ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;\n", + " 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的\n", + " 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``\n", + " 进行验证时会验证的 batch 的数量。\n", + " \n", + " 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``\n", + " 应当为一个很小的正整数,例如 2;\n", + " \n", + " .. note::\n", + " \n", + " 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;\n", + " \n", + " 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch\n", + " 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;\n", + " \n", + " fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:\n", + " \n", + " 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;\n", + " ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;\n", + " 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``;\n", + " \n", + " 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。\n", + "\n" + ] + } + ], + "source": [ + "help(trainer.run)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bc7cb4a", + "metadata": {}, "outputs": [], "source": [] } diff --git a/tutorials/fastnlp_tutorial_2.ipynb b/tutorials/fastnlp_tutorial_2.ipynb index 3aa27c86..64c4bc8b 100644 --- a/tutorials/fastnlp_tutorial_2.ipynb +++ b/tutorials/fastnlp_tutorial_2.ipynb @@ -281,13 +281,13 @@ "## 2. fastNLP 中的 tokenizer\n", "\n", "### 2.1 PreTrainTokenizer 的提出\n", - "\n", + "\n", "在`fastNLP 0.8`中,**使用`PreTrainedTokenizer`模块来为数据集中的词语进行词向量的标注**\n", "\n", "  需要注意的是,`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了`transformers`模块**\n", @@ -867,7 +867,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" }, "pycharm": { "stem_cell": { diff --git a/tutorials/fastnlp_tutorial_3.ipynb b/tutorials/fastnlp_tutorial_3.ipynb index 353e4645..172e1232 100644 --- a/tutorials/fastnlp_tutorial_3.ipynb +++ b/tutorials/fastnlp_tutorial_3.ipynb @@ -29,13 +29,7 @@ "\n", "### 1.1 dataloader 的职责描述\n", "\n", - "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n", - "\n", - "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n", - "\n", - "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n", - "\n", - "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过" + "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前" ] }, { @@ -45,13 +39,7 @@ "source": [ "### 1.2 dataloader 的基本使用\n", "\n", - "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前,还存在其他的一些模块,负责例如对文本数据\n", - "\n", - "  进行补零对齐,即 **核对器`collator`模块**,进行分词标注,即 **分词器`tokenizer`模块**\n", - "\n", - "  本节将对`fastNLP`中的核对器`collator`等展开介绍,分词器`tokenizer`将在下一节中详细介绍\n", - "\n", - "在`fastNLP 0.8`中,**核对器`collator`模块负责文本序列的补零对齐**,通过" + "在`fastNLP 0.8`中,在数据加载模块`DataLoader`之前," ] }, { diff --git a/tutorials/fastnlp_tutorial_4.ipynb b/tutorials/fastnlp_tutorial_4.ipynb index 532118b0..ee5a0c6b 100644 --- a/tutorials/fastnlp_tutorial_4.ipynb +++ b/tutorials/fastnlp_tutorial_4.ipynb @@ -5,32 +5,292 @@ "id": "fdd7ff16", "metadata": {}, "source": [ - "# T4. model 的搭建与 driver 的概念\n", + "# T4. trainer 和 evaluator 的深入介绍\n", "\n", - "  1   fastNLP 中预训练模型的使用\n", + "  1   fastNLP 中的更多 metric 类型\n", + "\n", + "    1.1   预定义的 metric 类型\n", + "\n", + "    1.2   自定义的 metric 类型\n", + "\n", + "  2   fastNLP 中 trainer 的补充介绍\n", " \n", - "    1.1   \n", + "    2.1   trainer 的提出构想 \n", "\n", - "    1.2   \n", + "    2.2   trainer 的内部结构\n", "\n", - "  2   fastNLP 中使用 Pytorch 搭建模型\n", + "    2.3   实例:\n", "\n", - "    2.1   \n", + "  3   fastNLP 中的 driver 与 device\n", "\n", - "    2.2   \n", + "    3.1   driver 的提出构想\n", + "\n", + "    3.2   device 与多卡训练" + ] + }, + { + "cell_type": "markdown", + "id": "8d19220c", + "metadata": {}, + "source": [ + "## 1. fastNLP 中的更多 metric 类型\n", "\n", - "  3   fastNLP 中的 driver\n", + "### 1.1 预定义的 metric 类型\n", "\n", - "    3.1   \n", + "在`fastNLP 0.8`中,除了前几篇`tutorial`中经常见到的**正确率`Accuracy`**,还有其他**预定义的评价标准`metric`**\n", + "\n", + "  包括**所有`metric`的基类`Metric`**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", + "\n", + "    **适用于分类语境下的`F1`值`ClassifyFPreRecMetric`**(其中也包括**召回率`Pre`**、**精确率`Rec`**\n", + "\n", + "    **适用于抽取语境下的`F1`值`SpanFPreRecMetric`**;相关基本信息内容见下表,之后是详细分析\n", + "\n", + "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|:--|:--|:--|\n", + "| `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", + "| `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", + "| `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", + "| `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", + "| `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" + ] + }, + { + "cell_type": "markdown", + "id": "fdc083a3", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "大概的描述一下,给出各个正确率的计算公式" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9775ea5e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "8a22f522", + "metadata": {}, + "source": [ + "### 2.2 自定义的 metric 类型\n", "\n", - "    3.2   " + "在`fastNLP 0.8`中,  给一个案例,训练部分留到trainer部分" ] }, { "cell_type": "code", "execution_count": null, + "id": "d8caba1d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e6247dd", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", "id": "08752c5a", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## 2. fastNLP 中 trainer 的补充介绍\n", + "\n", + "### 2.1 trainer 的提出构想\n", + "\n", + "在`fastNLP 0.8`中,  " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "977a6355", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69203cdc", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "ab1cea7d", + "metadata": {}, + "source": [ + "### 2.2 trainer 的内部结构\n", + "\n", + "在`fastNLP 0.8`中,  \n", + "\n", + "'accumulation_steps', 'add_callback_fn', 'backward', 'batch_idx_in_epoch', 'batch_step_fn',\n", + "'callback_manager', 'check_batch_step_fn', 'cur_epoch_idx', 'data_device', 'dataloader',\n", + "'device', 'driver', 'driver_name', 'epoch_evaluate', 'evaluate_batch_step_fn', 'evaluate_dataloaders',\n", + "'evaluate_every', 'evaluate_fn', 'evaluator', 'extract_loss_from_outputs', 'fp16',\n", + "'get_no_sync_context', 'global_forward_batches', 'has_checked_train_batch_loop',\n", + "'input_mapping', 'kwargs', 'larger_better', 'load_checkpoint', 'load_model', 'marker',\n", + "'metrics', 'model', 'model_device', 'monitor', 'move_data_to_device', 'n_epochs', 'num_batches_per_epoch',\n", + "'on', 'on_after_backward', 'on_after_optimizers_step', 'on_after_trainer_initialized',\n", + "'on_after_zero_grad', 'on_before_backward', 'on_before_optimizers_step', 'on_before_zero_grad',\n", + "'on_evaluate_begin', 'on_evaluate_end', 'on_exception', 'on_fetch_data_begin', 'on_fetch_data_end',\n", + "'on_load_checkpoint', 'on_load_model', 'on_sanity_check_begin', 'on_sanity_check_end',\n", + "'on_save_checkpoint', 'on_save_model', 'on_train_batch_begin', 'on_train_batch_end',\n", + "'on_train_begin', 'on_train_end', 'on_train_epoch_begin', 'on_train_epoch_end',\n", + "'optimizers', 'output_mapping', 'progress_bar', 'run', 'run_evaluate',\n", + "'save_checkpoint', 'save_model', 'start_batch_idx_in_epoch', 'state',\n", + "'step', 'step_evaluate', 'total_batches', 'train_batch_loop', 'train_dataloader', 'train_fn', 'train_step',\n", + "'trainer_state', 'zero_grad'\n", + "\n", + "  run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3c8342e", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d28f2624", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "ce6322b4", "metadata": {}, + "source": [ + "### 2.3 实例:\n", + "\n", + "在`fastNLP 0.8`中,  " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43be274f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c348864c", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "175d6ebb", + "metadata": {}, + "source": [ + "## 3. fastNLP 中的 driver 与 device\n", + "\n", + "### 3.1 driver 的提出构想\n", + "\n", + "在`fastNLP 0.8`中,  " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "47100e7a", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0204a223", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "6e723b87", + "metadata": {}, + "source": [ + "### 3.2 device 与多卡训练\n", + "\n", + "在`fastNLP 0.8`中,  " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ad81ac7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfb28b1b", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [] } @@ -51,7 +311,16 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } } }, "nbformat": 4, diff --git a/tutorials/fastnlp_tutorial_5.ipynb b/tutorials/fastnlp_tutorial_5.ipynb new file mode 100644 index 00000000..cb105c89 --- /dev/null +++ b/tutorials/fastnlp_tutorial_5.ipynb @@ -0,0 +1,1958 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T5. fastNLP 中的预定义模型\n", + "\n", + "  1   fastNLP 中 modules 的介绍\n", + " \n", + "    1.1   modules 模块、models 模块 简介\n", + "\n", + "    1.2   示例一:modules 实现 LSTM 分类\n", + "\n", + "  2   fastNLP 中 models 的介绍\n", + " \n", + "    2.1   示例一:models 实现 CNN 分类\n", + "\n", + "    2.3   示例二:models 实现 BiLSTM 标注" + ] + }, + { + "cell_type": "markdown", + "id": "d3d65d53", + "metadata": {}, + "source": [ + "## 1. fastNLP 中 modules 模块的介绍\n", + "\n", + "### 1.1 modules 模块、models 模块 简介\n", + "\n", + "在`fastNLP 0.8`中,**`modules.torch`路径下定义了一些基于`pytorch`实现的基础模块**\n", + "\n", + "    包括长短期记忆网络`LSTM`、条件随机场`CRF`、`transformer`的编解码器模块等,详见下表\n", + "\n", + "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|:--|:--|:--|\n", + "| `LSTM` | 轻量封装`pytorch`的`LSTM` | `/modules/torch/encoder/lstm.py` |\n", + "| `Seq2SeqEncoder` | 序列变换编码器,基类 | `/modules/torch/encoder/seq2seq_encoder.py` |\n", + "| `LSTMSeq2SeqEncoder` | 序列变换编码器,基于`LSTM` | `/modules/torch/encoder/seq2seq_encoder.py` |\n", + "| `TransformerSeq2SeqEncoder` | 序列变换编码器,基于`transformer` | `/modules/torch/encoder/seq2seq_encoder.py` |\n", + "| `StarTransformer` | `Star-Transformer`的编码器部分 | `/modules/torch/encoder/star_transformer.py` |\n", + "| `VarRNN` | 实现`Variational Dropout RNN` | `/modules/torch/encoder/variational_rnn.py` |\n", + "| `VarLSTM` | 实现`Variational Dropout LSTM` | `/modules/torch/encoder/variational_rnn.py` |\n", + "| `VarGRU` | 实现`Variational Dropout GRU` | `/modules/torch/encoder/variational_rnn.py` |\n", + "| `ConditionalRandomField` | 条件随机场模型 | `/modules/torch/decoder/crf.py` |\n", + "| `Seq2SeqDecoder` | 序列变换解码器,基类 | `/modules/torch/decoder/seq2seq_decoder.py` |\n", + "| `LSTMSeq2SeqDecoder` | 序列变换解码器,基于`LSTM` | `/modules/torch/decoder/seq2seq_decoder.py` |\n", + "| `TransformerSeq2SeqDecoder` | 序列变换解码器,基于`transformer` | `/modules/torch/decoder/seq2seq_decoder.py` |\n", + "| `SequenceGenerator` | 序列生成,封装`Seq2SeqDecoder` | `/models/torch/sequence_labeling.py` |\n", + "| `TimestepDropout` | 在每个`timestamp`上`dropout` | `/modules/torch/dropout.py` |" + ] + }, + { + "cell_type": "markdown", + "id": "89ffcf07", + "metadata": {}, + "source": [ + "  **`models.torch`路径下定义了一些基于`pytorch`、`modules`实现的预定义模型** \n", + "\n", + "    例如基于`CNN`的分类模型、基于`BiLSTM+CRF`的标注模型、基于[双仿射注意力机制](https://arxiv.org/pdf/1611.01734.pdf)的分析模型\n", + "\n", + "    基于`modules.torch`中的`LSTM`/`transformer`编/解码器模块的序列变换/生成模型,详见下表\n", + "\n", + "|
代码名称
|
简要介绍
|
代码路径
|\n", + "|:--|:--|:--|\n", + "| `BiaffineParser` | 句法分析模型,基于双仿射注意力 | `/models/torch/biaffine_parser.py` |\n", + "| `CNNText` | 文本分类模型,基于`CNN` | `/models/torch/cnn_text_classification.py` |\n", + "| `Seq2SeqModel` | 序列变换,基类`encoder+decoder` | `/models/torch/seq2seq_model.py` |\n", + "| `LSTMSeq2SeqModel` | 序列变换,基于`LSTM` | `/models/torch/seq2seq_model.py` |\n", + "| `TransformerSeq2SeqModel` | 序列变换,基于`transformer` | `/models/torch/seq2seq_model.py` |\n", + "| `SequenceGeneratorModel` | 封装`Seq2SeqModel`,结合`SequenceGenerator` | `/models/torch/seq2seq_generator.py` |\n", + "| `SeqLabeling` | 标注模型,基类`LSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n", + "| `BiLSTMCRF` | 标注模型,`BiLSTM+FC+CRF` | `/models/torch/sequence_labeling.py` |\n", + "| `AdvSeqLabel` | 标注模型,`LN+BiLSTM*2+LN+FC+CRF` | `/models/torch/sequence_labeling.py` |" + ] + }, + { + "cell_type": "markdown", + "id": "61318354", + "metadata": {}, + "source": [ + "上述`fastNLP`模块,不仅**为入门级用户提供了简单易用的工具**,以解决各种`NLP`任务,或复现相关论文\n", + "\n", + "  同时**也为专业研究人员提供了便捷可操作的接口**,封装部分代码的同时,也能指定参数修改细节\n", + "\n", + "  在接下来的`tutorial`中,我们将通过`SST-2`分类和`CoNLL-2003`标注,展示相关模型使用\n", + "\n", + "注一:**`SST`**,**单句情感分类**数据集,包含电影评论和对应情感极性,1 对应正面情感,0 对应负面情感\n", + "\n", + "  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "\n", + "注二:**`CoNLL-2003`**,**文本语法标注**数据集,包含语句和对应的词性标签`pos_tags`(名动形数量代)\n", + "\n", + "  语法结构标签`chunk_tags`(主谓宾定状补)、命名实体标签`ner_tags`(人名、组织名、地名、时间等)\n", + "\n", + "  数据集包括三部分:训练集 14041 条,验证集 3250 条,测试集 3453 条,更多参考[原始论文](https://aclanthology.org/W03-0419.pdf)" + ] + }, + { + "cell_type": "markdown", + "id": "2a36bbe4", + "metadata": {}, + "source": [ + "### 1.2 示例一:modules 实现 LSTM 分类" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "40e66b21", + "metadata": {}, + "outputs": [], + "source": [ + "# import sys\n", + "# sys.path.append('..')\n", + "\n", + "# from fastNLP.io import SST2Pipe # 没有 SST2Pipe 会运行很长时间,并且还会报错\n", + "\n", + "# databundle = SST2Pipe(tokenizer='raw').process_from_file()\n", + "\n", + "# dataset = databundle.get_dataset('train')[:6000]\n", + "\n", + "# dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + "# progress_bar=\"tqdm\")\n", + "# dataset.delete_field('sentence')\n", + "# dataset.delete_field('label')\n", + "# dataset.delete_field('idx')\n", + "\n", + "# from fastNLP import Vocabulary\n", + "\n", + "# vocab = Vocabulary()\n", + "# vocab.from_dataset(dataset, field_name='words')\n", + "# vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "# train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "50960476", + "metadata": {}, + "outputs": [], + "source": [ + "# from fastNLP import prepare_torch_dataloader\n", + "\n", + "# train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "# evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0b25b25c", + "metadata": {}, + "outputs": [], + "source": [ + "# import torch\n", + "# import torch.nn as nn\n", + "\n", + "# from fastNLP.modules.torch import LSTM, MLP # 没有 MLP\n", + "# from fastNLP import Embedding, CrossEntropyLoss\n", + "\n", + "\n", + "# class ClsByModules(nn.Module):\n", + "# def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + "# nn.Module.__init__(self)\n", + "\n", + "# self.embedding = Embedding((vocab_size, embedding_dim))\n", + "# self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", + "# self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", + " \n", + "# self.loss_fn = CrossEntropyLoss()\n", + "\n", + "# def forward(self, words):\n", + "# output = self.embedding(words)\n", + "# output, (hidden, cell) = self.lstm(output)\n", + "# output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", + "# return output\n", + " \n", + "# def train_step(self, words, target):\n", + "# pred = self(words)\n", + "# return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + "# def evaluate_step(self, words, target):\n", + "# pred = self(words)\n", + "# pred = torch.max(pred, dim=-1)[1]\n", + "# return {\"pred\": pred, \"target\": target}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9dbbf50d", + "metadata": {}, + "outputs": [], + "source": [ + "# model = ClsByModules(vocab_size=len(vocabulary), embedding_dim=100, output_dim=2)\n", + "\n", + "# from torch.optim import AdamW\n", + "\n", + "# optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7a93432f", + "metadata": {}, + "outputs": [], + "source": [ + "# from fastNLP import Trainer, Accuracy\n", + "\n", + "# trainer = Trainer(\n", + "# model=model,\n", + "# driver='torch',\n", + "# device=0, # 'cuda'\n", + "# n_epochs=10,\n", + "# optimizers=optimizers,\n", + "# train_dataloader=train_dataloader,\n", + "# evaluate_dataloaders=evaluate_dataloader,\n", + "# metrics={'acc': Accuracy()}\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "31102e0f", + "metadata": {}, + "outputs": [], + "source": [ + "# trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8bc4bfb2", + "metadata": {}, + "outputs": [], + "source": [ + "# trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "d9443213", + "metadata": {}, + "source": [ + "## 2. fastNLP 中 models 模块的介绍\n", + "\n", + "### 2.1 示例一:models 实现 CNN 分类\n", + "\n", + "  本示例使用`fastNLP 0.8`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n", + "\n", + "模型使用方面,如上所述,这里使用**基于卷积神经网络`CNN`的预定义文本分类模型`CNNText`**,结构如下所示\n", + "\n", + "  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n", + "\n", + "    **感受野为`1`、`3`、`5`的卷积算子变换至`30`维、`40`维、`50`维的卷积特征**,再将三者拼接\n", + "\n", + "  最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n", + "\n", + "```\n", + "CNNText(\n", + " (embed): Embedding(\n", + " (embed): Embedding(5194, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (conv_pool): ConvMaxpool(\n", + " (convs): ModuleList(\n", + " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n", + " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n", + " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=120, out_features=2, bias=True)\n", + ")\n", + "```\n", + "\n", + "数据使用方面,此处**使用`datasets`模块中的`load_dataset`函数**,以如下形式,指定`SST-2`数据集自动加载\n", + "\n", + "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1aa5cf6d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n", + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "70cde65067c64fdba1d5e798e2b8d631", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00[17:45:59] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[17:45:59]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=147745;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=708408;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.575,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 92.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.575\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m92.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 121.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.78125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.80625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 129.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.825,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 132.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.825\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m132.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.79, 'total#acc': 900.0, 'correct#acc': 711.0}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "7c811257",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "c1a2e2ca",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "342"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "del dataset\n",
+    "del sst2data\n",
+    "\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6aec2a19",
+   "metadata": {},
+   "source": [
+    "### 2.2  示例二:models 实现 BiLSTM 标注\n",
+    "\n",
+    "  通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
+    "\n",
+    "    针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
+    "\n",
+    "  避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
+    "\n",
+    "模型使用方面,如上所述,这里使用**基于双向`LSTM`+条件随机场`CRF`的标注模型`BiLSTMCRF`**,结构如下所示\n",
+    "\n",
+    "  其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
+    "\n",
+    "```\n",
+    "BiLSTMCRF(\n",
+    "  (embed): Embedding(7590, 100)\n",
+    "  (lstm): LSTM(\n",
+    "    (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
+    "  )\n",
+    "  (dropout): Dropout(p=0.1, inplace=False)\n",
+    "  (fc): Linear(in_features=200, out_features=9, bias=True)\n",
+    "  (crf): ConditionalRandomField()\n",
+    ")\n",
+    "```\n",
+    "\n",
+    "数据使用方面,此处仍然**使用`datasets`模块中的`load_dataset`函数**,以如下形式,加载`CoNLL-2003`数据集\n",
+    "\n",
+    "  首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "03e66686",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "3ec9e0ce9a054339a2453420c2c9f28b",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/3 [00:00[17:49:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "\n"
+      ],
+      "text/plain": [
+       "\u001b[2;36m[17:49:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=766109;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=787419;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.220374,\n",
+       "  \"pre#F1\": 0.25,\n",
+       "  \"rec#F1\": 0.197026\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.220374\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.25\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.197026\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.442857,\n",
+       "  \"pre#F1\": 0.426117,\n",
+       "  \"rec#F1\": 0.460967\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.442857\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.426117\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.460967\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.572954,\n",
+       "  \"pre#F1\": 0.549488,\n",
+       "  \"rec#F1\": 0.598513\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.572954\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.549488\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.598513\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.665399,\n",
+       "  \"pre#F1\": 0.680934,\n",
+       "  \"rec#F1\": 0.650558\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.665399\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.680934\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.650558\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.734694,\n",
+       "  \"pre#F1\": 0.733333,\n",
+       "  \"rec#F1\": 0.736059\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.734694\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.733333\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.736059\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.742647,\n",
+       "  \"pre#F1\": 0.734545,\n",
+       "  \"rec#F1\": 0.750929\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.742647\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.734545\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.750929\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.773585,\n",
+       "  \"pre#F1\": 0.785441,\n",
+       "  \"rec#F1\": 0.762082\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.773585\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.785441\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.762082\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.770115,\n",
+       "  \"pre#F1\": 0.794466,\n",
+       "  \"rec#F1\": 0.747212\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.770115\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.794466\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.747212\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.7603,\n",
+       "  \"pre#F1\": 0.766038,\n",
+       "  \"rec#F1\": 0.754647\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.7603\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.766038\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.754647\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.743682,\n",
+       "  \"pre#F1\": 0.722807,\n",
+       "  \"rec#F1\": 0.765799\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.743682\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.722807\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.765799\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'f#F1': 0.75283, 'pre#F1': 0.727438, 'rec#F1': 0.780059}"
+      ]
+     },
+     "execution_count": 22,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "96bae094",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/fastnlp_tutorial_6.ipynb b/tutorials/fastnlp_tutorial_6.ipynb
new file mode 100644
index 00000000..2052189e
--- /dev/null
+++ b/tutorials/fastnlp_tutorial_6.ipynb
@@ -0,0 +1,59 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fdd7ff16",
+   "metadata": {},
+   "source": [
+    "# T6. fastNLP 与 paddle 或 jittor 的结合\n",
+    "\n",
+    "  1   fastNLP 结合 paddle 训练模型\n",
+    " \n",
+    "    1.1   关于 paddle 的简单介绍\n",
+    "\n",
+    "    1.2   使用 paddle 搭建并训练模型\n",
+    "\n",
+    "  2   fastNLP 结合 jittor 训练模型\n",
+    "\n",
+    "    2.1   关于 jittor 的简单介绍\n",
+    "\n",
+    "    2.2   使用 jittor 搭建并训练模型\n",
+    "\n",
+    "  3   fastNLP 实现 paddle 与 pytorch 互转\n",
+    "\n",
+    "    3.1   \n",
+    "\n",
+    "    3.2   "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "08752c5a",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/fastnlp_tutorial_7.ipynb b/tutorials/fastnlp_tutorial_7.ipynb
new file mode 100644
index 00000000..0a7d6922
--- /dev/null
+++ b/tutorials/fastnlp_tutorial_7.ipynb
@@ -0,0 +1,59 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fdd7ff16",
+   "metadata": {},
+   "source": [
+    "# T7. callback 自定义训练过程\n",
+    "\n",
+    "  1   \n",
+    " \n",
+    "    1.1   \n",
+    "\n",
+    "    1.2   \n",
+    "\n",
+    "  2   \n",
+    "\n",
+    "    2.1   \n",
+    "\n",
+    "    2.2   \n",
+    "\n",
+    "  3   \n",
+    "\n",
+    "    3.1   \n",
+    "\n",
+    "    3.2   "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "08752c5a",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/fastnlp_tutorial_8.ipynb b/tutorials/fastnlp_tutorial_8.ipynb
new file mode 100644
index 00000000..0664bc41
--- /dev/null
+++ b/tutorials/fastnlp_tutorial_8.ipynb
@@ -0,0 +1,59 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fdd7ff16",
+   "metadata": {},
+   "source": [
+    "# T8. fastNLP 中的文件读取模块\n",
+    "\n",
+    "  1   fastNLP 中的 EmbedLoader 模块\n",
+    " \n",
+    "    1.1   \n",
+    "\n",
+    "    1.2   \n",
+    "\n",
+    "  2   fastNLP 中的 Loader 模块\n",
+    "\n",
+    "    2.1   \n",
+    "\n",
+    "    2.2   \n",
+    "\n",
+    "  3   fastNLP 中的 Pipe 模块\n",
+    "\n",
+    "    3.1   \n",
+    "\n",
+    "    3.2   "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "08752c5a",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/fastnlp_tutorial_e1.ipynb b/tutorials/fastnlp_tutorial_e1.ipynb
index 6ec04cb4..b0df8fc1 100644
--- a/tutorials/fastnlp_tutorial_e1.ipynb
+++ b/tutorials/fastnlp_tutorial_e1.ipynb
@@ -6,16 +6,16 @@
    "source": [
     "  从这篇开始,我们将开启**`fastNLP v0.8 tutorial`的`example`系列**,在接下来的\n",
     "\n",
-    "  每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在一些自然语言处理任务上的应用"
+    "  每篇`tutorial`里,我们将会介绍`fastNLP v0.8`在自然语言处理任务上的应用实例"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "# E1. 使用 Bert + fine-tuning 完成 SST2 分类\n",
+    "# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n",
     "\n",
-    "  1   基础介绍:`GLUE`通用语言理解评估、`SST2`文本情感二分类数据集 \n",
+    "  1   基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n",
     "\n",
     "  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n",
     "\n",
@@ -63,7 +63,7 @@
     "\n",
     "import fastNLP\n",
     "from fastNLP import Trainer\n",
-    "from fastNLP.core.metrics import Accuracy\n",
+    "from fastNLP import Accuracy\n",
     "\n",
     "print(transformers.__version__)"
    ]
@@ -72,19 +72,19 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "### 1. 基础介绍:GLUE 通用语言理解评估、SST2 文本情感二分类数据集\n",
+    "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n",
     "\n",
-    "  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`fine-tuning`方式\n",
+    "  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n",
     "\n",
-    "    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST2`\n",
+    "    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n",
     "\n",
     "**`GLUE`**,**全称`General Language Understanding Evaluation`**,**通用语言理解评估**,\n",
     "\n",
     "  包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n",
     "\n",
-    "    **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST2`**,文本分类任务,预测单句情感二分类\n",
+    "    **`CoLA`**,文本分类任务,预测单句语法正误分类;**`SST-2`**,文本分类任务,预测单句情感二分类\n",
     "\n",
-    "    **`MRPC`**,句对分类任务,预测句对语义一致性;**`STSB`**,相似度打分任务,预测句对语义相似度回归\n",
+    "    **`MRPC`**,句对分类任务,预测句对语义一致性;**`STS-B`**,相似度打分任务,预测句对语义相似度回归\n",
     "\n",
     "    **`QQP`**,句对分类任务,预测问题对语义一致性;**`MNLI`**,文本推理任务,预测句对蕴含/矛盾/中立预测\n",
     "\n",
@@ -92,7 +92,7 @@
     "\n",
     "  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n",
     "\n",
-    "    此处,我们使用`SST2`来训练`bert`,实现文本分类,其他任务描述见下图"
+    "    此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图"
    ]
   },
   {
@@ -116,9 +116,9 @@
     "\n",
     "  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n",
     "\n",
-    "  数据集包括三部分:训练集 67350 条,开发集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
+    "  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n",
     "\n",
-    "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST2`数据集,自动加载\n",
+    "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n",
     "\n",
     "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下"
    ]
@@ -134,14 +134,13 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Using the latest cached version of the module from /remote-home/xrliu/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu May 26 15:30:15 2022) since it couldn't be found locally at glue., or remotely on the Hugging Face Hub.\n",
       "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
      ]
     },
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "adc9449171454f658285f220b70126e1",
+       "model_id": "c5915debacf9443986b5b3b34870b303",
        "version_major": 2,
        "version_minor": 0
       },
@@ -163,7 +162,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "  加载之后,根据`GLUE`中`SST2`数据集的格式,尝试打印部分数据,检查加载结果"
+    "  加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果"
    ]
   },
   {
@@ -216,15 +215,15 @@
     "\n",
     "  即使用较小的、不区分大小写的数据集,**对`bert-base`进行知识蒸馏后的版本**,结构上\n",
     "\n",
-    "  模型包含1个编码层、6个自注意力层,详解见本篇末尾,更多细节请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n",
+    "  包含**1个编码层**、**6个自注意力层**,**参数量`66M`**,详解见本篇末尾,更多请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n",
     "\n",
-    "首先,通过从`transformers`库中导入`AutoTokenizer`模块,使用`from_pretrained`函数初始化\n",
+    "首先,通过从`transformers`库中导入**`AutoTokenizer`模块**,**使用`from_pretrained`函数初始化**\n",
     "\n",
     "  此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n",
     "\n",
-    "  需要注意的是,处理后返回的两个键值,`'input_ids'`表示原始文本对应的词素编号序列\n",
+    "  需要注意的是,处理后返回的两个键值,**`'input_ids'`**表示原始文本对应的词素编号序列\n",
     "\n",
-    "    `'attention_mask'`表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容"
+    "    **`'attention_mask'`**表示自注意力运算时的掩模(标上`0`的部分对应`padding`的内容"
    ]
   },
   {
@@ -287,7 +286,7 @@
     "\n",
     "  其中,**`__getitem__`函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n",
     "\n",
-    "  例如,`'label'`是`SST2`数据集中原有的内容(包括`'sentence'`和`'label'`\n",
+    "  例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n",
     "\n",
     "    `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段"
    ]
@@ -440,10 +439,10 @@
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']\n",
+      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n",
       "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
       "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
-      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier.bias', 'pre_classifier.weight']\n",
+      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n",
       "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
      ]
     }
@@ -472,7 +471,7 @@
     "trainer = Trainer(\n",
     "    model=model,\n",
     "    driver='torch',\n",
-    "    device=1,  # 'cuda'\n",
+    "    device=0,  # 'cuda'\n",
     "    n_epochs=10,\n",
     "    optimizers=optimizers,\n",
     "    train_dataloader=dataloader_train,\n",
@@ -495,6 +494,33 @@
    "execution_count": 13,
    "metadata": {},
    "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
[09:12:45] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -505,6 +531,490 @@ "metadata": {}, "output_type": "display_data" }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.878125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 281.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.9,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 288.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 284.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.88125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 282.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 280.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.865625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 277.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.878125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 281.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ @@ -540,10 +1050,14 @@ "outputs": [ { "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
      },
      "metadata": {},
      "output_type": "display_data"
@@ -561,7 +1075,7 @@
     {
      "data": {
       "text/plain": [
-       "{'acc#acc': 0.87156, 'total#acc': 872.0, 'correct#acc': 760.0}"
+       "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}"
       ]
      },
      "execution_count": 14,
diff --git a/tutorials/fastnlp_tutorial_e2.ipynb b/tutorials/fastnlp_tutorial_e2.ipynb
index 93143090..ca0b4754 100644
--- a/tutorials/fastnlp_tutorial_e2.ipynb
+++ b/tutorials/fastnlp_tutorial_e2.ipynb
@@ -4,7 +4,7 @@
    "cell_type": "markdown",
    "metadata": {},
    "source": [
-    "# E2. 使用 Bert + prompt 完成 SST2 分类\n",
+    "# E2. 使用 Bert + prompt 完成 SST-2 分类\n",
     "\n",
     "  1   基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n",
     "\n",
@@ -19,37 +19,71 @@
    "source": [
     "### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n",
     "\n",
-    "  本示例使用`GLUE`评估基准中的`SST2`数据集,通过`prompt-based tuning`方式\n",
+    "  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`prompt-based tuning`方式\n",
     "\n",
     "    微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n",
     "\n",
     "    将首先简单介绍提示学习模型的研究,以及与`fastNLP v0.8`结合的优势\n",
     "\n",
-    "**`prompt`**,**提示词、提词器**,最早出自**`PET`**,\n",
+    "**`prompt`**,**提示词**,最早出自论文[Exploiting Cloze Questions for Few Shot TC and NLI](https://arxiv.org/pdf/2001.07676.pdf)中的**`PET`模型**\n",
     "\n",
-    "  \n",
+    "    全称 **`Pattern-Exploiting Training`**,虽然文中并没有提到`prompt`的说法,但仍被视为开山之作\n",
     "\n",
-    "**`prompt-based tuning`**,**基于提示的微调**,描述\n",
+    "  其大致思路包括,对于文本分类任务,假定输入文本为`\" X . \"`,设计**输入模板`template`**,**后来被称为`prompt`**\n",
     "\n",
-    "  **`prompt-based model`**,**基于提示的模型**\n",
+    "    将输入重构为`\" X . It is [MASK] . \"`,**诱导或刺激语言模型在`[MASK]`位置生成含有情感倾向的词汇**\n",
     "\n",
-    "**`prompt-based model`**,**基于提示的模型**,举例\n",
+    "    接着将该词汇**输入分类器中**,**后来被称为`verbalizer`**,从而得到该语句对应的情感倾向,实现文本分类\n",
     "\n",
-    "  案例一:**`P-Tuning v1`**\n",
+    "  其主要贡献在于,通过构造`prompt`,诱导/刺激预训练模型生成期望适应下游任务特征,适合少样本学习的需求\n",
     "\n",
-    "  案例二:**`PromptTuning`**\n",
+    "\n",
     "\n",
-    "  案例三:**`PrefixTuning`**\n",
+    "**`prompt-based tuning`**,**基于提示的微调**,将`prompt`应用于**参数高效微调**,**`parameter-efficient tuning`**\n",
     "\n",
-    "  案例四:**`SoftPrompt`**\n",
+    "  通过**设计模板调整模型输入**或者**调整模型内部状态**,**固定预训练模型**,**诱导/刺激模型**调整输出以适应\n",
     "\n",
-    "使用`fastNLP v0.8`实现`prompt-based model`的优势\n",
+    "  当前任务,极大降低了训练开销,也省去了`verbalizer`的构造,更多参考[prompt综述](https://arxiv.org/pdf/2107.13586.pdf)、[DeltaTuning综述](https://arxiv.org/pdf/2203.06904.pdf)\n",
     "\n",
-    "  \n",
+    "    以下列举些经典的`prompt-based tuning`案例,简单地介绍下`prompt-based tuning`的脉络\n",
     "\n",
-    "  本示例仍使用了`tutorial-E1`的`SST2`数据集,将`bert-base-uncased`作为基础模型\n",
+    "  **案例一**:**`PrefixTuning`**,详细内容参考[PrefixTuning论文](https://arxiv.org/pdf/2101.00190.pdf)\n",
     "\n",
-    "    在后续实现中,意图通过将连续的`prompt`与`model`拼接,解决`SST2`二分类任务"
+    "    其主要贡献在于,**提出连续的、非人工构造的、任务导向的`prompt`**,即**前缀`prefix`**,**调整**\n",
+    "\n",
+    "      **模型内部更新状态**,诱导模型在特定任务下生成期望目标,降低优化难度,提升微调效果\n",
+    "\n",
+    "    其主要研究对象,是`GPT2`和`BART`,主要面向生成任务`NLG`,如`table-to-text`和摘要\n",
+    "\n",
+    "  **案例二**:**`P-Tuning v1`**,详细内容参考[P-Tuning-v1论文](https://arxiv.org/pdf/2103.10385.pdf)\n",
+    "\n",
+    "    其主要贡献在于,**通过连续的、非人工构造的`prompt`调整模型输入**,取代原先基于单词设计的\n",
+    "\n",
+    "      但离散且不易于优化的`prompt`;同时也**证明了`GPT2`在语言理解任务上仍然是可以胜任的**\n",
+    "\n",
+    "    其主要研究对象,是`GPT2`,主要面向知识探测`knowledge probing`和自然语言理解`NLU`\n",
+    "\n",
+    "  **案例三**:**`PromptTuning`**,详细内容参考[PromptTuning论文](https://arxiv.org/pdf/2104.08691.pdf)\n",
+    "\n",
+    "    其主要贡献在于,通过连续的`prompt`调整模型输入,**证明了`prompt-based tuning`的效果**\n",
+    "\n",
+    "      **随模型参数量的增加而提升**,最终**在`10B`左右追上了全参数微调`fine-tuning`的效果**\n",
+    "\n",
+    "    其主要面向自然语言理解`NLU`,通过为每个任务定义不同的`prompt`,从而支持多任务语境\n",
+    "\n",
+    "通过上述介绍可以发现`prompt-based tuning`只是模型微调方式,独立于预训练模型基础`backbone`\n",
+    "\n",
+    "  目前,加载预训练模型的主流方法是使用**`transformers`模块**,而实现微调的框架则\n",
+    "\n",
+    "    可以是`pytorch`、`paddle`、`jittor`等,而不同框架间又存在不兼容的问题\n",
+    "\n",
+    "  因此,**使用`fastNLP v0.8`实现`prompt-based tuning`**,可以**很好地解决`paddle`等框架**\n",
+    "\n",
+    "    **和`transformers`模块之间的桥接**(`transformers`模块基于`pytorch`实现)\n",
+    "\n",
+    "本示例仍使用了`tutorial-E1`的`SST-2`数据集、`distilbert-base-uncased`模型(便于比较\n",
+    "\n",
+    "  使用`pytorch`框架,通过将连续的`prompt`与`model`拼接,解决`SST-2`二分类任务"
    ]
   },
   {
@@ -98,7 +132,7 @@
     "print(transformers.__version__)\n",
     "\n",
     "task = 'sst2'\n",
-    "model_checkpoint = 'bert-base-uncased'"
+    "model_checkpoint = 'distilbert-base-uncased'  # 'bert-base-uncased'"
    ]
   },
   {
@@ -111,20 +145,32 @@
     "\n",
     "    以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v0.8`的代码实践\n",
     "\n",
-    "`P-Tuning v2`出自论文 [Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n",
+    "**`P-Tuning v2`**出自论文[Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n",
+    "\n",
+    "  其主要贡献在于,**在`PrefixTuning`等深度提示学习基础上**,**提升了其在分类标注等`NLU`任务的表现**\n",
+    "\n",
+    "    并使之在中等规模模型,主要是**参数量在`100M-1B`区间的模型上**,**获得与全参数微调相同的效果**\n",
+    "\n",
+    "  其结构如图所示,通过**在输入序列的分类符`[CLS]`之前**,**加入前缀序列**(**序号对应嵌入是待训练的连续值向量**\n",
     "\n",
-    "  其主要贡献在于,在`PrefixTuning`等深度提示学习基础上,提升了其在分类标注等`NLU`任务的表现\n",
+    "    **刺激模型在新任务下**,从`[CLS]`对应位置,**输出符合微调任务的输出**,从而达到适应微调任务的目的\n",
     "\n",
-    "    并使之在中等规模模型,主要是参数量在`100M-1B`区间的模型上,获得与全参数微调相同的效果\n",
+    "\n",
     "\n",
-    "  其结构如图所示,\n",
+    "本示例使用`bert-base-uncased`模型,作为`P-Tuning v2`的基础`backbone`,设置`requires_grad=False`\n",
     "\n",
-    ""
+    "    固定其参数不参与训练,**设置`pre_seq_len`长的`prefix_tokens`作为输入的提示前缀序列**\n",
+    "\n",
+    "  **使用基于`nn.Embedding`的`prefix_encoder`为提示前缀嵌入**,通过`get_prompt`函数获取,再将之\n",
+    "\n",
+    "    拼接至批量内每笔数据前得到`inputs_embeds`,同时更新自注意力掩模`attention_mask`\n",
+    "\n",
+    "  将`inputs_embeds`、`attention_mask`和`labels`输入`backbone`,**得到输出包括`loss`和`logits`**"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -178,24 +224,24 @@
    "source": [
     "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n",
     "\n",
-    "  根据`P-Tuning v2`论文:*Generally, simple classification tasks prefer shorter prompts (less than 20)*\n",
+    "  根据`P-Tuning v2`论文:*`Generally, simple classification tasks prefer shorter prompts (less than 20)`*\n",
     "\n",
     "  此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
-      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
-      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
-      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
+      "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n",
+      "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+      "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+      "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n",
       "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
      ]
     }
@@ -212,7 +258,7 @@
    "source": [
     "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n",
     "\n",
-    "  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST2`数据集\n",
+    "  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n",
     "\n",
     "    以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n",
     "\n",
@@ -225,7 +271,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 5,
+   "execution_count": 4,
    "metadata": {
     "scrolled": false
    },
@@ -240,7 +286,7 @@
     {
      "data": {
       "application/vnd.jupyter.widget-view+json": {
-       "model_id": "b72eeebd34354a88a99b2e07ec9a86df",
+       "model_id": "21cbd92c3397497d84dc10f017ec96f4",
        "version_major": 2,
        "version_minor": 0
       },
@@ -262,30 +308,17 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [
     {
      "name": "stderr",
      "output_type": "stream",
      "text": [
-      "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-18ec0e709f05e61e.arrow\n",
-      "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-e2f02ee7442ad73e.arrow\n"
+      "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n",
+      "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n",
+      "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f44c5576b89f9e6b.arrow\n"
      ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "d15505d825b34f649b719f1ff0d56114",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "  0%|          | 0/2 [00:00[22:53:00] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "\n"
+      ],
+      "text/plain": [
+       "\u001b[2;36m[22:53:00]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=406635;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=951504;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.540625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 173.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.540625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m173.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.5,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 160.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.5\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.509375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 163.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.509375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m163.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.634375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 203.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.634375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m203.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.6125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 196.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.6125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m196.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 216.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m216.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.64375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 206.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.64375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m206.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.665625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 213.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.665625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m213.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.659375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 211.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.659375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m211.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.696875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 223.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.696875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m223.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "trainer.run(num_eval_batch_per_dl=10)" ] @@ -421,14 +988,55 @@ "cell_type": "markdown", "metadata": {}, "source": [ - " " + "可以发现,其效果远远逊色于`fine-tuning`,这是因为`P-Tuning v2`虽然能够适应参数量\n", + "\n", + "  在`100M-1B`区间的模型,但是,**`distilbert-base`的参数量仅为`66M`**,无法触及其下限\n", + "\n", + "另一方面,**`fastNLP v0.8`不支持`jupyter`多卡**,所以无法在笔者的电脑/服务器上,完成\n", + "\n", + "  合适规模模型的学习,例如`110M`的`bert-base`模型,以及`340M`的`bert-large`模型" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.737385, 'total#acc': 872.0, 'correct#acc': 643.0}"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "trainer.evaluator.run()"
    ]
diff --git a/tutorials/figures/E2-fig-pet-model.png b/tutorials/figures/E2-fig-pet-model.png
new file mode 100644
index 00000000..c3c377c0
Binary files /dev/null and b/tutorials/figures/E2-fig-pet-model.png differ