* add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finish * revised confusion * revised two * revise two * add sep for conll loader * with remote * withdraw some update * finish merge * update test * update test * to avoid none situationtags/v0.6.0
| @@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase): | |||||
| print_ratio=False | print_ratio=False | ||||
| ): | ): | ||||
| r""" | r""" | ||||
| :param vocab: vocab词表类,要求有to_word()方法。 | :param vocab: vocab词表类,要求有to_word()方法。 | ||||
| :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
| :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
| @@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase): | |||||
| def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
| r""" | r""" | ||||
| evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
| :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
| torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
| :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | ||||
| @@ -62,6 +62,7 @@ class ConfusionMatrix: | |||||
| target = [2,2,1] | target = [2,2,1] | ||||
| confusion.add_pred_target(pred, target) | confusion.add_pred_target(pred, target) | ||||
| print(confusion) | print(confusion) | ||||
| target 1 2 3 all | target 1 2 3 all | ||||
| pred | pred | ||||
| 1 0 1 0 1 | 1 0 1 0 1 | ||||
| @@ -157,7 +158,6 @@ class ConfusionMatrix: | |||||
| (k, str(k if self.vocab == None else self.vocab.to_word(k))) | (k, str(k if self.vocab == None else self.vocab.to_word(k))) | ||||
| for k in totallabel | for k in totallabel | ||||
| ]) | ]) | ||||
| for label, idx in zip(totallabel, range(lenth)): | for label, idx in zip(totallabel, range(lenth)): | ||||
| idx2row[ | idx2row[ | ||||
| label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | ||||
| @@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
| yield line_idx, _res | yield line_idx, _res | ||||
| def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
| def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
| r""" | r""" | ||||
| Construct a generator to read conll items. | Construct a generator to read conll items. | ||||
| :param path: file path | :param path: file path | ||||
| :param encoding: file's encoding, default: utf-8 | :param encoding: file's encoding, default: utf-8 | ||||
| :param sep: seperator | |||||
| :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | ||||
| :param dropna: weather to ignore and drop invalid data, | :param dropna: weather to ignore and drop invalid data, | ||||
| :if False, raise ValueError when reading invalid data. default: True | :if False, raise ValueError when reading invalid data. default: True | ||||
| @@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
| sample = [] | sample = [] | ||||
| start = next(f).strip() | start = next(f).strip() | ||||
| if start != '': | if start != '': | ||||
| sample.append(start.split()) | |||||
| sample.append(start.split(sep)) if sep else sample.append(start.split()) | |||||
| for line_idx, line in enumerate(f, 1): | for line_idx, line in enumerate(f, 1): | ||||
| line = line.strip() | line = line.strip() | ||||
| if line == '': | if line == '': | ||||
| @@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
| elif line.startswith('#'): | elif line.startswith('#'): | ||||
| continue | continue | ||||
| else: | else: | ||||
| sample.append(line.split()) | |||||
| sample.append(line.split(sep)) if sep else sample.append(line.split()) | |||||
| if len(sample) > 0: | if len(sample) > 0: | ||||
| try: | try: | ||||
| res = parse_conll(sample) | res = parse_conll(sample) | ||||
| @@ -55,10 +55,11 @@ class ConllLoader(Loader): | |||||
| """ | """ | ||||
| def __init__(self, headers, indexes=None, dropna=True): | |||||
| def __init__(self, headers, sep=None, indexes=None, dropna=True): | |||||
| r""" | r""" | ||||
| :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | ||||
| :param list sep: 指定分隔符,默认为制表符 | |||||
| :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | ||||
| :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | ||||
| """ | """ | ||||
| @@ -68,6 +69,7 @@ class ConllLoader(Loader): | |||||
| 'invalid headers: {}, should be list of strings'.format(headers)) | 'invalid headers: {}, should be list of strings'.format(headers)) | ||||
| self.headers = headers | self.headers = headers | ||||
| self.dropna = dropna | self.dropna = dropna | ||||
| self.sep=sep | |||||
| if indexes is None: | if indexes is None: | ||||
| self.indexes = list(range(len(self.headers))) | self.indexes = list(range(len(self.headers))) | ||||
| else: | else: | ||||
| @@ -83,7 +85,7 @@ class ConllLoader(Loader): | |||||
| :return: DataSet | :return: DataSet | ||||
| """ | """ | ||||
| ds = DataSet() | ds = DataSet() | ||||
| for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
| for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): | |||||
| ins = {h: data[i] for i, h in enumerate(self.headers)} | ins = {h: data[i] for i, h in enumerate(self.headers)} | ||||
| ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
| return ds | return ds | ||||
| @@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result): | |||||
| return allen_result | return allen_result | ||||
| class TestConfusionMatrixMetric(unittest.TestCase): | class TestConfusionMatrixMetric(unittest.TestCase): | ||||
| def test_ConfusionMatrixMetric1(self): | def test_ConfusionMatrixMetric1(self): | ||||
| pred_dict = {"pred": torch.zeros(4,3)} | pred_dict = {"pred": torch.zeros(4,3)} | ||||
| @@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
| def test_ConfusionMatrixMetric2(self): | def test_ConfusionMatrixMetric2(self): | ||||
| # (2) with corrupted size | # (2) with corrupted size | ||||
| with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
| pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
| target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
| @@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
| print(metric.get_metric()) | print(metric.get_metric()) | ||||
| def test_ConfusionMatrixMetric4(self): | def test_ConfusionMatrixMetric4(self): | ||||
| # (4) check reset | # (4) check reset | ||||
| metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
| @@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
| def test_ConfusionMatrixMetric5(self): | def test_ConfusionMatrixMetric5(self): | ||||
| # (5) check numpy array is not acceptable | # (5) check numpy array is not acceptable | ||||
| with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
| metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
| pred_dict = {"pred": np.zeros((4, 3, 2))} | pred_dict = {"pred": np.zeros((4, 3, 2))} | ||||
| @@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
| print(metric.get_metric()) | print(metric.get_metric()) | ||||
| def test_duplicate(self): | def test_duplicate(self): | ||||
| # 0.4.1的潜在bug,不能出现形参重复的情况 | # 0.4.1的潜在bug,不能出现形参重复的情况 | ||||
| metric = ConfusionMatrixMetric(pred='predictions', target='targets') | metric = ConfusionMatrixMetric(pred='predictions', target='targets') | ||||
| @@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
| metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
| print(metric.get_metric()) | print(metric.get_metric()) | ||||
| def test_seq_len(self): | def test_seq_len(self): | ||||
| N = 256 | N = 256 | ||||
| seq_len = torch.zeros(N).long() | seq_len = torch.zeros(N).long() | ||||
| @@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
| print(metric.get_metric()) | print(metric.get_metric()) | ||||
| class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
| def test_AccuracyMetric1(self): | def test_AccuracyMetric1(self): | ||||
| # (1) only input, targets passed | # (1) only input, targets passed | ||||
| @@ -2,7 +2,7 @@ | |||||
| import unittest | import unittest | ||||
| import os | import os | ||||
| from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ | from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ | ||||
| Conll2003Loader | |||||
| Conll2003Loader, ConllLoader | |||||
| class TestMSRANER(unittest.TestCase): | class TestMSRANER(unittest.TestCase): | ||||
| @@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase): | |||||
| db = Conll2003Loader().load('test/data_for_tests/io/conll2003') | db = Conll2003Loader().load('test/data_for_tests/io/conll2003') | ||||
| print(db) | print(db) | ||||
| class TestConllLoader(unittest.TestCase): | |||||
| def test_sep(self): | |||||
| headers = [ | |||||
| 'raw_words', 'ner', | |||||
| ] | |||||
| db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER') | |||||
| print(db) | |||||