@@ -69,13 +69,20 @@ class DataSetGetter: | |||||
def may_to_tensor(data): | def may_to_tensor(data): | ||||
dtype, dim = _get_ele_type_and_dim(data) | dtype, dim = _get_ele_type_and_dim(data) | ||||
print(dtype, type(dtype)) | |||||
# print(dtype, type(dtype), str(dtype)) | |||||
if not self.as_numpy: | if not self.as_numpy: | ||||
try: | try: | ||||
data, flag = _to_tensor(data, dtype) | data, flag = _to_tensor(data, dtype) | ||||
except TypeError as e: | except TypeError as e: | ||||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | logger.error(f"Field {n} cannot be converted to torch.tensor.") | ||||
raise e | raise e | ||||
# if torch.is_tensor(data): | |||||
# str_dtype = str(dtype) | |||||
# if 'float' in str_dtype: | |||||
# data = data.float() | |||||
# elif 'int' in str_dtype: | |||||
# data = data.long() | |||||
# print(data.dtype) | |||||
return data | return data | ||||
def pad(batch_dict): | def pad(batch_dict): | ||||
@@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype): | |||||
if field_dtype is not None and isinstance(field_dtype, type)\ | if field_dtype is not None and isinstance(field_dtype, type)\ | ||||
and issubclass(field_dtype, Number) \ | and issubclass(field_dtype, Number) \ | ||||
and not isinstance(batch, torch.Tensor): | and not isinstance(batch, torch.Tensor): | ||||
if issubclass(field_dtype, np.floating): | |||||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||||
elif issubclass(field_dtype, np.integer): | |||||
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 | |||||
else: | |||||
new_batch = torch.as_tensor(batch) | |||||
return new_batch, True | |||||
new_batch = torch.as_tensor(batch) | |||||
flag = True | |||||
else: | else: | ||||
return batch, False | |||||
new_batch = batch | |||||
flag = False | |||||
if torch.is_tensor(new_batch): | |||||
if 'float' in new_batch.dtype.__repr__(): | |||||
new_batch = new_batch.float() | |||||
elif 'int' in new_batch.dtype.__repr__(): | |||||
new_batch = new_batch.long() | |||||
return new_batch, flag | |||||
except Exception as e: | except Exception as e: | ||||
raise e | raise e |
@@ -118,6 +118,12 @@ class Collector: | |||||
def outputs(self): | def outputs(self): | ||||
return self.output2fn.keys() | return self.output2fn.keys() | ||||
def copy_from(self, col): | |||||
assert isinstance(col, Collector) | |||||
self.fns = col.fns.copy() | |||||
self.input2fn = col.input2fn.copy() | |||||
self.output2fn = col.output2fn.copy() | |||||
self._clear_fn2io() | |||||
class CollectFn: | class CollectFn: | ||||
def __init__(self): | def __init__(self): | ||||
@@ -284,7 +284,7 @@ | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"DataSet" | |||||
"DataSet", | |||||
] | ] | ||||
import _pickle as pickle | import _pickle as pickle | ||||
@@ -305,6 +305,12 @@ from .utils import pretty_table_printer | |||||
from .collect_fn import Collector | from .collect_fn import Collector | ||||
class ApplyResultException(Exception): | |||||
def __init__(self, msg, index=None): | |||||
super().__init__(msg) | |||||
self.msg = msg | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | ||||
@@ -569,6 +575,7 @@ class DataSet(object): | |||||
:param str field_name: 需要删除的field的名称. | :param str field_name: 需要删除的field的名称. | ||||
""" | """ | ||||
self.field_arrays.pop(field_name) | self.field_arrays.pop(field_name) | ||||
self.collector.drop_field(field_name) | |||||
return self | return self | ||||
def copy_field(self, field_name, new_field_name): | def copy_field(self, field_name, new_field_name): | ||||
@@ -641,6 +648,7 @@ class DataSet(object): | |||||
if field_name in self.field_arrays: | if field_name in self.field_arrays: | ||||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | ||||
self.field_arrays[new_field_name].name = new_field_name | self.field_arrays[new_field_name].name = new_field_name | ||||
self.collector.rename_field(field_name, new_field_name) | |||||
else: | else: | ||||
raise KeyError("DataSet has no field named {}.".format(field_name)) | raise KeyError("DataSet has no field named {}.".format(field_name)) | ||||
return self | return self | ||||
@@ -778,23 +786,35 @@ class DataSet(object): | |||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
if field_name not in self: | if field_name not in self: | ||||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | raise KeyError("DataSet has no field named `{}`.".format(field_name)) | ||||
results = [] | |||||
idx = -1 | |||||
try: | |||||
for idx, ins in enumerate(self._inner_iter()): | |||||
results.append(func(ins[field_name])) | |||||
except Exception as e: | |||||
if idx != -1: | |||||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) | |||||
raise e | |||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||||
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) | |||||
if new_field_name is not None: | |||||
self._add_apply_field(results, new_field_name, kwargs) | |||||
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): | |||||
""" | |||||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | |||||
func 可以返回一个或多个 field 上的结果。 | |||||
.. note:: | |||||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||||
``apply`` 区别的介绍。 | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||||
:param str field_name: 传入func的是哪个field。 | |||||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||||
return results | |||||
1. is_input: bool, 如果为True则将被修改的field设置为input | |||||
2. is_target: bool, 如果为True则将被修改的field设置为target | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | |||||
:return Dict[int:Field]: 返回一个字典 | |||||
""" | |||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||||
if field_name not in self: | |||||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||||
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) | |||||
def _add_apply_field(self, results, new_field_name, kwargs): | def _add_apply_field(self, results, new_field_name, kwargs): | ||||
""" | """ | ||||
将results作为加入到新的field中,field名称为new_field_name | 将results作为加入到新的field中,field名称为new_field_name | ||||
@@ -827,12 +847,73 @@ class DataSet(object): | |||||
is_target=extra_param.get("is_target", None), | is_target=extra_param.get("is_target", None), | ||||
ignore_type=extra_param.get("ignore_type", False)) | ignore_type=extra_param.get("ignore_type", False)) | ||||
def apply_more(self, func, modify_fields=True, **kwargs): | |||||
""" | |||||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | |||||
.. note:: | |||||
``apply_more`` 与 ``apply`` 的区别: | |||||
1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果; | |||||
2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果; | |||||
3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||||
1. is_input: bool, 如果为True则将被修改的的field设置为input | |||||
2. is_target: bool, 如果为True则将被修改的的field设置为target | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | |||||
:return Dict[int:Field]: 返回一个字典 | |||||
""" | |||||
# 返回 dict , 检查是否一直相同 | |||||
assert callable(func), "The func you provide is not callable." | |||||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||||
idx = -1 | |||||
try: | |||||
results = {} | |||||
for idx, ins in enumerate(self._inner_iter()): | |||||
if "_apply_field" in kwargs: | |||||
res = func(ins[kwargs["_apply_field"]]) | |||||
else: | |||||
res = func(ins) | |||||
if not isinstance(res, dict): | |||||
raise ApplyResultException("The result of func is not a dict", idx) | |||||
if idx == 0: | |||||
for key, value in res.items(): | |||||
results[key] = [value] | |||||
else: | |||||
for key, value in res.items(): | |||||
if key not in results: | |||||
raise ApplyResultException("apply results have different fields", idx) | |||||
results[key].append(value) | |||||
if len(res) != len(results): | |||||
raise ApplyResultException("apply results have different fields", idx) | |||||
except Exception as e: | |||||
if idx != -1: | |||||
if isinstance(e, ApplyResultException): | |||||
logger.error(e.msg) | |||||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||||
raise e | |||||
if modify_fields is True: | |||||
for field, result in results.items(): | |||||
self._add_apply_field(result, field, kwargs) | |||||
return results | |||||
def apply(self, func, new_field_name=None, **kwargs): | def apply(self, func, new_field_name=None, **kwargs): | ||||
""" | """ | ||||
将DataSet中每个instance传入到func中,并获取它的返回值. | 将DataSet中每个instance传入到func中,并获取它的返回值. | ||||
:param callable func: 参数是DataSet中的Instance | |||||
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆 | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` | |||||
:param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||||
盖之前的field。如果为None则不创建新的field。 | 盖之前的field。如果为None则不创建新的field。 | ||||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | :param optional kwargs: 支持输入is_input,is_target,ignore_type | ||||
@@ -844,21 +925,21 @@ class DataSet(object): | |||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
assert callable(func), "The func you provide is not callable." | |||||
assert len(self) != 0, "Null DataSet cannot use apply()." | assert len(self) != 0, "Null DataSet cannot use apply()." | ||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
results = [] | results = [] | ||||
for idx, ins in enumerate(self._inner_iter()): | for idx, ins in enumerate(self._inner_iter()): | ||||
results.append(func(ins)) | |||||
if "_apply_field" in kwargs: | |||||
results.append(func(ins[kwargs["_apply_field"]])) | |||||
else: | |||||
results.append(func(ins)) | |||||
except BaseException as e: | except BaseException as e: | ||||
if idx != -1: | if idx != -1: | ||||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | logger.error("Exception happens at the `{}`th instance.".format(idx)) | ||||
raise e | raise e | ||||
# results = [func(ins) for ins in self._inner_iter()] | |||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||||
if new_field_name is not None: | if new_field_name is not None: | ||||
self._add_apply_field(results, new_field_name, kwargs) | self._add_apply_field(results, new_field_name, kwargs) | ||||
@@ -933,6 +1014,8 @@ class DataSet(object): | |||||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | ||||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | ||||
train_set.collector.copy_from(self.collector) | |||||
dev_set.collector.copy_from(self.collector) | |||||
return train_set, dev_set | return train_set, dev_set | ||||
def save(self, path): | def save(self, path): | ||||
@@ -538,6 +538,18 @@ class BertModel(nn.Module): | |||||
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') | ||||
model_type = 'BERT' | model_type = 'BERT' | ||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'bert' not in key: | |||||
new_key = 'bert.' + key | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
old_keys = [] | old_keys = [] | ||||
new_keys = [] | new_keys = [] | ||||
for key in state_dict.keys(): | for key in state_dict.keys(): | ||||
@@ -51,7 +51,7 @@ def preprocess(text): | |||||
def to_sentence_list(text, split_long_sentence=False): | def to_sentence_list(text, split_long_sentence=False): | ||||
text = preprocess(text) | text = preprocess(text) | ||||
delimiter = set() | delimiter = set() | ||||
delimiter.update("。!?:;…、,(),;!?、,\"'") | |||||
delimiter.update("。!?:;…、,(),;!?、.\"'") | |||||
delimiter.add("……") | delimiter.add("……") | ||||
sent_list = [] | sent_list = [] | ||||
sent = [] | sent = [] | ||||
@@ -226,36 +226,36 @@ class TestCallback(unittest.TestCase): | |||||
callbacks=EarlyStopCallback(1), check_code_level=2) | callbacks=EarlyStopCallback(1), check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") | |||||
def test_control_C(): | def test_control_C(): | ||||
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 | # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 | ||||
from fastNLP import ControlC, Callback | from fastNLP import ControlC, Callback | ||||
import time | import time | ||||
line1 = "\n\n\n\n\n*************************" | line1 = "\n\n\n\n\n*************************" | ||||
line2 = "*************************\n\n\n\n\n" | line2 = "*************************\n\n\n\n\n" | ||||
class Wait(Callback): | class Wait(Callback): | ||||
def on_epoch_end(self): | def on_epoch_end(self): | ||||
time.sleep(5) | time.sleep(5) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
print(line1 + "Test starts!" + line2) | print(line1 + "Test starts!" + line2) | ||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | ||||
batch_size=32, n_epochs=20, dev_data=data_set, | batch_size=32, n_epochs=20, dev_data=data_set, | ||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | ||||
callbacks=[Wait(), ControlC(False)], check_code_level=2) | callbacks=[Wait(), ControlC(False)], check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
print(line1 + "Program goes on ..." + line2) | print(line1 + "Program goes on ..." + line2) | ||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | ||||
batch_size=32, n_epochs=20, dev_data=data_set, | batch_size=32, n_epochs=20, dev_data=data_set, | ||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, | ||||
callbacks=[Wait(), ControlC(True)], check_code_level=2) | callbacks=[Wait(), ControlC(True)], check_code_level=2) | ||||
trainer.train() | trainer.train() | ||||
print(line1 + "Test failed!" + line2) | print(line1 + "Test failed!" + line2) | ||||
@@ -3,6 +3,7 @@ import sys | |||||
import unittest | import unittest | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP.core.dataset import ApplyResultException | |||||
from fastNLP import FieldArray | from fastNLP import FieldArray | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP.io import CSVLoader | from fastNLP.io import CSVLoader | ||||
@@ -143,6 +144,42 @@ class TestDataSetMethods(unittest.TestCase): | |||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
ds.apply(modify_inplace) | ds.apply(modify_inplace) | ||||
def test_apply_more(self): | |||||
T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]}) | |||||
func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2} | |||||
func_2 = lambda x: {"c": x * 3, "d": x ** 3} | |||||
def func_err_1(x): | |||||
if x["a"] == 1: | |||||
return {"e": x["a"] * 2, "f": x["a"] ** 2} | |||||
else: | |||||
return {"e": x["a"] * 2} | |||||
def func_err_2(x): | |||||
if x == 1: | |||||
return {"e": x * 2, "f": x ** 2} | |||||
else: | |||||
return {"e": x * 2} | |||||
T.apply_more(func_1) | |||||
self.assertEqual(list(T["c"]), [2, 4, 6]) | |||||
self.assertEqual(list(T["d"]), [1, 4, 9]) | |||||
res = T.apply_field_more(func_2, "a", modify_fields=False) | |||||
self.assertEqual(list(T["c"]), [2, 4, 6]) | |||||
self.assertEqual(list(T["d"]), [1, 4, 9]) | |||||
self.assertEqual(list(res["c"]), [3, 6, 9]) | |||||
self.assertEqual(list(res["d"]), [1, 8, 27]) | |||||
with self.assertRaises(ApplyResultException) as e: | |||||
T.apply_more(func_err_1) | |||||
print(e) | |||||
with self.assertRaises(ApplyResultException) as e: | |||||
T.apply_field_more(func_err_2, "a") | |||||
print(e) | |||||
def test_drop(self): | def test_drop(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | ||||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | ||||
@@ -36,16 +36,37 @@ class TestCNClassificationPipe(unittest.TestCase): | |||||
class TestRunClassificationPipe(unittest.TestCase): | class TestRunClassificationPipe(unittest.TestCase): | ||||
def test_process_from_file(self): | def test_process_from_file(self): | ||||
data_set_dict = { | data_set_dict = { | ||||
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False), | |||||
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False), | |||||
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True), | |||||
'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False), | |||||
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False), | |||||
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False), | |||||
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False), | |||||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False), | |||||
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False), | |||||
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False), | |||||
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, | |||||
False), | |||||
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, | |||||
False), | |||||
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, | |||||
{'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, | |||||
True), | |||||
'sst': ('test/data_for_tests/io/SST', SSTPipe, | |||||
{'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, | |||||
False), | |||||
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, | |||||
False), | |||||
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, | |||||
{'train': 4, 'test': 5}, {'words': 257, 'target': 4}, | |||||
False), | |||||
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, | |||||
{'train': 14, 'test': 5}, {'words': 496, 'target': 14}, | |||||
False), | |||||
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, | |||||
{'train': 6, 'dev': 6, 'test': 6}, | |||||
{'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, | |||||
False), | |||||
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, | |||||
{'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, | |||||
False), | |||||
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, | |||||
{'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, | |||||
False), | |||||
} | } | ||||
for k, v in data_set_dict.items(): | for k, v in data_set_dict.items(): | ||||
path, pipe, data_set, vocab, warns = v | path, pipe, data_set, vocab, warns = v | ||||
@@ -61,12 +82,12 @@ class TestRunClassificationPipe(unittest.TestCase): | |||||
self.assertTrue(isinstance(data_bundle, DataBundle)) | self.assertTrue(isinstance(data_bundle, DataBundle)) | ||||
self.assertEqual(len(data_set), data_bundle.num_dataset) | self.assertEqual(len(data_set), data_bundle.num_dataset) | ||||
for x, y in zip(data_set, data_bundle.iter_datasets()): | |||||
name, dataset = y | |||||
self.assertEqual(x, len(dataset)) | |||||
for name, dataset in data_bundle.iter_datasets(): | |||||
self.assertTrue(name in data_set.keys()) | |||||
self.assertEqual(data_set[name], len(dataset)) | |||||
self.assertEqual(len(vocab), data_bundle.num_vocab) | self.assertEqual(len(vocab), data_bundle.num_vocab) | ||||
for x, y in zip(vocab, data_bundle.iter_vocabs()): | |||||
name, vocabs = y | |||||
self.assertEqual(x, len(vocabs)) | |||||
for name, vocabs in data_bundle.iter_vocabs(): | |||||
self.assertTrue(name in vocab.keys()) | |||||
self.assertEqual(vocab[name], len(vocabs)) | |||||