diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 7930978f..f893bd74 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase): print_ratio=False ): r""" + :param vocab: vocab词表类,要求有to_word()方法。 :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` @@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase): def evaluate(self, pred, target, seq_len=None): r""" evaluate函数将针对一个批次的预测结果做评价指标的累计 - :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 0821c6ff..74c812a2 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -62,6 +62,7 @@ class ConfusionMatrix: target = [2,2,1] confusion.add_pred_target(pred, target) print(confusion) + target 1 2 3 all pred 1 0 1 0 1 @@ -157,7 +158,6 @@ class ConfusionMatrix: (k, str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel ]) - for label, idx in zip(totallabel, range(lenth)): idx2row[ label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 6f8a53bb..e70440de 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): yield line_idx, _res -def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): +def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): r""" Construct a generator to read conll items. :param path: file path :param encoding: file's encoding, default: utf-8 + :param sep: seperator :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None :param dropna: weather to ignore and drop invalid data, :if False, raise ValueError when reading invalid data. default: True @@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): sample = [] start = next(f).strip() if start != '': - sample.append(start.split()) + sample.append(start.split(sep)) if sep else sample.append(start.split()) for line_idx, line in enumerate(f, 1): line = line.strip() if line == '': @@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): elif line.startswith('#'): continue else: - sample.append(line.split()) + sample.append(line.split(sep)) if sep else sample.append(line.split()) if len(sample) > 0: try: res = parse_conll(sample) diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py index 0457ee0d..36289db8 100644 --- a/fastNLP/io/loader/conll.py +++ b/fastNLP/io/loader/conll.py @@ -55,10 +55,11 @@ class ConllLoader(Loader): """ - def __init__(self, headers, indexes=None, dropna=True): + def __init__(self, headers, sep=None, indexes=None, dropna=True): r""" :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list sep: 指定分隔符,默认为制表符 :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` """ @@ -68,6 +69,7 @@ class ConllLoader(Loader): 'invalid headers: {}, should be list of strings'.format(headers)) self.headers = headers self.dropna = dropna + self.sep=sep if indexes is None: self.indexes = list(range(len(self.headers))) else: @@ -83,7 +85,7 @@ class ConllLoader(Loader): :return: DataSet """ ds = DataSet() - for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): ins = {h: data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 06144ff3..27799c54 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result): return allen_result + class TestConfusionMatrixMetric(unittest.TestCase): def test_ConfusionMatrixMetric1(self): pred_dict = {"pred": torch.zeros(4,3)} @@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): def test_ConfusionMatrixMetric2(self): # (2) with corrupted size + with self.assertRaises(Exception): pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4)} @@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): print(metric.get_metric()) - def test_ConfusionMatrixMetric4(self): # (4) check reset metric = ConfusionMatrixMetric() @@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): def test_ConfusionMatrixMetric5(self): # (5) check numpy array is not acceptable + with self.assertRaises(Exception): metric = ConfusionMatrixMetric() pred_dict = {"pred": np.zeros((4, 3, 2))} @@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric()) + def test_duplicate(self): # 0.4.1的潜在bug,不能出现形参重复的情况 metric = ConfusionMatrixMetric(pred='predictions', target='targets') @@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric()) + def test_seq_len(self): N = 256 seq_len = torch.zeros(N).long() @@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): print(metric.get_metric()) + class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): # (1) only input, targets passed diff --git a/test/io/loader/test_conll_loader.py b/test/io/loader/test_conll_loader.py index 6668cccf..bf0ebb47 100644 --- a/test/io/loader/test_conll_loader.py +++ b/test/io/loader/test_conll_loader.py @@ -2,7 +2,7 @@ import unittest import os from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ - Conll2003Loader + Conll2003Loader, ConllLoader class TestMSRANER(unittest.TestCase): @@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase): db = Conll2003Loader().load('test/data_for_tests/io/conll2003') print(db) +class TestConllLoader(unittest.TestCase): + def test_sep(self): + headers = [ + 'raw_words', 'ner', + ] + db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER') + print(db)