Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 5 years ago
parent
commit
dce1a73f53
8 changed files with 224 additions and 56 deletions
  1. +18
    -9
      fastNLP/core/batch.py
  2. +6
    -0
      fastNLP/core/collect_fn.py
  3. +105
    -22
      fastNLP/core/dataset.py
  4. +12
    -0
      fastNLP/modules/encoder/bert.py
  5. +1
    -1
      reproduction/multi-criteria-cws/data-prepare.py
  6. +8
    -8
      test/core/test_callbacks.py
  7. +37
    -0
      test/core/test_dataset.py
  8. +37
    -16
      test/io/pipe/test_classification.py

+ 18
- 9
fastNLP/core/batch.py View File

@@ -69,13 +69,20 @@ class DataSetGetter:

def may_to_tensor(data):
dtype, dim = _get_ele_type_and_dim(data)
print(dtype, type(dtype))
# print(dtype, type(dtype), str(dtype))
if not self.as_numpy:
try:
data, flag = _to_tensor(data, dtype)
except TypeError as e:
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
# if torch.is_tensor(data):
# str_dtype = str(dtype)
# if 'float' in str_dtype:
# data = data.float()
# elif 'int' in str_dtype:
# data = data.long()
# print(data.dtype)
return data

def pad(batch_dict):
@@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype):
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor):
if issubclass(field_dtype, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
elif issubclass(field_dtype, np.integer):
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制
else:
new_batch = torch.as_tensor(batch)
return new_batch, True
new_batch = torch.as_tensor(batch)
flag = True
else:
return batch, False
new_batch = batch
flag = False
if torch.is_tensor(new_batch):
if 'float' in new_batch.dtype.__repr__():
new_batch = new_batch.float()
elif 'int' in new_batch.dtype.__repr__():
new_batch = new_batch.long()
return new_batch, flag
except Exception as e:
raise e

+ 6
- 0
fastNLP/core/collect_fn.py View File

@@ -118,6 +118,12 @@ class Collector:
def outputs(self):
return self.output2fn.keys()

def copy_from(self, col):
assert isinstance(col, Collector)
self.fns = col.fns.copy()
self.input2fn = col.input2fn.copy()
self.output2fn = col.output2fn.copy()
self._clear_fn2io()

class CollectFn:
def __init__(self):


+ 105
- 22
fastNLP/core/dataset.py View File

@@ -284,7 +284,7 @@

"""
__all__ = [
"DataSet"
"DataSet",
]

import _pickle as pickle
@@ -305,6 +305,12 @@ from .utils import pretty_table_printer
from .collect_fn import Collector


class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
class DataSet(object):
"""
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset`
@@ -569,6 +575,7 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称.
"""
self.field_arrays.pop(field_name)
self.collector.drop_field(field_name)
return self

def copy_field(self, field_name, new_field_name):
@@ -641,6 +648,7 @@ class DataSet(object):
if field_name in self.field_arrays:
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name)
self.field_arrays[new_field_name].name = new_field_name
self.collector.rename_field(field_name, new_field_name)
else:
raise KeyError("DataSet has no field named {}.".format(field_name))
return self
@@ -778,23 +786,35 @@ class DataSet(object):
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name))
results = []
idx = -1
try:
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs)

if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)
def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
"""
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。
func 可以返回一个或多个 field 上的结果。
.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type

return results
1. is_input: bool, 如果为True则将被修改的field设置为input

2. is_target: bool, 如果为True则将被修改的field设置为target

3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型

:return Dict[int:Field]: 返回一个字典
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs)
def _add_apply_field(self, results, new_field_name, kwargs):
"""
将results作为加入到新的field中,field名称为new_field_name
@@ -827,12 +847,73 @@ class DataSet(object):
is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False))

def apply_more(self, func, modify_fields=True, **kwargs):
"""
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。
.. note::
``apply_more`` 与 ``apply`` 的区别:
1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果;
2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果;
3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。

:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type

1. is_input: bool, 如果为True则将被修改的的field设置为input

2. is_target: bool, 如果为True则将被修改的的field设置为target

3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型

:return Dict[int:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1
try:
results = {}
for idx, ins in enumerate(self._inner_iter()):
if "_apply_field" in kwargs:
res = func(ins[kwargs["_apply_field"]])
else:
res = func(ins)
if not isinstance(res, dict):
raise ApplyResultException("The result of func is not a dict", idx)
if idx == 0:
for key, value in res.items():
results[key] = [value]
else:
for key, value in res.items():
if key not in results:
raise ApplyResultException("apply results have different fields", idx)
results[key].append(value)
if len(res) != len(results):
raise ApplyResultException("apply results have different fields", idx)
except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
logger.error(e.msg)
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
if modify_fields is True:
for field, result in results.items():
self._add_apply_field(result, field, kwargs)
return results

def apply(self, func, new_field_name=None, **kwargs):
"""
将DataSet中每个instance传入到func中,并获取它的返回值.

:param callable func: 参数是DataSet中的Instance
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆
:param callable func: 参数是 ``DataSet`` 中的 ``Instance``
:param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。
:param optional kwargs: 支持输入is_input,is_target,ignore_type

@@ -844,21 +925,21 @@ class DataSet(object):
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
"""
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1
try:
results = []
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
if "_apply_field" in kwargs:
results.append(func(ins[kwargs["_apply_field"]]))
else:
results.append(func(ins))
except BaseException as e:
if idx != -1:
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e

# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))

if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)

@@ -933,6 +1014,8 @@ class DataSet(object):
train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name])

train_set.collector.copy_from(self.collector)
dev_set.collector.copy_from(self.collector)
return train_set, dev_set

def save(self, path):


+ 12
- 0
fastNLP/modules/encoder/bert.py View File

@@ -538,6 +538,18 @@ class BertModel(nn.Module):
raise RuntimeError(f'Cannot load parameters through `state_dict` variable.')

model_type = 'BERT'
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'bert' not in key:
new_key = 'bert.' + key
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

old_keys = []
new_keys = []
for key in state_dict.keys():


+ 1
- 1
reproduction/multi-criteria-cws/data-prepare.py View File

@@ -51,7 +51,7 @@ def preprocess(text):
def to_sentence_list(text, split_long_sentence=False):
text = preprocess(text)
delimiter = set()
delimiter.update("。!?:;…、,(),;!?、,\"'")
delimiter.update("。!?:;…、,(),;!?、.\"'")
delimiter.add("……")
sent_list = []
sent = []


+ 8
- 8
test/core/test_callbacks.py View File

@@ -226,36 +226,36 @@ class TestCallback(unittest.TestCase):
callbacks=EarlyStopCallback(1), check_code_level=2)
trainer.train()

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_control_C():
# 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试
from fastNLP import ControlC, Callback
import time
line1 = "\n\n\n\n\n*************************"
line2 = "*************************\n\n\n\n\n"
class Wait(Callback):
def on_epoch_end(self):
time.sleep(5)
data_set, model = prepare_env()
print(line1 + "Test starts!" + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(False)], check_code_level=2)
trainer.train()
print(line1 + "Program goes on ..." + line2)
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True,
callbacks=[Wait(), ControlC(True)], check_code_level=2)
trainer.train()
print(line1 + "Test failed!" + line2)




+ 37
- 0
test/core/test_dataset.py View File

@@ -3,6 +3,7 @@ import sys
import unittest

from fastNLP import DataSet
from fastNLP.core.dataset import ApplyResultException
from fastNLP import FieldArray
from fastNLP import Instance
from fastNLP.io import CSVLoader
@@ -143,6 +144,42 @@ class TestDataSetMethods(unittest.TestCase):
with self.assertRaises(TypeError):
ds.apply(modify_inplace)

def test_apply_more(self):
T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]})
func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2}
func_2 = lambda x: {"c": x * 3, "d": x ** 3}
def func_err_1(x):
if x["a"] == 1:
return {"e": x["a"] * 2, "f": x["a"] ** 2}
else:
return {"e": x["a"] * 2}
def func_err_2(x):
if x == 1:
return {"e": x * 2, "f": x ** 2}
else:
return {"e": x * 2}
T.apply_more(func_1)
self.assertEqual(list(T["c"]), [2, 4, 6])
self.assertEqual(list(T["d"]), [1, 4, 9])
res = T.apply_field_more(func_2, "a", modify_fields=False)
self.assertEqual(list(T["c"]), [2, 4, 6])
self.assertEqual(list(T["d"]), [1, 4, 9])
self.assertEqual(list(res["c"]), [3, 6, 9])
self.assertEqual(list(res["d"]), [1, 8, 27])
with self.assertRaises(ApplyResultException) as e:
T.apply_more(func_err_1)
print(e)
with self.assertRaises(ApplyResultException) as e:
T.apply_field_more(func_err_2, "a")
print(e)

def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True)


+ 37
- 16
test/io/pipe/test_classification.py View File

@@ -36,16 +36,37 @@ class TestCNClassificationPipe(unittest.TestCase):
class TestRunClassificationPipe(unittest.TestCase):
def test_process_from_file(self):
data_set_dict = {
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False),
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False),
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True),
'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False),
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False),
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False),
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False),
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False),
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False),
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False),
'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe,
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2},
False),
'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe,
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5},
False),
'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe,
{'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2},
True),
'sst': ('test/data_for_tests/io/SST', SSTPipe,
{'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5},
False),
'imdb': ('test/data_for_tests/io/imdb', IMDBPipe,
{'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2},
False),
'ag': ('test/data_for_tests/io/ag', AGsNewsPipe,
{'train': 4, 'test': 5}, {'words': 257, 'target': 4},
False),
'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe,
{'train': 14, 'test': 5}, {'words': 496, 'target': 14},
False),
'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe,
{'train': 6, 'dev': 6, 'test': 6},
{'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2},
False),
'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe,
{'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9},
False),
'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe,
{'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2},
False),
}
for k, v in data_set_dict.items():
path, pipe, data_set, vocab, warns = v
@@ -61,12 +82,12 @@ class TestRunClassificationPipe(unittest.TestCase):

self.assertTrue(isinstance(data_bundle, DataBundle))
self.assertEqual(len(data_set), data_bundle.num_dataset)
for x, y in zip(data_set, data_bundle.iter_datasets()):
name, dataset = y
self.assertEqual(x, len(dataset))
for name, dataset in data_bundle.iter_datasets():
self.assertTrue(name in data_set.keys())
self.assertEqual(data_set[name], len(dataset))

self.assertEqual(len(vocab), data_bundle.num_vocab)
for x, y in zip(vocab, data_bundle.iter_vocabs()):
name, vocabs = y
self.assertEqual(x, len(vocabs))
for name, vocabs in data_bundle.iter_vocabs():
self.assertTrue(name in vocab.keys())
self.assertEqual(vocab[name], len(vocabs))


Loading…
Cancel
Save