Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 4 years ago
parent
commit
9bed203a35
16 changed files with 715 additions and 396 deletions
  1. +1
    -1
      fastNLP/__init__.py
  2. +93
    -1
      fastNLP/core/metrics.py
  3. +99
    -3
      fastNLP/core/utils.py
  4. +6
    -14
      fastNLP/io/__init__.py
  5. +5
    -2
      fastNLP/io/loader/__init__.py
  6. +163
    -162
      fastNLP/io/loader/classification.py
  7. +5
    -3
      fastNLP/io/loader/matching.py
  8. +5
    -2
      fastNLP/io/pipe/__init__.py
  9. +147
    -203
      fastNLP/io/pipe/classification.py
  10. +21
    -1
      fastNLP/io/pipe/utils.py
  11. +137
    -2
      test/core/test_metrics.py
  12. +5
    -0
      test/data_for_tests/io/ag/test.csv
  13. +4
    -0
      test/data_for_tests/io/ag/train.csv
  14. +5
    -0
      test/data_for_tests/io/dbpedia/test.csv
  15. +14
    -0
      test/data_for_tests/io/dbpedia/train.csv
  16. +5
    -2
      test/io/pipe/test_classification.py

+ 1
- 1
fastNLP/__init__.py View File

@@ -71,7 +71,7 @@ __all__ = [
'logger'
]
__version__ = '0.4.5'
__version__ = '0.5.0'

import sys



+ 93
- 1
fastNLP/core/metrics.py View File

@@ -7,7 +7,8 @@ __all__ = [
"AccuracyMetric",
"SpanFPreRecMetric",
"CMRC2018Metric",
"ClassifyFPreRecMetric"
"ClassifyFPreRecMetric",
"ConfusionMatrixMetric"
]

import inspect
@@ -15,6 +16,7 @@ import warnings
from abc import abstractmethod
from collections import defaultdict
from typing import Union
from copy import deepcopy
import re

import numpy as np
@@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list
from .utils import _get_func_signature
from .utils import seq_len_to_mask
from .vocabulary import Vocabulary
from .utils import ConfusionMatrix


class MetricBase(object):
@@ -276,6 +279,95 @@ class MetricBase(object):
return


class ConfusionMatrixMetric(MetricBase):
r"""
分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )

最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例}
ConfusionMatrix实例的print()函数将输出矩阵字符串。

pred_dict = {"pred": torch.Tensor([2,1,3])}
target_dict = {'target': torch.Tensor([2,2,1])}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())

{'confusion_matrix':
target 1.0 2.0 3.0 all
pred
1.0 0 1 0 1
2.0 0 1 0 1
3.0 1 0 0 1
all 1 2 0 3}
"""
def __init__(self, vocab=None, pred=None, target=None, seq_len=None):
"""
:param vocab: vocab词表类,要求有to_word()方法。
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
"""
super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.confusion_matrix = ConfusionMatrix(vocab=vocab)

def evaluate(self, pred, target, seq_len=None):
"""
evaluate函数将针对一个批次的预测结果做评价指标的累计

:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]).
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")

if pred.dim() == target.dim():
pass
elif pred.dim() == target.dim() + 1:
pred = pred.argmax(dim=-1)
if seq_len is None and target.dim() > 1:
warnings.warn("You are not passing `seq_len` to exclude pad.")
else:
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

target = target.to(pred)
if seq_len is not None and target.dim() > 1:
for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()):
l=int(l)
self.confusion_matrix.add_pred_target(p[:l], t[:l])
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出
for p, t in zip(pred.tolist(), target.tolist()):
self.confusion_matrix.add_pred_target(p, t)
else:
self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist())

def get_metric(self,reset=True):
"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:param bool reset: 在调用完get_metric后是否清空评价指标统计量.
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix}
"""
confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)}
if reset:
self.confusion_matrix.clear()
return confusion


class AccuracyMetric(MetricBase):
"""
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )


+ 99
- 3
fastNLP/core/utils.py View File

@@ -8,18 +8,22 @@ __all__ = [
"get_seq_len"
]

import _pickle
import inspect
import os
import warnings
from collections import Counter, namedtuple
from copy import deepcopy
from typing import List

import _pickle
import numpy as np
import torch
import torch.nn as nn
from typing import List
from ._logger import logger
from prettytable import PrettyTable

from ._logger import logger
from ._parallel_utils import _model_contains_inner_module
# from .vocabulary import Vocabulary

try:
from apex import amp
@@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs'])




class ConfusionMatrix:
"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None):
"""
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
"""
if vocab and not hasattr(vocab, 'to_word'):
raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary,"
f"got {type(vocab)}.")
self.confusiondict={} #key: pred index, value:target word ocunt
self.predcount={} #key:pred index, value:count
self.targetcount={} #key:target index, value:count
self.vocab=vocab
def add_pred_target(self, pred, target): #一组结果
"""
通过这个函数向ConfusionMatrix加入一组预测结果

:param list pred: 预测的标签列表
:param list target: 真实值的标签列表
:return ConfusionMatrix

confusion=ConfusionMatrix()
pred = [2,1,3]
target = [2,2,1]
confusion.add_pred_target(pred, target)
print(confusion)

target 1 2 3 all
pred
1 0 1 0 1
2 0 1 0 1
3 1 0 0 1
all 1 2 0 3
"""
for p,t in zip(pred,target): #<int, int>
self.predcount[p]=self.predcount.get(p,0)+ 1
self.targetcount[t]=self.targetcount.get(t,0)+1
if p in self.confusiondict:
self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1
else:
self.confusiondict[p]={}
self.confusiondict[p][t]= 1
return self.confusiondict

def clear(self):
"""
清除一些值,等待再次新加入
:return:
"""
self.confusiondict={}
self.targetcount={}
self.predcount={}
def __repr__(self):
"""
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。
"""
row2idx={}
idx2row={}
# 已知的所有键/label
totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys()))))
lenth=len(totallabel)
# namedict key :idx value:word/idx
namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel])

for label,idx in zip(totallabel,range(lenth)):
idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
# 这里打印东西
#表头
head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"]
output="\t".join(head) + "\n" + "pred" + "\n"
#内容
for i in row2idx.keys(): #第i行
p=row2idx[i]
h=namedict[p]
l=[0 for _ in range(lenth)]
if self.confusiondict.get(p,None):
for t,c in self.confusiondict[p].items():
l[idx2row[t]] = c #完成一行
l=[h]+[str(n) for n in l]+[str(sum(l))]
output+="\t".join(l) +"\n"
#表尾
tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()]
tail=["all"]+[str(n) for n in tail]+[str(sum(tail))]
output+="\t".join(tail)
return output


class Option(dict):
"""a dict can treat keys as attributes"""



+ 6
- 14
fastNLP/io/__init__.py View File

@@ -18,7 +18,9 @@ __all__ = [
'Loader',
'YelpLoader',
'CLSBaseLoader',
'AGsNewsLoader',
'DBPediaLoader',
'YelpFullLoader',
'YelpPolarityLoader',
'IMDBLoader',
@@ -55,6 +57,9 @@ __all__ = [

"Pipe",

"CLSBasePipe",
"AGsNewsPipe",
"DBPediaPipe",
"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
@@ -73,19 +78,6 @@ __all__ = [

"CWSPipe",
"Pipe",
"CWSPipe",
"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
"SST2Pipe",
"IMDBPipe",
"ChnSentiCorpPipe",
"THUCNewsPipe",
"WeiboSenti100kPipe",
"Conll2003NERPipe",
"OntoNotesNERPipe",
"MsraNERPipe",


+ 5
- 2
fastNLP/io/loader/__init__.py View File

@@ -47,9 +47,11 @@ fastNLP 目前提供了如下的 Loader
__all__ = [
'Loader',
'YelpLoader',
'CLSBaseLoader',
'YelpFullLoader',
'YelpPolarityLoader',
'AGsNewsLoader',
'DBPediaLoader',
'IMDBLoader',
'SSTLoader',
'SST2Loader',
@@ -84,7 +86,8 @@ __all__ = [

"CMRC2018Loader"
]
from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, \
from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \
SSTLoader, SST2Loader, DBPediaLoader, \
ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader


+ 163
- 162
fastNLP/io/loader/classification.py View File

@@ -1,9 +1,11 @@
"""undocumented"""

__all__ = [
"YelpLoader",
"CLSBaseLoader",
"YelpFullLoader",
"YelpPolarityLoader",
"AGsNewsLoader",
"DBPediaLoader",
"IMDBLoader",
"SSTLoader",
"SST2Loader",
@@ -12,6 +14,7 @@ __all__ = [
"WeiboSenti100kLoader"
]


import glob
import os
import random
@@ -22,14 +25,17 @@ import warnings
from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core._logger import logger


class YelpLoader(Loader):
class CLSBaseLoader(Loader):
"""
文本分类Loader的一个基类

原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。

Example::
"1","I got 'new' tires from the..."
"1","Don't waste your time..."

@@ -43,125 +49,112 @@ class YelpLoader(Loader):
"...", "..."

"""
def __init__(self):
super(YelpLoader, self).__init__()
def _load(self, path: str = None):

def __init__(self, sep=',', has_header=False):
super().__init__()
self.sep = sep
self.has_header = has_header

def _load(self, path: str):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
sep_index = line.index(',')
target = line[:sep_index]
raw_words = line[sep_index + 1:]
if target.startswith("\""):
target = target[1:]
if target.endswith("\""):
target = target[:-1]
if raw_words.endswith("\""):
raw_words = raw_words[:-1]
if raw_words.startswith('"'):
raw_words = raw_words[1:]
raw_words = raw_words.replace('""', '"') # 替换双引号
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
try:
with open(path, 'r', encoding='utf-8') as f:
read_header = self.has_header
for line in f:
if read_header:
read_header = False
continue
line = line.strip()
sep_index = line.index(self.sep)
target = line[:sep_index]
raw_words = line[sep_index + 1:]
if target.startswith("\""):
target = target[1:]
if target.endswith("\""):
target = target[:-1]
if raw_words.endswith("\""):
raw_words = raw_words[:-1]
if raw_words.startswith('"'):
raw_words = raw_words[1:]
raw_words = raw_words.replace('""', '"') # 替换双引号
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
except Exception as e:
logger.error(f'Load file `{path}` failed for `{e}`')
return ds


class YelpFullLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, re_download: bool = False):
def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix='csv'):
if dev_ratio == 0.0:
return data_dir
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = Loader()._get_dataset_path(dataset_name=dataset_name)

if not os.path.exists(os.path.join(data_dir, f'dev.{suffix}')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:
with open(os.path.join(data_dir, f'train.{suffix}'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, f'middle_file.{suffix}'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, f'dev.{suffix}'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, f'train.{suffix}'))
os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), os.path.join(data_dir, f'train.{suffix}'))
finally:
if os.path.exists(os.path.join(data_dir, f'middle_file.{suffix}')):
os.remove(os.path.join(data_dir, f'middle_file.{suffix}'))

return data_dir


class AGsNewsLoader(CLSBaseLoader):
def download(self):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv,
dev.csv三个文件。

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-full'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.csv'))
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv'))
return data_dir
return self._get_dataset_path(dataset_name='ag-news')


class YelpPolarityLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, re_download=False):
class DBPediaLoader(CLSBaseLoader):
def download(self, dev_ratio: float = 0.0, re_download: bool = False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分dev_ratio这么多作为dev
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv

:param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据。 如果为0,则不划分dev。
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-polarity'
dataset_name = 'dbpedia'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.csv'))
os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv'))
data_dir = _split_dev(dataset_name=dataset_name,
data_dir=data_dir,
dev_ratio=dev_ratio,
re_download=re_download,
suffix='csv')
return data_dir


class IMDBLoader(Loader):
class IMDBLoader(CLSBaseLoader):
"""
原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。

@@ -181,35 +174,16 @@ class IMDBLoader(Loader):
"...", "..."

"""
def __init__(self):
super(IMDBLoader, self).__init__()
def _load(self, path: str):
dataset = DataSet()
with open(path, 'r', encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
parts = line.split('\t')
target = parts[0]
words = parts[1]
if words:
dataset.append(Instance(raw_words=words, target=target))
if len(dataset) == 0:
raise RuntimeError(f"{path} has no valid data.")
return dataset
def download(self, dev_ratio: float = 0.1, re_download=False):
super().__init__(sep='\t')

def download(self, dev_ratio: float = 0.0, re_download=False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

http://www.aclweb.org/anthology/P11-1015

根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后不从train中切分dev

:param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev
:param bool re_download: 是否重新下载数据,以重新切分数据。
@@ -217,32 +191,11 @@ class IMDBLoader(Loader):
"""
dataset_name = 'aclImdb'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.txt')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.txt'))
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
os.remove(os.path.join(data_dir, 'middle_file.txt'))
data_dir = _split_dev(dataset_name=dataset_name,
data_dir=data_dir,
dev_ratio=dev_ratio,
re_download=re_download,
suffix='txt')
return data_dir


@@ -267,10 +220,10 @@ class SSTLoader(Loader):
raw_words列是str。

"""
def __init__(self):
super().__init__()
def _load(self, path: str):
"""
从path读取SST文件
@@ -285,7 +238,7 @@ class SSTLoader(Loader):
if line:
ds.append(Instance(raw_words=line))
return ds
def download(self):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章
@@ -298,6 +251,56 @@ class SSTLoader(Loader):
return output_dir


class YelpFullLoader(CLSBaseLoader):
def download(self, dev_ratio: float = 0.0, re_download: bool = False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-full'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
data_dir = _split_dev(dataset_name=dataset_name,
data_dir=data_dir,
dev_ratio=dev_ratio,
re_download=re_download,
suffix='csv')
return data_dir


class YelpPolarityLoader(CLSBaseLoader):
def download(self, dev_ratio: float = 0.0, re_download: bool = False):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-polarity'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
data_dir = _split_dev(dataset_name=dataset_name,
data_dir=data_dir,
dev_ratio=dev_ratio,
re_download=re_download,
suffix='csv')
return data_dir


class SST2Loader(Loader):
"""
原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label
@@ -319,19 +322,18 @@ class SST2Loader(Loader):

test的DataSet没有target列。
"""
def __init__(self):
super().__init__()
def _load(self, path: str):
"""
从path读取SST2文件
"""从path读取SST2文件

:param str path: 数据路径
:return: DataSet
"""
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if 'test' in os.path.split(path)[1]:
@@ -341,8 +343,9 @@ class SST2Loader(Loader):
if line:
sep_index = line.index('\t')
raw_words = line[sep_index + 1:]
index = int(line[: sep_index])
if raw_words:
ds.append(Instance(raw_words=raw_words))
ds.append(Instance(raw_words=raw_words, index=index))
else:
for line in f:
line = line.strip()
@@ -352,13 +355,11 @@ class SST2Loader(Loader):
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
return ds
def download(self):
"""
自动下载数据集,如果你使用了该数据集,请引用以下的文章

https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf

:return:
"""
output_dir = self._get_dataset_path(dataset_name='sst-2')
@@ -389,7 +390,7 @@ class ChnSentiCorpLoader(Loader):
def __init__(self):
super().__init__()

def _load(self, path:str):
def _load(self, path: str):
"""
从path中读取数据

@@ -404,7 +405,7 @@ class ChnSentiCorpLoader(Loader):
tab_index = line.index('\t')
if tab_index != -1:
target = line[:tab_index]
raw_chars = line[tab_index+1:]
raw_chars = line[tab_index + 1:]
if raw_chars:
ds.append(Instance(raw_chars=raw_chars, target=target))
return ds
@@ -432,10 +433,10 @@ class THUCNewsLoader(Loader):
读取后的Dataset将具有以下数据结构:

.. csv-table::
:header: "raw_words", "target"
"调查-您如何评价热火客场胜绿军总分3-1夺赛点?...", "体育"
"...", "..."
:header: "raw_words", "target"
"调查-您如何评价热火客场胜绿军总分3-1夺赛点?...", "体育"
"...", "..."

"""

@@ -481,7 +482,7 @@ class WeiboSenti100kLoader(Loader):

.. csv-table::
:header: "raw_chars", "target"
"多谢小莲,好运满满[爱你]", "1"
"能在他乡遇老友真不赖,哈哈,珠儿,我也要用...", "1"
"...", "..."


+ 5
- 3
fastNLP/io/loader/matching.py View File

@@ -56,15 +56,16 @@ class MNLILoader(Loader):
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
warnings.warn("RTE's test file has no target.")
warnings.warn("MNLI's test file has no target.")
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
raw_words1 = parts[8]
raw_words2 = parts[9]
idx = int(parts[0])
if raw_words1 and raw_words2:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, index=idx))
else:
for line in f:
line = line.strip()
@@ -73,8 +74,9 @@ class MNLILoader(Loader):
raw_words1 = parts[8]
raw_words2 = parts[9]
target = parts[-1]
idx = int(parts[0])
if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target, index=idx))
return ds
def load(self, paths: str = None):


+ 5
- 2
fastNLP/io/pipe/__init__.py View File

@@ -12,6 +12,9 @@ __all__ = [
"CWSPipe",
"CLSBasePipe",
"AGsNewsPipe",
"DBPediaPipe",
"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
@@ -55,8 +58,8 @@ __all__ = [
"CMRC2018BertPipe"
]

from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
WeiboSenti100kPipe
from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .conll import Conll2003Pipe
from .coreference import CoReferencePipe


+ 147
- 203
fastNLP/io/pipe/classification.py View File

@@ -1,6 +1,9 @@
"""undocumented"""

__all__ = [
"CLSBasePipe",
"AGsNewsPipe",
"DBPediaPipe",
"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
@@ -17,29 +20,24 @@ import warnings
from nltk import Tree

from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance, _add_chars_field
from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize
from ..data_bundle import DataBundle
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \
AGsNewsLoader, DBPediaLoader
from ...core._logger import logger
from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import Vocabulary

nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')

class CLSBasePipe(Pipe):

class _CLSPipe(Pipe):
"""
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列

"""
def __init__(self, tokenizer: str = 'spacy', lang='en'):
def __init__(self, lower: bool=False, tokenizer: str='spacy', lang='en'):
super().__init__()
self.lower = lower
self.tokenizer = get_tokenizer(tokenizer, lang=lang)

def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
"""
将DataBundle中的数据进行tokenize
@@ -52,47 +50,49 @@ class _CLSPipe(Pipe):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
return data_bundle
def _granularize(self, data_bundle, tag_map):
def process(self, data_bundle: DataBundle):
"""
该函数对data_bundle中'target'列中的内容进行转换。
传入的DataSet应该具备如下的结构

.. csv-table::
:header: "raw_words", "target"

"I got 'new' tires from them and... ", "1"
"Don't waste your time. We had two...", "1"
"...", "..."

:param data_bundle:
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance,
且将"1"认为是第0类。
:return: 传入的data_bundle
:return:
"""
for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name)
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET,
new_field_name=Const.TARGET)
dataset.drop(lambda ins: ins[Const.TARGET] == -100)
data_bundle.set_dataset(dataset, name)
# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 建立词表并index
data_bundle = _indexize(data_bundle=data_bundle)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths) -> DataBundle:
"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`

def _clean_str(words):
"""
heavily borrowed from github
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
:param sentence: is a str
:return:
"""
words_collection = []
for word in words:
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']:
continue
tt = nonalpnum.split(word)
t = ''.join(tt)
if t != '':
words_collection.append(t)
return words_collection
:param paths:
:return: DataBundle
"""
raise NotImplementedError


class YelpFullPipe(_CLSPipe):
class YelpFullPipe(CLSBasePipe):
"""
处理YelpFull的数据, 处理之后DataSet中的内容如下

@@ -124,32 +124,16 @@ class YelpFullPipe(_CLSPipe):
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')
assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity
if granularity == 2:
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1}
self.tag_map = {"1": "negative", "2": "negative", "4": "positive", "5": "positive"}
elif granularity == 3:
self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2}
self.tag_map = {"1": "negative", "2": "negative", "3": "medium", "4": "positive", "5": "positive"}
else:
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
"""
将DataBundle中的数据进行tokenize

:param DataBundle data_bundle:
:param str field_name:
:param str new_field_name:
:return: 传入的DataBundle对象
"""
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name)
return data_bundle
self.tag_map = None
def process(self, data_bundle):
"""
@@ -165,27 +149,10 @@ class YelpFullPipe(_CLSPipe):
:param data_bundle:
:return:
"""
# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
# 删除空行
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT)
# index
data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
if self.tag_map is not None:
data_bundle = _granularize(data_bundle, self.tag_map)

data_bundle = super().process(data_bundle)
return data_bundle
@@ -199,7 +166,7 @@ class YelpFullPipe(_CLSPipe):
return self.process(data_bundle=data_bundle)


class YelpPolarityPipe(_CLSPipe):
class YelpPolarityPipe(CLSBasePipe):
"""
处理YelpPolarity的数据, 处理之后DataSet中的内容如下

@@ -229,50 +196,101 @@ class YelpPolarityPipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')
def process(self, data_bundle):
def process_from_file(self, paths=None):
"""
传入的DataSet应该具备如下的结构

.. csv-table::
:header: "raw_words", "target"
:param str paths:
:return: DataBundle
"""
data_bundle = YelpPolarityLoader().load(paths)
return self.process(data_bundle=data_bundle)

"I got 'new' tires from them and... ", "1"
"Don't waste your time. We had two...", "1"
"...", "..."

:param data_bundle:
:return:
class AGsNewsPipe(CLSBasePipe):
"""
处理AG's News的数据, 处理之后DataSet中的内容如下

.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field
:header: "raw_words", "target", "words", "seq_len"

"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40
"...", ., "[...]", .

dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为::

+-------------+-----------+--------+-------+---------+
| field_names | raw_words | target | words | seq_len |
+-------------+-----------+--------+-------+---------+
| is_input | False | False | True | True |
| is_target | False | True | False | False |
| ignore_type | | False | False | False |
| pad_value | | 0 | 0 | 0 |
+-------------+-----------+--------+-------+---------+

"""

def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
"""
# 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# index
data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
return data_bundle

:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')

def process_from_file(self, paths=None):
"""
:param str paths:
:return: DataBundle
"""
data_bundle = AGsNewsLoader().load(paths)
return self.process(data_bundle=data_bundle)


class DBPediaPipe(CLSBasePipe):
"""
处理DBPedia的数据, 处理之后DataSet中的内容如下

.. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field
:header: "raw_words", "target", "words", "seq_len"

"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40
"...", ., "[...]", .

dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为::

+-------------+-----------+--------+-------+---------+
| field_names | raw_words | target | words | seq_len |
+-------------+-----------+--------+-------+---------+
| is_input | False | False | True | True |
| is_target | False | True | False | False |
| ignore_type | | False | False | False |
| pad_value | | 0 | 0 | 0 |
+-------------+-----------+--------+-------+---------+

"""

def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
"""

:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')

def process_from_file(self, paths=None):
"""
:param str paths:
:return: DataBundle
"""
data_bundle = YelpPolarityLoader().load(paths)
data_bundle = DBPediaLoader().load(paths)
return self.process(data_bundle=data_bundle)


class SSTPipe(_CLSPipe):
class SSTPipe(CLSBasePipe):
"""
经过该Pipe之后,DataSet中具备的field如下所示

@@ -314,11 +332,11 @@ class SSTPipe(_CLSPipe):
self.granularity = granularity
if granularity == 2:
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1}
self.tag_map = {"0": "negative", "1": "negative", "3": "positive", "4": "positive"}
elif granularity == 3:
self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2}
self.tag_map = {"0": "negative", "1": "negative", "2": "medium", "3": "positive", "4": "positive"}
else:
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
self.tag_map = None
def process(self, data_bundle: DataBundle):
"""
@@ -340,7 +358,7 @@ class SSTPipe(_CLSPipe):
ds = DataSet()
use_subtree = self.subtree or (name == 'train' and self.train_tree)
for ins in dataset:
raw_words = ins['raw_words']
raw_words = ins[Const.RAW_WORD]
tree = Tree.fromstring(raw_words)
if use_subtree:
for t in tree.subtrees():
@@ -351,23 +369,11 @@ class SSTPipe(_CLSPipe):
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
ds.append(instance)
data_bundle.set_dataset(ds, name)
_add_words_field(data_bundle, lower=self.lower)
# 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)

# 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
# index
data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
data_bundle = _granularize(data_bundle, tag_map=self.tag_map)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
data_bundle = super().process(data_bundle)
return data_bundle
@@ -376,7 +382,7 @@ class SSTPipe(_CLSPipe):
return self.process(data_bundle=data_bundle)


class SST2Pipe(_CLSPipe):
class SST2Pipe(CLSBasePipe):
"""
加载SST2的数据, 处理完成之后DataSet将拥有以下的field

@@ -406,61 +412,7 @@ class SST2Pipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower
def process(self, data_bundle: DataBundle):
"""
可以处理的DataSet应该具备如下的结构

.. csv-table::
:header: "raw_words", "target"

"it 's a charming and often affecting...", "1"
"unflinchingly bleak and...", "0"
"..."

:param data_bundle:
:return:
"""
_add_words_field(data_bundle, self.lower)
data_bundle = self._tokenize(data_bundle=data_bundle)
src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name],
field_name=Const.TARGET,
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets()
if ('train' not in name) and (ds.has_field(Const.TARGET))]
)
if len(tgt_vocab._no_create_word) > 0:
warn_msg = f"There are {len(tgt_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!."
warnings.warn(warn_msg)
logger.warning(warn_msg)
datasets = []
for name, dataset in data_bundle.datasets.items():
if dataset.has_field(Const.TARGET):
datasets.append(dataset)
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET)
data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET)
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)
return data_bundle
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')
def process_from_file(self, paths=None):
"""
@@ -472,7 +424,7 @@ class SST2Pipe(_CLSPipe):
return self.process(data_bundle)


class IMDBPipe(_CLSPipe):
class IMDBPipe(CLSBasePipe):
"""
经过本Pipe处理后DataSet将如下

@@ -532,14 +484,7 @@ class IMDBPipe(_CLSPipe):
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
_add_words_field(data_bundle, lower=self.lower)
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT)
_indexize(data_bundle)
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)
dataset.set_input(Const.INPUT, Const.INPUT_LEN)
dataset.set_target(Const.TARGET)
data_bundle = super().process(data_bundle)
return data_bundle
@@ -663,7 +608,7 @@ class ChnSentiCorpPipe(Pipe):
return data_bundle


class THUCNewsPipe(_CLSPipe):
class THUCNewsPipe(CLSBasePipe):
"""
处理之后的DataSet有以下的结构

@@ -727,7 +672,7 @@ class THUCNewsPipe(_CLSPipe):
"""
# 根据granularity设置tag
tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9}
data_bundle = self._granularize(data_bundle=data_bundle, tag_map=tag_map)
data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map)

# clean,lower

@@ -775,7 +720,7 @@ class THUCNewsPipe(_CLSPipe):
return data_bundle


class WeiboSenti100kPipe(_CLSPipe):
class WeiboSenti100kPipe(CLSBasePipe):
"""
处理之后的DataSet有以下的结构

@@ -820,7 +765,6 @@ class WeiboSenti100kPipe(_CLSPipe):
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name)
return data_bundle


def process(self, data_bundle: DataBundle):
"""
可处理的DataSet应具备以下的field


+ 21
- 1
fastNLP/io/pipe/utils.py View File

@@ -136,7 +136,7 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con
f"These label(s) are {tgt_vocab._no_create_word}"
warnings.warn(warn_msg)
logger.warning(warn_msg)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name)
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name)
return data_bundle
@@ -198,3 +198,23 @@ def _drop_empty_instance(data_bundle, field_name):
dataset.drop(empty_instance)
return data_bundle


def _granularize(data_bundle, tag_map):
"""
该函数对data_bundle中'target'列中的内容进行转换。

:param data_bundle:
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance,
且将"1"认为是第0类。
:return: 传入的data_bundle
"""
if tag_map is None:
return data_bundle
for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name)
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET,
new_field_name=Const.TARGET)
dataset.drop(lambda ins: ins[Const.TARGET] == -100)
data_bundle.set_dataset(dataset, name)
return data_bundle

+ 137
- 2
test/core/test_metrics.py View File

@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric
from fastNLP.core.metrics import _pred_topk, _accuracy_topk
from fastNLP.core.vocabulary import Vocabulary
from collections import Counter
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric


def _generate_tags(encoding_type, number_labels=4):
@@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result):
allen_result[key] = round(value, 6)
return allen_result



class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric1(self):
pred_dict = {"pred": torch.zeros(4,3)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric()

metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())

def test_ConfusionMatrixMetric2(self):
# (2) with corrupted size
try:
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
except Exception as e:
print(e)
return
print("No exception catches.")

def test_ConfusionMatrixMetric3(self):
# (3) the second batch is corrupted size
try:
metric = ConfusionMatrixMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
except Exception as e:
print(e)
return
assert(True, False), "No exception catches."

def test_ConfusionMatrixMetric4(self):
# (4) check reset
metric = ConfusionMatrixMetric()
pred_dict = {"pred": torch.randn(4, 3, 2)}
target_dict = {'target': torch.ones(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
res = metric.get_metric()
self.assertTrue(isinstance(res, dict))
print(res)

def test_ConfusionMatrixMetric5(self):
# (5) check numpy array is not acceptable
try:
metric = ConfusionMatrixMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
def test_ConfusionMatrixMetric6(self):
# (6) check map, match
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.randn(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
res = metric.get_metric()
print(res)

def test_ConfusionMatrixMetric7(self):
# (7) check map, include unused
try:
metric = ConfusionMatrixMetric(pred='prediction', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
def test_ConfusionMatrixMetric8(self):
# (8) check _fast_metric
try:
metric = ConfusionMatrixMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0}
target_dict = {'targets':torch.zeros(4, 3), 'target': 0}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())


def test_seq_len(self):
N = 256
seq_len = torch.zeros(N).long()
seq_len[0] = 2
pred = {'pred': torch.ones(N, 2)}
target = {'target': torch.ones(N, 2), 'seq_len': seq_len}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred, target_dict=target)
metric.get_metric(reset=False)
seq_len[1:] = 1
metric(pred_dict=pred, target_dict=target)
metric.get_metric()

def test_vocab(self):
vocab = Vocabulary()
word_list = "this is a word list".split()
vocab.update(word_list)
pred_dict = {"pred": torch.zeros(4,3)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric(vocab=vocab)
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())




class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed
@@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase):
def test_AccuaryMetric8(self):
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2)}
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1})


+ 5
- 0
test/data_for_tests/io/ag/test.csv View File

@@ -0,0 +1,5 @@
"3","Fears for T N pension after talks","Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
"4","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com)","SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket."
"4","Ky. Company Wins Grant to Study Peptides (AP)","AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins."
"4","Prediction Unit Helps Forecast Wildfires (AP)","AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will strike in places he expects. Winds will pick up, moist places will dry and flames will roar."
"4","Calif. Aims to Limit Farm-Related Smog (AP)","AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure."

+ 4
- 0
test/data_for_tests/io/ag/train.csv View File

@@ -0,0 +1,4 @@
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
"4","Building Dedicated to Columbia Astronauts (AP)","AP - A former dormitory converted to classrooms at the Pensacola Naval Air Station was dedicated Friday to two Columbia astronauts who were among the seven who died in the shuttle disaster Feb. 1, 2003."
"2","Phelps On Relay Team","Michael Phelps is named to the 4x100-meter freestyle relay team that will compete in Sunday's final, keeping alive his quest for a possible eight Olympic gold medals."
"1","Venezuelans Vote Early in Referendum on Chavez Rule (Reuters)","Reuters - Venezuelans turned out early\and in large numbers on Sunday to vote in a historic referendum\that will either remove left-wing President Hugo Chavez from\office or give him a new mandate to govern for the next two\years."

+ 5
- 0
test/data_for_tests/io/dbpedia/test.csv View File

@@ -0,0 +1,5 @@
1,"TY KU"," TY KU /taɪkuː/ is an American alcoholic beverage company that specializes in sake and other spirits. The privately-held company was founded in 2004 and is headquartered in New York City New York. While based in New York TY KU's beverages are made in Japan through a joint venture with two sake breweries. Since 2011 TY KU's growth has extended its products into all 50 states."
1,"Odd Lot Entertainment"," OddLot Entertainment founded in 2001 by longtime producers Gigi Pritzker and Deborah Del Prete (The Wedding Planner) is a film production and financing company based in Culver City California.OddLot produced the film version of Orson Scott Card's sci-fi novel Ender's Game. A film version of this novel had been in the works in one form or another for more than a decade by the time of its release."
1,"Henkel"," Henkel AG & Company KGaA operates worldwide with leading brands and technologies in three business areas: Laundry & Home Care Beauty Care and Adhesive Technologies. Henkel is the name behind some of America’s favorite brands."
1,"GOAT Store"," The GOAT Store (Games Of All Type Store) LLC is one of the largest retro gaming online stores and an Independent Video Game Publishing Label. Additionally they are one of the primary sponsors for Midwest Gaming Classic."
1,"RagWing Aircraft Designs"," RagWing Aircraft Designs (also called the RagWing Aeroplane Company and RagWing Aviation) was an American aircraft design and manufacturing company based in Belton South Carolina."

+ 14
- 0
test/data_for_tests/io/dbpedia/train.csv View File

@@ -0,0 +1,14 @@
1,"Boneau/Bryan-Brown"," Boneau/Bryan-Brown Inc. is a public relations company based in Manhattan New York USA largely supporting Broadway theatre productions as a theatrical press agency.The company was formed by the partnership of Chris Boneau and Adrian Bryan-Brown in 1991. Broadway productions supported include among hundreds the musical Guys and Dolls in 1992. The company initially represented the rock musical Spider-Man: Turn Off the Dark which finally opened on Broadway in 2011."
2,"Dubai Gem Private School & Nursery"," Dubai Gem Private School (DGPS) is a British school located in the Oud Metha area of Dubai United Arab Emirates. Dubai Gem Nursery is located in Jumeirah. Together the institutions enroll almost 1500 students aged 3 to 18."
3,"Shahar Marcus"," Shahar Marcus (born 1971 in Petach Tikva) is an Israeli performance artist."
4,"Martin McKinnon"," Martin Marty McKinnon (born 5 July 1975 in Adelaide) is a former Australian rules footballer who played with Adelaide Geelong and the Brisbane Lions in the Australian Football League (AFL).McKinnon was recruited by Adelaide in the 1992 AFL Draft with their first ever national draft pick. He was the youngest player on Adelaide's list at the time and played for Central District in the SANFL when not appearing with Adelaide."
5,"Steve Howitt"," Steven S. Howitt is the current member of the Massachusetts House of Representatives for the 4th Bristol district."
6,"Wedell-Williams XP-34"," The Wedell-Williams XP-34 was a fighter aircraft design submitted to the United States Army Air Corps (USAAC) before World War II by Marguerite Clark Williams widow of millionaire Harry P. Williams former owner and co-founder of the Wedell-Williams Air Service Corporation."
7,"Nationality Rooms"," The Nationality Rooms are a collection of 29 classrooms in the University of Pittsburgh's Cathedral of Learning depicting and donated by the ethnic groups that helped build the city of Pittsburgh."
8,"Duruitoarea River"," The Duruitoarea River is a tributary of the Camenca River in Romania."
9,"Shirvan Shahlu"," Shirvan Shahlu (Persian: شيروان شاهلو‎ also Romanized as Shīrvān Shāhlū; also known as Shīravān Shāmnū) is a village in Gavdul-e Sharqi Rural District in the Central District of Malekan County East Azerbaijan Province Iran. At the 2006 census its population was 137 in 35 families."
10,"Oenopota impressa"," Oenopota impressa is a species of sea snail a marine gastropod mollusk in the family Mangeliidae."
11,"Utricularia simulans"," Utricularia simulans the fringed bladderwort is a small to medium-sized probably perennial carnivorous plant that belongs to the genus Utricularia. U. simulans is native to tropical Africa and the Americas. It grows as a terrestrial plant in damp sandy soils in open savanna at altitudes from near sea level to 1575 m (5167 ft). U. simulans was originally described and published by Robert Knud Friedrich Pilger in 1914."
12,"Global Chillage"," Global Chillage is the second album by The Irresistible Force released in 1994 through Rising High Records."
13,"The Nuisance (1933 film)"," The Nuisance is a 1933 film starring Lee Tracy as a lawyer Madge Evans as his love interest (with a secret) and Frank Morgan as his accomplice."
14,"Razadarit Ayedawbon"," Razadarit Ayedawbon (Burmese: ရာဇာဓိရာဇ် အရေးတော်ပုံ) is a Burmese chronicle covering the history of Ramanya from 1287 to 1421. The chronicle consists of accounts of court intrigues rebellions diplomatic missions wars etc. About half of the chronicle is devoted to the reign of King Razadarit (r."

+ 5
- 2
test/io/pipe/test_classification.py View File

@@ -2,7 +2,8 @@ import unittest
import os

from fastNLP.io import DataBundle
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe
from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe, \
AGsNewsPipe, DBPediaPipe
from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe


@@ -36,10 +37,12 @@ class TestRunClassificationPipe(unittest.TestCase):
def test_process_from_file(self):
data_set_dict = {
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False),
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1023, 5), False),
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False),
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True),
'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False),
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False),
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False),
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False),
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False),
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False),
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False),


Loading…
Cancel
Save