Browse Source

完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏

tags/v0.4.10
yh_cc 6 years ago
parent
commit
6e265e5ae9
13 changed files with 373 additions and 351 deletions
  1. +2
    -2
      fastNLP/automl/enas_trainer.py
  2. +115
    -98
      fastNLP/core/callback.py
  3. +3
    -3
      fastNLP/core/dataset.py
  4. +35
    -57
      fastNLP/core/fieldarray.py
  5. +14
    -15
      fastNLP/core/instance.py
  6. +15
    -15
      fastNLP/core/losses.py
  7. +25
    -25
      fastNLP/core/metrics.py
  8. +6
    -6
      fastNLP/core/tester.py
  9. +34
    -57
      fastNLP/core/trainer.py
  10. +107
    -56
      fastNLP/core/utils.py
  11. +13
    -13
      fastNLP/core/vocabulary.py
  12. +3
    -3
      fastNLP/models/enas_trainer.py
  13. +1
    -1
      test/core/test_tester.py

+ 2
- 2
fastNLP/automl/enas_trainer.py View File

@@ -11,7 +11,7 @@ import torch
try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import pseudo_tqdm as tqdm
from fastNLP.core.utils import _pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackException
@@ -115,7 +115,7 @@ class ENASTrainer(fastNLP.Trainer):

def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0


+ 115
- 98
fastNLP/core/callback.py View File

@@ -2,21 +2,21 @@

.. _Callback:

"""
Callback是fastNLP中被设计用于增强 Trainer_ 的类。如果Callback被传递给了 Trainer_ , 则 Trainer_ 会在对应的阶段调用Callback
的函数,具体调用时机可以通过 Trainer_ 查看。

"""

import os

import torch
from tensorboardX import SummaryWriter

from fastNLP.io.model_io import ModelSaver, ModelLoader

try:
from tensorboardX import SummaryWriter
except:
pass

class Callback(object):
"""An Interface for all callbacks.

Any customized callback should implement at least one of the following methods.
"""这是Callback的基类,所有的callback必须继承自这个类。

"""

@@ -26,93 +26,150 @@ class Callback(object):

@property
def trainer(self):
"""
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。
:return:
"""
return self._trainer

@property
def step(self):
"""current step number, in range(1, self.n_steps+1)"""
"""当前运行到的step, 范围为[1, self.n_steps+1)"""
return self._trainer.step

@property
def n_steps(self):
"""total number of steps for training"""
"""Trainer一共会运行多少步"""
return self._trainer.n_steps

@property
def batch_size(self):
"""batch size for training"""
"""train和evaluate时的batch_size为多大"""
return self._trainer.batch_size

@property
def epoch(self):
"""current epoch number, in range(1, self.n_epochs+1)"""
"""当前运行的epoch数,范围是[1, self.n_epochs+1)"""
return self._trainer.epoch

@property
def n_epochs(self):
"""total number of epochs"""
"""一共会运行多少个epoch"""
return self._trainer.n_epochs

@property
def optimizer(self):
"""torch.optim.Optimizer for current model"""
"""初始化Trainer时传递的Optimizer"""
return self._trainer.optimizer

@property
def model(self):
"""training model"""
"""正在被Trainer训练的模型"""
return self._trainer.model

@property
def pbar(self):
"""If use_tqdm, return trainer's tqdm print bar, else return None."""
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。"""
return self._trainer.pbar

@property
def update_every(self):
"""The model in trainer will update parameters every `update_every` batches."""
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。"""
return self._trainer.update_every

@property
def batch_per_epoch(self):
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。"""
return self._trainer.batch_per_epoch

def on_train_begin(self):
# before the main training loop
"""
在Train过程开始之前调用。

:return:
"""
pass

def on_epoch_begin(self):
# at the beginning of each epoch
"""
在每个epoch开始之前调用一次

:return:
"""
pass

def on_batch_begin(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch
"""
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步
可以进行一些负采样之类的操作

:param dict batch_x: DataSet中被设置为input的field的batch。
:param dict batch_y: DataSet中被设置为target的field的batch。
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些
情况下可以帮助定位是哪个Sample导致了错误。
:return:
"""
pass

def on_loss_begin(self, batch_y, predict_y):
# after data_forward, and before loss computation
"""
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。

:param dict batch_y: 在DataSet中被设置为target的field的batch集合。
:param dict predict_y: 模型的forward()返回的结果。
:return:
"""
pass

def on_backward_begin(self, loss):
# after loss computation, and before gradient backward
"""
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。

:param torch.Tensor loss: 计算得到的loss值
:return:
"""
pass

def on_backward_end(self):
"""
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。

:return:
"""
pass

def on_step_end(self):
"""
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。

:return:
"""
pass

def on_batch_end(self, *args):
# at the end of each step/mini-batch
def on_batch_end(self):
"""
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。

"""
pass

def on_valid_begin(self):
"""
如果Trainer中设置了验证,则发生验证前会调用该函数

:return:
"""
pass

def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
"""
每次执行验证机的evaluation后会调用。传入eval_result
每次执行验证集的evaluation后会调用。

:param eval_result: Dict[str: Dict[str: float]], evaluation的结果
:param metric_key: str
:param optimizer: optimizer passed to trainer
:param is_better_eval: bool, 当前dev结果是否比之前的好
:param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即
传入的dict是有两层,第一层是metric的名称,第二层是metric的具体指标。
:param str metric_key: 初始化Trainer时传入的metric_key。
:param torch.Optimizer optimizer: Trainer中使用的优化器。
:param bool is_better_eval: 当前dev结果是否比之前的好。
:return:
"""
pass
@@ -137,7 +194,7 @@ class Callback(object):
pass


def transfer(func):
def _transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类.
:param func:
:return:
@@ -153,9 +210,7 @@ def transfer(func):


class CallbackManager(Callback):
"""A manager for all callbacks passed into Trainer.
It collects resources inside Trainer and raise callbacks.

"""内部使用的Callback管理类
"""

def __init__(self, env, callbacks=None):
@@ -182,104 +237,70 @@ class CallbackManager(Callback):
for callback in self.callbacks:
setattr(callback, '_'+env_name, env_val) # Callback.trainer

@transfer
@_transfer
def on_train_begin(self):
pass

@transfer
@_transfer
def on_epoch_begin(self):
pass

@transfer
@_transfer
def on_batch_begin(self, batch_x, batch_y, indices):
pass

@transfer
@_transfer
def on_loss_begin(self, batch_y, predict_y):
pass

@transfer
@_transfer
def on_backward_begin(self, loss):
pass

@transfer
@_transfer
def on_backward_end(self):
pass

@transfer
@_transfer
def on_step_end(self):
pass

@transfer
@_transfer
def on_batch_end(self):
pass

@transfer
@_transfer
def on_valid_begin(self):
pass

@transfer
@_transfer
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
pass

@transfer
@_transfer
def on_epoch_end(self):
pass

@transfer
@_transfer
def on_train_end(self):
pass

@transfer
@_transfer
def on_exception(self, exception):
pass


class DummyCallback(Callback):
def on_train_begin(self, *arg):
print(arg)

def on_epoch_end(self):
print(self.epoch, self.n_epochs)


class EchoCallback(Callback):
def on_train_begin(self):
print("before_train")

def on_epoch_begin(self):
print("before_epoch")

def on_batch_begin(self, batch_x, batch_y, indices):
print("before_batch")

def on_loss_begin(self, batch_y, predict_y):
print("before_loss")

def on_backward_begin(self, loss):
print("before_backward")

def on_batch_end(self):
print("after_batch")

def on_epoch_end(self):
print("after_epoch")

def on_train_end(self):
print("after_train")


class GradientClipCallback(Callback):
def __init__(self, parameters=None, clip_value=1, clip_type='norm'):
"""每次backward前,将parameter的gradient clip到某个范围。

:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer
的model中所有参数进行clip
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
:param clip_type: str, 支持'norm', 'value'两种。
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于
clip_value的gradient被赋值为clip_value.
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
:param str clip_type: 支持'norm', 'value'两种。
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
2. 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于
clip_value的gradient被赋值为clip_value.
"""
super().__init__()

@@ -314,7 +335,7 @@ class EarlyStopCallback(Callback):
def __init__(self, patience):
"""

:param int patience: 停止之前等待的epoch数
:param int patience: 多少个epoch没有变好就停止训练
"""
super(EarlyStopCallback, self).__init__()
self.patience = patience
@@ -341,7 +362,7 @@ class LRScheduler(Callback):
def __init__(self, lr_scheduler):
"""对PyTorch LR Scheduler的包装

:param lr_scheduler: PyTorch的lr_scheduler
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler
"""
super(LRScheduler, self).__init__()
import torch.optim
@@ -358,7 +379,7 @@ class ControlC(Callback):
def __init__(self, quit_all):
"""

:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
"""
super(ControlC, self).__init__()
if type(quit_all) != bool:
@@ -389,16 +410,16 @@ class SmoothValue(object):


class LRFinder(Callback):
def __init__(self, n_batch, start_lr=1e-6, end_lr=10):
def __init__(self, start_lr=1e-6, end_lr=10):
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它

:param n_batch: 一个epoch内的iteration数
:param start_lr: 学习率下界
:param end_lr: 学习率上界
:param int n_batch: 一个epoch内的iteration数
:param float start_lr: 学习率下界
:param float end_lr: 学习率上界
"""
super(LRFinder, self).__init__()
self.start_lr, self.end_lr = start_lr, end_lr
self.num_it = n_batch
self.num_it = self.batch_per_epoch
self.stop = False
self.best_loss = 0.
self.best_lr = None
@@ -514,7 +535,3 @@ class TensorboardCallback(Callback):
del self._summary_writer


if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.on_train_begin()
# print(manager.after_epoch())

+ 3
- 3
fastNLP/core/dataset.py View File

@@ -277,7 +277,7 @@ import warnings
from fastNLP.core.fieldarray import AutoPadder
from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature

class DataSet(object):
"""fastNLP的数据容器
@@ -642,7 +642,7 @@ class DataSet(object):
print("Exception happens at the `{}`th instance.".format(idx))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))

if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)
@@ -707,7 +707,7 @@ class DataSet(object):
raise e
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))

if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)


+ 35
- 57
fastNLP/core/fieldarray.py View File

@@ -1,5 +1,5 @@
"""
FieldArray是 DataSet_ 中一列的存储方式
FieldArray是 DataSet_ 中一列的存储方式,原理部分请参考 DataSet_ 处

.. _FieldArray:

@@ -11,41 +11,19 @@ from copy import deepcopy


class FieldArray(object):
"""``FieldArray`` is the collection of ``Instance``s of the same field.
It is the basic element of ``DataSet`` class.

:param str name: the name of the FieldArray
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input.
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型
则会进行padding;(3)其它情况不进行padder。
假设需要对English word中character进行padding,则需要使用其他的padder。
或ignore_type为True但是需要进行padding。
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray.
(default: False)
"""

def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
"""DataSet在初始化时会有两类方法对FieldArray操作:
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]})
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
然后后面的样本使用FieldArray.append进行添加。
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])

类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。
ignore_type用来控制是否进行类型检查,如果为True,则不检查。

"""FieldArray是用于保存 DataSet_ 中一个field的实体。

:param str name: FieldArray的名称
:param list,numpy.ndarray content: 列表的元素可以为list,int,float,
:param bool is_target: 这个field是否是一个target field。
:param bool is_input: 这个field是否是一个input field。
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。默认为None,即使用 AutoPadder_ 。
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, 就
可以设置为True。具体意义请参考 DataSet_ 。
"""

self.name = name
if isinstance(content, list):
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list
@@ -211,10 +189,10 @@ class FieldArray(object):
return "FieldArray {}: {}".format(self.name, self.content.__repr__())

def append(self, val):
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有
的内容是匹配的。

:param val: Any.
:param Any val: 需要append的值。
"""
if self.ignore_type is False:
if isinstance(val, list):
@@ -262,8 +240,8 @@ class FieldArray(object):
def get(self, indices, pad=True):
"""根据给定的indices返回内容

:param indices: (int, List[int]), 获取indices对应的内容。
:param pad: bool, 是否对返回的结果进行padding。仅对indices为List[int]时有效
:param int,list(int) indices:, 获取indices对应的内容。
:param bool pad: , 是否对返回的结果进行padding。仅对indices为List[int]时有效
:return: (single, List)
"""
if isinstance(indices, int):
@@ -281,7 +259,7 @@ class FieldArray(object):
"""
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。

:param padder: (None, Padder). 设置为None即删除padder.
:param None,Padder padder:. 设置为None即删除padder。
:return:
"""
if padder is not None:
@@ -293,7 +271,7 @@ class FieldArray(object):
def set_pad_val(self, pad_val):
"""修改padder的pad_val.

:param pad_val: int。将该field的pad值设置为该值
:param int pad_val: 该field的pad值设置为该值
:return:
"""
if self.padder is not None:
@@ -312,8 +290,8 @@ class FieldArray(object):
"""
将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type

:param other: FieldArray
:return:
:param FieldArray other: 从哪个field拷贝属性
:return: FieldArray
"""
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other))

@@ -324,7 +302,7 @@ class FieldArray(object):

return self

def is_iterable(content):
def _is_iterable(content):
try:
_ = (e for e in content)
except TypeError:
@@ -350,11 +328,10 @@ class Padder:
"""
传入的是List内容。假设有以下的DataSet。

:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
deepcopy一份。
:param field_name: str, field的名称。
:param field_ele_dtype: (np.int64, np.float64, np.str, None), 该field的内层元素的类型。如果该field的ignore_type
为True,该这个值为None。
:param str, field_name: field的名称。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
:return: np.array([padded_element])

Example::
@@ -400,10 +377,10 @@ class AutoPadder(Padder):

2 如果元素类型为(np.int64, np.float64),

2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding

2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
如果某个instance中field为[1, 2, 3],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
"""

def __init__(self, pad_val=0):
@@ -427,7 +404,7 @@ class AutoPadder(Padder):
return False

def __call__(self, contents, field_name, field_ele_dtype):
if not is_iterable(contents[0]):
if not _is_iterable(contents[0]):
array = np.array([content for content in contents], dtype=field_ele_dtype)
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents):
max_len = max([len(content) for content in contents])
@@ -454,7 +431,7 @@ class EngChar2DPadder(Padder):
Example::

from fastNLP import DataSet
from fastNLP import EnChar2DPadder
from fastNLP import EngChar2DPadder
from fastNLP import Vocabulary
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']})
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars')
@@ -462,14 +439,15 @@ class EngChar2DPadder(Padder):
vocab.from_dataset(dataset, field_name='chars')
vocab.index_dataset(dataset, field_name='chars')
dataset.set_input('chars')
padder = EnChar2DPadder()
padder = EngChar2DPadder()
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder

"""
def __init__(self, pad_val=0, pad_length=0):
"""
:param pad_val: int, pad的位置使用该index
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度都pad或截
取到该长度.
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度
都pad或截取到该长度.
"""
super().__init__(pad_val=pad_val)

@@ -494,7 +472,7 @@ class EngChar2DPadder(Padder):
except:
raise ValueError("Field:{} only has two dimensions.".format(field_name))

if is_iterable(value):
if _is_iterable(value):
raise ValueError("Field:{} has more than 3 dimension.".format(field_name))

def __call__(self, contents, field_name, field_ele_dtype):


+ 14
- 15
fastNLP/core/instance.py View File

@@ -3,34 +3,33 @@ Instance文档

.. _Instance:

测试
Instance是fastNLP中对应于一个sample的类。一个sample可以认为是fastNLP中的一个Instance对象。一个具像化的表示类似与 DataSet_
出那个表中所展示的一行。

"""



class Instance(object):
"""An Instance is an example of data.
Example::
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
ins["field_1"]
>>[1, 1, 1]
ins.add_field("field_3", [3, 3, 3])
"""

def __init__(self, **fields):
"""
"""Instance的初始化如下面的Example所示

Example::

ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
ins["field_1"]
>>[1, 1, 1]
ins.add_field("field_3", [3, 3, 3])

:param fields: 可能是一维或者二维的 list or np.array
ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))})
"""
self.fields = fields

def add_field(self, field_name, field):
"""Add a new field to the instance.
"""向Instance中增加一个field

:param field_name: str, the name of the field.
:param str field_name: 新增field的名称
:param Any field: 新增field的内容
"""
self.fields[field_name] = field



+ 15
- 15
fastNLP/core/losses.py View File

@@ -12,12 +12,12 @@ from collections import defaultdict
import torch
import torch.nn.functional as F

from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature


class LossBase(object):
@@ -70,7 +70,7 @@ class LossBase(object):
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the "
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the "
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
@@ -111,7 +111,7 @@ class LossBase(object):
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.")
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.")

# 2. only part of the param_map are passed, left are not
for arg in func_args:
@@ -151,16 +151,16 @@ class LossBase(object):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
check_res = _CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.get_loss))
raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.get_loss))
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)

loss = self.get_loss(**refined_args)
@@ -289,14 +289,14 @@ class LossInForward(LossBase):

def get_loss(self, **kwargs):
if self.loss_key not in kwargs:
check_res = CheckRes(
check_res = _CheckRes(
missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"],
unused=[],
duplicated=[],
required=[],
all_needed=[],
varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))
raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss))
return kwargs[self.loss_key]

def __call__(self, pred_dict, target_dict, check=False):


+ 25
- 25
fastNLP/core/metrics.py View File

@@ -13,11 +13,11 @@ from collections import defaultdict
import numpy as np
import torch

from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.vocabulary import Vocabulary

@@ -161,7 +161,7 @@ class MetricBase(object):
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.")

def _fast_param_map(self, pred_dict, target_dict):
@@ -207,7 +207,7 @@ class MetricBase(object):
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.")

# 2. only part of the param_map are passed, left are not
for arg in func_args:
@@ -248,16 +248,16 @@ class MetricBase(object):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
check_res = _CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)

self.evaluate(**refined_args)
@@ -294,14 +294,14 @@ class AccuracyMetric(MetricBase):
"""
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")

if seq_len is not None:
@@ -314,7 +314,7 @@ class AccuracyMetric(MetricBase):
elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

@@ -516,14 +516,14 @@ class SpanFPreRecMetric(MetricBase):
:return:
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")

if pred.size() == target.size() and len(target.size()) == 2:
@@ -535,7 +535,7 @@ class SpanFPreRecMetric(MetricBase):
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

@@ -714,14 +714,14 @@ class BMESF1PreRecMetric(MetricBase):
:return:
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")

if pred.size() == target.size() and len(target.size()) == 2:
@@ -729,7 +729,7 @@ class BMESF1PreRecMetric(MetricBase):
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")



+ 6
- 6
fastNLP/core/tester.py View File

@@ -5,11 +5,11 @@ from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import _get_device


@@ -75,19 +75,19 @@ class Tester(object):
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics:
metric(pred_dict, batch_y)
for metric in self.metrics:
eval_result = metric.get_metric()
if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be "
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}")
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func)
except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0)


+ 34
- 57
fastNLP/core/trainer.py View File

@@ -85,17 +85,17 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_
使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。

尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了
LossInForward_ 这个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,
并使用它作为loss。 如果Trainer初始化没有提供loss则使用这个loss TODO 补充一个例子
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 LossInForward_ 这
个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,
并使用它作为loss。 如果Trainer初始化没有提供loss则默认使用 LossInForward_ 。详细例子可以参照 TODO 补充一个例子

3. Metric

Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。

在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中的input的选择
出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中被设置为input
的field中选择出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子

2. Trainer的代码检查

@@ -112,12 +112,12 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
from fastNLP import DataSet

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x, b):
loss = torch.mean((self.fc(x)-b)**2)
return {'loss': loss}
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x, b):
loss = torch.mean((self.fc(x)-b)**2)
return {'loss': loss}
model = Model()

dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
@@ -138,16 +138,15 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
# unused field: ['a']
# Suggestion: You need to provide ['x'] in DataSet and set it as input.

这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里两类
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里两类
信息可以为你提供参考

1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。

2. 如果出现了映射错误,出现NameError。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
出当前是在调取forward出错),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x'
,或者model的参数从'x'修改为'a'都可以解决问题。
2. NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。

下面的例子是由于loss计算的时候找不到需要的值

@@ -249,18 +248,18 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).

报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output').
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')

可以通过check_code_level调节检查的强度。默认为0,即进行检查。

3. Trainer与callback

虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的,所有的 Callback_ 都具有
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的函数集合,所有的 Callback_ 都具有
on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。

我们将Train.train()这个函数内部分为以下的阶段
我们将Train.train()这个函数内部分为以下的阶段,在对应阶段会触发相应的调用。

Example::

@@ -305,7 +304,7 @@ import warnings
try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import pseudo_tqdm as tqdm
from fastNLP.core.utils import _pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException
@@ -316,12 +315,12 @@ from fastNLP.core.sampler import Sampler
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import _get_device


@@ -466,34 +465,11 @@ class Trainer(object):

def train(self, load_best_model=True):
"""

开始训练过程。主要有以下几个步骤::

for epoch in range(num_epochs):
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。
非float,int类型的参数将不会被转换为Tensor,且不进行padding。
for batch_x, batch_y in Batch(DataSet)
# batch_x是一个dict, 被设为input的field会出现在这个dict中,
key为DataSet中的field_name, value为该field的value
# batch_y也是一个dict,被设为target的field会出现在这个dict中,
key为DataSet中的field_name, value为该field的value
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形
参完成参数传递。例如,
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数
而且默认参数这个key没有在batch_x中,则使用默认参数。
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target
, 并完成loss的计算。
4. 获取到loss之后,进行反向求导并更新梯度
根据需要适时进行验证机测试
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型
使用该函数使Trainer开始训练。

:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。
:return results: 返回一个字典类型的数据,
:return dict: 返回一个字典类型的数据,
内含以下内容::

seconds: float, 表示训练时长
@@ -547,7 +523,7 @@ class Trainer(object):

def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0
@@ -559,6 +535,7 @@ class Trainer(object):
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
self.batch_per_epoch = data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
@@ -660,7 +637,7 @@ class Trainer(object):
x = _build_args(network.forward, **x)
y = network(**x)
if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
raise TypeError(f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y

def _grad_backward(self, loss):
@@ -796,7 +773,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_

refined_batch_x = _build_args(model.forward, **batch_x)
pred_dict = model(**refined_batch_x)
func_signature = get_func_signature(model.forward)
func_signature = _get_func_signature(model.forward)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")

@@ -807,16 +784,16 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"The return value of {_get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size()) != 0:
raise ValueError(
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
f"The size of return value of {_get_func_signature(losser.get_loss)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
except CheckError as e:
# TODO: another error raised if CheckError caught
pre_func_signature = get_func_signature(model.forward)
except _CheckError as e:
# TODO: another error raised if _CheckError caught
pre_func_signature = _get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=dataset, check_level=check_level)


+ 107
- 56
fastNLP/core/utils.py View File

@@ -9,7 +9,7 @@ import numpy as np
import torch
from torch import nn

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])

def _prepare_cache_filepath(filepath):
@@ -28,6 +28,57 @@ def _prepare_cache_filepath(filepath):

# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。
def cache_results(_cache_fp, _refresh=False, _verbose=1):
"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用

Example::

import time
import numpy as np
from fastNLP import cache_results

@cache_results('cache.pkl')
def process_data():
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
time.sleep(1)
return np.random.randint(5, size=(10, 20))

start_time = time.time()
process_data()
print(time.time() - start_time)

start_time = time.time()
process_data()
print(time.time() - start_time)

# 输出内容如下
# Save cache to cache.pkl.
# 1.0015439987182617
# Read cache from cache.pkl.
# 0.00013065338134765625

可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理

Example::

# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'

上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称。

Example::

process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache

:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
:param int _verbose: 是否打印cache的信息。
:return:
"""
def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
@@ -74,48 +125,48 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
return wrapper
return wrapper_

def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.
:param obj: an object
:param pickle_path: str, the directory where the pickle file is to be saved
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
"""
if not os.path.exists(pickle_path):
os.mkdir(pickle_path)
print("make dir {} before saving pickle file".format(pickle_path))
with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f)
print("{} saved in {}".format(file_name, pickle_path))
def load_pickle(pickle_path, file_name):
"""Load an object from a given pickle file.
:param pickle_path: str, the directory where the pickle file is.
:param file_name: str, the name of the pickle file.
:return obj: an object stored in the pickle
"""
with open(os.path.join(pickle_path, file_name), "rb") as f:
obj = _pickle.load(f)
print("{} loaded from {}".format(file_name, pickle_path))
return obj
def pickle_exist(pickle_path, pickle_name):
"""Check if a given pickle file exists in the directory.
:param pickle_path: the directory of target pickle file
:param pickle_name: the filename of target pickle file
:return: True if file exists else False
"""
if not os.path.exists(pickle_path):
os.makedirs(pickle_path)
file_name = os.path.join(pickle_path, pickle_name)
if os.path.exists(file_name):
return True
else:
return False
# def save_pickle(obj, pickle_path, file_name):
# """Save an object into a pickle file.
#
# :param obj: an object
# :param pickle_path: str, the directory where the pickle file is to be saved
# :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
# """
# if not os.path.exists(pickle_path):
# os.mkdir(pickle_path)
# print("make dir {} before saving pickle file".format(pickle_path))
# with open(os.path.join(pickle_path, file_name), "wb") as f:
# _pickle.dump(obj, f)
# print("{} saved in {}".format(file_name, pickle_path))
#
#
# def load_pickle(pickle_path, file_name):
# """Load an object from a given pickle file.
#
# :param pickle_path: str, the directory where the pickle file is.
# :param file_name: str, the name of the pickle file.
# :return obj: an object stored in the pickle
# """
# with open(os.path.join(pickle_path, file_name), "rb") as f:
# obj = _pickle.load(f)
# print("{} loaded from {}".format(file_name, pickle_path))
# return obj
#
#
# def pickle_exist(pickle_path, pickle_name):
# """Check if a given pickle file exists in the directory.
#
# :param pickle_path: the directory of target pickle file
# :param pickle_name: the filename of target pickle file
# :return: True if file exists else False
# """
# if not os.path.exists(pickle_path):
# os.makedirs(pickle_path)
# file_name = os.path.join(pickle_path, pickle_name)
# if os.path.exists(file_name):
# return True
# else:
# return False

def _get_device(device, check_exist=False):
"""
@@ -232,15 +283,15 @@ def _check_arg_dict_list(func, args):
missing = list(require_args - input_args)
unused = list(input_args - all_args)
varargs = [] if not spect.varargs else [spect.varargs]
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args),
varargs=varargs)
return _CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args),
varargs=varargs)


def get_func_signature(func):
def _get_func_signature(func):
"""

Given a function or method, return its signature.
@@ -318,13 +369,13 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
raise TypeError("Only support `dict` type right now.")


class CheckError(Exception):
class _CheckError(Exception):
"""

CheckError. Used in losses.LossBase, metrics.MetricBase.
_CheckError. Used in losses.LossBase, metrics.MetricBase.
"""

def __init__(self, check_res: CheckRes, func_signature: str):
def __init__(self, check_res: _CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`']

if check_res.varargs:
@@ -347,7 +398,7 @@ WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2


def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: _CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = []
unuseds = []
@@ -449,7 +500,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re

def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)
func_signature = _get_func_signature(forward_func)

errs = []
suggestions = []
@@ -543,7 +594,7 @@ def seq_mask(seq_len, max_len):
return torch.gt(seq_len, seq_range) # [batch_size, max_len]


class pseudo_tqdm:
class _pseudo_tqdm:
"""
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
"""


+ 13
- 13
fastNLP/core/vocabulary.py View File

@@ -2,7 +2,7 @@ from functools import wraps
from collections import Counter
from fastNLP.core.dataset import DataSet

def check_build_vocab(func):
def _check_build_vocab(func):
"""A decorator to make sure the indexing is built before used.

"""
@@ -15,7 +15,7 @@ def check_build_vocab(func):
return _wrapper


def check_build_status(func):
def _check_build_status(func):
"""A decorator to check whether the vocabulary updates after the last build.

"""
@@ -67,7 +67,7 @@ class Vocabulary(object):
self.idx2word = None
self.rebuild = True

@check_build_status
@_check_build_status
def update(self, word_lst):
"""依次增加序列中词在词典中的出现频率

@@ -75,7 +75,7 @@ class Vocabulary(object):
"""
self.word_count.update(word_lst)

@check_build_status
@_check_build_status
def add(self, word):
"""
增加一个新词在词典中的出现频率
@@ -84,7 +84,7 @@ class Vocabulary(object):
"""
self.word_count[word] += 1

@check_build_status
@_check_build_status
def add_word(self, word):
"""
增加一个新词在词典中的出现频率
@@ -93,7 +93,7 @@ class Vocabulary(object):
"""
self.add(word)

@check_build_status
@_check_build_status
def add_word_lst(self, word_lst):
"""
依次增加序列中词在词典中的出现频率
@@ -132,11 +132,11 @@ class Vocabulary(object):
"""
self.idx2word = {i: w for w, i in self.word2idx.items()}

@check_build_vocab
@_check_build_vocab
def __len__(self):
return len(self.word2idx)

@check_build_vocab
@_check_build_vocab
def __contains__(self, item):
"""
检查词是否被记录
@@ -161,7 +161,7 @@ class Vocabulary(object):
"""
return self.__contains__(w)

@check_build_vocab
@_check_build_vocab
def __getitem__(self, w):
"""
To support usage like::
@@ -175,7 +175,7 @@ class Vocabulary(object):
else:
raise ValueError("word {} not in vocabulary".format(w))

@check_build_vocab
@_check_build_vocab
def index_dataset(self, *datasets, field_name, new_field_name=None):
"""
将DataSet中对应field的词转为数字.
@@ -275,7 +275,7 @@ class Vocabulary(object):
return self.__getitem__(w)

@property
@check_build_vocab
@_check_build_vocab
def unknown_idx(self):
"""
unknown 对应的数字.
@@ -285,7 +285,7 @@ class Vocabulary(object):
return self.word2idx[self.unknown]

@property
@check_build_vocab
@_check_build_vocab
def padding_idx(self):
"""
padding 对应的数字
@@ -294,7 +294,7 @@ class Vocabulary(object):
return None
return self.word2idx[self.padding]

@check_build_vocab
@_check_build_vocab
def to_word(self, idx):
"""
给定一个数字, 将其转为对应的词.


+ 3
- 3
fastNLP/models/enas_trainer.py View File

@@ -13,12 +13,12 @@ from torch import nn
try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import pseudo_tqdm as tqdm
from fastNLP.core.utils import _pseudo_tqdm as tqdm

from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP
import fastNLP.models.enas_utils as utils
@@ -118,7 +118,7 @@ class ENASTrainer(fastNLP.Trainer):

def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0


+ 1
- 1
test/core/test_tester.py View File

@@ -8,7 +8,7 @@ import numpy as np
import torch.nn.functional as F
from torch import nn
import time
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss


Loading…
Cancel
Save