Browse Source

完善了trainer,callback等的文档; 修改了部分代码的命名以使得代码从文档中隐藏

tags/v0.4.10
yh_cc 6 years ago
parent
commit
6e265e5ae9
13 changed files with 373 additions and 351 deletions
  1. +2
    -2
      fastNLP/automl/enas_trainer.py
  2. +115
    -98
      fastNLP/core/callback.py
  3. +3
    -3
      fastNLP/core/dataset.py
  4. +35
    -57
      fastNLP/core/fieldarray.py
  5. +14
    -15
      fastNLP/core/instance.py
  6. +15
    -15
      fastNLP/core/losses.py
  7. +25
    -25
      fastNLP/core/metrics.py
  8. +6
    -6
      fastNLP/core/tester.py
  9. +34
    -57
      fastNLP/core/trainer.py
  10. +107
    -56
      fastNLP/core/utils.py
  11. +13
    -13
      fastNLP/core/vocabulary.py
  12. +3
    -3
      fastNLP/models/enas_trainer.py
  13. +1
    -1
      test/core/test_tester.py

+ 2
- 2
fastNLP/automl/enas_trainer.py View File

@@ -11,7 +11,7 @@ import torch
try: try:
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
except: except:
from fastNLP.core.utils import pseudo_tqdm as tqdm
from fastNLP.core.utils import _pseudo_tqdm as tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackException from fastNLP.core.callback import CallbackException
@@ -115,7 +115,7 @@ class ENASTrainer(fastNLP.Trainer):


def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
else: else:
inner_tqdm = tqdm inner_tqdm = tqdm
self.step = 0 self.step = 0


+ 115
- 98
fastNLP/core/callback.py View File

@@ -2,21 +2,21 @@


.. _Callback: .. _Callback:


"""
Callback是fastNLP中被设计用于增强 Trainer_ 的类。如果Callback被传递给了 Trainer_ , 则 Trainer_ 会在对应的阶段调用Callback
的函数,具体调用时机可以通过 Trainer_ 查看。


"""


import os import os

import torch import torch
from tensorboardX import SummaryWriter

from fastNLP.io.model_io import ModelSaver, ModelLoader from fastNLP.io.model_io import ModelSaver, ModelLoader

try:
from tensorboardX import SummaryWriter
except:
pass


class Callback(object): class Callback(object):
"""An Interface for all callbacks.

Any customized callback should implement at least one of the following methods.
"""这是Callback的基类,所有的callback必须继承自这个类。


""" """


@@ -26,93 +26,150 @@ class Callback(object):


@property @property
def trainer(self): def trainer(self):
"""
该属性可以通过self.trainer获取到,一般情况下不需要使用这个属性。
:return:
"""
return self._trainer return self._trainer


@property @property
def step(self): def step(self):
"""current step number, in range(1, self.n_steps+1)"""
"""当前运行到的step, 范围为[1, self.n_steps+1)"""
return self._trainer.step return self._trainer.step


@property @property
def n_steps(self): def n_steps(self):
"""total number of steps for training"""
"""Trainer一共会运行多少步"""
return self._trainer.n_steps return self._trainer.n_steps


@property @property
def batch_size(self): def batch_size(self):
"""batch size for training"""
"""train和evaluate时的batch_size为多大"""
return self._trainer.batch_size return self._trainer.batch_size


@property @property
def epoch(self): def epoch(self):
"""current epoch number, in range(1, self.n_epochs+1)"""
"""当前运行的epoch数,范围是[1, self.n_epochs+1)"""
return self._trainer.epoch return self._trainer.epoch


@property @property
def n_epochs(self): def n_epochs(self):
"""total number of epochs"""
"""一共会运行多少个epoch"""
return self._trainer.n_epochs return self._trainer.n_epochs


@property @property
def optimizer(self): def optimizer(self):
"""torch.optim.Optimizer for current model"""
"""初始化Trainer时传递的Optimizer"""
return self._trainer.optimizer return self._trainer.optimizer


@property @property
def model(self): def model(self):
"""training model"""
"""正在被Trainer训练的模型"""
return self._trainer.model return self._trainer.model


@property @property
def pbar(self): def pbar(self):
"""If use_tqdm, return trainer's tqdm print bar, else return None."""
"""如果在Callback中需要打印内容,请使用self.pbar.write(str)。否则可能出现命令行显示效果不太好的问题。"""
return self._trainer.pbar return self._trainer.pbar


@property @property
def update_every(self): def update_every(self):
"""The model in trainer will update parameters every `update_every` batches."""
"""Trainer中的模型多少次反向传播才进行一次梯度更新,在Trainer初始化时传入的。"""
return self._trainer.update_every return self._trainer.update_every

@property
def batch_per_epoch(self):
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。"""
return self._trainer.batch_per_epoch

def on_train_begin(self): def on_train_begin(self):
# before the main training loop
"""
在Train过程开始之前调用。

:return:
"""
pass pass


def on_epoch_begin(self): def on_epoch_begin(self):
# at the beginning of each epoch
"""
在每个epoch开始之前调用一次

:return:
"""
pass pass


def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch
"""
每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步
可以进行一些负采样之类的操作

:param dict batch_x: DataSet中被设置为input的field的batch。
:param dict batch_y: DataSet中被设置为target的field的batch。
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些
情况下可以帮助定位是哪个Sample导致了错误。
:return:
"""
pass pass


def on_loss_begin(self, batch_y, predict_y): def on_loss_begin(self, batch_y, predict_y):
# after data_forward, and before loss computation
"""
在计算loss前调用,即这里修改batch_y或predict_y的值是可以影响到loss计算的。

:param dict batch_y: 在DataSet中被设置为target的field的batch集合。
:param dict predict_y: 模型的forward()返回的结果。
:return:
"""
pass pass


def on_backward_begin(self, loss): def on_backward_begin(self, loss):
# after loss computation, and before gradient backward
"""
在loss得到之后,但在反向传播之前。可能可以进行loss是否为NaN的检查。

:param torch.Tensor loss: 计算得到的loss值
:return:
"""
pass pass


def on_backward_end(self): def on_backward_end(self):
"""
反向梯度传播已完成,但由于update_every的设置,可能并不是每一次调用都有梯度。到这一步,还没有更新参数。

:return:
"""
pass pass


def on_step_end(self): def on_step_end(self):
"""
到这里模型的参数已经按照梯度更新。但可能受update_every影响,并不是每次都更新了。

:return:
"""
pass pass


def on_batch_end(self, *args):
# at the end of each step/mini-batch
def on_batch_end(self):
"""
这一步与on_step_end是紧接着的。只是为了对称性加上了这一步。

"""
pass pass


def on_valid_begin(self): def on_valid_begin(self):
"""
如果Trainer中设置了验证,则发生验证前会调用该函数

:return:
"""
pass pass


def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
""" """
每次执行验证机的evaluation后会调用。传入eval_result
每次执行验证集的evaluation后会调用。


:param eval_result: Dict[str: Dict[str: float]], evaluation的结果
:param metric_key: str
:param optimizer: optimizer passed to trainer
:param is_better_eval: bool, 当前dev结果是否比之前的好
:param Dict[str: Dict[str: float]] eval_result: , evaluation的结果。一个例子为{'AccuracyMetric':{'acc':1.0}},即
传入的dict是有两层,第一层是metric的名称,第二层是metric的具体指标。
:param str metric_key: 初始化Trainer时传入的metric_key。
:param torch.Optimizer optimizer: Trainer中使用的优化器。
:param bool is_better_eval: 当前dev结果是否比之前的好。
:return: :return:
""" """
pass pass
@@ -137,7 +194,7 @@ class Callback(object):
pass pass




def transfer(func):
def _transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. """装饰器,将对CallbackManager的调用转发到各个Callback子类.
:param func: :param func:
:return: :return:
@@ -153,9 +210,7 @@ def transfer(func):




class CallbackManager(Callback): class CallbackManager(Callback):
"""A manager for all callbacks passed into Trainer.
It collects resources inside Trainer and raise callbacks.

"""内部使用的Callback管理类
""" """


def __init__(self, env, callbacks=None): def __init__(self, env, callbacks=None):
@@ -182,104 +237,70 @@ class CallbackManager(Callback):
for callback in self.callbacks: for callback in self.callbacks:
setattr(callback, '_'+env_name, env_val) # Callback.trainer setattr(callback, '_'+env_name, env_val) # Callback.trainer


@transfer
@_transfer
def on_train_begin(self): def on_train_begin(self):
pass pass


@transfer
@_transfer
def on_epoch_begin(self): def on_epoch_begin(self):
pass pass


@transfer
@_transfer
def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
pass pass


@transfer
@_transfer
def on_loss_begin(self, batch_y, predict_y): def on_loss_begin(self, batch_y, predict_y):
pass pass


@transfer
@_transfer
def on_backward_begin(self, loss): def on_backward_begin(self, loss):
pass pass


@transfer
@_transfer
def on_backward_end(self): def on_backward_end(self):
pass pass


@transfer
@_transfer
def on_step_end(self): def on_step_end(self):
pass pass


@transfer
@_transfer
def on_batch_end(self): def on_batch_end(self):
pass pass


@transfer
@_transfer
def on_valid_begin(self): def on_valid_begin(self):
pass pass


@transfer
@_transfer
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
pass pass


@transfer
@_transfer
def on_epoch_end(self): def on_epoch_end(self):
pass pass


@transfer
@_transfer
def on_train_end(self): def on_train_end(self):
pass pass


@transfer
@_transfer
def on_exception(self, exception): def on_exception(self, exception):
pass pass




class DummyCallback(Callback):
def on_train_begin(self, *arg):
print(arg)

def on_epoch_end(self):
print(self.epoch, self.n_epochs)


class EchoCallback(Callback):
def on_train_begin(self):
print("before_train")

def on_epoch_begin(self):
print("before_epoch")

def on_batch_begin(self, batch_x, batch_y, indices):
print("before_batch")

def on_loss_begin(self, batch_y, predict_y):
print("before_loss")

def on_backward_begin(self, loss):
print("before_backward")

def on_batch_end(self):
print("after_batch")

def on_epoch_end(self):
print("after_epoch")

def on_train_end(self):
print("after_train")


class GradientClipCallback(Callback): class GradientClipCallback(Callback):
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): def __init__(self, parameters=None, clip_value=1, clip_type='norm'):
"""每次backward前,将parameter的gradient clip到某个范围。 """每次backward前,将parameter的gradient clip到某个范围。


:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。如果为None则默认对Trainer
的model中所有参数进行clip 的model中所有参数进行clip
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
:param clip_type: str, 支持'norm', 'value'两种。
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于
clip_value的gradient被赋值为clip_value.
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
:param str clip_type: 支持'norm', 'value'两种。
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
2. 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于
clip_value的gradient被赋值为clip_value.
""" """
super().__init__() super().__init__()


@@ -314,7 +335,7 @@ class EarlyStopCallback(Callback):
def __init__(self, patience): def __init__(self, patience):
""" """


:param int patience: 停止之前等待的epoch数
:param int patience: 多少个epoch没有变好就停止训练
""" """
super(EarlyStopCallback, self).__init__() super(EarlyStopCallback, self).__init__()
self.patience = patience self.patience = patience
@@ -341,7 +362,7 @@ class LRScheduler(Callback):
def __init__(self, lr_scheduler): def __init__(self, lr_scheduler):
"""对PyTorch LR Scheduler的包装 """对PyTorch LR Scheduler的包装


:param lr_scheduler: PyTorch的lr_scheduler
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler
""" """
super(LRScheduler, self).__init__() super(LRScheduler, self).__init__()
import torch.optim import torch.optim
@@ -358,7 +379,7 @@ class ControlC(Callback):
def __init__(self, quit_all): def __init__(self, quit_all):
""" """


:param quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer
""" """
super(ControlC, self).__init__() super(ControlC, self).__init__()
if type(quit_all) != bool: if type(quit_all) != bool:
@@ -389,16 +410,16 @@ class SmoothValue(object):




class LRFinder(Callback): class LRFinder(Callback):
def __init__(self, n_batch, start_lr=1e-6, end_lr=10):
def __init__(self, start_lr=1e-6, end_lr=10):
"""用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 """用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它


:param n_batch: 一个epoch内的iteration数
:param start_lr: 学习率下界
:param end_lr: 学习率上界
:param int n_batch: 一个epoch内的iteration数
:param float start_lr: 学习率下界
:param float end_lr: 学习率上界
""" """
super(LRFinder, self).__init__() super(LRFinder, self).__init__()
self.start_lr, self.end_lr = start_lr, end_lr self.start_lr, self.end_lr = start_lr, end_lr
self.num_it = n_batch
self.num_it = self.batch_per_epoch
self.stop = False self.stop = False
self.best_loss = 0. self.best_loss = 0.
self.best_lr = None self.best_lr = None
@@ -514,7 +535,3 @@ class TensorboardCallback(Callback):
del self._summary_writer del self._summary_writer




if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])
manager.on_train_begin()
# print(manager.after_epoch())

+ 3
- 3
fastNLP/core/dataset.py View File

@@ -277,7 +277,7 @@ import warnings
from fastNLP.core.fieldarray import AutoPadder from fastNLP.core.fieldarray import AutoPadder
from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature


class DataSet(object): class DataSet(object):
"""fastNLP的数据容器 """fastNLP的数据容器
@@ -642,7 +642,7 @@ class DataSet(object):
print("Exception happens at the `{}`th instance.".format(idx)) print("Exception happens at the `{}`th instance.".format(idx))
raise e raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))


if new_field_name is not None: if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs) self._add_apply_field(results, new_field_name, kwargs)
@@ -707,7 +707,7 @@ class DataSet(object):
raise e raise e
# results = [func(ins) for ins in self._inner_iter()] # results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))


if new_field_name is not None: if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs) self._add_apply_field(results, new_field_name, kwargs)


+ 35
- 57
fastNLP/core/fieldarray.py View File

@@ -1,5 +1,5 @@
""" """
FieldArray是 DataSet_ 中一列的存储方式
FieldArray是 DataSet_ 中一列的存储方式,原理部分请参考 DataSet_ 处


.. _FieldArray: .. _FieldArray:


@@ -11,41 +11,19 @@ from copy import deepcopy




class FieldArray(object): class FieldArray(object):
"""``FieldArray`` is the collection of ``Instance``s of the same field.
It is the basic element of ``DataSet`` class.

:param str name: the name of the FieldArray
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input.
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。
默认为None,(1)如果某个field是scalar,则不进行任何padding;(2)如果为一维list, 且fieldarray的dtype为float或int类型
则会进行padding;(3)其它情况不进行padder。
假设需要对English word中character进行padding,则需要使用其他的padder。
或ignore_type为True但是需要进行padding。
:param bool ignore_type: whether to ignore type. If True, no type detection will rise for this FieldArray.
(default: False)
"""

def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
"""DataSet在初始化时会有两类方法对FieldArray操作:
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]})
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
然后后面的样本使用FieldArray.append进行添加。
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])

类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。
ignore_type用来控制是否进行类型检查,如果为True,则不检查。

"""FieldArray是用于保存 DataSet_ 中一个field的实体。

:param str name: FieldArray的名称
:param list,numpy.ndarray content: 列表的元素可以为list,int,float,
:param bool is_target: 这个field是否是一个target field。
:param bool is_input: 这个field是否是一个input field。
:param Padder padder: PadderBase类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过
fieldarray.set_pad_val()。默认为None,即使用 AutoPadder_ 。
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, 就
可以设置为True。具体意义请参考 DataSet_ 。
""" """

self.name = name self.name = name
if isinstance(content, list): if isinstance(content, list):
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list
@@ -211,10 +189,10 @@ class FieldArray(object):
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) return "FieldArray {}: {}".format(self.name, self.content.__repr__())


def append(self, val): def append(self, val):
"""将val增加到FieldArray中,若该field的ignore_type为True则直接append到这个field中;若ignore_type为False,且当前field为
input或者target,则会检查传入的content是否与之前的内容在dimension, 元素的类型上是匹配的。
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有
的内容是匹配的。


:param val: Any.
:param Any val: 需要append的值。
""" """
if self.ignore_type is False: if self.ignore_type is False:
if isinstance(val, list): if isinstance(val, list):
@@ -262,8 +240,8 @@ class FieldArray(object):
def get(self, indices, pad=True): def get(self, indices, pad=True):
"""根据给定的indices返回内容 """根据给定的indices返回内容


:param indices: (int, List[int]), 获取indices对应的内容。
:param pad: bool, 是否对返回的结果进行padding。仅对indices为List[int]时有效
:param int,list(int) indices:, 获取indices对应的内容。
:param bool pad: , 是否对返回的结果进行padding。仅对indices为List[int]时有效
:return: (single, List) :return: (single, List)
""" """
if isinstance(indices, int): if isinstance(indices, int):
@@ -281,7 +259,7 @@ class FieldArray(object):
""" """
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。


:param padder: (None, Padder). 设置为None即删除padder.
:param None,Padder padder:. 设置为None即删除padder。
:return: :return:
""" """
if padder is not None: if padder is not None:
@@ -293,7 +271,7 @@ class FieldArray(object):
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
"""修改padder的pad_val. """修改padder的pad_val.


:param pad_val: int。将该field的pad值设置为该值
:param int pad_val: 该field的pad值设置为该值
:return: :return:
""" """
if self.padder is not None: if self.padder is not None:
@@ -312,8 +290,8 @@ class FieldArray(object):
""" """
将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type 将other的属性复制给本FieldArray(other必须为FieldArray类型).属性包括 is_input, is_target, padder, ignore_type


:param other: FieldArray
:return:
:param FieldArray other: 从哪个field拷贝属性
:return: FieldArray
""" """
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other))


@@ -324,7 +302,7 @@ class FieldArray(object):


return self return self


def is_iterable(content):
def _is_iterable(content):
try: try:
_ = (e for e in content) _ = (e for e in content)
except TypeError: except TypeError:
@@ -350,11 +328,10 @@ class Padder:
""" """
传入的是List内容。假设有以下的DataSet。 传入的是List内容。假设有以下的DataSet。


:param contents: List[element]。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
deepcopy一份。 deepcopy一份。
:param field_name: str, field的名称。
:param field_ele_dtype: (np.int64, np.float64, np.str, None), 该field的内层元素的类型。如果该field的ignore_type
为True,该这个值为None。
:param str, field_name: field的名称。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
:return: np.array([padded_element]) :return: np.array([padded_element])


Example:: Example::
@@ -400,10 +377,10 @@ class AutoPadder(Padder):


2 如果元素类型为(np.int64, np.float64), 2 如果元素类型为(np.int64, np.float64),


2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding


2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
如果某个instance中field为[1, 2, 3],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
""" """


def __init__(self, pad_val=0): def __init__(self, pad_val=0):
@@ -427,7 +404,7 @@ class AutoPadder(Padder):
return False return False


def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
if not is_iterable(contents[0]):
if not _is_iterable(contents[0]):
array = np.array([content for content in contents], dtype=field_ele_dtype) array = np.array([content for content in contents], dtype=field_ele_dtype)
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents):
max_len = max([len(content) for content in contents]) max_len = max([len(content) for content in contents])
@@ -454,7 +431,7 @@ class EngChar2DPadder(Padder):
Example:: Example::


from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import EnChar2DPadder
from fastNLP import EngChar2DPadder
from fastNLP import Vocabulary from fastNLP import Vocabulary
dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']}) dataset = DataSet({'sent': ['This is the first demo', 'This is the second demo']})
dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars') dataset.apply(lambda ins:[list(word) for word in ins['sent'].split()], new_field_name='chars')
@@ -462,14 +439,15 @@ class EngChar2DPadder(Padder):
vocab.from_dataset(dataset, field_name='chars') vocab.from_dataset(dataset, field_name='chars')
vocab.index_dataset(dataset, field_name='chars') vocab.index_dataset(dataset, field_name='chars')
dataset.set_input('chars') dataset.set_input('chars')
padder = EnChar2DPadder()
padder = EngChar2DPadder()
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder

""" """
def __init__(self, pad_val=0, pad_length=0): def __init__(self, pad_val=0, pad_length=0):
""" """
:param pad_val: int, pad的位置使用该index :param pad_val: int, pad的位置使用该index
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度都pad或截
取到该长度.
:param pad_length: int, 如果为0则取一个batch中最大的单词长度作为padding长度。如果为大于0的数,则将所有单词的长度
都pad或截取到该长度.
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)


@@ -494,7 +472,7 @@ class EngChar2DPadder(Padder):
except: except:
raise ValueError("Field:{} only has two dimensions.".format(field_name)) raise ValueError("Field:{} only has two dimensions.".format(field_name))


if is_iterable(value):
if _is_iterable(value):
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) raise ValueError("Field:{} has more than 3 dimension.".format(field_name))


def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):


+ 14
- 15
fastNLP/core/instance.py View File

@@ -3,34 +3,33 @@ Instance文档


.. _Instance: .. _Instance:


测试
Instance是fastNLP中对应于一个sample的类。一个sample可以认为是fastNLP中的一个Instance对象。一个具像化的表示类似与 DataSet_
出那个表中所展示的一行。


""" """






class Instance(object): class Instance(object):
"""An Instance is an example of data.
Example::
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
ins["field_1"]
>>[1, 1, 1]
ins.add_field("field_3", [3, 3, 3])
"""

def __init__(self, **fields): def __init__(self, **fields):
"""
"""Instance的初始化如下面的Example所示

Example::

ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
ins["field_1"]
>>[1, 1, 1]
ins.add_field("field_3", [3, 3, 3])


:param fields: 可能是一维或者二维的 list or np.array
ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))})
""" """
self.fields = fields self.fields = fields


def add_field(self, field_name, field): def add_field(self, field_name, field):
"""Add a new field to the instance.
"""向Instance中增加一个field


:param field_name: str, the name of the field.
:param str field_name: 新增field的名称
:param Any field: 新增field的内容
""" """
self.fields[field_name] = field self.fields[field_name] = field




+ 15
- 15
fastNLP/core/losses.py View File

@@ -12,12 +12,12 @@ from collections import defaultdict
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F


from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _CheckRes
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _check_function_or_method from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature




class LossBase(object): class LossBase(object):
@@ -70,7 +70,7 @@ class LossBase(object):
for func_param, input_param in self.param_map.items(): for func_param, input_param in self.param_map.items():
if func_param not in func_args: if func_param not in func_args:
raise NameError( raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the "
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the "
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")


# evaluate should not have varargs. # evaluate should not have varargs.
@@ -111,7 +111,7 @@ class LossBase(object):
func_args = set([arg for arg in func_spect.args if arg != 'self']) func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items(): for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args: if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.")
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.")


# 2. only part of the param_map are passed, left are not # 2. only part of the param_map are passed, left are not
for arg in func_args: for arg in func_args:
@@ -151,16 +151,16 @@ class LossBase(object):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)" f"in `{self.__class__.__name__}`)"


check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
check_res = _CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)


if check_res.missing or check_res.duplicated: if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.get_loss))
raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.get_loss))
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)


loss = self.get_loss(**refined_args) loss = self.get_loss(**refined_args)
@@ -289,14 +289,14 @@ class LossInForward(LossBase):


def get_loss(self, **kwargs): def get_loss(self, **kwargs):
if self.loss_key not in kwargs: if self.loss_key not in kwargs:
check_res = CheckRes(
check_res = _CheckRes(
missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"], missing=[self.loss_key + f"(assign to `{self.loss_key}` in `{self.__class__.__name__}`"],
unused=[], unused=[],
duplicated=[], duplicated=[],
required=[], required=[],
all_needed=[], all_needed=[],
varargs=[]) varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))
raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss))
return kwargs[self.loss_key] return kwargs[self.loss_key]


def __call__(self, pred_dict, target_dict, check=False): def __call__(self, pred_dict, target_dict, check=False):


+ 25
- 25
fastNLP/core/metrics.py View File

@@ -13,11 +13,11 @@ from collections import defaultdict
import numpy as np import numpy as np
import torch import torch


from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _CheckRes
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import seq_lens_to_masks from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary


@@ -161,7 +161,7 @@ class MetricBase(object):
for func_param, input_param in self.param_map.items(): for func_param, input_param in self.param_map.items():
if func_param not in func_args: if func_param not in func_args:
raise NameError( raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")


def _fast_param_map(self, pred_dict, target_dict): def _fast_param_map(self, pred_dict, target_dict):
@@ -207,7 +207,7 @@ class MetricBase(object):
func_args = set([arg for arg in func_spect.args if arg != 'self']) func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items(): for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args: if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.")


# 2. only part of the param_map are passed, left are not # 2. only part of the param_map are passed, left are not
for arg in func_args: for arg in func_args:
@@ -248,16 +248,16 @@ class MetricBase(object):
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)" f"in `{self.__class__.__name__}`)"


check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)
check_res = _CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)


if check_res.missing or check_res.duplicated: if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)


self.evaluate(**refined_args) self.evaluate(**refined_args)
@@ -294,14 +294,14 @@ class AccuracyMetric(MetricBase):
""" """
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor): if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.") f"got {type(pred)}.")
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")


if seq_len is not None and not isinstance(seq_len, torch.Tensor): if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.") f"got {type(seq_lens)}.")


if seq_len is not None: if seq_len is not None:
@@ -314,7 +314,7 @@ class AccuracyMetric(MetricBase):
elif len(pred.size()) == len(target.size()) + 1: elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1) pred = pred.argmax(dim=-1)
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")


@@ -516,14 +516,14 @@ class SpanFPreRecMetric(MetricBase):
:return: :return:
""" """
if not isinstance(pred, torch.Tensor): if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.") f"got {type(pred)}.")
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")


if not isinstance(seq_len, torch.Tensor): if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.") f"got {type(seq_len)}.")


if pred.size() == target.size() and len(target.size()) == 2: if pred.size() == target.size() and len(target.size()) == 2:
@@ -535,7 +535,7 @@ class SpanFPreRecMetric(MetricBase):
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes)) "id >= {}, the number of classes.".format(num_classes))
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")


@@ -714,14 +714,14 @@ class BMESF1PreRecMetric(MetricBase):
:return: :return:
""" """
if not isinstance(pred, torch.Tensor): if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.") f"got {type(pred)}.")
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")


if not isinstance(seq_len, torch.Tensor): if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.") f"got {type(seq_len)}.")


if pred.size() == target.size() and len(target.size()) == 2: if pred.size() == target.size() and len(target.size()) == 2:
@@ -729,7 +729,7 @@ class BMESF1PreRecMetric(MetricBase):
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1) pred = pred.argmax(dim=-1)
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")




+ 6
- 6
fastNLP/core/tester.py View File

@@ -5,11 +5,11 @@ from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import _get_device from fastNLP.core.utils import _get_device




@@ -75,19 +75,19 @@ class Tester(object):
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
pred_dict = self._data_forward(self._predict_func, batch_x) pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict): if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(pred_dict)}.") f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics: for metric in self.metrics:
metric(pred_dict, batch_y) metric(pred_dict, batch_y)
for metric in self.metrics: for metric in self.metrics:
eval_result = metric.get_metric() eval_result = metric.get_metric()
if not isinstance(eval_result, dict): if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be "
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}") f"`dict`, got {type(eval_result)}")
metric_name = metric.__class__.__name__ metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result eval_results[metric_name] = eval_result
except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func)
except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0) dataset=self.data, check_level=0)


+ 34
- 57
fastNLP/core/trainer.py View File

@@ -85,17 +85,17 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_ 对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_
使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。 使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。


尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了
LossInForward_ 这个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,
并使用它作为loss。 如果Trainer初始化没有提供loss则使用这个loss TODO 补充一个例子
尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 LossInForward_ 这
个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,
并使用它作为loss。 如果Trainer初始化没有提供loss则默认使用 LossInForward_ 。详细例子可以参照 TODO 补充一个例子


3. Metric 3. Metric


Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。


在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, 在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中的input的选择
出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中被设置为input
的field中选择出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子


2. Trainer的代码检查 2. Trainer的代码检查


@@ -112,12 +112,12 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
from fastNLP import DataSet from fastNLP import DataSet


class Model(nn.Module): class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x, b):
loss = torch.mean((self.fc(x)-b)**2)
return {'loss': loss}
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x, b):
loss = torch.mean((self.fc(x)-b)**2)
return {'loss': loss}
model = Model() model = Model()


dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
@@ -138,16 +138,15 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
# unused field: ['a'] # unused field: ['a']
# Suggestion: You need to provide ['x'] in DataSet and set it as input. # Suggestion: You need to provide ['x'] in DataSet and set it as input.


这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里两类
这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里两类
信息可以为你提供参考 信息可以为你提供参考


1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。
因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。


2. 如果出现了映射错误,出现NameError。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
出当前是在调取forward出错),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x'
,或者model的参数从'x'修改为'a'都可以解决问题。
2. NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。


下面的例子是由于loss计算的时候找不到需要的值 下面的例子是由于loss计算的时候找不到需要的值


@@ -249,18 +248,18 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). # (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).


报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation 报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output').
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')


可以通过check_code_level调节检查的强度。默认为0,即进行检查。 可以通过check_code_level调节检查的强度。默认为0,即进行检查。


3. Trainer与callback 3. Trainer与callback


虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的,所有的 Callback_ 都具有
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的函数集合,所有的 Callback_ 都具有
on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。


我们将Train.train()这个函数内部分为以下的阶段
我们将Train.train()这个函数内部分为以下的阶段,在对应阶段会触发相应的调用。


Example:: Example::


@@ -305,7 +304,7 @@ import warnings
try: try:
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
except: except:
from fastNLP.core.utils import pseudo_tqdm as tqdm
from fastNLP.core.utils import _pseudo_tqdm as tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException from fastNLP.core.callback import CallbackManager, CallbackException
@@ -316,12 +315,12 @@ from fastNLP.core.sampler import Sampler
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_forward_error from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _get_func_signature
from fastNLP.core.utils import _get_device from fastNLP.core.utils import _get_device




@@ -466,34 +465,11 @@ class Trainer(object):


def train(self, load_best_model=True): def train(self, load_best_model=True):
""" """

开始训练过程。主要有以下几个步骤::

for epoch in range(num_epochs):
# 使用Batch从DataSet中按批取出数据,并自动对DataSet中dtype为(float, int)的fields进行padding。并转换为Tensor。
非float,int类型的参数将不会被转换为Tensor,且不进行padding。
for batch_x, batch_y in Batch(DataSet)
# batch_x是一个dict, 被设为input的field会出现在这个dict中,
key为DataSet中的field_name, value为该field的value
# batch_y也是一个dict,被设为target的field会出现在这个dict中,
key为DataSet中的field_name, value为该field的value
2. 将batch_x的数据送入到model.forward函数中,并获取结果。这里我们就是通过匹配batch_x中的key与forward函数的形
参完成参数传递。例如,
forward(self, x, seq_lens) # fastNLP会在batch_x中找到key为"x"的value传递给x,key为"seq_lens"的
value传递给seq_lens。若在batch_x中没有找到所有必须要传递的参数,就会报错。如果forward存在默认参数
而且默认参数这个key没有在batch_x中,则使用默认参数。
3. 将batch_y与model.forward的结果一并送入loss中计算loss。loss计算时一般都涉及到pred与target。但是在不同情况
中,可能pred称为output或prediction, target称为y或label。fastNLP通过初始化loss时传入的映射找到pred或
target。比如在初始化Trainer时初始化loss为CrossEntropyLoss(pred='output', target='y'), 那么fastNLP计
算loss时,就会使用"output"在batch_y与forward的结果中找到pred;使用"y"在batch_y与forward的结果中找target
, 并完成loss的计算。
4. 获取到loss之后,进行反向求导并更新梯度
根据需要适时进行验证机测试
根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型
使用该函数使Trainer开始训练。


:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。 最好的模型参数。
:return results: 返回一个字典类型的数据,
:return dict: 返回一个字典类型的数据,
内含以下内容:: 内含以下内容::


seconds: float, 表示训练时长 seconds: float, 表示训练时长
@@ -547,7 +523,7 @@ class Trainer(object):


def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
else: else:
inner_tqdm = tqdm inner_tqdm = tqdm
self.step = 0 self.step = 0
@@ -559,6 +535,7 @@ class Trainer(object):
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch) prefetch=self.prefetch)
self.batch_per_epoch = data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch self.epoch = epoch
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
@@ -660,7 +637,7 @@ class Trainer(object):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict): if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
raise TypeError(f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


def _grad_backward(self, loss): def _grad_backward(self, loss):
@@ -796,7 +773,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_


refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
pred_dict = model(**refined_batch_x) pred_dict = model(**refined_batch_x)
func_signature = get_func_signature(model.forward)
func_signature = _get_func_signature(model.forward)
if not isinstance(pred_dict, dict): if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")


@@ -807,16 +784,16 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
if batch_count == 0: if batch_count == 0:
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise TypeError( raise TypeError(
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"The return value of {_get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.") f"but got `{type(loss)}`.")
if len(loss.size()) != 0: if len(loss.size()) != 0:
raise ValueError( raise ValueError(
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
f"The size of return value of {_get_func_signature(losser.get_loss)} is {loss.size()}, "
f"should be torch.size([])") f"should be torch.size([])")
loss.backward() loss.backward()
except CheckError as e:
# TODO: another error raised if CheckError caught
pre_func_signature = get_func_signature(model.forward)
except _CheckError as e:
# TODO: another error raised if _CheckError caught
pre_func_signature = _get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=dataset, check_level=check_level) dataset=dataset, check_level=check_level)


+ 107
- 56
fastNLP/core/utils.py View File

@@ -9,7 +9,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs']) 'varargs'])


def _prepare_cache_filepath(filepath): def _prepare_cache_filepath(filepath):
@@ -28,6 +28,57 @@ def _prepare_cache_filepath(filepath):


# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 # TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。
def cache_results(_cache_fp, _refresh=False, _verbose=1): def cache_results(_cache_fp, _refresh=False, _verbose=1):
"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用

Example::

import time
import numpy as np
from fastNLP import cache_results

@cache_results('cache.pkl')
def process_data():
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
time.sleep(1)
return np.random.randint(5, size=(10, 20))

start_time = time.time()
process_data()
print(time.time() - start_time)

start_time = time.time()
process_data()
print(time.time() - start_time)

# 输出内容如下
# Save cache to cache.pkl.
# 1.0015439987182617
# Read cache from cache.pkl.
# 0.00013065338134765625

可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理

Example::

# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'

上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称。

Example::

process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache

:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
:param int _verbose: 是否打印cache的信息。
:return:
"""
def wrapper_(func): def wrapper_(func):
signature = inspect.signature(func) signature = inspect.signature(func)
for key, _ in signature.parameters.items(): for key, _ in signature.parameters.items():
@@ -74,48 +125,48 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
return wrapper return wrapper
return wrapper_ return wrapper_


def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.
:param obj: an object
:param pickle_path: str, the directory where the pickle file is to be saved
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
"""
if not os.path.exists(pickle_path):
os.mkdir(pickle_path)
print("make dir {} before saving pickle file".format(pickle_path))
with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f)
print("{} saved in {}".format(file_name, pickle_path))
def load_pickle(pickle_path, file_name):
"""Load an object from a given pickle file.
:param pickle_path: str, the directory where the pickle file is.
:param file_name: str, the name of the pickle file.
:return obj: an object stored in the pickle
"""
with open(os.path.join(pickle_path, file_name), "rb") as f:
obj = _pickle.load(f)
print("{} loaded from {}".format(file_name, pickle_path))
return obj
def pickle_exist(pickle_path, pickle_name):
"""Check if a given pickle file exists in the directory.
:param pickle_path: the directory of target pickle file
:param pickle_name: the filename of target pickle file
:return: True if file exists else False
"""
if not os.path.exists(pickle_path):
os.makedirs(pickle_path)
file_name = os.path.join(pickle_path, pickle_name)
if os.path.exists(file_name):
return True
else:
return False
# def save_pickle(obj, pickle_path, file_name):
# """Save an object into a pickle file.
#
# :param obj: an object
# :param pickle_path: str, the directory where the pickle file is to be saved
# :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
# """
# if not os.path.exists(pickle_path):
# os.mkdir(pickle_path)
# print("make dir {} before saving pickle file".format(pickle_path))
# with open(os.path.join(pickle_path, file_name), "wb") as f:
# _pickle.dump(obj, f)
# print("{} saved in {}".format(file_name, pickle_path))
#
#
# def load_pickle(pickle_path, file_name):
# """Load an object from a given pickle file.
#
# :param pickle_path: str, the directory where the pickle file is.
# :param file_name: str, the name of the pickle file.
# :return obj: an object stored in the pickle
# """
# with open(os.path.join(pickle_path, file_name), "rb") as f:
# obj = _pickle.load(f)
# print("{} loaded from {}".format(file_name, pickle_path))
# return obj
#
#
# def pickle_exist(pickle_path, pickle_name):
# """Check if a given pickle file exists in the directory.
#
# :param pickle_path: the directory of target pickle file
# :param pickle_name: the filename of target pickle file
# :return: True if file exists else False
# """
# if not os.path.exists(pickle_path):
# os.makedirs(pickle_path)
# file_name = os.path.join(pickle_path, pickle_name)
# if os.path.exists(file_name):
# return True
# else:
# return False


def _get_device(device, check_exist=False): def _get_device(device, check_exist=False):
""" """
@@ -232,15 +283,15 @@ def _check_arg_dict_list(func, args):
missing = list(require_args - input_args) missing = list(require_args - input_args)
unused = list(input_args - all_args) unused = list(input_args - all_args)
varargs = [] if not spect.varargs else [spect.varargs] varargs = [] if not spect.varargs else [spect.varargs]
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args),
varargs=varargs)
return _CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args),
varargs=varargs)




def get_func_signature(func):
def _get_func_signature(func):
""" """


Given a function or method, return its signature. Given a function or method, return its signature.
@@ -318,13 +369,13 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
raise TypeError("Only support `dict` type right now.") raise TypeError("Only support `dict` type right now.")




class CheckError(Exception):
class _CheckError(Exception):
""" """


CheckError. Used in losses.LossBase, metrics.MetricBase.
_CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """


def __init__(self, check_res: CheckRes, func_signature: str):
def __init__(self, check_res: _CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`'] errs = [f'Problems occurred when calling `{func_signature}`']


if check_res.varargs: if check_res.varargs:
@@ -347,7 +398,7 @@ WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2 STRICT_CHECK_LEVEL = 2




def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: _CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0): pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = [] errs = []
unuseds = [] unuseds = []
@@ -449,7 +500,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re


def _check_forward_error(forward_func, batch_x, dataset, check_level): def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x) check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)
func_signature = _get_func_signature(forward_func)


errs = [] errs = []
suggestions = [] suggestions = []
@@ -543,7 +594,7 @@ def seq_mask(seq_len, max_len):
return torch.gt(seq_len, seq_range) # [batch_size, max_len] return torch.gt(seq_len, seq_range) # [batch_size, max_len]




class pseudo_tqdm:
class _pseudo_tqdm:
""" """
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
""" """


+ 13
- 13
fastNLP/core/vocabulary.py View File

@@ -2,7 +2,7 @@ from functools import wraps
from collections import Counter from collections import Counter
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet


def check_build_vocab(func):
def _check_build_vocab(func):
"""A decorator to make sure the indexing is built before used. """A decorator to make sure the indexing is built before used.


""" """
@@ -15,7 +15,7 @@ def check_build_vocab(func):
return _wrapper return _wrapper




def check_build_status(func):
def _check_build_status(func):
"""A decorator to check whether the vocabulary updates after the last build. """A decorator to check whether the vocabulary updates after the last build.


""" """
@@ -67,7 +67,7 @@ class Vocabulary(object):
self.idx2word = None self.idx2word = None
self.rebuild = True self.rebuild = True


@check_build_status
@_check_build_status
def update(self, word_lst): def update(self, word_lst):
"""依次增加序列中词在词典中的出现频率 """依次增加序列中词在词典中的出现频率


@@ -75,7 +75,7 @@ class Vocabulary(object):
""" """
self.word_count.update(word_lst) self.word_count.update(word_lst)


@check_build_status
@_check_build_status
def add(self, word): def add(self, word):
""" """
增加一个新词在词典中的出现频率 增加一个新词在词典中的出现频率
@@ -84,7 +84,7 @@ class Vocabulary(object):
""" """
self.word_count[word] += 1 self.word_count[word] += 1


@check_build_status
@_check_build_status
def add_word(self, word): def add_word(self, word):
""" """
增加一个新词在词典中的出现频率 增加一个新词在词典中的出现频率
@@ -93,7 +93,7 @@ class Vocabulary(object):
""" """
self.add(word) self.add(word)


@check_build_status
@_check_build_status
def add_word_lst(self, word_lst): def add_word_lst(self, word_lst):
""" """
依次增加序列中词在词典中的出现频率 依次增加序列中词在词典中的出现频率
@@ -132,11 +132,11 @@ class Vocabulary(object):
""" """
self.idx2word = {i: w for w, i in self.word2idx.items()} self.idx2word = {i: w for w, i in self.word2idx.items()}


@check_build_vocab
@_check_build_vocab
def __len__(self): def __len__(self):
return len(self.word2idx) return len(self.word2idx)


@check_build_vocab
@_check_build_vocab
def __contains__(self, item): def __contains__(self, item):
""" """
检查词是否被记录 检查词是否被记录
@@ -161,7 +161,7 @@ class Vocabulary(object):
""" """
return self.__contains__(w) return self.__contains__(w)


@check_build_vocab
@_check_build_vocab
def __getitem__(self, w): def __getitem__(self, w):
""" """
To support usage like:: To support usage like::
@@ -175,7 +175,7 @@ class Vocabulary(object):
else: else:
raise ValueError("word {} not in vocabulary".format(w)) raise ValueError("word {} not in vocabulary".format(w))


@check_build_vocab
@_check_build_vocab
def index_dataset(self, *datasets, field_name, new_field_name=None): def index_dataset(self, *datasets, field_name, new_field_name=None):
""" """
将DataSet中对应field的词转为数字. 将DataSet中对应field的词转为数字.
@@ -275,7 +275,7 @@ class Vocabulary(object):
return self.__getitem__(w) return self.__getitem__(w)


@property @property
@check_build_vocab
@_check_build_vocab
def unknown_idx(self): def unknown_idx(self):
""" """
unknown 对应的数字. unknown 对应的数字.
@@ -285,7 +285,7 @@ class Vocabulary(object):
return self.word2idx[self.unknown] return self.word2idx[self.unknown]


@property @property
@check_build_vocab
@_check_build_vocab
def padding_idx(self): def padding_idx(self):
""" """
padding 对应的数字 padding 对应的数字
@@ -294,7 +294,7 @@ class Vocabulary(object):
return None return None
return self.word2idx[self.padding] return self.word2idx[self.padding]


@check_build_vocab
@_check_build_vocab
def to_word(self, idx): def to_word(self, idx):
""" """
给定一个数字, 将其转为对应的词. 给定一个数字, 将其转为对应的词.


+ 3
- 3
fastNLP/models/enas_trainer.py View File

@@ -13,12 +13,12 @@ from torch import nn
try: try:
from tqdm.autonotebook import tqdm from tqdm.autonotebook import tqdm
except: except:
from fastNLP.core.utils import pseudo_tqdm as tqdm
from fastNLP.core.utils import _pseudo_tqdm as tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager, CallbackException from fastNLP.core.callback import CallbackManager, CallbackException
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
import fastNLP import fastNLP
import fastNLP.models.enas_utils as utils import fastNLP.models.enas_utils as utils
@@ -118,7 +118,7 @@ class ENASTrainer(fastNLP.Trainer):


def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
else: else:
inner_tqdm = tqdm inner_tqdm = tqdm
self.step = 0 self.step = 0


+ 1
- 1
test/core/test_tester.py View File

@@ -8,7 +8,7 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
import time import time
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _CheckError
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss from fastNLP.core.losses import BCELoss


Loading…
Cancel
Save