Browse Source

增加对Trainer的注释

tags/v0.4.10
yh 5 years ago
parent
commit
4a57011315
3 changed files with 261 additions and 39 deletions
  1. +24
    -4
      fastNLP/core/metrics.py
  2. +215
    -12
      fastNLP/core/trainer.py
  3. +22
    -23
      fastNLP/core/utils.py

+ 24
- 4
fastNLP/core/metrics.py View File

@@ -119,6 +119,9 @@ class MetricBase(object):
def evaluate(self, *args, **kwargs): def evaluate(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


def get_metric(self, reset=True):
raise NotImplemented

def _init_param_map(self, key_map=None, **kwargs): def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map """检查key_map和其他参数map,并将这些映射关系添加到self.param_map


@@ -161,8 +164,20 @@ class MetricBase(object):
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")


def get_metric(self, reset=True):
raise NotImplemented
def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element

:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
return fast_param


def __call__(self, pred_dict, target_dict): def __call__(self, pred_dict, target_dict):
""" """
@@ -178,10 +193,15 @@ class MetricBase(object):
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容)
:return: :return:
""" """
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param:
self.evaluate(**fast_param)
return


if not self._checked: if not self._checked:
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")
# 1. check consistence between signature and param_map # 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg != 'self']) func_args = set([arg for arg in func_spect.args if arg != 'self'])


+ 215
- 12
fastNLP/core/trainer.py View File

@@ -7,6 +7,7 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
对Batch进行pad; (4) 每个epoch结束或一定step后进行验证集验证; (5) 保存获得更好验证性能的模型等。 对Batch进行pad; (4) 每个epoch结束或一定step后进行验证集验证; (5) 保存获得更好验证性能的模型等。


1. Trainer的基本使用 1. Trainer的基本使用

下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。 下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。


Example:: Example::
@@ -53,20 +54,23 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
trainer.train() trainer.train()


由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。 由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。
使用Trainer需要满足以下几个条件
使用Trainer需要满足以下几个条件:


1. 模型 1. 模型


1. 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是 1. 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是
通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该 通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该
改名为'data'。 改名为'data'。

2. 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递 2. 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递
给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。 给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。

3. 模型的forward()返回值需要为一个dict。 3. 模型的forward()返回值需要为一个dict。


2. Loss与Metric
fastNLP中的为了不限制forward函数的返回内容数量,以及对多Metric等的支持等, Loss_ 与 Metric_ 都使用了通过名称来匹配相
应内容。如上面的例子中
2. Loss

fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing, Loss_ 与 Metric_ 都使
用了通过名称来匹配相应内容的策略。如上面的例子中


Example:: Example::


@@ -74,17 +78,216 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000, optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
dev_data = dev_data, metrics=AccuracyMetric(target='label')) dev_data = dev_data, metrics=AccuracyMetric(target='label'))


loss被设置为了 CrossEntropyLoss_ , 但在初始化的时候传入了一个target='label'这个参数, CrossEntropyLoss_ 的初始化
loss被设置为了 CrossEntropyLoss_ , 但在初始化的时候传入了target='label'这个参数, CrossEntropyLoss_ 的初始化
参数为(pred=None, target=None, padding_idx=-100)。这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值 参数为(pred=None, target=None, padding_idx=-100)。这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值
与ground truth的label。其中'pred'一般是模型forward()返回结果的内容,'target'一般是来自于DataSet中被设置为target的
field。
与真实值。其中'pred'一般来自于模型forward()的返回结果,'target'一般是来自于DataSet中被设置为target的
field。由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 Loss_ 提供一种类似于映射的机制来匹配
对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_
使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。

尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了
LossInForward_ 这个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,
并使用它作为loss。 如果Trainer初始化没有提供loss则使用这个loss TODO 补充一个例子

3. Metric


2. Trainer与callback
Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。


在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中的input的选择
出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子


3. Trainer的代码检查
2. Trainer的代码检查


由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制
比如下面的例子中,由于各种原因产生的报错

Example1::

import numpy as np
from torch import nn
import torch
from torch.optim import SGD
from fastNLP import Trainer
from fastNLP import DataSet

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, x, b):
loss = torch.mean((self.fc(x)-b)**2)
return {'loss': loss}
model = Model()

dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
dataset.set_input('a', 'b')

trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001))

trainer = Trainer(dataset, model, SGD(model.parameters()))
# 会报以下的错误
# input fields after batch(if batch size is 2):
# a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
# b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
# There is no target field.
# ....
# NameError:
# Problems occurred when calling Model.forward(self, x, b)
# missing param: ['x']
# unused field: ['a']
# Suggestion: You need to provide ['x'] in DataSet and set it as input.

这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里由两类
信息可以为你提供参考

1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
因为train dataset没有target所以没有显示。根据这里你可以看出是否正确将需要的内容设置为了input或target。

2. 如果出现了映射错误,出现NameError。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
出当前是在调取forward出错),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x'
,或者model的参数从'x'修改为'a'都可以解决问题。

下面的例子是由于loss计算的时候找不到需要的值

Example2::

import numpy as np
from torch import nn
from torch.optim import SGD
from fastNLP import Trainer
from fastNLP import DataSet
from fastNLP.core.losses import L1Loss
import torch

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, a):
return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1}

model = Model()

dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})

dataset.set_input('a')
dataset.set_target('b')

trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001))
# 报错信息如下
# input fields after batch(if batch size is 2):
# a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
# target fields after batch(if batch size is 2):
# b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
# ....
# NameError:
# Problems occurred when calling L1Loss.get_loss(self, pred, target)
# missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)']
# unused field: ['b']
# unused param: ['pred_b', 'No use']
# target field: ['b']
# param from Model.forward(self, a): ['pred_b', 'No use']
# Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a).
# (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a).

报错信息也包含两部分:

1. 第一部分与上面是一样的

2. 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);报错的原因是因为
`pred`和`label`(我们在初始化L1Loss时将target指定为了label)都没有找到。这里'unused field'是DataSet中出现了,但却没有
被设置为input或者target的field;'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了
target的field; 'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。

但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred, 将
DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。


下面是带有dev dataset时如果出现错误会发生的报错,

Example3::

import numpy as np
from torch import nn
from torch.optim import SGD
from fastNLP import Trainer
from fastNLP import DataSet
from fastNLP.core.metrics import AccuracyMetric
import torch

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(1, 1)
def forward(self, a, b):
loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2)
return {'loss': loss}
def predict(self, a): # 使用predict()进行验证
return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key
model = Model()

dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2})

dataset.set_input('a', 'b')
dev_data.set_input('a') # 这里没有设置target

trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001),
dev_data=dev_data, metrics=AccuracyMetric())

# 报错信息
# ...
# NameError:
# Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None)
# missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)']
# unused param: ['output']
# target field: []
# param from Model.predict(self, a): ['output']
# Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a).
# (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).

报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation的弄错的情况。这里的修改是通过在初始化metric的时候
指明通过'output'获取`pred`, 即AccuracyMetric(pred='output').

可以通过check_code_level调节检查的强度。默认为0,即进行检查。

3. Trainer与callback

虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的类,所有的 Callback_ 都具有
on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。

我们将Train.train()这个函数内部分为以下的阶段

Example::


callback.on_train_begin() # 开始进行训练
for i in range(1, n_epochs+1):
callback.on_epoch_begin() # 开始新的epoch
for batch_x, batch_y in Batch:
callback.on_batch_begin(batch_x, batch_y, indices) # batch_x是设置为input的field,batch_y是设置为target的field
获取模型输出
callback.on_loss_begin()
计算loss
callback.on_backward_begin() # 可以进行一些检查,比如loss是否为None
反向梯度回传
callback.on_backward_end() # 进行梯度截断等
进行参数更新
callback.on_step_end()
callback.on_batch_end()
# 根据设置进行evaluation,比如这是本epoch最后一个batch或者达到一定step
if do evaluation:
callback.on_valid_begin()
进行dev data上的验证
callback.on_valid_end() # 可以进行在其它数据集上进行验证
callback.on_epoch_end() # epoch结束调用
callback.on_train_end() # 训练结束
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里

fastNLP已经自带了很多callback函数供使用,可以参考 Callback_ 。一些关于callback的例子,请参考 #TODO callback的例子


""" """


@@ -123,7 +326,7 @@ from fastNLP.core.utils import _get_device




class Trainer(object): class Trainer(object):
def __init__(self, train_data, model, loss, optimizer,
def __init__(self, train_data, model, optimizer, loss=None,
batch_size=32, sampler=None, update_every=1, batch_size=32, sampler=None, update_every=1,
n_epochs=10, print_every=5, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
@@ -135,9 +338,9 @@ class Trainer(object):
:param DataSet train_data: 训练集 :param DataSet train_data: 训练集
:param nn.modules model: 待训练的模型 :param nn.modules model: 待训练的模型
:param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型, :param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型,
请确保已在callback中进行了更新
:param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。
请确保已在callback中进行了更新。
:param int batch_size: 训练和验证的时候的batch大小。 :param int batch_size: 训练和验证的时候的batch大小。
:param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。
:param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 :param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。


+ 22
- 23
fastNLP/core/utils.py View File

@@ -373,14 +373,13 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
import re import re
mapped_missing = []
unmapped_missing = []
mapped_missing = [] # 提供了映射的参数
unmapped_missing = [] # 没有指定映射的参数
input_func_map = {} input_func_map = {}
for _miss in check_res.missing:
if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0]
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
for _miss_ in check_res.missing:
# they shoudl like 'SomeParam(assign to xxx)'
_miss = _miss_.split('(')[0]
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss_)
if len(matches) == 2: if len(matches) == 2:
fun_arg, module_name = matches fun_arg, module_name = matches
input_func_map[_miss] = fun_arg input_func_map[_miss] = fun_arg
@@ -391,30 +390,30 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
else: else:
unmapped_missing.append(_miss) unmapped_missing.append(_miss)


for _miss in mapped_missing:
for _miss in mapped_missing + unmapped_missing:
if _miss in dataset: if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
suggestions.append(f"Set `{_miss}` as target.")
else: else:
_tmp = '' _tmp = ''
if check_res.unused: if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." _tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp: if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp)
for _miss in unmapped_missing:
if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
else:
_tmp = ''
if check_res.unused:
_tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.'
else: else:
_tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.'
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp) suggestions.append(_tmp)
# for _miss in unmapped_missing:
# if _miss in dataset:
# suggestions.append(f"Set `{_miss}` as target.")
# else:
# _tmp = ''
# if check_res.unused:
# _tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
# if _tmp:
# _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.'
# else:
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.'
# suggestions.append(_tmp)


if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")


Loading…
Cancel
Save