From 04ad8e604e3372ee1b881d3993d8affb65543a61 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 5 Jan 2021 23:45:15 +0800 Subject: [PATCH] =?UTF-8?q?DataSet=20apply=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E5=8F=AF=E4=BB=A5=E4=BC=A0=E5=85=A5use=5Ftqdm=E5=92=8Ctqdm=5Fd?= =?UTF-8?q?esc?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 29 +++++++++++++++++++++++++++-- fastNLP/io/data_bundle.py | 29 +++++++++++++++++++++++++++++ tests/core/test_dataset.py | 8 ++++++++ 3 files changed, 64 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index ec64d484..9005a8d6 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -371,6 +371,10 @@ from .field import SetInputOrTargetException from .instance import Instance from .utils import pretty_table_printer from .collate_fn import Collater +try: + from tqdm.auto import tqdm +except: + from .utils import _pseudo_tqdm as tqdm class ApplyResultException(Exception): @@ -860,6 +864,11 @@ class DataSet(object): 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 + + 4. use_tqdm: bool, 是否使用tqdm显示预处理进度 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 + :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." @@ -887,6 +896,10 @@ class DataSet(object): 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 + 4. use_tqdm: bool, 是否使用tqdm显示预处理进度 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 + :return Dict[str:Field]: 返回一个字典 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." @@ -949,6 +962,10 @@ class DataSet(object): 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 + 4. use_tqdm: bool, 是否使用tqdm显示预处理进度 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 + :return Dict[str:Field]: 返回一个字典 """ # 返回 dict , 检查是否一直相同 @@ -957,7 +974,9 @@ class DataSet(object): idx = -1 try: results = {} - for idx, ins in enumerate(self._inner_iter()): + for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, + desc=kwargs.get('tqdm_desc', ''), + leave=False, disable=not kwargs.get('use_tqdm', False)): if "_apply_field" in kwargs: res = func(ins[kwargs["_apply_field"]]) else: @@ -1001,6 +1020,10 @@ class DataSet(object): 2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target 3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 + + 4. use_tqdm: bool, 是否使用tqdm显示预处理进度 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 """ @@ -1009,7 +1032,9 @@ class DataSet(object): idx = -1 try: results = [] - for idx, ins in enumerate(self._inner_iter()): + for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, leave=False, + desc=kwargs.get('tqdm_desc', ''), + disable=not kwargs.get('use_tqdm', False)): if "_apply_field" in kwargs: results.append(func(ins[kwargs["_apply_field"]])) else: diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index e911a26f..8528ebf8 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -321,8 +321,15 @@ class DataBundle: 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 + + 4. use_tqdm: bool, 是否显示tqdm进度条 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 """ + tqdm_desc = kwargs.get('tqdm_desc', '') for name, dataset in self.datasets.items(): + if tqdm_desc != '': + kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) elif not ignore_miss_dataset: @@ -350,10 +357,17 @@ class DataBundle: 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 + 4. use_tqdm: bool, 是否显示tqdm进度条 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 """ res = {} + tqdm_desc = kwargs.get('tqdm_desc', '') for name, dataset in self.datasets.items(): + if tqdm_desc != '': + kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) elif not ignore_miss_dataset: @@ -376,8 +390,16 @@ class DataBundle: 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 + + 4. use_tqdm: bool, 是否显示tqdm进度条 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 + """ + tqdm_desc = kwargs.get('tqdm_desc', '') for name, dataset in self.datasets.items(): + if tqdm_desc != '': + kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' dataset.apply(func, new_field_name=new_field_name, **kwargs) return self @@ -399,10 +421,17 @@ class DataBundle: 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 + 4. use_tqdm: bool, 是否显示tqdm进度条 + + 5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 """ res = {} + tqdm_desc = kwargs.get('tqdm_desc', '') for name, dataset in self.datasets.items(): + if tqdm_desc!='': + kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) return res diff --git a/tests/core/test_dataset.py b/tests/core/test_dataset.py index d0d08d97..e960dac7 100644 --- a/tests/core/test_dataset.py +++ b/tests/core/test_dataset.py @@ -136,6 +136,14 @@ class TestDataSetMethods(unittest.TestCase): ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) # expect no exception raised + def test_apply_tqdm(self): + import time + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) + def do_nothing(ins): + time.sleep(0.01) + ds.apply(do_nothing, use_tqdm=True) + ds.apply_field(do_nothing, field_name='x', use_tqdm=True) + def test_apply_cannot_modify_instance(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) def modify_inplace(instance):