@@ -111,6 +111,7 @@ class Evaluator: | |||
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; | |||
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; | |||
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; | |||
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 | |||
""" | |||
@@ -134,6 +135,8 @@ class Evaluator: | |||
self.device = device | |||
self.verbose = verbose | |||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) | |||
if evaluate_batch_step_fn is not None: | |||
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn') | |||
self.evaluate_batch_step_fn = evaluate_batch_step_fn | |||
@@ -141,10 +144,23 @@ class Evaluator: | |||
self.input_mapping = input_mapping | |||
self.output_mapping = output_mapping | |||
# check dataloader | |||
if not isinstance(dataloaders, dict): | |||
if kwargs.get('check_dataloader_legality', True): | |||
try: | |||
self.driver.check_dataloader_legality(dataloader=dataloaders) | |||
except TypeError as e: | |||
logger.error("`dataloaders` is invalid.") | |||
raise e | |||
dataloaders = {None: dataloaders} | |||
self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) | |||
else: | |||
if kwargs.get('check_dataloader_legality', True): | |||
for key, dataloader in dataloaders.items(): | |||
try: | |||
self.driver.check_dataloader_legality(dataloader=dataloader) | |||
except TypeError as e: | |||
logger.error(f"The dataloader named:{key} is invalid.") | |||
raise e | |||
self.driver.setup() | |||
self.driver.barrier() | |||
@@ -333,7 +349,7 @@ class Evaluator: | |||
@evaluate_batch_loop.setter | |||
def evaluate_batch_loop(self, loop: Loop): | |||
if self.evaluate_batch_step_fn is not None: | |||
if getattr(self, 'evaluate_step_fn', None) is not None: | |||
logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " | |||
"when the `evaluate_batch_loop` is also customized.") | |||
self._evaluate_batch_loop = loop | |||
@@ -304,6 +304,7 @@ class Trainer(TrainerEventTrigger): | |||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | |||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | |||
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Evaluator`` 中。与 output_mapping 互斥。 | |||
* *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 | |||
.. note:: | |||
``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; | |||
@@ -463,6 +464,14 @@ class Trainer(TrainerEventTrigger): | |||
self.driver.setup() | |||
self.driver.barrier() | |||
# check train_dataloader | |||
if kwargs.get('check_dataloader_legality', True): | |||
try: | |||
self.driver.check_dataloader_legality(dataloader=train_dataloader) | |||
except TypeError as e: | |||
logger.error("`train_dataloader` is invalid.") | |||
raise e | |||
use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) | |||
if use_dist_sampler: | |||
_dist_sampler = "dist" | |||
@@ -482,7 +491,8 @@ class Trainer(TrainerEventTrigger): | |||
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, | |||
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, | |||
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), | |||
progress_bar=progress_bar) | |||
progress_bar=progress_bar, | |||
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) | |||
if train_fn is not None and not isinstance(train_fn, str): | |||
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") | |||
@@ -10,8 +10,7 @@ __all__ = [ | |||
"prepare_dataloader" | |||
] | |||
from .mix_dataloader import MixDataLoader | |||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader | |||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader | |||
from .prepare_dataloader import prepare_dataloader |
@@ -1,6 +1,8 @@ | |||
__all__ = [ | |||
"TorchDataLoader", | |||
"prepare_torch_dataloader" | |||
"prepare_torch_dataloader", | |||
"MixDataLoader" | |||
] | |||
from .fdl import TorchDataLoader, prepare_torch_dataloader | |||
from .mix_dataloader import MixDataLoader |
@@ -168,6 +168,7 @@ from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress | |||
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress | |||
from ..log import logger | |||
from fastNLP.core.utils.dummy_class import DummyClass | |||
progress_bars = { | |||
@@ -231,7 +232,8 @@ def _multi_proc(ds, _apply_field, func, counter, queue): | |||
""" | |||
idx = -1 | |||
import contextlib | |||
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁 | |||
null = DummyClass() | |||
with contextlib.redirect_stdout(null): # 避免打印触发 rich 的锁 | |||
logger.set_stdout(stdout='raw') | |||
results = [] | |||
try: | |||
@@ -597,7 +599,8 @@ class DataSet: | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。 | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_desc: 进度条的描述字符,默认为 ``Processing``; | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
@@ -631,8 +634,14 @@ class DataSet: | |||
:param field_name: 传入func的是哪个field。 | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
@@ -672,7 +681,13 @@ class DataSet: | |||
progress_bar: str = 'rich', _apply_field: str = None, | |||
progress_desc: str = 'Main') -> list: | |||
""" | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | |||
:param _apply_field: 需要传进去func的数据集的field_name | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
@@ -744,7 +759,13 @@ class DataSet: | |||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:return Dict[str:Field]: 返回一个字典 | |||
@@ -789,7 +810,13 @@ class DataSet: | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: progress bar 显示的值,默认为空。 | |||
""" | |||
@@ -175,6 +175,14 @@ class Driver(ABC): | |||
raise NotImplementedError( | |||
"Each specific driver should implemented its own `_check_optimizer_legality` function.") | |||
def check_dataloader_legality(self, dataloader): | |||
""" | |||
检测 DataLoader 是否合法,如果不合法,会 raise TypeError 。 | |||
:param dataloder: | |||
:return: | |||
""" | |||
def set_optimizers(self, optimizers=None): | |||
r""" | |||
trainer 会调用该函数将用户传入的 optimizers 挂载到 driver 实例上; | |||
@@ -13,6 +13,7 @@ if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from jittor import Module | |||
from jittor.optim import Optimizer | |||
from jittor.dataset import Dataset | |||
_reduces = { | |||
'max': jt.max, | |||
@@ -52,21 +53,11 @@ class JittorDriver(Driver): | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
@staticmethod | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
def check_dataloader_legality(self, dataloader): | |||
# 在fastnlp中实现了JittorDataLoader | |||
# TODO: 是否允许传入Dataset? | |||
if is_train: | |||
if not isinstance(dataloader, JittorDataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, JittorDataLoader): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' " | |||
f"type, not {type(each_dataloader)}.") | |||
if not isinstance(dataloader, Dataset): | |||
raise TypeError(f"{Dataset} is expected, instead of `{type(dataloader)}`") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
@@ -94,29 +94,9 @@ class PaddleDriver(Driver): | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
@staticmethod | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
if not isinstance(dataloader, DataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.") | |||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | |||
if isinstance(dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
if dataloader.batch_sampler is None and dataloader.batch_size is None: | |||
raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, DataLoader): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'paddle.io.DataLoader' " | |||
f"type, not {type(each_dataloader)}.") | |||
if isinstance(each_dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
if each_dataloader.batch_sampler is None and each_dataloader.batch_size is None: | |||
raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " | |||
f"`batch_sampler` and `batch_size` should be set.") | |||
def check_dataloader_legality(self, dataloader): | |||
if not isinstance(dataloader, DataLoader): | |||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
@@ -91,26 +91,9 @@ class TorchDriver(Driver): | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
@staticmethod | |||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
if not isinstance(dataloader, DataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") | |||
# todo 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | |||
if isinstance(dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, DataLoader): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'DataLoader' " | |||
f"type, not {type(each_dataloader)}.") | |||
if isinstance(each_dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
def check_dataloader_legality(self, dataloader): | |||
if not isinstance(dataloader, DataLoader): | |||
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
@@ -3,3 +3,9 @@ __all__ = [] | |||
class DummyClass: | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def __getattr__(self, item): | |||
return lambda *args, **kwargs: ... | |||
def __call__(self, *args, **kwargs): | |||
pass |
@@ -4,7 +4,8 @@ __all__ = [ | |||
import uuid | |||
import sys | |||
from ...envs.imports import _module_available, _compare_version | |||
from ...envs.utils import _module_available, _compare_version, _get_version | |||
from ...envs import get_global_rank | |||
from .utils import is_notebook | |||
from ..log import logger | |||
@@ -82,8 +83,10 @@ class TqdmProgress(metaclass=Singleton): | |||
:param kwargs: | |||
:return: | |||
""" | |||
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ | |||
f"To use tqdm, tqdm>=4.57 is needed." | |||
if not _module_available('tqdm'): | |||
raise ModuleNotFoundError("Package tqdm is not installed.") | |||
elif not _compare_version('tqdm', operator.ge, '4.57'): | |||
raise RuntimeError(f"Package tqdm>=4.57 is needed, instead of {_get_version('tqdm')}.") | |||
from .rich_progress import f_rich_progress | |||
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." | |||
@@ -26,6 +26,25 @@ def _module_available(module_path: str) -> bool: | |||
return False | |||
def _get_version(package, use_base_version: bool = False): | |||
try: | |||
pkg = importlib.import_module(package) | |||
except (ModuleNotFoundError, DistributionNotFound): | |||
return False | |||
try: | |||
if hasattr(pkg, "__version__"): | |||
pkg_version = Version(pkg.__version__) | |||
else: | |||
# try pkg_resources to infer version | |||
pkg_version = Version(pkg_resources.get_distribution(package).version) | |||
except TypeError: | |||
# this is mocked by Sphinx, so it should return True to generate all summaries | |||
return True | |||
if use_base_version: | |||
pkg_version = Version(pkg_version.base_version) | |||
return pkg_version | |||
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: | |||
"""Compare package version with some requirements. | |||
@@ -231,7 +231,13 @@ class DataBundle: | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | |||
:param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
@@ -260,10 +266,16 @@ class DataBundle: | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。 | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
@@ -292,8 +304,14 @@ class DataBundle: | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称 | |||
""" | |||
@@ -316,8 +334,14 @@ class DataBundle: | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param num_proc: 使用进程的数量。 | |||
.. note:: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, | |||
``func`` 函数中的打印将不会输出。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称 | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
@@ -382,4 +406,3 @@ class DataBundle: | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
@@ -4,6 +4,7 @@ import pytest | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException | |||
from fastNLP import logger | |||
class TestDataSetInit: | |||
@@ -379,6 +380,14 @@ class TestDataSetMethods: | |||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=0) | |||
def test_apply_more_proc(self): | |||
def func(x): | |||
print("x") | |||
logger.info("demo") | |||
return len(x) | |||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
data.apply_field(func, field_name='x', new_field_name='len_x', num_proc=2) | |||
class TestFieldArrayInit: | |||
""" | |||