|
|
@@ -2,7 +2,7 @@ |
|
|
|
|
|
|
|
from typing import Dict, Iterable, List |
|
|
|
|
|
|
|
from nltk.translate.bleu_score import sentence_bleu |
|
|
|
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu |
|
|
|
from rouge import Rouge |
|
|
|
|
|
|
|
from modelscope.metainfo import Metrics |
|
|
@@ -63,14 +63,18 @@ class TextGenerationMetric(Metric): |
|
|
|
rouge_scores = self.rouge.get_scores(hyps=preds, refs=tgts) |
|
|
|
rouge_1 = mean(map(lambda score: score['rouge-1']['f'], rouge_scores)) |
|
|
|
rouge_l = mean(map(lambda score: score['rouge-l']['f'], rouge_scores)) |
|
|
|
pred_split = tuple(pred.split(' ') for pred in self.preds) |
|
|
|
tgt_split = tuple(tgt.split(' ') for tgt in self.tgts) |
|
|
|
bleu_1 = mean( |
|
|
|
sentence_bleu([tgt], pred, weights=(1, 0, 0, 0)) |
|
|
|
for pred, tgt in zip(pred_split, tgt_split)) |
|
|
|
bleu_4 = mean( |
|
|
|
sentence_bleu([tgt], pred) |
|
|
|
for pred, tgt in zip(pred_split, tgt_split)) |
|
|
|
|
|
|
|
pred_list = [each.strip().split(' ') for each in self.preds] |
|
|
|
tgt_list = [[each.strip().split(' ')] for each in self.tgts] |
|
|
|
bleu_1 = corpus_bleu( |
|
|
|
tgt_list, |
|
|
|
pred_list, |
|
|
|
weights=(1, 0, 0, 0), |
|
|
|
smoothing_function=SmoothingFunction().method3) |
|
|
|
bleu_4 = corpus_bleu( |
|
|
|
tgt_list, |
|
|
|
pred_list, |
|
|
|
smoothing_function=SmoothingFunction().method3) |
|
|
|
return { |
|
|
|
MetricKeys.ROUGE_1: rouge_1, |
|
|
|
MetricKeys.ROUGE_L: rouge_l, |
|
|
|