From 4e95989e973f59b2ecb7f718647257e8b6fea0c7 Mon Sep 17 00:00:00 2001 From: ROGERDJQ Date: Sun, 3 May 2020 19:46:11 +0800 Subject: [PATCH 1/3] [new] add seperator for conll loader (#293) * add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finish * revised confusion * revised two * revise two * add sep for conll loader * with remote * withdraw some update * finish merge * update test * update test * to avoid none situation --- fastNLP/core/metrics.py | 2 +- fastNLP/core/utils.py | 2 +- fastNLP/io/file_reader.py | 7 ++++--- fastNLP/io/loader/conll.py | 6 ++++-- test/core/test_metrics.py | 7 ++++++- test/io/loader/test_conll_loader.py | 9 ++++++++- 6 files changed, 24 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 7930978f..f893bd74 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase): print_ratio=False ): r""" + :param vocab: vocab词表类,要求有to_word()方法。 :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` @@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase): def evaluate(self, pred, target, seq_len=None): r""" evaluate函数将针对一个批次的预测结果做评价指标的累计 - :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 0821c6ff..74c812a2 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -62,6 +62,7 @@ class ConfusionMatrix: target = [2,2,1] confusion.add_pred_target(pred, target) print(confusion) + target 1 2 3 all pred 1 0 1 0 1 @@ -157,7 +158,6 @@ class ConfusionMatrix: (k, str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel ]) - for label, idx in zip(totallabel, range(lenth)): idx2row[ label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 6f8a53bb..e70440de 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): yield line_idx, _res -def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): +def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): r""" Construct a generator to read conll items. :param path: file path :param encoding: file's encoding, default: utf-8 + :param sep: seperator :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None :param dropna: weather to ignore and drop invalid data, :if False, raise ValueError when reading invalid data. default: True @@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): sample = [] start = next(f).strip() if start != '': - sample.append(start.split()) + sample.append(start.split(sep)) if sep else sample.append(start.split()) for line_idx, line in enumerate(f, 1): line = line.strip() if line == '': @@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): elif line.startswith('#'): continue else: - sample.append(line.split()) + sample.append(line.split(sep)) if sep else sample.append(line.split()) if len(sample) > 0: try: res = parse_conll(sample) diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py index 0457ee0d..36289db8 100644 --- a/fastNLP/io/loader/conll.py +++ b/fastNLP/io/loader/conll.py @@ -55,10 +55,11 @@ class ConllLoader(Loader): """ - def __init__(self, headers, indexes=None, dropna=True): + def __init__(self, headers, sep=None, indexes=None, dropna=True): r""" :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 + :param list sep: 指定分隔符,默认为制表符 :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` """ @@ -68,6 +69,7 @@ class ConllLoader(Loader): 'invalid headers: {}, should be list of strings'.format(headers)) self.headers = headers self.dropna = dropna + self.sep=sep if indexes is None: self.indexes = list(range(len(self.headers))) else: @@ -83,7 +85,7 @@ class ConllLoader(Loader): :return: DataSet """ ds = DataSet() - for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): + for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): ins = {h: data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 06144ff3..27799c54 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result): return allen_result + class TestConfusionMatrixMetric(unittest.TestCase): def test_ConfusionMatrixMetric1(self): pred_dict = {"pred": torch.zeros(4,3)} @@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): def test_ConfusionMatrixMetric2(self): # (2) with corrupted size + with self.assertRaises(Exception): pred_dict = {"pred": torch.zeros(4, 3, 2)} target_dict = {'target': torch.zeros(4)} @@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): print(metric.get_metric()) - def test_ConfusionMatrixMetric4(self): # (4) check reset metric = ConfusionMatrixMetric() @@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): def test_ConfusionMatrixMetric5(self): # (5) check numpy array is not acceptable + with self.assertRaises(Exception): metric = ConfusionMatrixMetric() pred_dict = {"pred": np.zeros((4, 3, 2))} @@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric()) + def test_duplicate(self): # 0.4.1的潜在bug,不能出现形参重复的情况 metric = ConfusionMatrixMetric(pred='predictions', target='targets') @@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric()) + def test_seq_len(self): N = 256 seq_len = torch.zeros(N).long() @@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): print(metric.get_metric()) + class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): # (1) only input, targets passed diff --git a/test/io/loader/test_conll_loader.py b/test/io/loader/test_conll_loader.py index 6668cccf..bf0ebb47 100644 --- a/test/io/loader/test_conll_loader.py +++ b/test/io/loader/test_conll_loader.py @@ -2,7 +2,7 @@ import unittest import os from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ - Conll2003Loader + Conll2003Loader, ConllLoader class TestMSRANER(unittest.TestCase): @@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase): db = Conll2003Loader().load('test/data_for_tests/io/conll2003') print(db) +class TestConllLoader(unittest.TestCase): + def test_sep(self): + headers = [ + 'raw_words', 'ner', + ] + db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER') + print(db) From 1f27d007d1cbdc018e2e844ccb9a7f698718a1d9 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Thu, 21 May 2020 12:48:08 +0800 Subject: [PATCH 2/3] [add] add MSELoss --- fastNLP/__init__.py | 3 ++- fastNLP/core/__init__.py | 4 +++- fastNLP/core/losses.py | 21 +++++++++++++++++++++ fastNLP/models/bert.py | 7 ++++++- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 5f18561a..4a963a46 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -69,6 +69,7 @@ __all__ = [ "LossFunc", "CrossEntropyLoss", + "MSELoss", "L1Loss", "BCELoss", "NLLLoss", @@ -81,7 +82,7 @@ __all__ = [ 'logger', "init_logger_dist", ] -__version__ = '0.5.0' +__version__ = '0.5.5' import sys diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index f4e42ab3..6eb3e424 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -65,6 +65,7 @@ __all__ = [ "NLLLoss", "LossInForward", "CMRC2018Loss", + "MSELoss", "LossBase", "MetricBase", @@ -94,7 +95,8 @@ from .const import Const from .dataset import DataSet from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .instance import Instance -from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase +from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ + LossInForward, CMRC2018Loss, LossBase, MSELoss from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ ConfusionMatrixMetric from .optimizer import Optimizer, SGD, Adam, AdamW diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 6788e6da..574738bb 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -12,6 +12,7 @@ __all__ = [ "BCELoss", "L1Loss", "NLLLoss", + "MSELoss", "CMRC2018Loss" @@ -265,6 +266,26 @@ class L1Loss(LossBase): return F.l1_loss(input=pred, target=target, reduction=self.reduction) +class MSELoss(LossBase): + r""" + MSE损失函数 + + :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` + :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` + :param str reduction: 支持'mean','sum'和'none'. + + """ + + def __init__(self, pred=None, target=None, reduction='mean'): + super(MSELoss, self).__init__() + self._init_param_map(pred=pred, target=target) + assert reduction in ('mean', 'sum', 'none') + self.reduction = reduction + + def get_loss(self, pred, target): + return F.mse_loss(input=pred, target=target, reduction=self.reduction) + + class BCELoss(LossBase): r""" 二分类交叉熵损失函数 diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py index 976b4638..827717d0 100644 --- a/fastNLP/models/bert.py +++ b/fastNLP/models/bert.py @@ -76,6 +76,8 @@ class BertForSequenceClassification(BaseModel): hidden = self.dropout(self.bert(words)) cls_hidden = hidden[:, 0] logits = self.classifier(cls_hidden) + if logits.size(-1) == 1: + logits = logits.squeeze(-1) return {Const.OUTPUT: logits} @@ -85,7 +87,10 @@ class BertForSequenceClassification(BaseModel): :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] """ logits = self.forward(words)[Const.OUTPUT] - return {Const.OUTPUT: torch.argmax(logits, dim=-1)} + if self.num_labels > 1: + return {Const.OUTPUT: torch.argmax(logits, dim=-1)} + else: + return {Const.OUTPUT: logits} class BertForSentenceMatching(BaseModel): From 8732dfd979aaff0960f063d1a9005f95130313d4 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Sat, 6 Jun 2020 11:06:17 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=BA=86=20DataBundle=20?= =?UTF-8?q?=E7=9A=84=20apply=5Fmore=20=E5=92=8C=20apply=5Ffield=5Fmore=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E3=80=82=E9=9C=80=E8=A6=81=E8=BF=9B=E4=B8=80?= =?UTF-8?q?=E6=AD=A5=E8=AF=95=E7=94=A8=E5=92=8C=E6=B5=8B=E8=AF=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 9 +++-- fastNLP/io/data_bundle.py | 70 +++++++++++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 464a6446..5e80a6fb 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -861,10 +861,9 @@ class DataSet(object): 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 - """ assert len(self) != 0, "Null DataSet cannot use apply_field()." - if field_name not in self: + if not self.has_field(field_name=field_name): raise KeyError("DataSet has no field named `{}`.".format(field_name)) return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) @@ -888,10 +887,10 @@ class DataSet(object): 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 - :return Dict[int:Field]: 返回一个字典 + :return Dict[str:Field]: 返回一个字典 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." - if field_name not in self: + if not self.has_field(field_name=field_name): raise KeyError("DataSet has no field named `{}`.".format(field_name)) return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) @@ -950,7 +949,7 @@ class DataSet(object): 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 - :return Dict[int:Field]: 返回一个字典 + :return Dict[str:Field]: 返回一个字典 """ # 返回 dict , 检查是否一直相同 assert callable(func), "The func you provide is not callable." diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index bcb8a211..e911a26f 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -166,7 +166,7 @@ class DataBundle: dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) return self - def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): + def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): r""" 将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. @@ -282,7 +282,7 @@ class DataBundle: """ return list(self.datasets.keys()) - def get_vocab_names(self)->List[str]: + def get_vocab_names(self) -> List[str]: r""" 返回DataBundle中Vocabulary的名称 @@ -304,9 +304,9 @@ class DataBundle: for field_name, vocab in self.vocabs.items(): yield field_name, vocab - def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): + def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): r""" - 对DataBundle中所有的dataset使用apply_field方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 :param callable func: input是instance中名为 `field_name` 的field的内容。 :param str field_name: 传入func的是哪个field。 @@ -329,8 +329,41 @@ class DataBundle: raise KeyError(f"{field_name} not found DataSet:{name}.") return self - def apply(self, func, new_field_name:str, **kwargs): + def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs): r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 + + .. note:: + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param str field_name: 传入func的是哪个field。 + :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :param optional kwargs: 支持输入is_input, is_target, ignore_type + + 1. is_input: bool, 如果为True则将被修改的field设置为input + + 2. is_target: bool, 如果为True则将被修改的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 + + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + """ + res = {} + for name, dataset in self.datasets.items(): + if dataset.has_field(field_name=field_name): + res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name} .") + return res + + def apply(self, func, new_field_name: str, **kwargs): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 + 对DataBundle中所有的dataset使用apply方法 :param callable func: input是instance中名为 `field_name` 的field的内容。 @@ -348,6 +381,31 @@ class DataBundle: dataset.apply(func, new_field_name=new_field_name, **kwargs) return self + def apply_more(self, func, modify_fields=True, **kwargs): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 + + .. note:: + ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param optional kwargs: 支持输入is_input,is_target,ignore_type + + 1. is_input: bool, 如果为True则将被修改的的field设置为input + + 2. is_target: bool, 如果为True则将被修改的的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 + + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + """ + res = {} + for name, dataset in self.datasets.items(): + res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) + return res + def add_collate_fn(self, fn, name=None): r""" 向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. @@ -380,5 +438,3 @@ class DataBundle: for name, vocab in self.vocabs.items(): _str += '\t{} has {} entries.\n'.format(name, len(vocab)) return _str - -