|
- #!/usr/bin/python
- # -*- coding: utf-8 -*-
-
- # __author__="Danqing Wang"
-
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- from __future__ import division
-
-
- import torch
- import torch.nn.functional as F
-
- from rouge import Rouge
-
- from fastNLP.core.const import Const
- from fastNLP.core.metrics import MetricBase
-
- # from tools.logger import *
- from fastNLP.core._logger import logger
- from tools.utils import pyrouge_score_all, pyrouge_score_all_multi
-
-
- class LossMetric(MetricBase):
- def __init__(self, pred=None, target=None, mask=None, padding_idx=-100, reduce='mean'):
- super().__init__()
-
- self._init_param_map(pred=pred, target=target, mask=mask)
- self.padding_idx = padding_idx
- self.reduce = reduce
- self.loss = 0.0
- self.iteration = 0
-
- def evaluate(self, pred, target, mask):
- """
-
- :param pred: [batch, N, 2]
- :param target: [batch, N]
- :param input_mask: [batch, N]
- :return:
- """
-
- batch, N, _ = pred.size()
- pred = pred.view(-1, 2)
- target = target.view(-1)
- loss = F.cross_entropy(input=pred, target=target,
- ignore_index=self.padding_idx, reduction=self.reduce)
- loss = loss.view(batch, -1)
- loss = loss.masked_fill(mask.eq(False), 0)
- loss = loss.sum(1).mean()
- self.loss += loss
- self.iteration += 1
-
- def get_metric(self, reset=True):
- epoch_avg_loss = self.loss / self.iteration
- if reset:
- self.loss = 0.0
- self.iteration = 0
- metric = {"loss": -epoch_avg_loss}
- logger.info(metric)
- return metric
-
-
-
-
- class LabelFMetric(MetricBase):
- def __init__(self, pred=None, target=None):
- super().__init__()
-
- self._init_param_map(pred=pred, target=target)
-
- self.match = 0.0
- self.pred = 0.0
- self.true = 0.0
- self.match_true = 0.0
- self.total = 0.0
-
-
- def evaluate(self, pred, target):
- """
-
- :param pred: [batch, N] int
- :param target: [batch, N] int
- :return:
- """
- target = target.data
- pred = pred.data
- # logger.debug(pred.size())
- # logger.debug(pred[:5,:])
- batch, N = pred.size()
- self.pred += pred.sum()
- self.true += target.sum()
- self.match += (pred == target).sum()
- self.match_true += ((pred == target) & (pred == 1)).sum()
- self.total += batch * N
-
- def get_metric(self, reset=True):
- self.match,self.pred, self.true, self.match_true, self.total = self.match.float(),self.pred.float(), self.true.float(), self.match_true.float(), self.total
- logger.debug((self.match,self.pred, self.true, self.match_true, self.total))
- try:
- accu = self.match / self.total
- precision = self.match_true / self.pred
- recall = self.match_true / self.true
- F = 2 * precision * recall / (precision + recall)
- except ZeroDivisionError:
- F = 0.0
- logger.error("[Error] float division by zero")
- if reset:
- self.pred, self.true, self.match_true, self.match, self.total = 0, 0, 0, 0, 0
- ret = {"accu": accu.cpu(), "p":precision.cpu(), "r":recall.cpu(), "f": F.cpu()}
- logger.info(ret)
- return ret
-
-
- class RougeMetric(MetricBase):
- def __init__(self, hps, pred=None, text=None, refer=None):
- super().__init__()
-
- self._hps = hps
- self._init_param_map(pred=pred, text=text, summary=refer)
-
- self.hyps = []
- self.refers = []
-
- def evaluate(self, pred, text, summary):
- """
-
- :param prediction: [batch, N]
- :param text: [batch, N]
- :param summary: [batch, N]
- :return:
- """
-
- batch_size, N = pred.size()
- for j in range(batch_size):
- original_article_sents = text[j]
- sent_max_number = len(original_article_sents)
- refer = "\n".join(summary[j])
- hyps = "\n".join(original_article_sents[id] for id in range(len(pred[j])) if
- pred[j][id] == 1 and id < sent_max_number)
- if sent_max_number < self._hps.m and len(hyps) <= 1:
- print("sent_max_number is too short %d, Skip!", sent_max_number)
- continue
-
- if len(hyps) >= 1 and hyps != '.':
- self.hyps.append(hyps)
- self.refers.append(refer)
- elif refer == "." or refer == "":
- logger.error("Refer is None!")
- logger.debug(refer)
- elif hyps == "." or hyps == "":
- logger.error("hyps is None!")
- logger.debug("sent_max_number:%d", sent_max_number)
- logger.debug("pred:")
- logger.debug(pred[j])
- logger.debug(hyps)
- else:
- logger.error("Do not select any sentences!")
- logger.debug("sent_max_number:%d", sent_max_number)
- logger.debug(original_article_sents)
- logger.debug(refer)
- continue
-
- def get_metric(self, reset=True):
- pass
-
- class FastRougeMetric(RougeMetric):
- def __init__(self, hps, pred=None, text=None, refer=None):
- super().__init__(hps, pred, text, refer)
-
- def get_metric(self, reset=True):
- logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
- if len(self.hyps) == 0 or len(self.refers) == 0 :
- logger.error("During testing, no hyps or refers is selected!")
- return
- rouge = Rouge()
- scores_all = rouge.get_scores(self.hyps, self.refers, avg=True)
- if reset:
- self.hyps = []
- self.refers = []
- logger.info(scores_all)
- return scores_all
-
-
- class PyRougeMetric(RougeMetric):
- def __init__(self, hps, pred=None, text=None, refer=None):
- super().__init__(hps, pred, text, refer)
-
- def get_metric(self, reset=True):
- logger.info("[INFO] Hyps and Refer number is %d, %d", len(self.hyps), len(self.refers))
- if len(self.hyps) == 0 or len(self.refers) == 0:
- logger.error("During testing, no hyps or refers is selected!")
- return
- if isinstance(self.refers[0], list):
- logger.info("Multi Reference summaries!")
- scores_all = pyrouge_score_all_multi(self.hyps, self.refers)
- else:
- scores_all = pyrouge_score_all(self.hyps, self.refers)
- if reset:
- self.hyps = []
- self.refers = []
- logger.info(scores_all)
- return scores_all
-
-
-
|