@@ -371,6 +371,10 @@ from .field import SetInputOrTargetException | |||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import pretty_table_printer | from .utils import pretty_table_printer | ||||
from .collate_fn import Collater | from .collate_fn import Collater | ||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from .utils import _pseudo_tqdm as tqdm | |||||
class ApplyResultException(Exception): | class ApplyResultException(Exception): | ||||
@@ -860,6 +864,11 @@ class DataSet(object): | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -887,6 +896,10 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -949,6 +962,10 @@ class DataSet(object): | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
# 返回 dict , 检查是否一直相同 | # 返回 dict , 检查是否一直相同 | ||||
@@ -957,7 +974,9 @@ class DataSet(object): | |||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
results = {} | results = {} | ||||
for idx, ins in enumerate(self._inner_iter()): | |||||
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, | |||||
desc=kwargs.get('tqdm_desc', ''), | |||||
leave=False, disable=not kwargs.get('use_tqdm', False)): | |||||
if "_apply_field" in kwargs: | if "_apply_field" in kwargs: | ||||
res = func(ins[kwargs["_apply_field"]]) | res = func(ins[kwargs["_apply_field"]]) | ||||
else: | else: | ||||
@@ -1001,6 +1020,10 @@ class DataSet(object): | |||||
2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 | ||||
""" | """ | ||||
@@ -1009,7 +1032,9 @@ class DataSet(object): | |||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
results = [] | results = [] | ||||
for idx, ins in enumerate(self._inner_iter()): | |||||
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, leave=False, | |||||
desc=kwargs.get('tqdm_desc', ''), | |||||
disable=not kwargs.get('use_tqdm', False)): | |||||
if "_apply_field" in kwargs: | if "_apply_field" in kwargs: | ||||
results.append(func(ins[kwargs["_apply_field"]])) | results.append(func(ins[kwargs["_apply_field"]])) | ||||
else: | else: | ||||
@@ -321,8 +321,15 @@ class DataBundle: | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
""" | """ | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc != '': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | ||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
@@ -350,10 +357,17 @@ class DataBundle: | |||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
""" | """ | ||||
res = {} | res = {} | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc != '': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) | res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) | ||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
@@ -376,8 +390,16 @@ class DataBundle: | |||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | ||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
""" | """ | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc != '': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | dataset.apply(func, new_field_name=new_field_name, **kwargs) | ||||
return self | return self | ||||
@@ -399,10 +421,17 @@ class DataBundle: | |||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 | ||||
4. use_tqdm: bool, 是否显示tqdm进度条 | |||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
""" | """ | ||||
res = {} | res = {} | ||||
tqdm_desc = kwargs.get('tqdm_desc', '') | |||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if tqdm_desc!='': | |||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`' | |||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) | res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) | ||||
return res | return res | ||||
@@ -136,6 +136,14 @@ class TestDataSetMethods(unittest.TestCase): | |||||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True) | ||||
# expect no exception raised | # expect no exception raised | ||||
def test_apply_tqdm(self): | |||||
import time | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
def do_nothing(ins): | |||||
time.sleep(0.01) | |||||
ds.apply(do_nothing, use_tqdm=True) | |||||
ds.apply_field(do_nothing, field_name='x', use_tqdm=True) | |||||
def test_apply_cannot_modify_instance(self): | def test_apply_cannot_modify_instance(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
def modify_inplace(instance): | def modify_inplace(instance): | ||||