Browse Source

[new] add seperator for conll loader (#293)

* add ConfusionMatrix, ConfusionMatrixMetric

* add confusionmatrix to utils

* add ConfusionMatrixmetric

* add ConfusionMatrixMetric

* init for test

* begin test

* test finish

* doc finish

* revised confusion

* revised two

* revise two

* add sep for conll loader

* with remote

* withdraw some update

* finish merge

* update test

* update test

* to avoid none situation
tags/v0.6.0
ROGERDJQ GitHub 4 years ago
parent
commit
4e95989e97
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 24 additions and 9 deletions
  1. +1
    -1
      fastNLP/core/metrics.py
  2. +1
    -1
      fastNLP/core/utils.py
  3. +4
    -3
      fastNLP/io/file_reader.py
  4. +4
    -2
      fastNLP/io/loader/conll.py
  5. +6
    -1
      test/core/test_metrics.py
  6. +8
    -1
      test/io/loader/test_conll_loader.py

+ 1
- 1
fastNLP/core/metrics.py View File

@@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase):
print_ratio=False print_ratio=False
): ):
r""" r"""

:param vocab: vocab词表类,要求有to_word()方法。 :param vocab: vocab词表类,要求有to_word()方法。
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
@@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase):
def evaluate(self, pred, target, seq_len=None): def evaluate(self, pred, target, seq_len=None):
r""" r"""
evaluate函数将针对一个批次的预测结果做评价指标的累计 evaluate函数将针对一个批次的预测结果做评价指标的累计
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),


+ 1
- 1
fastNLP/core/utils.py View File

@@ -62,6 +62,7 @@ class ConfusionMatrix:
target = [2,2,1] target = [2,2,1]
confusion.add_pred_target(pred, target) confusion.add_pred_target(pred, target)
print(confusion) print(confusion)

target 1 2 3 all target 1 2 3 all
pred pred
1 0 1 0 1 1 0 1 0 1
@@ -157,7 +158,6 @@ class ConfusionMatrix:
(k, str(k if self.vocab == None else self.vocab.to_word(k))) (k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel for k in totallabel
]) ])

for label, idx in zip(totallabel, range(lenth)): for label, idx in zip(totallabel, range(lenth)):
idx2row[ idx2row[
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...


+ 4
- 3
fastNLP/io/file_reader.py View File

@@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True):
yield line_idx, _res yield line_idx, _res




def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
r""" r"""
Construct a generator to read conll items. Construct a generator to read conll items.


:param path: file path :param path: file path
:param encoding: file's encoding, default: utf-8 :param encoding: file's encoding, default: utf-8
:param sep: seperator
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
:param dropna: weather to ignore and drop invalid data, :param dropna: weather to ignore and drop invalid data,
:if False, raise ValueError when reading invalid data. default: True :if False, raise ValueError when reading invalid data. default: True
@@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
sample = [] sample = []
start = next(f).strip() start = next(f).strip()
if start != '': if start != '':
sample.append(start.split())
sample.append(start.split(sep)) if sep else sample.append(start.split())
for line_idx, line in enumerate(f, 1): for line_idx, line in enumerate(f, 1):
line = line.strip() line = line.strip()
if line == '': if line == '':
@@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
elif line.startswith('#'): elif line.startswith('#'):
continue continue
else: else:
sample.append(line.split())
sample.append(line.split(sep)) if sep else sample.append(line.split())
if len(sample) > 0: if len(sample) > 0:
try: try:
res = parse_conll(sample) res = parse_conll(sample)


+ 4
- 2
fastNLP/io/loader/conll.py View File

@@ -55,10 +55,11 @@ class ConllLoader(Loader):


""" """
def __init__(self, headers, indexes=None, dropna=True):
def __init__(self, headers, sep=None, indexes=None, dropna=True):
r""" r"""
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param list sep: 指定分隔符,默认为制表符
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
""" """
@@ -68,6 +69,7 @@ class ConllLoader(Loader):
'invalid headers: {}, should be list of strings'.format(headers)) 'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers self.headers = headers
self.dropna = dropna self.dropna = dropna
self.sep=sep
if indexes is None: if indexes is None:
self.indexes = list(range(len(self.headers))) self.indexes = list(range(len(self.headers)))
else: else:
@@ -83,7 +85,7 @@ class ConllLoader(Loader):
:return: DataSet :return: DataSet
""" """
ds = DataSet() ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds


+ 6
- 1
test/core/test_metrics.py View File

@@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result):
return allen_result return allen_result





class TestConfusionMatrixMetric(unittest.TestCase): class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric1(self): def test_ConfusionMatrixMetric1(self):
pred_dict = {"pred": torch.zeros(4,3)} pred_dict = {"pred": torch.zeros(4,3)}
@@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):


def test_ConfusionMatrixMetric2(self): def test_ConfusionMatrixMetric2(self):
# (2) with corrupted size # (2) with corrupted size

with self.assertRaises(Exception): with self.assertRaises(Exception):
pred_dict = {"pred": torch.zeros(4, 3, 2)} pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)} target_dict = {'target': torch.zeros(4)}
@@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase):
print(metric.get_metric()) print(metric.get_metric())



def test_ConfusionMatrixMetric4(self): def test_ConfusionMatrixMetric4(self):
# (4) check reset # (4) check reset
metric = ConfusionMatrixMetric() metric = ConfusionMatrixMetric()
@@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):


def test_ConfusionMatrixMetric5(self): def test_ConfusionMatrixMetric5(self):
# (5) check numpy array is not acceptable # (5) check numpy array is not acceptable

with self.assertRaises(Exception): with self.assertRaises(Exception):
metric = ConfusionMatrixMetric() metric = ConfusionMatrixMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))} pred_dict = {"pred": np.zeros((4, 3, 2))}
@@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())



def test_duplicate(self): def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况 # 0.4.1的潜在bug,不能出现形参重复的情况
metric = ConfusionMatrixMetric(pred='predictions', target='targets') metric = ConfusionMatrixMetric(pred='predictions', target='targets')
@@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())



def test_seq_len(self): def test_seq_len(self):
N = 256 N = 256
seq_len = torch.zeros(N).long() seq_len = torch.zeros(N).long()
@@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
print(metric.get_metric()) print(metric.get_metric())





class TestAccuracyMetric(unittest.TestCase): class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self): def test_AccuracyMetric1(self):
# (1) only input, targets passed # (1) only input, targets passed


+ 8
- 1
test/io/loader/test_conll_loader.py View File

@@ -2,7 +2,7 @@
import unittest import unittest
import os import os
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \
Conll2003Loader
Conll2003Loader, ConllLoader




class TestMSRANER(unittest.TestCase): class TestMSRANER(unittest.TestCase):
@@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase):
db = Conll2003Loader().load('test/data_for_tests/io/conll2003') db = Conll2003Loader().load('test/data_for_tests/io/conll2003')
print(db) print(db)


class TestConllLoader(unittest.TestCase):
def test_sep(self):
headers = [
'raw_words', 'ner',
]
db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER')
print(db)

Loading…
Cancel
Save