@@ -69,6 +69,7 @@ __all__ = [ | |||||
"LossFunc", | "LossFunc", | ||||
"CrossEntropyLoss", | "CrossEntropyLoss", | ||||
"MSELoss", | |||||
"L1Loss", | "L1Loss", | ||||
"BCELoss", | "BCELoss", | ||||
"NLLLoss", | "NLLLoss", | ||||
@@ -65,6 +65,7 @@ __all__ = [ | |||||
"NLLLoss", | "NLLLoss", | ||||
"LossInForward", | "LossInForward", | ||||
"CMRC2018Loss", | "CMRC2018Loss", | ||||
"MSELoss", | |||||
"LossBase", | "LossBase", | ||||
"MetricBase", | "MetricBase", | ||||
@@ -94,7 +95,8 @@ from .const import Const | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase | |||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ | |||||
LossInForward, CMRC2018Loss, LossBase, MSELoss | |||||
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ | ||||
ConfusionMatrixMetric | ConfusionMatrixMetric | ||||
from .optimizer import Optimizer, SGD, Adam, AdamW | from .optimizer import Optimizer, SGD, Adam, AdamW | ||||
@@ -861,10 +861,9 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
if field_name not in self: | |||||
if not self.has_field(field_name=field_name): | |||||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | raise KeyError("DataSet has no field named `{}`.".format(field_name)) | ||||
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | ||||
@@ -888,10 +887,10 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | ||||
:return Dict[int:Field]: 返回一个字典 | |||||
:return Dict[str:Field]: 返回一个字典 | |||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
if field_name not in self: | |||||
if not self.has_field(field_name=field_name): | |||||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | raise KeyError("DataSet has no field named `{}`.".format(field_name)) | ||||
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | ||||
@@ -950,7 +949,7 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | ||||
:return Dict[int:Field]: 返回一个字典 | |||||
:return Dict[str:Field]: 返回一个字典 | |||||
""" | """ | ||||
# 返回 dict , 检查是否一直相同 | # 返回 dict , 检查是否一直相同 | ||||
assert callable(func), "The func you provide is not callable." | assert callable(func), "The func you provide is not callable." | ||||
@@ -12,6 +12,7 @@ __all__ = [ | |||||
"BCELoss", | "BCELoss", | ||||
"L1Loss", | "L1Loss", | ||||
"NLLLoss", | "NLLLoss", | ||||
"MSELoss", | |||||
"CMRC2018Loss" | "CMRC2018Loss" | ||||
@@ -265,6 +266,26 @@ class L1Loss(LossBase): | |||||
return F.l1_loss(input=pred, target=target, reduction=self.reduction) | return F.l1_loss(input=pred, target=target, reduction=self.reduction) | ||||
class MSELoss(LossBase): | |||||
r""" | |||||
MSE损失函数 | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target` | |||||
:param str reduction: 支持'mean','sum'和'none'. | |||||
""" | |||||
def __init__(self, pred=None, target=None, reduction='mean'): | |||||
super(MSELoss, self).__init__() | |||||
self._init_param_map(pred=pred, target=target) | |||||
assert reduction in ('mean', 'sum', 'none') | |||||
self.reduction = reduction | |||||
def get_loss(self, pred, target): | |||||
return F.mse_loss(input=pred, target=target, reduction=self.reduction) | |||||
class BCELoss(LossBase): | class BCELoss(LossBase): | ||||
r""" | r""" | ||||
二分类交叉熵损失函数 | 二分类交叉熵损失函数 | ||||
@@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase): | |||||
print_ratio=False | print_ratio=False | ||||
): | ): | ||||
r""" | r""" | ||||
:param vocab: vocab词表类,要求有to_word()方法。 | :param vocab: vocab词表类,要求有to_word()方法。 | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
@@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase): | |||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
r""" | r""" | ||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | ||||
@@ -62,6 +62,7 @@ class ConfusionMatrix: | |||||
target = [2,2,1] | target = [2,2,1] | ||||
confusion.add_pred_target(pred, target) | confusion.add_pred_target(pred, target) | ||||
print(confusion) | print(confusion) | ||||
target 1 2 3 all | target 1 2 3 all | ||||
pred | pred | ||||
1 0 1 0 1 | 1 0 1 0 1 | ||||
@@ -157,7 +158,6 @@ class ConfusionMatrix: | |||||
(k, str(k if self.vocab == None else self.vocab.to_word(k))) | (k, str(k if self.vocab == None else self.vocab.to_word(k))) | ||||
for k in totallabel | for k in totallabel | ||||
]) | ]) | ||||
for label, idx in zip(totallabel, range(lenth)): | for label, idx in zip(totallabel, range(lenth)): | ||||
idx2row[ | idx2row[ | ||||
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | ||||
@@ -166,7 +166,7 @@ class DataBundle: | |||||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | ||||
return self | return self | ||||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||||
r""" | r""" | ||||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | 将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | ||||
@@ -282,7 +282,7 @@ class DataBundle: | |||||
""" | """ | ||||
return list(self.datasets.keys()) | return list(self.datasets.keys()) | ||||
def get_vocab_names(self)->List[str]: | |||||
def get_vocab_names(self) -> List[str]: | |||||
r""" | r""" | ||||
返回DataBundle中Vocabulary的名称 | 返回DataBundle中Vocabulary的名称 | ||||
@@ -304,9 +304,9 @@ class DataBundle: | |||||
for field_name, vocab in self.vocabs.items(): | for field_name, vocab in self.vocabs.items(): | ||||
yield field_name, vocab | yield field_name, vocab | ||||
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | |||||
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): | |||||
r""" | r""" | ||||
对DataBundle中所有的dataset使用apply_field方法 | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 | |||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | :param callable func: input是instance中名为 `field_name` 的field的内容。 | ||||
:param str field_name: 传入func的是哪个field。 | :param str field_name: 传入func的是哪个field。 | ||||
@@ -329,8 +329,41 @@ class DataBundle: | |||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | raise KeyError(f"{field_name} not found DataSet:{name}.") | ||||
return self | return self | ||||
def apply(self, func, new_field_name:str, **kwargs): | |||||
def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||||
.. note:: | |||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||||
:param str field_name: 传入func的是哪个field。 | |||||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||||
如果为False,则报错 | |||||
:param optional kwargs: 支持输入is_input, is_target, ignore_type | |||||
1. is_input: bool, 如果为True则将被修改的field设置为input | |||||
2. is_target: bool, 如果为True则将被修改的field设置为target | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||||
""" | |||||
res = {} | |||||
for name, dataset in self.datasets.items(): | |||||
if dataset.has_field(field_name=field_name): | |||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) | |||||
elif not ignore_miss_dataset: | |||||
raise KeyError(f"{field_name} not found DataSet:{name} .") | |||||
return res | |||||
def apply(self, func, new_field_name: str, **kwargs): | |||||
r""" | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||||
对DataBundle中所有的dataset使用apply方法 | 对DataBundle中所有的dataset使用apply方法 | ||||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | :param callable func: input是instance中名为 `field_name` 的field的内容。 | ||||
@@ -348,6 +381,31 @@ class DataBundle: | |||||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | dataset.apply(func, new_field_name=new_field_name, **kwargs) | ||||
return self | return self | ||||
def apply_more(self, func, modify_fields=True, **kwargs): | |||||
r""" | |||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||||
.. note:: | |||||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||||
1. is_input: bool, 如果为True则将被修改的的field设置为input | |||||
2. is_target: bool, 如果为True则将被修改的的field设置为target | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||||
""" | |||||
res = {} | |||||
for name, dataset in self.datasets.items(): | |||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) | |||||
return res | |||||
def add_collate_fn(self, fn, name=None): | def add_collate_fn(self, fn, name=None): | ||||
r""" | r""" | ||||
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | 向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. | ||||
@@ -380,5 +438,3 @@ class DataBundle: | |||||
for name, vocab in self.vocabs.items(): | for name, vocab in self.vocabs.items(): | ||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | _str += '\t{} has {} entries.\n'.format(name, len(vocab)) | ||||
return _str | return _str | ||||
@@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
yield line_idx, _res | yield line_idx, _res | ||||
def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
r""" | r""" | ||||
Construct a generator to read conll items. | Construct a generator to read conll items. | ||||
:param path: file path | :param path: file path | ||||
:param encoding: file's encoding, default: utf-8 | :param encoding: file's encoding, default: utf-8 | ||||
:param sep: seperator | |||||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | ||||
:param dropna: weather to ignore and drop invalid data, | :param dropna: weather to ignore and drop invalid data, | ||||
:if False, raise ValueError when reading invalid data. default: True | :if False, raise ValueError when reading invalid data. default: True | ||||
@@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
sample = [] | sample = [] | ||||
start = next(f).strip() | start = next(f).strip() | ||||
if start != '': | if start != '': | ||||
sample.append(start.split()) | |||||
sample.append(start.split(sep)) if sep else sample.append(start.split()) | |||||
for line_idx, line in enumerate(f, 1): | for line_idx, line in enumerate(f, 1): | ||||
line = line.strip() | line = line.strip() | ||||
if line == '': | if line == '': | ||||
@@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
sample.append(line.split()) | |||||
sample.append(line.split(sep)) if sep else sample.append(line.split()) | |||||
if len(sample) > 0: | if len(sample) > 0: | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -55,10 +55,11 @@ class ConllLoader(Loader): | |||||
""" | """ | ||||
def __init__(self, headers, indexes=None, dropna=True): | |||||
def __init__(self, headers, sep=None, indexes=None, dropna=True): | |||||
r""" | r""" | ||||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | ||||
:param list sep: 指定分隔符,默认为制表符 | |||||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | ||||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | ||||
""" | """ | ||||
@@ -68,6 +69,7 @@ class ConllLoader(Loader): | |||||
'invalid headers: {}, should be list of strings'.format(headers)) | 'invalid headers: {}, should be list of strings'.format(headers)) | ||||
self.headers = headers | self.headers = headers | ||||
self.dropna = dropna | self.dropna = dropna | ||||
self.sep=sep | |||||
if indexes is None: | if indexes is None: | ||||
self.indexes = list(range(len(self.headers))) | self.indexes = list(range(len(self.headers))) | ||||
else: | else: | ||||
@@ -83,7 +85,7 @@ class ConllLoader(Loader): | |||||
:return: DataSet | :return: DataSet | ||||
""" | """ | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | ins = {h: data[i] for i, h in enumerate(self.headers)} | ||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
@@ -77,6 +77,8 @@ class BertForSequenceClassification(BaseModel): | |||||
hidden = self.dropout(self.bert(words)) | hidden = self.dropout(self.bert(words)) | ||||
cls_hidden = hidden[:, 0] | cls_hidden = hidden[:, 0] | ||||
logits = self.classifier(cls_hidden) | logits = self.classifier(cls_hidden) | ||||
if logits.size(-1) == 1: | |||||
logits = logits.squeeze(-1) | |||||
return {Const.OUTPUT: logits} | return {Const.OUTPUT: logits} | ||||
@@ -86,7 +88,10 @@ class BertForSequenceClassification(BaseModel): | |||||
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] | ||||
""" | """ | ||||
logits = self.forward(words)[Const.OUTPUT] | logits = self.forward(words)[Const.OUTPUT] | ||||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||||
if self.num_labels > 1: | |||||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||||
else: | |||||
return {Const.OUTPUT: logits} | |||||
class BertForSentenceMatching(BaseModel): | class BertForSentenceMatching(BaseModel): | ||||
@@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result): | |||||
return allen_result | return allen_result | ||||
class TestConfusionMatrixMetric(unittest.TestCase): | class TestConfusionMatrixMetric(unittest.TestCase): | ||||
def test_ConfusionMatrixMetric1(self): | def test_ConfusionMatrixMetric1(self): | ||||
pred_dict = {"pred": torch.zeros(4,3)} | pred_dict = {"pred": torch.zeros(4,3)} | ||||
@@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
def test_ConfusionMatrixMetric2(self): | def test_ConfusionMatrixMetric2(self): | ||||
# (2) with corrupted size | # (2) with corrupted size | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
@@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_ConfusionMatrixMetric4(self): | def test_ConfusionMatrixMetric4(self): | ||||
# (4) check reset | # (4) check reset | ||||
metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
@@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
def test_ConfusionMatrixMetric5(self): | def test_ConfusionMatrixMetric5(self): | ||||
# (5) check numpy array is not acceptable | # (5) check numpy array is not acceptable | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
metric = ConfusionMatrixMetric() | metric = ConfusionMatrixMetric() | ||||
pred_dict = {"pred": np.zeros((4, 3, 2))} | pred_dict = {"pred": np.zeros((4, 3, 2))} | ||||
@@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_duplicate(self): | def test_duplicate(self): | ||||
# 0.4.1的潜在bug,不能出现形参重复的情况 | # 0.4.1的潜在bug,不能出现形参重复的情况 | ||||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | metric = ConfusionMatrixMetric(pred='predictions', target='targets') | ||||
@@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_seq_len(self): | def test_seq_len(self): | ||||
N = 256 | N = 256 | ||||
seq_len = torch.zeros(N).long() | seq_len = torch.zeros(N).long() | ||||
@@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase): | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
def test_AccuracyMetric1(self): | def test_AccuracyMetric1(self): | ||||
# (1) only input, targets passed | # (1) only input, targets passed | ||||
@@ -2,7 +2,7 @@ | |||||
import unittest | import unittest | ||||
import os | import os | ||||
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ | from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ | ||||
Conll2003Loader | |||||
Conll2003Loader, ConllLoader | |||||
class TestMSRANER(unittest.TestCase): | class TestMSRANER(unittest.TestCase): | ||||
@@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase): | |||||
db = Conll2003Loader().load('test/data_for_tests/io/conll2003') | db = Conll2003Loader().load('test/data_for_tests/io/conll2003') | ||||
print(db) | print(db) | ||||
class TestConllLoader(unittest.TestCase): | |||||
def test_sep(self): | |||||
headers = [ | |||||
'raw_words', 'ner', | |||||
] | |||||
db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER') | |||||
print(db) |