diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 4dab772c..64f9dfbe 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -119,6 +119,9 @@ class MetricBase(object): def evaluate(self, *args, **kwargs): raise NotImplementedError + def get_metric(self, reset=True): + raise NotImplemented + def _init_param_map(self, key_map=None, **kwargs): """检查key_map和其他参数map,并将这些映射关系添加到self.param_map @@ -161,8 +164,20 @@ class MetricBase(object): f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " f"initialization parameters, or change its signature.") - def get_metric(self, reset=True): - raise NotImplemented + def _fast_param_map(self, pred_dict, target_dict): + """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. + such as pred_dict has one element, target_dict has one element + + :param pred_dict: + :param target_dict: + :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. + """ + fast_param = {} + if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + fast_param['pred'] = list(pred_dict.values())[0] + fast_param['target'] = list(target_dict.values())[0] + return fast_param + return fast_param def __call__(self, pred_dict, target_dict): """ @@ -178,10 +193,15 @@ class MetricBase(object): :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) :return: """ - if not callable(self.evaluate): - raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") + + fast_param = self._fast_param_map(pred_dict, target_dict) + if fast_param: + self.evaluate(**fast_param) + return if not self._checked: + if not callable(self.evaluate): + raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = set([arg for arg in func_spect.args if arg != 'self']) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 7e07dd18..09075940 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -7,6 +7,7 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 对Batch进行pad; (4) 每个epoch结束或一定step后进行验证集验证; (5) 保存获得更好验证性能的模型等。 1. Trainer的基本使用 + 下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。 Example:: @@ -53,20 +54,23 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 trainer.train() 由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。 - 使用Trainer需要满足以下几个条件 + 使用Trainer需要满足以下几个条件: 1. 模型 1. 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是 通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该 改名为'data'。 + 2. 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递 给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。 + 3. 模型的forward()返回值需要为一个dict。 - 2. Loss与Metric - fastNLP中的为了不限制forward函数的返回内容数量,以及对多Metric等的支持等, Loss_ 与 Metric_ 都使用了通过名称来匹配相 - 应内容。如上面的例子中 + 2. Loss + + fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing, Loss_ 与 Metric_ 都使 + 用了通过名称来匹配相应内容的策略。如上面的例子中 Example:: @@ -74,17 +78,216 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在 optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000, dev_data = dev_data, metrics=AccuracyMetric(target='label')) - loss被设置为了 CrossEntropyLoss_ , 但在初始化的时候传入了一个target='label'这个参数, CrossEntropyLoss_ 的初始化 + loss被设置为了 CrossEntropyLoss_ , 但在初始化的时候传入了target='label'这个参数, CrossEntropyLoss_ 的初始化 参数为(pred=None, target=None, padding_idx=-100)。这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值 - 与ground truth的label。其中'pred'一般是模型forward()返回结果的内容,'target'一般是来自于DataSet中被设置为target的 - field。 + 与真实值。其中'pred'一般来自于模型forward()的返回结果,'target'一般是来自于DataSet中被设置为target的 + field。由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 Loss_ 提供一种类似于映射的机制来匹配 + 对应的值,比如这里 CrossEntropyLoss_ 将尝试找到名为'label'的内容来作为真实值得到loss;而pred=None, 则 CrossEntropyLoss_ + 使用'pred'作为名称匹配预测值,正好forward的返回值也叫pred,所以这里不需要申明pred。 + + 尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。fastNLP中提供了 + LossInForward_ 这个loss。这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor, + 并使用它作为loss。 如果Trainer初始化没有提供loss则使用这个loss TODO 补充一个例子 + + 3. Metric -2. Trainer与callback + Metric_ 使用了与上述Loss一样的策略,即使用名称进行匹配。AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。 + 在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法, + 如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,传入到predict()的参数也是从DataSet中的input的选择 + 出来的; 与forward()一样,返回值需要为一个dict。具体例子可以参考 TODO 补充一个例子 -3. Trainer的代码检查 +2. Trainer的代码检查 + 由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制 + 比如下面的例子中,由于各种原因产生的报错 + + Example1:: + + import numpy as np + from torch import nn + import torch + from torch.optim import SGD + from fastNLP import Trainer + from fastNLP import DataSet + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + def forward(self, x, b): + loss = torch.mean((self.fc(x)-b)**2) + return {'loss': loss} + model = Model() + + dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) + dataset.set_input('a', 'b') + + trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001)) + + trainer = Trainer(dataset, model, SGD(model.parameters())) + # 会报以下的错误 + # input fields after batch(if batch size is 2): + # a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) + # b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) + # There is no target field. + # .... + # NameError: + # Problems occurred when calling Model.forward(self, x, b) + # missing param: ['x'] + # unused field: ['a'] + # Suggestion: You need to provide ['x'] in DataSet and set it as input. + + 这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里由两类 + 信息可以为你提供参考 + + 1. 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里 + 因为train dataset没有target所以没有显示。根据这里你可以看出是否正确将需要的内容设置为了input或target。 + + 2. 如果出现了映射错误,出现NameError。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断 + 出当前是在调取forward出错),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能 + 就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x' + ,或者model的参数从'x'修改为'a'都可以解决问题。 + + 下面的例子是由于loss计算的时候找不到需要的值 + + Example2:: + + import numpy as np + from torch import nn + from torch.optim import SGD + from fastNLP import Trainer + from fastNLP import DataSet + from fastNLP.core.losses import L1Loss + import torch + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + def forward(self, a): + return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1} + + model = Model() + + dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2}) + + dataset.set_input('a') + dataset.set_target('b') + + trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001)) + # 报错信息如下 + # input fields after batch(if batch size is 2): + # a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2]) + # target fields after batch(if batch size is 2): + # b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2]) + # .... + # NameError: + # Problems occurred when calling L1Loss.get_loss(self, pred, target) + # missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)'] + # unused field: ['b'] + # unused param: ['pred_b', 'No use'] + # target field: ['b'] + # param from Model.forward(self, a): ['pred_b', 'No use'] + # Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a). + # (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a). + + 报错信息也包含两部分: + + 1. 第一部分与上面是一样的 + + 2. 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);报错的原因是因为 + `pred`和`label`(我们在初始化L1Loss时将target指定为了label)都没有找到。这里'unused field'是DataSet中出现了,但却没有 + 被设置为input或者target的field;'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了 + target的field; 'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。 + + 但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred, 将 + DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。 + + + 下面是带有dev dataset时如果出现错误会发生的报错, + + Example3:: + + import numpy as np + from torch import nn + from torch.optim import SGD + from fastNLP import Trainer + from fastNLP import DataSet + from fastNLP.core.metrics import AccuracyMetric + import torch + + class Model(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(1, 1) + def forward(self, a, b): + loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2) + return {'loss': loss} + def predict(self, a): # 使用predict()进行验证 + return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key + model = Model() + + dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2}) + dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2}) + + dataset.set_input('a', 'b') + dev_data.set_input('a') # 这里没有设置target + + trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001), + dev_data=dev_data, metrics=AccuracyMetric()) + + # 报错信息 + # ... + # NameError: + # Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None) + # missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)'] + # unused param: ['output'] + # target field: [] + # param from Model.predict(self, a): ['output'] + # Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a). + # (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a). + + 报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation + 的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation的弄错的情况。这里的修改是通过在初始化metric的时候 + 指明通过'output'获取`pred`, 即AccuracyMetric(pred='output'). + + 可以通过check_code_level调节检查的强度。默认为0,即进行检查。 + +3. Trainer与callback + + 虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。 + 为了解决这个问题fastNLP引入了callback的机制,Callback_ 是一种在Trainer训练过程中特定阶段会运行的类,所有的 Callback_ 都具有 + on_*(比如on_train_start, on_backward_begin)等函数。如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用。 + + 我们将Train.train()这个函数内部分为以下的阶段 + + Example:: + callback.on_train_begin() # 开始进行训练 + for i in range(1, n_epochs+1): + callback.on_epoch_begin() # 开始新的epoch + for batch_x, batch_y in Batch: + callback.on_batch_begin(batch_x, batch_y, indices) # batch_x是设置为input的field,batch_y是设置为target的field + 获取模型输出 + callback.on_loss_begin() + 计算loss + callback.on_backward_begin() # 可以进行一些检查,比如loss是否为None + 反向梯度回传 + callback.on_backward_end() # 进行梯度截断等 + 进行参数更新 + callback.on_step_end() + callback.on_batch_end() + # 根据设置进行evaluation,比如这是本epoch最后一个batch或者达到一定step + if do evaluation: + callback.on_valid_begin() + 进行dev data上的验证 + callback.on_valid_end() # 可以进行在其它数据集上进行验证 + callback.on_epoch_end() # epoch结束调用 + callback.on_train_end() # 训练结束 + callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里 + + fastNLP已经自带了很多callback函数供使用,可以参考 Callback_ 。一些关于callback的例子,请参考 #TODO callback的例子 """ @@ -123,7 +326,7 @@ from fastNLP.core.utils import _get_device class Trainer(object): - def __init__(self, train_data, model, loss, optimizer, + def __init__(self, train_data, model, optimizer, loss=None, batch_size=32, sampler=None, update_every=1, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, @@ -135,9 +338,9 @@ class Trainer(object): :param DataSet train_data: 训练集 :param nn.modules model: 待训练的模型 :param Optimizer,None optimizer: 优化器,pytorch的torch.optim.Optimizer类型。如果为None,则Trainer不会更新模型, - 请确保已在callback中进行了更新 - :param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。 + 请确保已在callback中进行了更新。 :param int batch_size: 训练和验证的时候的batch大小。 + :param LossBase loss: 使用的Loss对象。 详见 LossBase_ 。当loss为None时,默认使用 LossInForward_ 。 :param Sampler sampler: Batch数据生成的顺序。详见 Sampler_ 。如果为None,默认使用 RandomSampler_ 。 :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index c7a6fdd8..af1a7db6 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -373,14 +373,13 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") import re - mapped_missing = [] - unmapped_missing = [] + mapped_missing = [] # 提供了映射的参数 + unmapped_missing = [] # 没有指定映射的参数 input_func_map = {} - for _miss in check_res.missing: - if '(' in _miss: - # if they are like 'SomeParam(assign to xxx)' - _miss = _miss.split('(')[0] - matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss) + for _miss_ in check_res.missing: + # they shoudl like 'SomeParam(assign to xxx)' + _miss = _miss_.split('(')[0] + matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss_) if len(matches) == 2: fun_arg, module_name = matches input_func_map[_miss] = fun_arg @@ -391,30 +390,30 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re else: unmapped_missing.append(_miss) - for _miss in mapped_missing: + for _miss in mapped_missing + unmapped_missing: if _miss in dataset: - suggestions.append(f"Set {_miss} as target.") + suggestions.append(f"Set `{_miss}` as target.") else: _tmp = '' if check_res.unused: _tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." if _tmp: - _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' - else: - _tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.' - suggestions.append(_tmp) - for _miss in unmapped_missing: - if _miss in dataset: - suggestions.append(f"Set {_miss} as target.") - else: - _tmp = '' - if check_res.unused: - _tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." - if _tmp: - _tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.' + _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' else: - _tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.' + _tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' suggestions.append(_tmp) + # for _miss in unmapped_missing: + # if _miss in dataset: + # suggestions.append(f"Set `{_miss}` as target.") + # else: + # _tmp = '' + # if check_res.unused: + # _tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." + # if _tmp: + # _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' + # else: + # _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' + # suggestions.append(_tmp) if check_res.duplicated: errs.append(f"\tduplicated param: {check_res.duplicated}.")