Browse Source

增加metric注释;修改trainer save过程中的bug

tags/v0.4.10
yh 5 years ago
parent
commit
8d4f26bbd9
2 changed files with 70 additions and 6 deletions
  1. +68
    -4
      fastNLP/core/metrics.py
  2. +2
    -2
      fastNLP/core/trainer.py

+ 68
- 4
fastNLP/core/metrics.py View File

@@ -16,6 +16,69 @@ from fastNLP.core.vocabulary import Vocabulary
class MetricBase(object): class MetricBase(object):
"""Base class for all metrics. """Base class for all metrics.


所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。
evaluate(xxx)中传入的是一个batch的数据。
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值
以分类问题中,Accuracy计算为例
假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy
class Model(nn.Module):
def __init__(xxx):
# do something
def forward(self, xxx):
# do something
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target
对应的AccMetric可以按如下的定义
# version1, 只使用这一次
class AccMetric(MetricBase):
def __init__(self):
super().__init__()

# 根据你的情况自定义指标
self.corr_num = 0
self.total = 0

def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
self.total += label.size(0)
self.corr_num += label.eq(pred).sum().item()

def get_metric(self, reset=True): # 在这里定义如何计算metric
acc = self.corr_num/self.total
if reset: # 是否清零以便重新计算
self.corr_num = 0
self.total = 0
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中


# version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred
class AccMetric(MetricBase):
def __init__(self, label=None, pred=None):
# 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时,
# acc_metric = AccMetric(label='y', pred='pred_y')即可。
# 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对
# 应的的值
super().__init__()
self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可
# 如果没有注册该则效果与version1就是一样的

# 根据你的情况自定义指标
self.corr_num = 0
self.total = 0

def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。
# dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric
self.total += label.size(0)
self.corr_num += label.eq(pred).sum().item()

def get_metric(self, reset=True): # 在这里定义如何计算metric
acc = self.corr_num/self.total
if reset: # 是否清零以便重新计算
self.corr_num = 0
self.total = 0
return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中


``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``.
``pred_dict`` is the output of ``forward()`` or prediction function of a model. ``pred_dict`` is the output of ``forward()`` or prediction function of a model.
``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``.
@@ -24,7 +87,6 @@ class MetricBase(object):
1. whether self.evaluate has varargs, which is not supported. 1. whether self.evaluate has varargs, which is not supported.
2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``.
3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``.
4. whether params in ``pred_dict``, ``target_dict`` are not used by evaluate.(Might cause warning)


Besides, before passing params into self.evaluate, this function will filter out params from output_dict and Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
@@ -297,7 +359,7 @@ class AccuracyMetric(MetricBase):
def bmes_tag_to_spans(tags, ignore_labels=None): def bmes_tag_to_spans(tags, ignore_labels=None):
""" """
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。
返回[('', (0, 1)), ('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4)), ('', (4, 5)), ('', (5, 6))]
返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间)


:param tags: List[str], :param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略 :param ignore_labels: List[str], 在该list中的label将被忽略
@@ -325,7 +387,7 @@ def bmes_tag_to_spans(tags, ignore_labels=None):
def bmeso_tag_to_spans(tags, ignore_labels=None): def bmeso_tag_to_spans(tags, ignore_labels=None):
""" """
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。
返回[('singer', (1, 2)), ('singer', (2, 3)), ('singer', (3, 4))]
返回[('singer', (1, 4))] (左闭右开区间)


:param tags: List[str], :param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略 :param ignore_labels: List[str], 在该list中的label将被忽略
@@ -355,7 +417,7 @@ def bmeso_tag_to_spans(tags, ignore_labels=None):
def bio_tag_to_spans(tags, ignore_labels=None): def bio_tag_to_spans(tags, ignore_labels=None):
""" """
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。
返回[('singer', (1, 4))] (特别注意这是左闭右开区间)
返回[('singer', (1, 4))] (左闭右开区间)


:param tags: List[str], :param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略 :param ignore_labels: List[str], 在该list中的label将被忽略
@@ -386,6 +448,8 @@ def bio_tag_to_spans(tags, ignore_labels=None):
class SpanFPreRecMetric(MetricBase): class SpanFPreRecMetric(MetricBase):
""" """
在序列标注问题中,以span的方式计算F, pre, rec. 在序列标注问题中,以span的方式计算F, pre, rec.
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例)
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
最后得到的metric结果为 最后得到的metric结果为
{ {
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 'f': xxx, # 这里使用f考虑以后可以计算f_beta值


+ 2
- 2
fastNLP/core/trainer.py View File

@@ -202,7 +202,7 @@ class Trainer(object):
except (CallbackException, KeyboardInterrupt) as e: except (CallbackException, KeyboardInterrupt) as e:
self.callback_manager.on_exception(e, self.model) self.callback_manager.on_exception(e, self.model)


if self.dev_data is not None:
if self.dev_data is not None and hasattr(self, 'best_dev_perf'):
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf),) self.tester._format_eval_results(self.best_dev_perf),)
results['best_eval'] = self.best_dev_perf results['best_eval'] = self.best_dev_perf
@@ -367,7 +367,7 @@ class Trainer(object):
else: else:
model.cpu() model.cpu()
torch.save(model, model_path) torch.save(model, model_path)
model.cuda()
model.to(self._model_device)


def _load_model(self, model, model_name, only_param=False): def _load_model(self, model, model_name, only_param=False):
# 返回bool值指示是否成功reload模型 # 返回bool值指示是否成功reload模型


Loading…
Cancel
Save