Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 5 years ago
parent
commit
dce1a73f53
8 changed files with 224 additions and 56 deletions
  1. +18
    -9
      fastNLP/core/batch.py
  2. +6
    -0
      fastNLP/core/collect_fn.py
  3. +105
    -22
      fastNLP/core/dataset.py
  4. +12
    -0
      fastNLP/modules/encoder/bert.py
  5. +1
    -1
      reproduction/multi-criteria-cws/data-prepare.py
  6. +8
    -8
      test/core/test_callbacks.py
  7. +37
    -0
      test/core/test_dataset.py
  8. +37
    -16
      test/io/pipe/test_classification.py

+ 18
- 9
fastNLP/core/batch.py View File

@@ -69,13 +69,20 @@ class DataSetGetter:


def may_to_tensor(data): def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data) dtype, dim = _get_ele_type_and_dim(data)
print(dtype, type(dtype))
# print(dtype, type(dtype), str(dtype))
if not self.as_numpy: if not self.as_numpy:
try: try:
data, flag = _to_tensor(data, dtype) data, flag = _to_tensor(data, dtype)
except TypeError as e: except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.") logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e raise e
# if torch.is_tensor(data):
# str_dtype = str(dtype)
# if 'float' in str_dtype:
# data = data.float()
# elif 'int' in str_dtype:
# data = data.long()
# print(data.dtype)
return data return data


def pad(batch_dict): def pad(batch_dict):
@@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype):
if field_dtype is not None and isinstance(field_dtype, type)\ if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \ and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor): and not isinstance(batch, torch.Tensor):
if issubclass(field_dtype, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
elif issubclass(field_dtype, np.integer):
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
else:
new_batch = torch.as_tensor(batch)
return new_batch, True
new_batch = torch.as_tensor(batch)
flag = True
else: else:
return batch, False
new_batch = batch
flag = False
if torch.is_tensor(new_batch):
if 'float' in new_batch.dtype.__repr__():
new_batch = new_batch.float()
elif 'int' in new_batch.dtype.__repr__():
new_batch = new_batch.long()
return new_batch, flag
except Exception as e: except Exception as e:
raise e raise e

+ 6
- 0
fastNLP/core/collect_fn.py View File

@@ -118,6 +118,12 @@ class Collector:
def outputs(self): def outputs(self):
return self.output2fn.keys() return self.output2fn.keys()


def copy_from(self, col):
assert isinstance(col, Collector)
self.fns = col.fns.copy()
self.input2fn = col.input2fn.copy()
self.output2fn = col.output2fn.copy()
self._clear_fn2io()


class CollectFn: class CollectFn:
def __init__(self): def __init__(self):


+ 105
- 22
fastNLP/core/dataset.py View File

@@ -284,7 +284,7 @@


""" """
__all__ = [ __all__ = [
"DataSet"
"DataSet",
] ]


import _pickle as pickle import _pickle as pickle
@@ -305,6 +305,12 @@ from .utils import pretty_table_printer
from .collect_fn import Collector from .collect_fn import Collector




class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
class DataSet(object): class DataSet(object):
""" """
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset`
@@ -569,6 +575,7 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称. :param str field_name: 需要删除的field的名称.
""" """
self.field_arrays.pop(field_name) self.field_arrays.pop(field_name)
self.collector.drop_field(field_name)
return self return self


def copy_field(self, field_name, new_field_name): def copy_field(self, field_name, new_field_name):
@@ -641,6 +648,7 @@ class DataSet(object):
if field_name in self.field_arrays: if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
self.field_arrays[new_field_name].name = new_field_name self.field_arrays[new_field_name].name = new_field_name
self.collector.rename_field(field_name, new_field_name)
else: else:
raise KeyError("DataSet has no field named {}.".format(field_name)) raise KeyError("DataSet has no field named {}.".format(field_name))
return self return self
@@ -778,23 +786,35 @@ class DataSet(object):
assert len(self) != 0, "Null DataSet cannot use apply_field()." assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self: if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name)) raise KeyError("DataSet has no field named `{}`.".format(field_name))
results = []
idx = -1
try:
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs)


if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
"""
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。
func 可以返回一个或多个 field 上的结果。
.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type


return results
1. is_input: bool, 如果为True则将被修改的field设置为input

2. is_target: bool, 如果为True则将被修改的field设置为target

3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型


:return Dict[int:Field]: 返回一个字典
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs)
def _add_apply_field(self, results, new_field_name, kwargs): def _add_apply_field(self, results, new_field_name, kwargs):
""" """
将results作为加入到新的field中,field名称为new_field_name 将results作为加入到新的field中,field名称为new_field_name
@@ -827,12 +847,73 @@ class DataSet(object):
is_target=extra_param.get("is_target", None), is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False)) ignore_type=extra_param.get("ignore_type", False))


def apply_more(self, func, modify_fields=True, **kwargs):
"""
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。
.. note::
``apply_more`` 与 ``apply`` 的区别:
1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果;
2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果;
3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type

1. is_input: bool, 如果为True则将被修改的的field设置为input

2. is_target: bool, 如果为True则将被修改的的field设置为target

3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型

:return Dict[int:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1
try:
results = {}
for idx, ins in enumerate(self._inner_iter()):
if "_apply_field" in kwargs:
res = func(ins[kwargs["_apply_field"]])
else:
res = func(ins)
if not isinstance(res, dict):
raise ApplyResultException("The result of func is not a dict", idx)
if idx == 0:
for key, value in res.items():
results[key] = [value]
else:
for key, value in res.items():
if key not in results:
raise ApplyResultException("apply results have different fields", idx)
results[key].append(value)
if len(res) != len(results):
raise ApplyResultException("apply results have different fields", idx)
except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
logger.error(e.msg)
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
if modify_fields is True:
for field, result in results.items():
self._add_apply_field(result, field, kwargs)
return results

def apply(self, func, new_field_name=None, **kwargs): def apply(self, func, new_field_name=None, **kwargs):
""" """
将DataSet中每个instance传入到func中,并获取它的返回值. 将DataSet中每个instance传入到func中,并获取它的返回值.


:param callable func: 参数是DataSet中的Instance
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆
:param callable func: 参数是 ``DataSet`` 中的 ``Instance``
:param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。 盖之前的field。如果为None则不创建新的field。
:param optional kwargs: 支持输入is_input,is_target,ignore_type :param optional kwargs: 支持输入is_input,is_target,ignore_type


@@ -844,21 +925,21 @@ class DataSet(object):
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
""" """
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()." assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1 idx = -1
try: try:
results = [] results = []
for idx, ins in enumerate(self._inner_iter()): for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
if "_apply_field" in kwargs:
results.append(func(ins[kwargs["_apply_field"]]))
else:
results.append(func(ins))
except BaseException as e: except BaseException as e:
if idx != -1: if idx != -1:
logger.error("Exception happens at the `{}`th instance.".format(idx)) logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e raise e


# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))

if new_field_name is not None: if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs) self._add_apply_field(results, new_field_name, kwargs)


@@ -933,6 +1014,8 @@ class DataSet(object):
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name])


train_set.collector.copy_from(self.collector)
dev_set.collector.copy_from(self.collector)
return train_set, dev_set return train_set, dev_set


def save(self, path): def save(self, path):


+ 12
- 0
fastNLP/modules/encoder/bert.py View File

@@ -538,6 +538,18 @@ class BertModel(nn.Module):
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')


model_type = 'BERT' model_type = 'BERT'
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'bert' not in key:
new_key = 'bert.' + key
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

old_keys = [] old_keys = []
new_keys = [] new_keys = []
for key in state_dict.keys(): for key in state_dict.keys():


+ 1
- 1
reproduction/multi-criteria-cws/data-prepare.py View File

@@ -51,7 +51,7 @@ def preprocess(text):
def to_sentence_list(text, split_long_sentence=False): def to_sentence_list(text, split_long_sentence=False):
text = preprocess(text) text = preprocess(text)
delimiter = set() delimiter = set()
delimiter.update("。!?:;…、,(),;!?、,\"'")
delimiter.update("。!?:;…、,(),;!?、.\"'")
delimiter.add("……") delimiter.add("……")
sent_list = [] sent_list = []
sent = [] sent = []


+ 8
- 8
test/core/test_callbacks.py View File

@@ -226,36 +226,36 @@ class TestCallback(unittest.TestCase):
callbacks=EarlyStopCallback(1), check_code_level=2) callbacks=EarlyStopCallback(1), check_code_level=2)
trainer.train() trainer.train()


@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_control_C(): def test_control_C():
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试
from fastNLP import ControlC, Callback from fastNLP import ControlC, Callback
import time import time
line1 = "\n\n\n\n\n*************************" line1 = "\n\n\n\n\n*************************"
line2 = "*************************\n\n\n\n\n" line2 = "*************************\n\n\n\n\n"
class Wait(Callback): class Wait(Callback):
def on_epoch_end(self): def on_epoch_end(self):
time.sleep(5) time.sleep(5)
data_set, model = prepare_env() data_set, model = prepare_env()
print(line1 + "Test starts!" + line2) print(line1 + "Test starts!" + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set, batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(False)], check_code_level=2) callbacks=[Wait(), ControlC(False)], check_code_level=2)
trainer.train() trainer.train()
print(line1 + "Program goes on ..." + line2) print(line1 + "Program goes on ..." + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set, batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(True)], check_code_level=2) callbacks=[Wait(), ControlC(True)], check_code_level=2)
trainer.train() trainer.train()
print(line1 + "Test failed!" + line2) print(line1 + "Test failed!" + line2)






+ 37
- 0
test/core/test_dataset.py View File

@@ -3,6 +3,7 @@ import sys
import unittest import unittest


from fastNLP import DataSet from fastNLP import DataSet
from fastNLP.core.dataset import ApplyResultException
from fastNLP import FieldArray from fastNLP import FieldArray
from fastNLP import Instance from fastNLP import Instance
from fastNLP.io import CSVLoader from fastNLP.io import CSVLoader
@@ -143,6 +144,42 @@ class TestDataSetMethods(unittest.TestCase):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
ds.apply(modify_inplace) ds.apply(modify_inplace)


def test_apply_more(self):
T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]})
func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2}
func_2 = lambda x: {"c": x * 3, "d": x ** 3}
def func_err_1(x):
if x["a"] == 1:
return {"e": x["a"] * 2, "f": x["a"] ** 2}
else:
return {"e": x["a"] * 2}
def func_err_2(x):
if x == 1:
return {"e": x * 2, "f": x ** 2}
else:
return {"e": x * 2}
T.apply_more(func_1)
self.assertEqual(list(T["c"]), [2, 4, 6])
self.assertEqual(list(T["d"]), [1, 4, 9])
res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"]), [2, 4, 6])
self.assertEqual(list(T["d"]), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
with self.assertRaises(ApplyResultException) as e:
T.apply_more(func_err_1)
print(e)
with self.assertRaises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a")
print(e)

def test_drop(self): def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)


+ 37
- 16
test/io/pipe/test_classification.py View File

@@ -36,16 +36,37 @@ class TestCNClassificationPipe(unittest.TestCase):
class TestRunClassificationPipe(unittest.TestCase): class TestRunClassificationPipe(unittest.TestCase):
def test_process_from_file(self): def test_process_from_file(self):
data_set_dict = { data_set_dict = {
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False),
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False),
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True),
'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False),
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False),
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False),
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False),
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False),
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False),
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False),
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe,
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2},
False),
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe,
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5},
False),
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe,
{'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2},
True),
'sst': ('test/data_for_tests/io/SST', SSTPipe,
{'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5},
False),
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe,
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2},
False),
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe,
{'train': 4, 'test': 5}, {'words': 257, 'target': 4},
False),
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe,
{'train': 14, 'test': 5}, {'words': 496, 'target': 14},
False),
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe,
{'train': 6, 'dev': 6, 'test': 6},
{'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2},
False),
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe,
{'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9},
False),
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe,
{'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2},
False),
} }
for k, v in data_set_dict.items(): for k, v in data_set_dict.items():
path, pipe, data_set, vocab, warns = v path, pipe, data_set, vocab, warns = v
@@ -61,12 +82,12 @@ class TestRunClassificationPipe(unittest.TestCase):


self.assertTrue(isinstance(data_bundle, DataBundle)) self.assertTrue(isinstance(data_bundle, DataBundle))
self.assertEqual(len(data_set), data_bundle.num_dataset) self.assertEqual(len(data_set), data_bundle.num_dataset)
for x, y in zip(data_set, data_bundle.iter_datasets()):
name, dataset = y
self.assertEqual(x, len(dataset))
for name, dataset in data_bundle.iter_datasets():
self.assertTrue(name in data_set.keys())
self.assertEqual(data_set[name], len(dataset))


self.assertEqual(len(vocab), data_bundle.num_vocab) self.assertEqual(len(vocab), data_bundle.num_vocab)
for x, y in zip(vocab, data_bundle.iter_vocabs()):
name, vocabs = y
self.assertEqual(x, len(vocabs))
for name, vocabs in data_bundle.iter_vocabs():
self.assertTrue(name in vocab.keys())
self.assertEqual(vocab[name], len(vocabs))



Loading…
Cancel
Save