diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 76265a01..3cb6aa88 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -71,7 +71,7 @@ __all__ = [ 'logger' ] -__version__ = '0.4.5' +__version__ = '0.5.0' import sys diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 95a3331f..146b532f 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -7,7 +7,8 @@ __all__ = [ "AccuracyMetric", "SpanFPreRecMetric", "CMRC2018Metric", - "ClassifyFPreRecMetric" + "ClassifyFPreRecMetric", + "ConfusionMatrixMetric" ] import inspect @@ -15,6 +16,7 @@ import warnings from abc import abstractmethod from collections import defaultdict from typing import Union +from copy import deepcopy import re import numpy as np @@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary +from .utils import ConfusionMatrix class MetricBase(object): @@ -276,6 +279,95 @@ class MetricBase(object): return +class ConfusionMatrixMetric(MetricBase): + r""" + 分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) + + 最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} + ConfusionMatrix实例的print()函数将输出矩阵字符串。 + + pred_dict = {"pred": torch.Tensor([2,1,3])} + target_dict = {'target': torch.Tensor([2,2,1])} + metric = ConfusionMatrixMetric() + metric(pred_dict=pred_dict, target_dict=target_dict, ) + print(metric.get_metric()) + + {'confusion_matrix': + target 1.0 2.0 3.0 all + pred + 1.0 0 1 0 1 + 2.0 0 1 0 1 + 3.0 1 0 0 1 + all 1 2 0 3} + """ + def __init__(self, vocab=None, pred=None, target=None, seq_len=None): + """ + :param vocab: vocab词表类,要求有to_word()方法。 + :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` + :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` + :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` + """ + super().__init__() + self._init_param_map(pred=pred, target=target, seq_len=seq_len) + self.confusion_matrix = ConfusionMatrix(vocab=vocab) + + def evaluate(self, pred, target, seq_len=None): + """ + evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) + :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) + :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]). + + """ + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") + if not isinstance(target, torch.Tensor): + raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") + + if seq_len is not None and not isinstance(seq_len, torch.Tensor): + raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_len)}.") + + if pred.dim() == target.dim(): + pass + elif pred.dim() == target.dim() + 1: + pred = pred.argmax(dim=-1) + if seq_len is None and target.dim() > 1: + warnings.warn("You are not passing `seq_len` to exclude pad.") + else: + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") + + target = target.to(pred) + if seq_len is not None and target.dim() > 1: + for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()): + l=int(l) + self.confusion_matrix.add_pred_target(p[:l], t[:l]) + elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 + for p, t in zip(pred.tolist(), target.tolist()): + self.confusion_matrix.add_pred_target(p, t) + else: + self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist()) + + def get_metric(self,reset=True): + """ + get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. + + :param bool reset: 在调用完get_metric后是否清空评价指标统计量. + :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} + """ + confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)} + if reset: + self.confusion_matrix.clear() + return confusion + + class AccuracyMetric(MetricBase): """ 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index ba9ec850..b1d5f4e2 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -8,18 +8,22 @@ __all__ = [ "get_seq_len" ] -import _pickle import inspect import os import warnings from collections import Counter, namedtuple +from copy import deepcopy +from typing import List + +import _pickle import numpy as np import torch import torch.nn as nn -from typing import List -from ._logger import logger from prettytable import PrettyTable + +from ._logger import logger from ._parallel_utils import _model_contains_inner_module +# from .vocabulary import Vocabulary try: from apex import amp @@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require 'varargs']) + + +class ConfusionMatrix: + """a dict can provide Confusion Matrix""" + def __init__(self, vocab=None): + """ + :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 + """ + if vocab and not hasattr(vocab, 'to_word'): + raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," + f"got {type(vocab)}.") + self.confusiondict={} #key: pred index, value:target word ocunt + self.predcount={} #key:pred index, value:count + self.targetcount={} #key:target index, value:count + self.vocab=vocab + + def add_pred_target(self, pred, target): #一组结果 + """ + 通过这个函数向ConfusionMatrix加入一组预测结果 + + :param list pred: 预测的标签列表 + :param list target: 真实值的标签列表 + :return ConfusionMatrix + + confusion=ConfusionMatrix() + pred = [2,1,3] + target = [2,2,1] + confusion.add_pred_target(pred, target) + print(confusion) + + target 1 2 3 all + pred + 1 0 1 0 1 + 2 0 1 0 1 + 3 1 0 0 1 + all 1 2 0 3 + """ + for p,t in zip(pred,target): # + self.predcount[p]=self.predcount.get(p,0)+ 1 + self.targetcount[t]=self.targetcount.get(t,0)+1 + if p in self.confusiondict: + self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1 + else: + self.confusiondict[p]={} + self.confusiondict[p][t]= 1 + return self.confusiondict + + def clear(self): + """ + 清除一些值,等待再次新加入 + :return: + """ + self.confusiondict={} + self.targetcount={} + self.predcount={} + + def __repr__(self): + """ + :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 + """ + row2idx={} + idx2row={} + # 已知的所有键/label + totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys())))) + lenth=len(totallabel) + # namedict key :idx value:word/idx + namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel]) + + for label,idx in zip(totallabel,range(lenth)): + idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... + row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... + # 这里打印东西 + #表头 + head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"] + output="\t".join(head) + "\n" + "pred" + "\n" + #内容 + for i in row2idx.keys(): #第i行 + p=row2idx[i] + h=namedict[p] + l=[0 for _ in range(lenth)] + if self.confusiondict.get(p,None): + for t,c in self.confusiondict[p].items(): + l[idx2row[t]] = c #完成一行 + l=[h]+[str(n) for n in l]+[str(sum(l))] + output+="\t".join(l) +"\n" + #表尾 + tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()] + tail=["all"]+[str(n) for n in tail]+[str(sum(tail))] + output+="\t".join(tail) + return output + + class Option(dict): """a dict can treat keys as attributes""" diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 377597ea..d4fba930 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -18,7 +18,9 @@ __all__ = [ 'Loader', - 'YelpLoader', + 'CLSBaseLoader', + 'AGsNewsLoader', + 'DBPediaLoader', 'YelpFullLoader', 'YelpPolarityLoader', 'IMDBLoader', @@ -55,6 +57,9 @@ __all__ = [ "Pipe", + "CLSBasePipe", + "AGsNewsPipe", + "DBPediaPipe", "YelpFullPipe", "YelpPolarityPipe", "SSTPipe", @@ -73,19 +78,6 @@ __all__ = [ "CWSPipe", - "Pipe", - - "CWSPipe", - - "YelpFullPipe", - "YelpPolarityPipe", - "SSTPipe", - "SST2Pipe", - "IMDBPipe", - "ChnSentiCorpPipe", - "THUCNewsPipe", - "WeiboSenti100kPipe", - "Conll2003NERPipe", "OntoNotesNERPipe", "MsraNERPipe", diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index c50ce383..ef537f63 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -47,9 +47,11 @@ fastNLP 目前提供了如下的 Loader __all__ = [ 'Loader', - 'YelpLoader', + 'CLSBaseLoader', 'YelpFullLoader', 'YelpPolarityLoader', + 'AGsNewsLoader', + 'DBPediaLoader', 'IMDBLoader', 'SSTLoader', 'SST2Loader', @@ -84,7 +86,8 @@ __all__ = [ "CMRC2018Loader" ] -from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, \ +from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \ + SSTLoader, SST2Loader, DBPediaLoader, \ ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 12b10541..94fc993d 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -1,9 +1,11 @@ """undocumented""" __all__ = [ - "YelpLoader", + "CLSBaseLoader", "YelpFullLoader", "YelpPolarityLoader", + "AGsNewsLoader", + "DBPediaLoader", "IMDBLoader", "SSTLoader", "SST2Loader", @@ -12,6 +14,7 @@ __all__ = [ "WeiboSenti100kLoader" ] + import glob import os import random @@ -22,14 +25,17 @@ import warnings from .loader import Loader from ...core.dataset import DataSet from ...core.instance import Instance +from ...core._logger import logger -class YelpLoader(Loader): +class CLSBaseLoader(Loader): """ + 文本分类Loader的一个基类 + 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 Example:: - + "1","I got 'new' tires from the..." "1","Don't waste your time..." @@ -43,125 +49,112 @@ class YelpLoader(Loader): "...", "..." """ - - def __init__(self): - super(YelpLoader, self).__init__() - - def _load(self, path: str = None): + + def __init__(self, sep=',', has_header=False): + super().__init__() + self.sep = sep + self.has_header = has_header + + def _load(self, path: str): ds = DataSet() - with open(path, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - sep_index = line.index(',') - target = line[:sep_index] - raw_words = line[sep_index + 1:] - if target.startswith("\""): - target = target[1:] - if target.endswith("\""): - target = target[:-1] - if raw_words.endswith("\""): - raw_words = raw_words[:-1] - if raw_words.startswith('"'): - raw_words = raw_words[1:] - raw_words = raw_words.replace('""', '"') # 替换双引号 - if raw_words: - ds.append(Instance(raw_words=raw_words, target=target)) + try: + with open(path, 'r', encoding='utf-8') as f: + read_header = self.has_header + for line in f: + if read_header: + read_header = False + continue + line = line.strip() + sep_index = line.index(self.sep) + target = line[:sep_index] + raw_words = line[sep_index + 1:] + if target.startswith("\""): + target = target[1:] + if target.endswith("\""): + target = target[:-1] + if raw_words.endswith("\""): + raw_words = raw_words[:-1] + if raw_words.startswith('"'): + raw_words = raw_words[1:] + raw_words = raw_words.replace('""', '"') # 替换双引号 + if raw_words: + ds.append(Instance(raw_words=raw_words, target=target)) + except Exception as e: + logger.error(f'Load file `{path}` failed for `{e}`') return ds -class YelpFullLoader(YelpLoader): - def download(self, dev_ratio: float = 0.1, re_download: bool = False): +def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix='csv'): + if dev_ratio == 0.0: + return data_dir + modify_time = 0 + for filepath in glob.glob(os.path.join(data_dir, '*')): + modify_time = os.stat(filepath).st_mtime + break + if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 + shutil.rmtree(data_dir) + data_dir = Loader()._get_dataset_path(dataset_name=dataset_name) + + if not os.path.exists(os.path.join(data_dir, f'dev.{suffix}')): + if dev_ratio > 0: + assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." + try: + with open(os.path.join(data_dir, f'train.{suffix}'), 'r', encoding='utf-8') as f, \ + open(os.path.join(data_dir, f'middle_file.{suffix}'), 'w', encoding='utf-8') as f1, \ + open(os.path.join(data_dir, f'dev.{suffix}'), 'w', encoding='utf-8') as f2: + for line in f: + if random.random() < dev_ratio: + f2.write(line) + else: + f1.write(line) + os.remove(os.path.join(data_dir, f'train.{suffix}')) + os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), os.path.join(data_dir, f'train.{suffix}')) + finally: + if os.path.exists(os.path.join(data_dir, f'middle_file.{suffix}')): + os.remove(os.path.join(data_dir, f'middle_file.{suffix}')) + + return data_dir + + +class AGsNewsLoader(CLSBaseLoader): + def download(self): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015) - 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv, - dev.csv三个文件。 - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 :return: str, 数据集的目录地址 """ - - dataset_name = 'yelp-review-full' - data_dir = self._get_dataset_path(dataset_name=dataset_name) - modify_time = 0 - for filepath in glob.glob(os.path.join(data_dir, '*')): - modify_time = os.stat(filepath).st_mtime - break - if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 - shutil.rmtree(data_dir) - data_dir = self._get_dataset_path(dataset_name=dataset_name) - - if not os.path.exists(os.path.join(data_dir, 'dev.csv')): - if dev_ratio > 0: - assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." - try: - with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ - open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ - open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: - for line in f: - if random.random() < dev_ratio: - f2.write(line) - else: - f1.write(line) - os.remove(os.path.join(data_dir, 'train.csv')) - os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) - finally: - if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): - os.remove(os.path.join(data_dir, 'middle_file.csv')) - - return data_dir + return self._get_dataset_path(dataset_name='ag-news') -class YelpPolarityLoader(YelpLoader): - def download(self, dev_ratio: float = 0.1, re_download=False): +class DBPediaLoader(CLSBaseLoader): + def download(self, dev_ratio: float = 0.0, re_download: bool = False): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015) - 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分dev_ratio这么多作为dev + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv - :param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据。 如果为0,则不划分dev。 + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 :param bool re_download: 是否重新下载数据,以重新切分数据。 :return: str, 数据集的目录地址 """ - dataset_name = 'yelp-review-polarity' + dataset_name = 'dbpedia' data_dir = self._get_dataset_path(dataset_name=dataset_name) - modify_time = 0 - for filepath in glob.glob(os.path.join(data_dir, '*')): - modify_time = os.stat(filepath).st_mtime - break - if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 - shutil.rmtree(data_dir) - data_dir = self._get_dataset_path(dataset_name=dataset_name) - - if not os.path.exists(os.path.join(data_dir, 'dev.csv')): - if dev_ratio > 0: - assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." - try: - with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \ - open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \ - open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2: - for line in f: - if random.random() < dev_ratio: - f2.write(line) - else: - f1.write(line) - os.remove(os.path.join(data_dir, 'train.csv')) - os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv')) - finally: - if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): - os.remove(os.path.join(data_dir, 'middle_file.csv')) - + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') return data_dir -class IMDBLoader(Loader): +class IMDBLoader(CLSBaseLoader): """ 原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 @@ -181,35 +174,16 @@ class IMDBLoader(Loader): "...", "..." """ - def __init__(self): - super(IMDBLoader, self).__init__() - - def _load(self, path: str): - dataset = DataSet() - with open(path, 'r', encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - parts = line.split('\t') - target = parts[0] - words = parts[1] - if words: - dataset.append(Instance(raw_words=words, target=target)) - - if len(dataset) == 0: - raise RuntimeError(f"{path} has no valid data.") - - return dataset - - def download(self, dev_ratio: float = 0.1, re_download=False): + super().__init__(sep='\t') + + def download(self, dev_ratio: float = 0.0, re_download=False): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 http://www.aclweb.org/anthology/P11-1015 - 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev + 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后不从train中切分dev :param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev :param bool re_download: 是否重新下载数据,以重新切分数据。 @@ -217,32 +191,11 @@ class IMDBLoader(Loader): """ dataset_name = 'aclImdb' data_dir = self._get_dataset_path(dataset_name=dataset_name) - modify_time = 0 - for filepath in glob.glob(os.path.join(data_dir, '*')): - modify_time = os.stat(filepath).st_mtime - break - if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 - shutil.rmtree(data_dir) - data_dir = self._get_dataset_path(dataset_name=dataset_name) - - if not os.path.exists(os.path.join(data_dir, 'dev.txt')): - if dev_ratio > 0: - assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." - try: - with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ - open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ - open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: - for line in f: - if random.random() < dev_ratio: - f2.write(line) - else: - f1.write(line) - os.remove(os.path.join(data_dir, 'train.txt')) - os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) - finally: - if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): - os.remove(os.path.join(data_dir, 'middle_file.txt')) - + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='txt') return data_dir @@ -267,10 +220,10 @@ class SSTLoader(Loader): raw_words列是str。 """ - + def __init__(self): super().__init__() - + def _load(self, path: str): """ 从path读取SST文件 @@ -285,7 +238,7 @@ class SSTLoader(Loader): if line: ds.append(Instance(raw_words=line)) return ds - + def download(self): """ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 @@ -298,6 +251,56 @@ class SSTLoader(Loader): return output_dir +class YelpFullLoader(CLSBaseLoader): + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = 'yelp-review-full' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + +class YelpPolarityLoader(CLSBaseLoader): + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 + 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + + :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 + :param bool re_download: 是否重新下载数据,以重新切分数据。 + :return: str, 数据集的目录地址 + """ + dataset_name = 'yelp-review-polarity' + data_dir = self._get_dataset_path(dataset_name=dataset_name) + data_dir = _split_dev(dataset_name=dataset_name, + data_dir=data_dir, + dev_ratio=dev_ratio, + re_download=re_download, + suffix='csv') + return data_dir + + class SST2Loader(Loader): """ 原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label @@ -319,19 +322,18 @@ class SST2Loader(Loader): test的DataSet没有target列。 """ - + def __init__(self): super().__init__() - + def _load(self, path: str): - """ - 从path读取SST2文件 + """从path读取SST2文件 :param str path: 数据路径 :return: DataSet """ ds = DataSet() - + with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if 'test' in os.path.split(path)[1]: @@ -341,8 +343,9 @@ class SST2Loader(Loader): if line: sep_index = line.index('\t') raw_words = line[sep_index + 1:] + index = int(line[: sep_index]) if raw_words: - ds.append(Instance(raw_words=raw_words)) + ds.append(Instance(raw_words=raw_words, index=index)) else: for line in f: line = line.strip() @@ -352,13 +355,11 @@ class SST2Loader(Loader): if raw_words: ds.append(Instance(raw_words=raw_words, target=target)) return ds - + def download(self): """ 自动下载数据集,如果你使用了该数据集,请引用以下的文章 - https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf - :return: """ output_dir = self._get_dataset_path(dataset_name='sst-2') @@ -389,7 +390,7 @@ class ChnSentiCorpLoader(Loader): def __init__(self): super().__init__() - def _load(self, path:str): + def _load(self, path: str): """ 从path中读取数据 @@ -404,7 +405,7 @@ class ChnSentiCorpLoader(Loader): tab_index = line.index('\t') if tab_index != -1: target = line[:tab_index] - raw_chars = line[tab_index+1:] + raw_chars = line[tab_index + 1:] if raw_chars: ds.append(Instance(raw_chars=raw_chars, target=target)) return ds @@ -432,10 +433,10 @@ class THUCNewsLoader(Loader): 读取后的Dataset将具有以下数据结构: .. csv-table:: - :header: "raw_words", "target" - - "调查-您如何评价热火客场胜绿军总分3-1夺赛点?...", "体育" - "...", "..." + :header: "raw_words", "target" + + "调查-您如何评价热火客场胜绿军总分3-1夺赛点?...", "体育" + "...", "..." """ @@ -481,7 +482,7 @@ class WeiboSenti100kLoader(Loader): .. csv-table:: :header: "raw_chars", "target" - + "多谢小莲,好运满满[爱你]", "1" "能在他乡遇老友真不赖,哈哈,珠儿,我也要用...", "1" "...", "..." diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 9c4c90d9..cf5d8130 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -56,15 +56,16 @@ class MNLILoader(Loader): with open(path, 'r', encoding='utf-8') as f: f.readline() # 跳过header if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): - warnings.warn("RTE's test file has no target.") + warnings.warn("MNLI's test file has no target.") for line in f: line = line.strip() if line: parts = line.split('\t') raw_words1 = parts[8] raw_words2 = parts[9] + idx = int(parts[0]) if raw_words1 and raw_words2: - ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, index=idx)) else: for line in f: line = line.strip() @@ -73,8 +74,9 @@ class MNLILoader(Loader): raw_words1 = parts[8] raw_words2 = parts[9] target = parts[-1] + idx = int(parts[0]) if raw_words1 and raw_words2 and target: - ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) + ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target, index=idx)) return ds def load(self, paths: str = None): diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index aa2a59ca..450fdfcb 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -12,6 +12,9 @@ __all__ = [ "CWSPipe", + "CLSBasePipe", + "AGsNewsPipe", + "DBPediaPipe", "YelpFullPipe", "YelpPolarityPipe", "SSTPipe", @@ -55,8 +58,8 @@ __all__ = [ "CMRC2018BertPipe" ] -from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ - WeiboSenti100kPipe +from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ + WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe from .conll import Conll2003Pipe from .coreference import CoReferencePipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index ab31c9de..d254dccf 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -1,6 +1,9 @@ """undocumented""" __all__ = [ + "CLSBasePipe", + "AGsNewsPipe", + "DBPediaPipe", "YelpFullPipe", "YelpPolarityPipe", "SSTPipe", @@ -17,29 +20,24 @@ import warnings from nltk import Tree from .pipe import Pipe -from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field +from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize from ..data_bundle import DataBundle from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader -from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader +from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \ + AGsNewsLoader, DBPediaLoader from ...core._logger import logger from ...core.const import Const from ...core.dataset import DataSet from ...core.instance import Instance -from ...core.vocabulary import Vocabulary -nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') +class CLSBasePipe(Pipe): -class _CLSPipe(Pipe): - """ - 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 - - """ - - def __init__(self, tokenizer: str = 'spacy', lang='en'): - + def __init__(self, lower: bool=False, tokenizer: str='spacy', lang='en'): + super().__init__() + self.lower = lower self.tokenizer = get_tokenizer(tokenizer, lang=lang) - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): """ 将DataBundle中的数据进行tokenize @@ -52,47 +50,49 @@ class _CLSPipe(Pipe): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) - + return data_bundle - - def _granularize(self, data_bundle, tag_map): + + def process(self, data_bundle: DataBundle): """ - 该函数对data_bundle中'target'列中的内容进行转换。 + 传入的DataSet应该具备如下的结构 + + .. csv-table:: + :header: "raw_words", "target" + + "I got 'new' tires from them and... ", "1" + "Don't waste your time. We had two...", "1" + "...", "..." :param data_bundle: - :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, - 且将"1"认为是第0类。 - :return: 传入的data_bundle + :return: """ - for name in list(data_bundle.datasets.keys()): - dataset = data_bundle.get_dataset(name) - dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET, - new_field_name=Const.TARGET) - dataset.drop(lambda ins: ins[Const.TARGET] == -100) - data_bundle.set_dataset(dataset, name) + # 复制一列words + data_bundle = _add_words_field(data_bundle, lower=self.lower) + # 进行tokenize + data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) + # 建立词表并index + data_bundle = _indexize(data_bundle=data_bundle) + + for name, dataset in data_bundle.datasets.items(): + dataset.add_seq_len(Const.INPUT) + + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) + data_bundle.set_target(Const.TARGET) + return data_bundle + def process_from_file(self, paths) -> DataBundle: + """ + 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` -def _clean_str(words): - """ - heavily borrowed from github - https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb - :param sentence: is a str - :return: - """ - words_collection = [] - for word in words: - if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']: - continue - tt = nonalpnum.split(word) - t = ''.join(tt) - if t != '': - words_collection.append(t) - - return words_collection + :param paths: + :return: DataBundle + """ + raise NotImplementedError -class YelpFullPipe(_CLSPipe): +class YelpFullPipe(CLSBasePipe): """ 处理YelpFull的数据, 处理之后DataSet中的内容如下 @@ -124,32 +124,16 @@ class YelpFullPipe(_CLSPipe): 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(tokenizer=tokenizer, lang='en') - self.lower = lower + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') assert granularity in (2, 3, 5), "granularity can only be 2,3,5." self.granularity = granularity if granularity == 2: - self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} + self.tag_map = {"1": "negative", "2": "negative", "4": "positive", "5": "positive"} elif granularity == 3: - self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2} + self.tag_map = {"1": "negative", "2": "negative", "3": "medium", "4": "positive", "5": "positive"} else: - self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} - - def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): - """ - 将DataBundle中的数据进行tokenize - - :param DataBundle data_bundle: - :param str field_name: - :param str new_field_name: - :return: 传入的DataBundle对象 - """ - new_field_name = new_field_name or field_name - for name, dataset in data_bundle.datasets.items(): - dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) - dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) - return data_bundle + self.tag_map = None def process(self, data_bundle): """ @@ -165,27 +149,10 @@ class YelpFullPipe(_CLSPipe): :param data_bundle: :return: """ - - # 复制一列words - data_bundle = _add_words_field(data_bundle, lower=self.lower) - - # 进行tokenize - data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) - - # 根据granularity设置tag - data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) - - # 删除空行 - data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) - - # index - data_bundle = _indexize(data_bundle=data_bundle) - - for name, dataset in data_bundle.datasets.items(): - dataset.add_seq_len(Const.INPUT) - - data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) - data_bundle.set_target(Const.TARGET) + if self.tag_map is not None: + data_bundle = _granularize(data_bundle, self.tag_map) + + data_bundle = super().process(data_bundle) return data_bundle @@ -199,7 +166,7 @@ class YelpFullPipe(_CLSPipe): return self.process(data_bundle=data_bundle) -class YelpPolarityPipe(_CLSPipe): +class YelpPolarityPipe(CLSBasePipe): """ 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 @@ -229,50 +196,101 @@ class YelpPolarityPipe(_CLSPipe): :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(tokenizer=tokenizer, lang='en') - self.lower = lower + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') - def process(self, data_bundle): + def process_from_file(self, paths=None): """ - 传入的DataSet应该具备如下的结构 - .. csv-table:: - :header: "raw_words", "target" + :param str paths: + :return: DataBundle + """ + data_bundle = YelpPolarityLoader().load(paths) + return self.process(data_bundle=data_bundle) - "I got 'new' tires from them and... ", "1" - "Don't waste your time. We had two...", "1" - "...", "..." - :param data_bundle: - :return: +class AGsNewsPipe(CLSBasePipe): + """ + 处理AG's News的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): """ - # 复制一列words - data_bundle = _add_words_field(data_bundle, lower=self.lower) - - # 进行tokenize - data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) - # index - data_bundle = _indexize(data_bundle=data_bundle) - - for name, dataset in data_bundle.datasets.items(): - dataset.add_seq_len(Const.INPUT) - - data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) - data_bundle.set_target(Const.TARGET) - - return data_bundle - + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + def process_from_file(self, paths=None): """ + :param str paths: + :return: DataBundle + """ + data_bundle = AGsNewsLoader().load(paths) + return self.process(data_bundle=data_bundle) + + +class DBPediaPipe(CLSBasePipe): + """ + 处理DBPedia的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + """ + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') + def process_from_file(self, paths=None): + """ :param str paths: :return: DataBundle """ - data_bundle = YelpPolarityLoader().load(paths) + data_bundle = DBPediaLoader().load(paths) return self.process(data_bundle=data_bundle) -class SSTPipe(_CLSPipe): +class SSTPipe(CLSBasePipe): """ 经过该Pipe之后,DataSet中具备的field如下所示 @@ -314,11 +332,11 @@ class SSTPipe(_CLSPipe): self.granularity = granularity if granularity == 2: - self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} + self.tag_map = {"0": "negative", "1": "negative", "3": "positive", "4": "positive"} elif granularity == 3: - self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2} + self.tag_map = {"0": "negative", "1": "negative", "2": "medium", "3": "positive", "4": "positive"} else: - self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} + self.tag_map = None def process(self, data_bundle: DataBundle): """ @@ -340,7 +358,7 @@ class SSTPipe(_CLSPipe): ds = DataSet() use_subtree = self.subtree or (name == 'train' and self.train_tree) for ins in dataset: - raw_words = ins['raw_words'] + raw_words = ins[Const.RAW_WORD] tree = Tree.fromstring(raw_words) if use_subtree: for t in tree.subtrees(): @@ -351,23 +369,11 @@ class SSTPipe(_CLSPipe): instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) ds.append(instance) data_bundle.set_dataset(ds, name) - - _add_words_field(data_bundle, lower=self.lower) - - # 进行tokenize - data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) - + # 根据granularity设置tag - data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) - - # index - data_bundle = _indexize(data_bundle=data_bundle) - - for name, dataset in data_bundle.datasets.items(): - dataset.add_seq_len(Const.INPUT) + data_bundle = _granularize(data_bundle, tag_map=self.tag_map) - data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) - data_bundle.set_target(Const.TARGET) + data_bundle = super().process(data_bundle) return data_bundle @@ -376,7 +382,7 @@ class SSTPipe(_CLSPipe): return self.process(data_bundle=data_bundle) -class SST2Pipe(_CLSPipe): +class SST2Pipe(CLSBasePipe): """ 加载SST2的数据, 处理完成之后DataSet将拥有以下的field @@ -406,61 +412,7 @@ class SST2Pipe(_CLSPipe): :param bool lower: 是否对输入进行小写化。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ - super().__init__(tokenizer=tokenizer, lang='en') - self.lower = lower - - def process(self, data_bundle: DataBundle): - """ - 可以处理的DataSet应该具备如下的结构 - - .. csv-table:: - :header: "raw_words", "target" - - "it 's a charming and often affecting...", "1" - "unflinchingly bleak and...", "0" - "..." - - :param data_bundle: - :return: - """ - _add_words_field(data_bundle, self.lower) - - data_bundle = self._tokenize(data_bundle=data_bundle) - - src_vocab = Vocabulary() - src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, - no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if - name != 'train']) - src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) - - tgt_vocab = Vocabulary(unknown=None, padding=None) - tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], - field_name=Const.TARGET, - no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() - if ('train' not in name) and (ds.has_field(Const.TARGET))] - ) - if len(tgt_vocab._no_create_word) > 0: - warn_msg = f"There are {len(tgt_vocab._no_create_word)} target labels" \ - f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ - f"data set but not in train data set!." - warnings.warn(warn_msg) - logger.warning(warn_msg) - datasets = [] - for name, dataset in data_bundle.datasets.items(): - if dataset.has_field(Const.TARGET): - datasets.append(dataset) - tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) - - data_bundle.set_vocab(src_vocab, Const.INPUT) - data_bundle.set_vocab(tgt_vocab, Const.TARGET) - - for name, dataset in data_bundle.datasets.items(): - dataset.add_seq_len(Const.INPUT) - - data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) - data_bundle.set_target(Const.TARGET) - - return data_bundle + super().__init__(lower=lower, tokenizer=tokenizer, lang='en') def process_from_file(self, paths=None): """ @@ -472,7 +424,7 @@ class SST2Pipe(_CLSPipe): return self.process(data_bundle) -class IMDBPipe(_CLSPipe): +class IMDBPipe(CLSBasePipe): """ 经过本Pipe处理后DataSet将如下 @@ -532,14 +484,7 @@ class IMDBPipe(_CLSPipe): for name, dataset in data_bundle.datasets.items(): dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) - _add_words_field(data_bundle, lower=self.lower) - self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) - _indexize(data_bundle) - - for name, dataset in data_bundle.datasets.items(): - dataset.add_seq_len(Const.INPUT) - dataset.set_input(Const.INPUT, Const.INPUT_LEN) - dataset.set_target(Const.TARGET) + data_bundle = super().process(data_bundle) return data_bundle @@ -663,7 +608,7 @@ class ChnSentiCorpPipe(Pipe): return data_bundle -class THUCNewsPipe(_CLSPipe): +class THUCNewsPipe(CLSBasePipe): """ 处理之后的DataSet有以下的结构 @@ -727,7 +672,7 @@ class THUCNewsPipe(_CLSPipe): """ # 根据granularity设置tag tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} - data_bundle = self._granularize(data_bundle=data_bundle, tag_map=tag_map) + data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map) # clean,lower @@ -775,7 +720,7 @@ class THUCNewsPipe(_CLSPipe): return data_bundle -class WeiboSenti100kPipe(_CLSPipe): +class WeiboSenti100kPipe(CLSBasePipe): """ 处理之后的DataSet有以下的结构 @@ -820,7 +765,6 @@ class WeiboSenti100kPipe(_CLSPipe): dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) return data_bundle - def process(self, data_bundle: DataBundle): """ 可处理的DataSet应具备以下的field diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index d05ffe96..0a82c434 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -136,7 +136,7 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con f"These label(s) are {tgt_vocab._no_create_word}" warnings.warn(warn_msg) logger.warning(warn_msg) - tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) + tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name) return data_bundle @@ -198,3 +198,23 @@ def _drop_empty_instance(data_bundle, field_name): dataset.drop(empty_instance) return data_bundle + + +def _granularize(data_bundle, tag_map): + """ + 该函数对data_bundle中'target'列中的内容进行转换。 + + :param data_bundle: + :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, + 且将"1"认为是第0类。 + :return: 传入的data_bundle + """ + if tag_map is None: + return data_bundle + for name in list(data_bundle.datasets.keys()): + dataset = data_bundle.get_dataset(name) + dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET, + new_field_name=Const.TARGET) + dataset.drop(lambda ins: ins[Const.TARGET] == -100) + data_bundle.set_dataset(dataset, name) + return data_bundle diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 32581e23..f6cbbb4f 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric from fastNLP.core.metrics import _pred_topk, _accuracy_topk from fastNLP.core.vocabulary import Vocabulary from collections import Counter -from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric +from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric def _generate_tags(encoding_type, number_labels=4): @@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result): allen_result[key] = round(value, 6) return allen_result + + +class TestConfusionMatrixMetric(unittest.TestCase): + def test_ConfusionMatrixMetric1(self): + pred_dict = {"pred": torch.zeros(4,3)} + target_dict = {'target': torch.zeros(4)} + metric = ConfusionMatrixMetric() + + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + + def test_ConfusionMatrixMetric2(self): + # (2) with corrupted size + try: + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric = ConfusionMatrixMetric() + + metric(pred_dict=pred_dict, target_dict=target_dict, ) + print(metric.get_metric()) + except Exception as e: + print(e) + return + print("No exception catches.") + + def test_ConfusionMatrixMetric3(self): + # (3) the second batch is corrupted size + try: + metric = ConfusionMatrixMetric() + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric(pred_dict=pred_dict, target_dict=target_dict) + + print(metric.get_metric()) + except Exception as e: + print(e) + return + assert(True, False), "No exception catches." + + def test_ConfusionMatrixMetric4(self): + # (4) check reset + metric = ConfusionMatrixMetric() + pred_dict = {"pred": torch.randn(4, 3, 2)} + target_dict = {'target': torch.ones(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + res = metric.get_metric() + self.assertTrue(isinstance(res, dict)) + print(res) + + def test_ConfusionMatrixMetric5(self): + # (5) check numpy array is not acceptable + try: + metric = ConfusionMatrixMetric() + pred_dict = {"pred": np.zeros((4, 3, 2))} + target_dict = {'target': np.zeros((4, 3))} + metric(pred_dict=pred_dict, target_dict=target_dict) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_ConfusionMatrixMetric6(self): + # (6) check map, match + metric = ConfusionMatrixMetric(pred='predictions', target='targets') + pred_dict = {"predictions": torch.randn(4, 3, 2)} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + res = metric.get_metric() + print(res) + + def test_ConfusionMatrixMetric7(self): + # (7) check map, include unused + try: + metric = ConfusionMatrixMetric(pred='prediction', target='targets') + pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_ConfusionMatrixMetric8(self): +# (8) check _fast_metric + try: + metric = ConfusionMatrixMetric() + pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_duplicate(self): + # 0.4.1的潜在bug,不能出现形参重复的情况 + metric = ConfusionMatrixMetric(pred='predictions', target='targets') + pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} + target_dict = {'targets':torch.zeros(4, 3), 'target': 0} + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + + + def test_seq_len(self): + N = 256 + seq_len = torch.zeros(N).long() + seq_len[0] = 2 + pred = {'pred': torch.ones(N, 2)} + target = {'target': torch.ones(N, 2), 'seq_len': seq_len} + metric = ConfusionMatrixMetric() + metric(pred_dict=pred, target_dict=target) + metric.get_metric(reset=False) + seq_len[1:] = 1 + metric(pred_dict=pred, target_dict=target) + metric.get_metric() + + def test_vocab(self): + vocab = Vocabulary() + word_list = "this is a word list".split() + vocab.update(word_list) + + pred_dict = {"pred": torch.zeros(4,3)} + target_dict = {'target': torch.zeros(4)} + metric = ConfusionMatrixMetric(vocab=vocab) + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + + + + class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): # (1) only input, targets passed @@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase): def test_AccuaryMetric8(self): try: metric = AccuracyMetric(pred='predictions', target='targets') - pred_dict = {"prediction": torch.zeros(4, 3, 2)} + pred_dict = {"predictions": torch.zeros(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict, ) self.assertDictEqual(metric.get_metric(), {'acc': 1}) diff --git a/test/data_for_tests/io/ag/test.csv b/test/data_for_tests/io/ag/test.csv new file mode 100644 index 00000000..3a4cc0ae --- /dev/null +++ b/test/data_for_tests/io/ag/test.csv @@ -0,0 +1,5 @@ +"3","Fears for T N pension after talks","Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul." +"4","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com)","SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket." +"4","Ky. Company Wins Grant to Study Peptides (AP)","AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins." +"4","Prediction Unit Helps Forecast Wildfires (AP)","AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar." +"4","Calif. Aims to Limit Farm-Related Smog (AP)","AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure." diff --git a/test/data_for_tests/io/ag/train.csv b/test/data_for_tests/io/ag/train.csv new file mode 100644 index 00000000..e766a481 --- /dev/null +++ b/test/data_for_tests/io/ag/train.csv @@ -0,0 +1,4 @@ +"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again." +"4","Building Dedicated to Columbia Astronauts (AP)","AP - A former dormitory converted to classrooms at the Pensacola Naval Air Station was dedicated Friday to two Columbia astronauts who were among the seven who died in the shuttle disaster Feb. 1, 2003." +"2","Phelps On Relay Team","Michael Phelps is named to the 4x100-meter freestyle relay team that will compete in Sunday's final, keeping alive his quest for a possible eight Olympic gold medals." +"1","Venezuelans Vote Early in Referendum on Chavez Rule (Reuters)","Reuters - Venezuelans turned out early\and in large numbers on Sunday to vote in a historic referendum\that will either remove left-wing President Hugo Chavez from\office or give him a new mandate to govern for the next two\years." diff --git a/test/data_for_tests/io/dbpedia/test.csv b/test/data_for_tests/io/dbpedia/test.csv new file mode 100644 index 00000000..4e50b3fb --- /dev/null +++ b/test/data_for_tests/io/dbpedia/test.csv @@ -0,0 +1,5 @@ +1,"TY KU"," TY KU /taɪkuː/ is an American alcoholic beverage company that specializes in sake and other spirits. The privately-held company was founded in 2004 and is headquartered in New York City New York. While based in New York TY KU's beverages are made in Japan through a joint venture with two sake breweries. Since 2011 TY KU's growth has extended its products into all 50 states." +1,"Odd Lot Entertainment"," OddLot Entertainment founded in 2001 by longtime producers Gigi Pritzker and Deborah Del Prete (The Wedding Planner) is a film production and financing company based in Culver City California.OddLot produced the film version of Orson Scott Card's sci-fi novel Ender's Game. A film version of this novel had been in the works in one form or another for more than a decade by the time of its release." +1,"Henkel"," Henkel AG & Company KGaA operates worldwide with leading brands and technologies in three business areas: Laundry & Home Care Beauty Care and Adhesive Technologies. Henkel is the name behind some of America’s favorite brands." +1,"GOAT Store"," The GOAT Store (Games Of All Type Store) LLC is one of the largest retro gaming online stores and an Independent Video Game Publishing Label. Additionally they are one of the primary sponsors for Midwest Gaming Classic." +1,"RagWing Aircraft Designs"," RagWing Aircraft Designs (also called the RagWing Aeroplane Company and RagWing Aviation) was an American aircraft design and manufacturing company based in Belton South Carolina." diff --git a/test/data_for_tests/io/dbpedia/train.csv b/test/data_for_tests/io/dbpedia/train.csv new file mode 100644 index 00000000..d3698589 --- /dev/null +++ b/test/data_for_tests/io/dbpedia/train.csv @@ -0,0 +1,14 @@ +1,"Boneau/Bryan-Brown"," Boneau/Bryan-Brown Inc. is a public relations company based in Manhattan New York USA largely supporting Broadway theatre productions as a theatrical press agency.The company was formed by the partnership of Chris Boneau and Adrian Bryan-Brown in 1991. Broadway productions supported include among hundreds the musical Guys and Dolls in 1992. The company initially represented the rock musical Spider-Man: Turn Off the Dark which finally opened on Broadway in 2011." +2,"Dubai Gem Private School & Nursery"," Dubai Gem Private School (DGPS) is a British school located in the Oud Metha area of Dubai United Arab Emirates. Dubai Gem Nursery is located in Jumeirah. Together the institutions enroll almost 1500 students aged 3 to 18." +3,"Shahar Marcus"," Shahar Marcus (born 1971 in Petach Tikva) is an Israeli performance artist." +4,"Martin McKinnon"," Martin Marty McKinnon (born 5 July 1975 in Adelaide) is a former Australian rules footballer who played with Adelaide Geelong and the Brisbane Lions in the Australian Football League (AFL).McKinnon was recruited by Adelaide in the 1992 AFL Draft with their first ever national draft pick. He was the youngest player on Adelaide's list at the time and played for Central District in the SANFL when not appearing with Adelaide." +5,"Steve Howitt"," Steven S. Howitt is the current member of the Massachusetts House of Representatives for the 4th Bristol district." +6,"Wedell-Williams XP-34"," The Wedell-Williams XP-34 was a fighter aircraft design submitted to the United States Army Air Corps (USAAC) before World War II by Marguerite Clark Williams widow of millionaire Harry P. Williams former owner and co-founder of the Wedell-Williams Air Service Corporation." +7,"Nationality Rooms"," The Nationality Rooms are a collection of 29 classrooms in the University of Pittsburgh's Cathedral of Learning depicting and donated by the ethnic groups that helped build the city of Pittsburgh." +8,"Duruitoarea River"," The Duruitoarea River is a tributary of the Camenca River in Romania." +9,"Shirvan Shahlu"," Shirvan Shahlu (Persian: شيروان شاهلو‎ also Romanized as Shīrvān Shāhlū; also known as Shīravān Shāmnū) is a village in Gavdul-e Sharqi Rural District in the Central District of Malekan County East Azerbaijan Province Iran. At the 2006 census its population was 137 in 35 families." +10,"Oenopota impressa"," Oenopota impressa is a species of sea snail a marine gastropod mollusk in the family Mangeliidae." +11,"Utricularia simulans"," Utricularia simulans the fringed bladderwort is a small to medium-sized probably perennial carnivorous plant that belongs to the genus Utricularia. U. simulans is native to tropical Africa and the Americas. It grows as a terrestrial plant in damp sandy soils in open savanna at altitudes from near sea level to 1575 m (5167 ft). U. simulans was originally described and published by Robert Knud Friedrich Pilger in 1914." +12,"Global Chillage"," Global Chillage is the second album by The Irresistible Force released in 1994 through Rising High Records." +13,"The Nuisance (1933 film)"," The Nuisance is a 1933 film starring Lee Tracy as a lawyer Madge Evans as his love interest (with a secret) and Frank Morgan as his accomplice." +14,"Razadarit Ayedawbon"," Razadarit Ayedawbon (Burmese: ရာဇာဓိရာဇ် အရေးတော်ပုံ) is a Burmese chronicle covering the history of Ramanya from 1287 to 1421. The chronicle consists of accounts of court intrigues rebellions diplomatic missions wars etc. About half of the chronicle is devoted to the reign of King Razadarit (r." diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py index 036530c3..987327d4 100644 --- a/test/io/pipe/test_classification.py +++ b/test/io/pipe/test_classification.py @@ -2,7 +2,8 @@ import unittest import os from fastNLP.io import DataBundle -from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe +from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe, \ + AGsNewsPipe, DBPediaPipe from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe @@ -36,10 +37,12 @@ class TestRunClassificationPipe(unittest.TestCase): def test_process_from_file(self): data_set_dict = { 'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False), - 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1023, 5), False), + 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False), 'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True), 'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False), 'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False), + 'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False), + 'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False), 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False), 'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False), 'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False),