Browse Source

准备开发0.5.6

tags/v0.6.0
yh_cc 5 years ago
parent
commit
1efd40c075
12 changed files with 122 additions and 23 deletions
  1. +1
    -0
      fastNLP/__init__.py
  2. +3
    -1
      fastNLP/core/__init__.py
  3. +4
    -5
      fastNLP/core/dataset.py
  4. +21
    -0
      fastNLP/core/losses.py
  5. +1
    -1
      fastNLP/core/metrics.py
  6. +1
    -1
      fastNLP/core/utils.py
  7. +63
    -7
      fastNLP/io/data_bundle.py
  8. +4
    -3
      fastNLP/io/file_reader.py
  9. +4
    -2
      fastNLP/io/loader/conll.py
  10. +6
    -1
      fastNLP/models/bert.py
  11. +6
    -1
      test/core/test_metrics.py
  12. +8
    -1
      test/io/loader/test_conll_loader.py

+ 1
- 0
fastNLP/__init__.py View File

@@ -69,6 +69,7 @@ __all__ = [
"LossFunc", "LossFunc",
"CrossEntropyLoss", "CrossEntropyLoss",
"MSELoss",
"L1Loss", "L1Loss",
"BCELoss", "BCELoss",
"NLLLoss", "NLLLoss",


+ 3
- 1
fastNLP/core/__init__.py View File

@@ -65,6 +65,7 @@ __all__ = [
"NLLLoss", "NLLLoss",
"LossInForward", "LossInForward",
"CMRC2018Loss", "CMRC2018Loss",
"MSELoss",
"LossBase", "LossBase",


"MetricBase", "MetricBase",
@@ -94,7 +95,8 @@ from .const import Const
from .dataset import DataSet from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, CMRC2018Loss, LossBase
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \
LossInForward, CMRC2018Loss, LossBase, MSELoss
from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\
ConfusionMatrixMetric ConfusionMatrixMetric
from .optimizer import Optimizer, SGD, Adam, AdamW from .optimizer import Optimizer, SGD, Adam, AdamW


+ 4
- 5
fastNLP/core/dataset.py View File

@@ -861,10 +861,9 @@ class DataSet(object):


3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度

""" """
assert len(self) != 0, "Null DataSet cannot use apply_field()." assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
if not self.has_field(field_name=field_name):
raise KeyError("DataSet has no field named `{}`.".format(field_name)) raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) return self.apply(func, new_field_name, _apply_field=field_name, **kwargs)


@@ -888,10 +887,10 @@ class DataSet(object):


3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型


:return Dict[int:Field]: 返回一个字典
:return Dict[str:Field]: 返回一个字典
""" """
assert len(self) != 0, "Null DataSet cannot use apply_field()." assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
if not self.has_field(field_name=field_name):
raise KeyError("DataSet has no field named `{}`.".format(field_name)) raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs)
@@ -950,7 +949,7 @@ class DataSet(object):


3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型


:return Dict[int:Field]: 返回一个字典
:return Dict[str:Field]: 返回一个字典
""" """
# 返回 dict , 检查是否一直相同 # 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable." assert callable(func), "The func you provide is not callable."


+ 21
- 0
fastNLP/core/losses.py View File

@@ -12,6 +12,7 @@ __all__ = [
"BCELoss", "BCELoss",
"L1Loss", "L1Loss",
"NLLLoss", "NLLLoss",
"MSELoss",


"CMRC2018Loss" "CMRC2018Loss"


@@ -265,6 +266,26 @@ class L1Loss(LossBase):
return F.l1_loss(input=pred, target=target, reduction=self.reduction) return F.l1_loss(input=pred, target=target, reduction=self.reduction)




class MSELoss(LossBase):
r"""
MSE损失函数

:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` >`target`
:param str reduction: 支持'mean','sum'和'none'.

"""

def __init__(self, pred=None, target=None, reduction='mean'):
super(MSELoss, self).__init__()
self._init_param_map(pred=pred, target=target)
assert reduction in ('mean', 'sum', 'none')
self.reduction = reduction

def get_loss(self, pred, target):
return F.mse_loss(input=pred, target=target, reduction=self.reduction)


class BCELoss(LossBase): class BCELoss(LossBase):
r""" r"""
二分类交叉熵损失函数 二分类交叉熵损失函数


+ 1
- 1
fastNLP/core/metrics.py View File

@@ -316,6 +316,7 @@ class ConfusionMatrixMetric(MetricBase):
print_ratio=False print_ratio=False
): ):
r""" r"""

:param vocab: vocab词表类,要求有to_word()方法。 :param vocab: vocab词表类,要求有to_word()方法。
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
@@ -332,7 +333,6 @@ class ConfusionMatrixMetric(MetricBase):
def evaluate(self, pred, target, seq_len=None): def evaluate(self, pred, target, seq_len=None):
r""" r"""
evaluate函数将针对一个批次的预测结果做评价指标的累计 evaluate函数将针对一个批次的预测结果做评价指标的累计
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),


+ 1
- 1
fastNLP/core/utils.py View File

@@ -62,6 +62,7 @@ class ConfusionMatrix:
target = [2,2,1] target = [2,2,1]
confusion.add_pred_target(pred, target) confusion.add_pred_target(pred, target)
print(confusion) print(confusion)

target 1 2 3 all target 1 2 3 all
pred pred
1 0 1 0 1 1 0 1 0 1
@@ -157,7 +158,6 @@ class ConfusionMatrix:
(k, str(k if self.vocab == None else self.vocab.to_word(k))) (k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel for k in totallabel
]) ])

for label, idx in zip(totallabel, range(lenth)): for label, idx in zip(totallabel, range(lenth)):
idx2row[ idx2row[
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...


+ 63
- 7
fastNLP/io/data_bundle.py View File

@@ -166,7 +166,7 @@ class DataBundle:
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
return self return self


def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True):
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True):
r""" r"""
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. 将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val.


@@ -282,7 +282,7 @@ class DataBundle:
""" """
return list(self.datasets.keys()) return list(self.datasets.keys())


def get_vocab_names(self)->List[str]:
def get_vocab_names(self) -> List[str]:
r""" r"""
返回DataBundle中Vocabulary的名称 返回DataBundle中Vocabulary的名称


@@ -304,9 +304,9 @@ class DataBundle:
for field_name, vocab in self.vocabs.items(): for field_name, vocab in self.vocabs.items():
yield field_name, vocab yield field_name, vocab


def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs):
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs):
r""" r"""
对DataBundle中所有的dataset使用apply_field方法
:class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法


:param callable func: input是instance中名为 `field_name` 的field的内容。 :param callable func: input是instance中名为 `field_name` 的field的内容。
:param str field_name: 传入func的是哪个field。 :param str field_name: 传入func的是哪个field。
@@ -329,8 +329,41 @@ class DataBundle:
raise KeyError(f"{field_name} not found DataSet:{name}.") raise KeyError(f"{field_name} not found DataSet:{name}.")
return self return self


def apply(self, func, new_field_name:str, **kwargs):
def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs):
r""" r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法

.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param optional kwargs: 支持输入is_input, is_target, ignore_type

1. is_input: bool, 如果为True则将被修改的field设置为input

2. is_target: bool, 如果为True则将被修改的field设置为target

3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型

:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
res = {}
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name} .")
return res

def apply(self, func, new_field_name: str, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法

对DataBundle中所有的dataset使用apply方法 对DataBundle中所有的dataset使用apply方法


:param callable func: input是instance中名为 `field_name` 的field的内容。 :param callable func: input是instance中名为 `field_name` 的field的内容。
@@ -348,6 +381,31 @@ class DataBundle:
dataset.apply(func, new_field_name=new_field_name, **kwargs) dataset.apply(func, new_field_name=new_field_name, **kwargs)
return self return self


def apply_more(self, func, modify_fields=True, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法

.. note::
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type

1. is_input: bool, 如果为True则将被修改的的field设置为input

2. is_target: bool, 如果为True则将被修改的的field设置为target

3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型

:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
res = {}
for name, dataset in self.datasets.items():
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs)
return res

def add_collate_fn(self, fn, name=None): def add_collate_fn(self, fn, name=None):
r""" r"""
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. 向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明.
@@ -380,5 +438,3 @@ class DataBundle:
for name, vocab in self.vocabs.items(): for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) _str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str return _str



+ 4
- 3
fastNLP/io/file_reader.py View File

@@ -81,12 +81,13 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True):
yield line_idx, _res yield line_idx, _res




def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
r""" r"""
Construct a generator to read conll items. Construct a generator to read conll items.


:param path: file path :param path: file path
:param encoding: file's encoding, default: utf-8 :param encoding: file's encoding, default: utf-8
:param sep: seperator
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None
:param dropna: weather to ignore and drop invalid data, :param dropna: weather to ignore and drop invalid data,
:if False, raise ValueError when reading invalid data. default: True :if False, raise ValueError when reading invalid data. default: True
@@ -105,7 +106,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
sample = [] sample = []
start = next(f).strip() start = next(f).strip()
if start != '': if start != '':
sample.append(start.split())
sample.append(start.split(sep)) if sep else sample.append(start.split())
for line_idx, line in enumerate(f, 1): for line_idx, line in enumerate(f, 1):
line = line.strip() line = line.strip()
if line == '': if line == '':
@@ -123,7 +124,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
elif line.startswith('#'): elif line.startswith('#'):
continue continue
else: else:
sample.append(line.split())
sample.append(line.split(sep)) if sep else sample.append(line.split())
if len(sample) > 0: if len(sample) > 0:
try: try:
res = parse_conll(sample) res = parse_conll(sample)


+ 4
- 2
fastNLP/io/loader/conll.py View File

@@ -55,10 +55,11 @@ class ConllLoader(Loader):


""" """
def __init__(self, headers, indexes=None, dropna=True):
def __init__(self, headers, sep=None, indexes=None, dropna=True):
r""" r"""
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param list sep: 指定分隔符,默认为制表符
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
""" """
@@ -68,6 +69,7 @@ class ConllLoader(Loader):
'invalid headers: {}, should be list of strings'.format(headers)) 'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers self.headers = headers
self.dropna = dropna self.dropna = dropna
self.sep=sep
if indexes is None: if indexes is None:
self.indexes = list(range(len(self.headers))) self.indexes = list(range(len(self.headers)))
else: else:
@@ -83,7 +85,7 @@ class ConllLoader(Loader):
:return: DataSet :return: DataSet
""" """
ds = DataSet() ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds


+ 6
- 1
fastNLP/models/bert.py View File

@@ -77,6 +77,8 @@ class BertForSequenceClassification(BaseModel):
hidden = self.dropout(self.bert(words)) hidden = self.dropout(self.bert(words))
cls_hidden = hidden[:, 0] cls_hidden = hidden[:, 0]
logits = self.classifier(cls_hidden) logits = self.classifier(cls_hidden)
if logits.size(-1) == 1:
logits = logits.squeeze(-1)


return {Const.OUTPUT: logits} return {Const.OUTPUT: logits}


@@ -86,7 +88,10 @@ class BertForSequenceClassification(BaseModel):
:return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size] :return: { :attr:`fastNLP.Const.OUTPUT` : logits}: torch.LongTensor [batch_size]
""" """
logits = self.forward(words)[Const.OUTPUT] logits = self.forward(words)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
if self.num_labels > 1:
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
else:
return {Const.OUTPUT: logits}




class BertForSentenceMatching(BaseModel): class BertForSentenceMatching(BaseModel):


+ 6
- 1
test/core/test_metrics.py View File

@@ -45,6 +45,7 @@ def _convert_res_to_fastnlp_res(metric_result):
return allen_result return allen_result





class TestConfusionMatrixMetric(unittest.TestCase): class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric1(self): def test_ConfusionMatrixMetric1(self):
pred_dict = {"pred": torch.zeros(4,3)} pred_dict = {"pred": torch.zeros(4,3)}
@@ -56,6 +57,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):


def test_ConfusionMatrixMetric2(self): def test_ConfusionMatrixMetric2(self):
# (2) with corrupted size # (2) with corrupted size

with self.assertRaises(Exception): with self.assertRaises(Exception):
pred_dict = {"pred": torch.zeros(4, 3, 2)} pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)} target_dict = {'target': torch.zeros(4)}
@@ -78,7 +80,6 @@ class TestConfusionMatrixMetric(unittest.TestCase):
print(metric.get_metric()) print(metric.get_metric())



def test_ConfusionMatrixMetric4(self): def test_ConfusionMatrixMetric4(self):
# (4) check reset # (4) check reset
metric = ConfusionMatrixMetric() metric = ConfusionMatrixMetric()
@@ -91,6 +92,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):


def test_ConfusionMatrixMetric5(self): def test_ConfusionMatrixMetric5(self):
# (5) check numpy array is not acceptable # (5) check numpy array is not acceptable

with self.assertRaises(Exception): with self.assertRaises(Exception):
metric = ConfusionMatrixMetric() metric = ConfusionMatrixMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))} pred_dict = {"pred": np.zeros((4, 3, 2))}
@@ -122,6 +124,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())



def test_duplicate(self): def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况 # 0.4.1的潜在bug,不能出现形参重复的情况
metric = ConfusionMatrixMetric(pred='predictions', target='targets') metric = ConfusionMatrixMetric(pred='predictions', target='targets')
@@ -130,6 +133,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())



def test_seq_len(self): def test_seq_len(self):
N = 256 N = 256
seq_len = torch.zeros(N).long() seq_len = torch.zeros(N).long()
@@ -155,6 +159,7 @@ class TestConfusionMatrixMetric(unittest.TestCase):
print(metric.get_metric()) print(metric.get_metric())





class TestAccuracyMetric(unittest.TestCase): class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self): def test_AccuracyMetric1(self):
# (1) only input, targets passed # (1) only input, targets passed


+ 8
- 1
test/io/loader/test_conll_loader.py View File

@@ -2,7 +2,7 @@
import unittest import unittest
import os import os
from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \ from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, \
Conll2003Loader
Conll2003Loader, ConllLoader




class TestMSRANER(unittest.TestCase): class TestMSRANER(unittest.TestCase):
@@ -35,3 +35,10 @@ class TestConllLoader(unittest.TestCase):
db = Conll2003Loader().load('test/data_for_tests/io/conll2003') db = Conll2003Loader().load('test/data_for_tests/io/conll2003')
print(db) print(db)


class TestConllLoader(unittest.TestCase):
def test_sep(self):
headers = [
'raw_words', 'ner',
]
db = ConllLoader(headers = headers,sep="\n").load('test/data_for_tests/io/MSRA_NER')
print(db)

Loading…
Cancel
Save