diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 7f0c858b..7090ea01 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -69,13 +69,20 @@ class DataSetGetter: def may_to_tensor(data): dtype, dim = _get_ele_type_and_dim(data) - print(dtype, type(dtype)) + # print(dtype, type(dtype), str(dtype)) if not self.as_numpy: try: data, flag = _to_tensor(data, dtype) except TypeError as e: logger.error(f"Field {n} cannot be converted to torch.tensor.") raise e + # if torch.is_tensor(data): + # str_dtype = str(dtype) + # if 'float' in str_dtype: + # data = data.float() + # elif 'int' in str_dtype: + # data = data.long() + # print(data.dtype) return data def pad(batch_dict): @@ -293,14 +300,16 @@ def _to_tensor(batch, field_dtype): if field_dtype is not None and isinstance(field_dtype, type)\ and issubclass(field_dtype, Number) \ and not isinstance(batch, torch.Tensor): - if issubclass(field_dtype, np.floating): - new_batch = torch.as_tensor(batch).float() # 默认使用float32 - elif issubclass(field_dtype, np.integer): - new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 - else: - new_batch = torch.as_tensor(batch) - return new_batch, True + new_batch = torch.as_tensor(batch) + flag = True else: - return batch, False + new_batch = batch + flag = False + if torch.is_tensor(new_batch): + if 'float' in new_batch.dtype.__repr__(): + new_batch = new_batch.float() + elif 'int' in new_batch.dtype.__repr__(): + new_batch = new_batch.long() + return new_batch, flag except Exception as e: raise e diff --git a/fastNLP/core/collect_fn.py b/fastNLP/core/collect_fn.py index 14add06f..29f19e2c 100644 --- a/fastNLP/core/collect_fn.py +++ b/fastNLP/core/collect_fn.py @@ -118,6 +118,12 @@ class Collector: def outputs(self): return self.output2fn.keys() + def copy_from(self, col): + assert isinstance(col, Collector) + self.fns = col.fns.copy() + self.input2fn = col.input2fn.copy() + self.output2fn = col.output2fn.copy() + self._clear_fn2io() class CollectFn: def __init__(self): diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index b13eab76..74c0023c 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -284,7 +284,7 @@ """ __all__ = [ - "DataSet" + "DataSet", ] import _pickle as pickle @@ -305,6 +305,12 @@ from .utils import pretty_table_printer from .collect_fn import Collector +class ApplyResultException(Exception): + def __init__(self, msg, index=None): + super().__init__(msg) + self.msg = msg + self.index = index # 标示在哪个数据遭遇到问题了 + class DataSet(object): """ fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` @@ -569,6 +575,7 @@ class DataSet(object): :param str field_name: 需要删除的field的名称. """ self.field_arrays.pop(field_name) + self.collector.drop_field(field_name) return self def copy_field(self, field_name, new_field_name): @@ -641,6 +648,7 @@ class DataSet(object): if field_name in self.field_arrays: self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) self.field_arrays[new_field_name].name = new_field_name + self.collector.rename_field(field_name, new_field_name) else: raise KeyError("DataSet has no field named {}.".format(field_name)) return self @@ -778,23 +786,35 @@ class DataSet(object): assert len(self) != 0, "Null DataSet cannot use apply_field()." if field_name not in self: raise KeyError("DataSet has no field named `{}`.".format(field_name)) - results = [] - idx = -1 - try: - for idx, ins in enumerate(self._inner_iter()): - results.append(func(ins[field_name])) - except Exception as e: - if idx != -1: - logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) - raise e - if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None - raise ValueError("{} always return None.".format(_get_func_signature(func=func))) + return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) - if new_field_name is not None: - self._add_apply_field(results, new_field_name, kwargs) + def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): + """ + 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 + func 可以返回一个或多个 field 上的结果。 + + .. note:: + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param str field_name: 传入func的是哪个field。 + :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param optional kwargs: 支持输入is_input,is_target,ignore_type - return results + 1. is_input: bool, 如果为True则将被修改的field设置为input + + 2. is_target: bool, 如果为True则将被修改的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 + :return Dict[int:Field]: 返回一个字典 + """ + assert len(self) != 0, "Null DataSet cannot use apply_field()." + if field_name not in self: + raise KeyError("DataSet has no field named `{}`.".format(field_name)) + return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) + def _add_apply_field(self, results, new_field_name, kwargs): """ 将results作为加入到新的field中,field名称为new_field_name @@ -827,12 +847,73 @@ class DataSet(object): is_target=extra_param.get("is_target", None), ignore_type=extra_param.get("ignore_type", False)) + def apply_more(self, func, modify_fields=True, **kwargs): + """ + 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 + + .. note:: + ``apply_more`` 与 ``apply`` 的区别: + + 1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果; + + 2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果; + + 3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param optional kwargs: 支持输入is_input,is_target,ignore_type + + 1. is_input: bool, 如果为True则将被修改的的field设置为input + + 2. is_target: bool, 如果为True则将被修改的的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 + + :return Dict[int:Field]: 返回一个字典 + """ + # 返回 dict , 检查是否一直相同 + assert callable(func), "The func you provide is not callable." + assert len(self) != 0, "Null DataSet cannot use apply()." + idx = -1 + try: + results = {} + for idx, ins in enumerate(self._inner_iter()): + if "_apply_field" in kwargs: + res = func(ins[kwargs["_apply_field"]]) + else: + res = func(ins) + if not isinstance(res, dict): + raise ApplyResultException("The result of func is not a dict", idx) + if idx == 0: + for key, value in res.items(): + results[key] = [value] + else: + for key, value in res.items(): + if key not in results: + raise ApplyResultException("apply results have different fields", idx) + results[key].append(value) + if len(res) != len(results): + raise ApplyResultException("apply results have different fields", idx) + except Exception as e: + if idx != -1: + if isinstance(e, ApplyResultException): + logger.error(e.msg) + logger.error("Exception happens at the `{}`th instance.".format(idx)) + raise e + + if modify_fields is True: + for field, result in results.items(): + self._add_apply_field(result, field, kwargs) + + return results + def apply(self, func, new_field_name=None, **kwargs): """ 将DataSet中每个instance传入到func中,并获取它的返回值. - :param callable func: 参数是DataSet中的Instance - :param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆 + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` + :param None,str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 :param optional kwargs: 支持输入is_input,is_target,ignore_type @@ -844,21 +925,21 @@ class DataSet(object): :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 """ + assert callable(func), "The func you provide is not callable." assert len(self) != 0, "Null DataSet cannot use apply()." idx = -1 try: results = [] for idx, ins in enumerate(self._inner_iter()): - results.append(func(ins)) + if "_apply_field" in kwargs: + results.append(func(ins[kwargs["_apply_field"]])) + else: + results.append(func(ins)) except BaseException as e: if idx != -1: logger.error("Exception happens at the `{}`th instance.".format(idx)) raise e - # results = [func(ins) for ins in self._inner_iter()] - if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None - raise ValueError("{} always return None.".format(_get_func_signature(func=func))) - if new_field_name is not None: self._add_apply_field(results, new_field_name, kwargs) @@ -933,6 +1014,8 @@ class DataSet(object): train_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) + train_set.collector.copy_from(self.collector) + dev_set.collector.copy_from(self.collector) return train_set, dev_set def save(self, path): diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index c3f2fa8b..8f79a59f 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -538,6 +538,18 @@ class BertModel(nn.Module): raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') model_type = 'BERT' + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'bert' not in key: + new_key = 'bert.' + key + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + old_keys = [] new_keys = [] for key in state_dict.keys(): diff --git a/reproduction/multi-criteria-cws/data-prepare.py b/reproduction/multi-criteria-cws/data-prepare.py index 1d6e89b5..2c28e3b6 100644 --- a/reproduction/multi-criteria-cws/data-prepare.py +++ b/reproduction/multi-criteria-cws/data-prepare.py @@ -51,7 +51,7 @@ def preprocess(text): def to_sentence_list(text, split_long_sentence=False): text = preprocess(text) delimiter = set() - delimiter.update("。!?:;…、,(),;!?、,\"'") + delimiter.update("。!?:;…、,(),;!?、.\"'") delimiter.add("……") sent_list = [] sent = [] diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index 11cf6704..e756040c 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -226,36 +226,36 @@ class TestCallback(unittest.TestCase): callbacks=EarlyStopCallback(1), check_code_level=2) trainer.train() - +@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_control_C(): # 用于测试 ControlC , 再两次训练时用 Control+C 进行退出,如果最后不显示 "Test failed!" 则通过测试 from fastNLP import ControlC, Callback import time - + line1 = "\n\n\n\n\n*************************" line2 = "*************************\n\n\n\n\n" - + class Wait(Callback): def on_epoch_end(self): time.sleep(5) - + data_set, model = prepare_env() - + print(line1 + "Test starts!" + line2) trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=20, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=[Wait(), ControlC(False)], check_code_level=2) trainer.train() - + print(line1 + "Program goes on ..." + line2) - + trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), batch_size=32, n_epochs=20, dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=True, callbacks=[Wait(), ControlC(True)], check_code_level=2) trainer.train() - + print(line1 + "Test failed!" + line2) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index e05148a6..d048191f 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -3,6 +3,7 @@ import sys import unittest from fastNLP import DataSet +from fastNLP.core.dataset import ApplyResultException from fastNLP import FieldArray from fastNLP import Instance from fastNLP.io import CSVLoader @@ -143,6 +144,42 @@ class TestDataSetMethods(unittest.TestCase): with self.assertRaises(TypeError): ds.apply(modify_inplace) + def test_apply_more(self): + + T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]}) + func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2} + func_2 = lambda x: {"c": x * 3, "d": x ** 3} + + def func_err_1(x): + if x["a"] == 1: + return {"e": x["a"] * 2, "f": x["a"] ** 2} + else: + return {"e": x["a"] * 2} + + def func_err_2(x): + if x == 1: + return {"e": x * 2, "f": x ** 2} + else: + return {"e": x * 2} + + T.apply_more(func_1) + self.assertEqual(list(T["c"]), [2, 4, 6]) + self.assertEqual(list(T["d"]), [1, 4, 9]) + + res = T.apply_field_more(func_2, "a", modify_fields=False) + self.assertEqual(list(T["c"]), [2, 4, 6]) + self.assertEqual(list(T["d"]), [1, 4, 9]) + self.assertEqual(list(res["c"]), [3, 6, 9]) + self.assertEqual(list(res["d"]), [1, 8, 27]) + + with self.assertRaises(ApplyResultException) as e: + T.apply_more(func_err_1) + print(e) + + with self.assertRaises(ApplyResultException) as e: + T.apply_field_more(func_err_2, "a") + print(e) + def test_drop(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py index 987327d4..e8081590 100644 --- a/test/io/pipe/test_classification.py +++ b/test/io/pipe/test_classification.py @@ -36,16 +36,37 @@ class TestCNClassificationPipe(unittest.TestCase): class TestRunClassificationPipe(unittest.TestCase): def test_process_from_file(self): data_set_dict = { - 'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, (6, 6, 6), (1176, 2), False), - 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, (6, 6, 6), (1166, 5), False), - 'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, (5, 5, 5), (139, 2), True), - 'sst': ('test/data_for_tests/io/SST', SSTPipe, (6, 354, 6), (232, 5), False), - 'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, (6, 6, 6), (1670, 2), False), - 'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, (5, 4), (257, 4), False), - 'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, (5, 14), (496, 14), False), - 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, (6, 6, 6), (529, 1296, 1483, 2), False), - 'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, (9, 9, 9), (1864, 9), False), - 'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, (7, 6, 6), (452, 2), False), + 'yelp.p': ('test/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe, + {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2}, + False), + 'yelp.f': ('test/data_for_tests/io/yelp_review_full', YelpFullPipe, + {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5}, + False), + 'sst-2': ('test/data_for_tests/io/SST-2', SST2Pipe, + {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2}, + True), + 'sst': ('test/data_for_tests/io/SST', SSTPipe, + {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5}, + False), + 'imdb': ('test/data_for_tests/io/imdb', IMDBPipe, + {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2}, + False), + 'ag': ('test/data_for_tests/io/ag', AGsNewsPipe, + {'train': 4, 'test': 5}, {'words': 257, 'target': 4}, + False), + 'dbpedia': ('test/data_for_tests/io/dbpedia', DBPediaPipe, + {'train': 14, 'test': 5}, {'words': 496, 'target': 14}, + False), + 'ChnSentiCorp': ('test/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe, + {'train': 6, 'dev': 6, 'test': 6}, + {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2}, + False), + 'Chn-THUCNews': ('test/data_for_tests/io/THUCNews', THUCNewsPipe, + {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9}, + False), + 'Chn-WeiboSenti100k': ('test/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe, + {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2}, + False), } for k, v in data_set_dict.items(): path, pipe, data_set, vocab, warns = v @@ -61,12 +82,12 @@ class TestRunClassificationPipe(unittest.TestCase): self.assertTrue(isinstance(data_bundle, DataBundle)) self.assertEqual(len(data_set), data_bundle.num_dataset) - for x, y in zip(data_set, data_bundle.iter_datasets()): - name, dataset = y - self.assertEqual(x, len(dataset)) + for name, dataset in data_bundle.iter_datasets(): + self.assertTrue(name in data_set.keys()) + self.assertEqual(data_set[name], len(dataset)) self.assertEqual(len(vocab), data_bundle.num_vocab) - for x, y in zip(vocab, data_bundle.iter_vocabs()): - name, vocabs = y - self.assertEqual(x, len(vocabs)) + for name, vocabs in data_bundle.iter_vocabs(): + self.assertTrue(name in vocab.keys()) + self.assertEqual(vocab[name], len(vocabs))