@@ -69,6 +69,7 @@ __all__ = [ | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
"MSELoss", | |||
"L1Loss", | |||
"BCELoss", | |||
"NLLLoss", | |||
@@ -65,6 +65,7 @@ __all__ = [ | |||
"NLLLoss", | |||
"LossInForward", | |||
"CMRC2018Loss", | |||
"MSELoss", | |||
"LossBase", | |||
"MetricBase", | |||
@@ -94,7 +95,8 @@ from .const import Const | |||
from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | |||
LossInForward, CMRC2018Loss, LossBase, MSELoss | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | |||
ConfusionMatrixMetric | |||
from .optimizer import Optimizer, SGD, Adam, AdamW | |||
@@ -861,10 +861,9 @@ class DataSet(object): | |||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
if field_name not in self: | |||
if not self.has_field(field_name=field_name): | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | |||
@@ -888,10 +887,10 @@ class DataSet(object): | |||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | |||
:return Dict[int:Field]: 返回一个字典 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
if field_name not in self: | |||
if not self.has_field(field_name=field_name): | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | |||
@@ -950,7 +949,7 @@ class DataSet(object): | |||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | |||
:return Dict[int:Field]: 返回一个字典 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
# 返回 dict , 检查是否一直相同 | |||
assert callable(func), "The func you provide is not callable." | |||
@@ -12,6 +12,7 @@ __all__ = [ | |||
"BCELoss", | |||
"L1Loss", | |||
"NLLLoss", | |||
"MSELoss", | |||
"CMRC2018Loss" | |||
@@ -265,6 +266,26 @@ class L1Loss(LossBase): | |||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | |||
class MSELoss(LossBase): | |||
r""" | |||
MSE损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | |||
:param str reduction: 支持'mean','sum'和'none'. | |||
""" | |||
def __init__(self, pred=None, target=None, reduction='mean'): | |||
super(MSELoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
assert reduction in ('mean', 'sum', 'none') | |||
self.reduction = reduction | |||
def get_loss(self, pred, target): | |||
return F.mse_loss(input=pred, target=target, reduction=self.reduction) | |||
class BCELoss(LossBase): | |||
r""" | |||
二分类交叉熵损失函数 | |||
@@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase): | |||
print_ratio=False | |||
): | |||
r""" | |||
:param vocab: vocab词表类,要求有to_word()方法。 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
@@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase): | |||
def evaluate(self, pred, target, seq_len=None): | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
@@ -62,6 +62,7 @@ class ConfusionMatrix: | |||
target = [2,2,1] | |||
confusion.add_pred_target(pred, target) | |||
print(confusion) | |||
target 1 2 3 all | |||
pred | |||
1 0 1 0 1 | |||
@@ -157,7 +158,6 @@ class ConfusionMatrix: | |||
(k, str(k if self.vocab == None else self.vocab.to_word(k))) | |||
for k in totallabel | |||
]) | |||
for label, idx in zip(totallabel, range(lenth)): | |||
idx2row[ | |||
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | |||
@@ -166,7 +166,7 @@ class DataBundle: | |||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||
return self | |||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||
r""" | |||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | |||
@@ -282,7 +282,7 @@ class DataBundle: | |||
""" | |||
return list(self.datasets.keys()) | |||
def get_vocab_names(self)->List[str]: | |||
def get_vocab_names(self) -> List[str]: | |||
r""" | |||
返回DataBundle中Vocabulary的名称 | |||
@@ -304,9 +304,9 @@ class DataBundle: | |||
for field_name, vocab in self.vocabs.items(): | |||
yield field_name, vocab | |||
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | |||
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | |||
r""" | |||
对DataBundle中所有的dataset使用apply_field方法 | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str field_name: 传入func的是哪个field。 | |||
@@ -329,8 +329,41 @@ class DataBundle: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def apply(self, func, new_field_name:str, **kwargs): | |||
def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
.. note:: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param optional kwargs: 支持输入is_input, is_target, ignore_type | |||
1. is_input: bool, 如果为True则将被修改的field设置为input | |||
2. is_target: bool, 如果为True则将被修改的field设置为target | |||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
""" | |||
res = {} | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name} .") | |||
return res | |||
def apply(self, func, new_field_name: str, **kwargs): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
对DataBundle中所有的dataset使用apply方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
@@ -348,6 +381,31 @@ class DataBundle: | |||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | |||
return self | |||
def apply_more(self, func, modify_fields=True, **kwargs): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
.. note:: | |||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将被修改的的field设置为input | |||
2. is_target: bool, 如果为True则将被修改的的field设置为target | |||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
""" | |||
res = {} | |||
for name, dataset in self.datasets.items(): | |||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) | |||
return res | |||
def add_collate_fn(self, fn, name=None): | |||
r""" | |||
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | |||
@@ -380,5 +438,3 @@ class DataBundle: | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
@@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
yield line_idx, _res | |||
def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
r""" | |||
Construct a generator to read conll items. | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param sep: seperator | |||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
:if False, raise ValueError when reading invalid data. default: True | |||
@@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
sample = [] | |||
start = next(f).strip() | |||
if start != '': | |||
sample.append(start.split()) | |||
sample.append(start.split(sep)) if sep else sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
line = line.strip() | |||
if line == '': | |||
@@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split()) | |||
sample.append(line.split(sep)) if sep else sample.append(line.split()) | |||
if len(sample) > 0: | |||
try: | |||
res = parse_conll(sample) | |||
@@ -55,10 +55,11 @@ class ConllLoader(Loader): | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=True): | |||
def __init__(self, headers, sep=None, indexes=None, dropna=True): | |||
r""" | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list sep: 指定分隔符,默认为制表符 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
""" | |||
@@ -68,6 +69,7 @@ class ConllLoader(Loader): | |||
'invalid headers: {}, should be list of strings'.format(headers)) | |||
self.headers = headers | |||
self.dropna = dropna | |||
self.sep=sep | |||
if indexes is None: | |||
self.indexes = list(range(len(self.headers))) | |||
else: | |||
@@ -83,7 +85,7 @@ class ConllLoader(Loader): | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
@@ -77,6 +77,8 @@ class BertForSequenceClassification(BaseModel): | |||
hidden = self.dropout(self.bert(words)) | |||
cls_hidden = hidden[:, 0] | |||
logits = self.classifier(cls_hidden) | |||
if logits.size(-1) == 1: | |||
logits = logits.squeeze(-1) | |||
return {Const.OUTPUT: logits} | |||
@@ -86,7 +88,10 @@ class BertForSequenceClassification(BaseModel): | |||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | |||
""" | |||
logits = self.forward(words)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
if self.num_labels > 1: | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
class BertForSentenceMatching(BaseModel): | |||
@@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result): | |||
return allen_result | |||
class TestConfusionMatrixMetric(unittest.TestCase): | |||
def test_ConfusionMatrixMetric1(self): | |||
pred_dict = {"pred": torch.zeros(4,3)} | |||
@@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||
def test_ConfusionMatrixMetric2(self): | |||
# (2) with corrupted size | |||
with self.assertRaises(Exception): | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
@@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||
print(metric.get_metric()) | |||
def test_ConfusionMatrixMetric4(self): | |||
# (4) check reset | |||
metric = ConfusionMatrixMetric() | |||
@@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||
def test_ConfusionMatrixMetric5(self): | |||
# (5) check numpy array is not acceptable | |||
with self.assertRaises(Exception): | |||
metric = ConfusionMatrixMetric() | |||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
@@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
def test_duplicate(self): | |||
# 0.4.1的潜在bug,不能出现形参重复的情况 | |||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||
@@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
def test_seq_len(self): | |||
N = 256 | |||
seq_len = torch.zeros(N).long() | |||
@@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||
print(metric.get_metric()) | |||
class TestAccuracyMetric(unittest.TestCase): | |||
def test_AccuracyMetric1(self): | |||
# (1) only input, targets passed | |||
@@ -2,7 +2,7 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ | |||
Conll2003Loader | |||
Conll2003Loader, ConllLoader | |||
class TestMSRANER(unittest.TestCase): | |||
@@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase): | |||
db = Conll2003Loader().load('test/data_for_tests/io/conll2003') | |||
print(db) | |||
class TestConllLoader(unittest.TestCase): | |||
def test_sep(self): | |||
headers = [ | |||
'raw_words', 'ner', | |||
] | |||
db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER') | |||
print(db) |