Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
92befbecdd
13 changed files with 99 additions and 126 deletions
  1. +3
    -3
      fastNLP/core/callbacks/has_monitor_callback.py
  2. +1
    -1
      fastNLP/core/callbacks/topk_saver.py
  3. +1
    -1
      fastNLP/core/controllers/evaluator.py
  4. +2
    -2
      fastNLP/core/controllers/trainer.py
  5. +72
    -104
      fastNLP/core/dataset/dataset.py
  6. +2
    -2
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  7. +4
    -4
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  8. +1
    -1
      fastNLP/core/drivers/paddle_driver/utils.py
  9. +2
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py
  10. +2
    -2
      fastNLP/core/drivers/torch_driver/utils.py
  11. +6
    -2
      fastNLP/core/log/logger.py
  12. +2
    -1
      fastNLP/core/utils/cache_results.py
  13. +1
    -1
      fastNLP/core/utils/utils.py

+ 3
- 3
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -78,11 +78,11 @@ class MonitorUtility:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
logger.rank_zero_warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as "
f"{list(results.keys())}), we use the `{use_monitor}` as the monitor.", once=True)
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
logger.rank_zero_warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")


+ 1
- 1
fastNLP/core/callbacks/topk_saver.py View File

@@ -33,7 +33,7 @@ class Saver:
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。
"""
if folder is None:
logger.warning(
logger.rank_zero_warning(
"Parameter `folder` is None, and we will use the current work directory to find and load your model.")
folder = Path.cwd()
folder = Path(folder)


+ 1
- 1
fastNLP/core/controllers/evaluator.py View File

@@ -221,7 +221,7 @@ class Evaluator:
@evaluate_batch_loop.setter
def evaluate_batch_loop(self, loop: Loop):
if self.evaluate_batch_step_fn is not None:
logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
logger.rank_zero_warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored "
"when the `evaluate_batch_loop` is also customized.")
self._evaluate_batch_loop = loop



+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -305,7 +305,7 @@ class Trainer(TrainerEventTrigger):
else:
if self.driver.is_distributed():
if catch_KeyboardInterrupt:
logger.warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device "
logger.rank_zero_warning("Parameter `catch_KeyboardInterrupt` can only be False when you are using multi-device "
"driver. And we are gonna to set it to False.")
catch_KeyboardInterrupt = False

@@ -535,7 +535,7 @@ class Trainer(TrainerEventTrigger):
_not_called_callback_fns.append(each_callback_fn)

if check_mode:
logger.warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
f"callback_fns: {_not_called_callback_fns}, but it seems that"
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
# 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass


+ 72
- 104
fastNLP/core/dataset/dataset.py View File

@@ -9,22 +9,18 @@ __all__ = [
import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from functools import partial
from types import LambdaType
import sys
import time

import numpy as np
from threading import Thread

try:
import multiprocessing as mp
except:
pass

from .field import FieldArray
from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.log import logger
from ..log import logger


class ApplyResultException(Exception):
@@ -35,14 +31,13 @@ class ApplyResultException(Exception):


def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True,
pipe=None, desc: str = None) -> list:
desc: str = None) -> list:
"""
对数据集进行处理封装函数,以便多进程使用

:param ds: 数据集
:param _apply_field: 需要处理数据集的field_name
:param func: 用户自定义的func
:param pipe: 管道
:param desc: 进度条的描述字符
:param show_progress_bar: 是否展示子进程进度条
:return:
@@ -60,8 +55,6 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s
results.append(func(ins[_apply_field]))
else:
results.append(func(ins))
if pipe is not None:
pipe.send([idx + 1])
if show_progress_bar:
f_rich_progress.update(pg_main, advance=1)

@@ -75,31 +68,36 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s
return results


def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None:
def _multi_proc(ds, _apply_field, func, counter, queue):
"""
多进程下显示主进程的进度条
对数据集进行处理封装函数,以便多进程使用

:param parent: 进程管道
:param total_len: 数据集总长度
:param desc: 进度条描述符
:param show_progress_bar: 是否展示进度条
:param ds: 数据集
:param _apply_field: 需要处理数据集的field_name
:param func: 用户自定义的func
:param counter: 计数器
:param queue: 多进程时,将结果输入到这个 queue 中
:return:
"""
desc = desc if desc else "Main"

main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar)
# pb_main = tqdm(total=total_len, desc=desc, position=0)
nums = 0
while True:
msg = parent.recv()[0]
if msg is not None:
f_rich_progress.update(main_pro, advance=1)
nums += 1

if nums == total_len:
break
f_rich_progress.destroy_task(main_pro)
# pb_main.close()
idx = -1
import contextlib
with contextlib.redirect_stdout(None): # 避免打印触发 rich 的锁
logger.set_stdout(stdout='raw')
results = []
try:
for idx, ins in enumerate(ds):
if _apply_field is not None:
res = func(ins[_apply_field])
else:
res = func(ins)
results.append(res)
with counter.get_lock():
counter.value += 1
except BaseException as e:
if idx != -1:
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
queue.put(pickle.dumps(results))


class DataSet:
@@ -114,7 +112,7 @@ class DataSet:
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。
"""
self.field_arrays = {}
self._collator = Collator(backend="numpy")
self._collator = Collator()
if data is not None:
if isinstance(data, Dict):
length_set = set()
@@ -127,7 +125,6 @@ class DataSet:
for ins in data:
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
self.append(ins)

else:
raise ValueError("data only be dict or list type.")

@@ -263,7 +260,7 @@ class DataSet:
try:
self.field_arrays[name].append(field)
except Exception as e:
print(f"Cannot append to field:{name}.")
logger.error(f"Cannot append to field:{name}.")
raise e

def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None:
@@ -469,9 +466,7 @@ class DataSet:

except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
print(e.msg)
print("Exception happens at the `{}`th instance.".format(idx + 1))
logger.error("Exception happens at the `{}`th instance.".format(idx + 1))
raise e

if modify_fields is True:
@@ -490,18 +485,19 @@ class DataSet:
:param show_progress_bar: 是否展示progress进度条,默认为展示
:param progress_desc: 进度条的描述字符,默认为'Main
"""
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>":
raise ("Lambda function does not support multiple processes, please set `num_proc=0`.")
if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'):
raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`")

if num_proc == 0:
if num_proc < 2:
results = _apply_single(ds=self, _apply_field=_apply_field, func=func,
desc=progress_desc, show_progress_bar=show_progress_bar)
else:
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2
results = []
if num_proc > len(self):
num_proc = len(self)
print(
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}."
)
import multiprocessing as mp
ctx = mp.get_context('fork')
num_proc = min(num_proc, len(self))
# 划分数据集
shard_len = len(self) // num_proc
num_left_sample = len(self) % num_proc
@@ -511,24 +507,32 @@ class DataSet:
end = shard_len + int(_i<num_left_sample) + start
shard_data.append(self[start:end])
start = end
# 配置管道,线程以实现 main progress 能够实时更新。
parent, child = mp.Pipe()
main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc,
show_progress_bar))
partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func,
pipe=child, show_progress_bar=False)
# 开启进程池,线程
main_thread.start()
pool = mp.Pool(processes=num_proc)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
main_thread.join()

for async_result in pool_outs:
data = async_result.get()
results.extend(data)
# 配置共享参数,线程以实现 main progress 能够实时更新。
counter = ctx.Value('i', 0, lock=True)
pool = []
queues = []
results = []
for i in range(num_proc):
queue = ctx.SimpleQueue()
proc = ctx.Process(target=_multi_proc, args=(shard_data[i], _apply_field, func, counter, queue))
proc.start()
pool.append(proc)
queues.append(queue)

total_len = len(self)
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar)
last_count = -1
while counter.value < total_len or last_count == -1:
while counter.value == last_count:
time.sleep(0.1)
advance = counter.value - last_count
last_count = counter.value
f_rich_progress.update(task_id, advance=advance, refresh=True)

for idx, proc in enumerate(pool):
results.extend(pickle.loads(queues[idx].get()))
proc.join()
f_rich_progress.destroy_task(task_id)
return results

def apply_more(self, func: Callable = None, modify_fields: bool = True,
@@ -552,8 +556,7 @@ class DataSet:
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称
:return Dict[str:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable."
assert callable(func), "The func is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
assert num_proc >= 0, "num_proc must >= 0"
idx = -1
@@ -577,9 +580,7 @@ class DataSet:

except Exception as e:
if idx != -1:
if isinstance(e, ApplyResultException):
print(e.msg)
print("Exception happens at the `{}`th instance.".format(idx + 1))
logger.error("Exception happens at the `{}`th instance.".format(idx + 1))
raise e

if modify_fields is True:
@@ -665,8 +666,7 @@ class DataSet:
np.random.shuffle(all_indices)
split = int(ratio * len(self))
if split == 0:
error_msg = f'Dev DataSet has {split} instance after split.'
print(error_msg)
error_msg = f'Dev DataSet has `{split}` instance after split.'
raise IndexError(error_msg)
dev_indices = all_indices[:split]
train_indices = all_indices[split:]
@@ -776,35 +776,3 @@ class DataSet:
if self._collator is None:
self._collator = Collator()
return self._collator


if __name__ == '__main__':
# from fastNLP import DataSet

# if __name__=='__main__':
# data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})
# data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2, show_progress_bar=True)

import multiprocess as mp
# from fastNLP.core.dataset.dataset import _apply_single, _progress_bar
from functools import partial
from threading import Thread

shard_data = [DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}),
DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100})]
parent, chid = mp.Pipe()
partial_single_map = partial(_apply_single, _apply_field='x', func=lambda x: len(x),
pipe=chid, show_progress_bar=False)
thread = Thread(target=_progress_bar, args=(parent, 400, 'main'))
thread.start()
pool = mp.Pool(processes=6)
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds})
for proc_id, ds in enumerate(shard_data)]
pool.close()
pool.join()
thread.join()
results = []
for async_result in pool_outs:
data = async_result.get()
results.extend(data)
print(results)

+ 2
- 2
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -69,7 +69,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
if not isinstance(device, List):
return PaddleSingleDriver(model, device, **kwargs)
else:
logger.warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use"
logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use"
"`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter"
"`driver` as `PaddleFleetDriver`.")
return PaddleFleetDriver(model, device, **kwargs)
@@ -77,7 +77,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[
if not isinstance(device, List):
if device == "cpu":
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.")
logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will"
logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will"
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should "
"choose `paddle` driver.")
return PaddleFleetDriver(model, [device], **kwargs)


+ 4
- 4
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -72,7 +72,7 @@ class PaddleDriver(Driver):
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。
"""
if set_to_none:
logger.warning_once("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
logger.rank_zero_warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.")
for optimizer in self.optimizers:
optimizer.clear_grad()

@@ -233,7 +233,7 @@ class PaddleDriver(Driver):
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
@@ -243,7 +243,7 @@ class PaddleDriver(Driver):
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
else:
raise RuntimeError(
@@ -306,7 +306,7 @@ class PaddleDriver(Driver):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")

# 4. 恢复 sampler 的状态;


+ 1
- 1
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -51,7 +51,7 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) ->
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for "
logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for "
"you.")

# rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")


+ 2
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -197,7 +197,7 @@ class TorchDriver(Driver):
if dataloader_args.batch_size is not None:
num_consumed_batches = num_consumed_batches * dataloader_args.batch_size
else: # 有可能 batch_size 为 None,就只有损失精度了
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")
num_consumed_batches = sampler_states['num_consumed_samples']
sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches]
@@ -207,7 +207,7 @@ class TorchDriver(Driver):
sampler_states['num_consumed_samples'] = sampler.num_replicas * dataloader_args.batch_size \
* num_consumed_batches
else:
logger.warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
logger.rank_zero_warning("fastNLP cannot get batch_size, we have to save based on `num_consumed_samples`, "
"it may cause missing some samples when reload.")

states['sampler_states'] = sampler_states


+ 2
- 2
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -60,7 +60,7 @@ def torch_seed_everything(seed: Optional[int] = None, workers: bool = False) ->
seed = int(seed)

if not (min_seed_value <= seed <= max_seed_value):
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.")
logger.rank_zero_warning("Your seed value is two big or two small for numpy, we will choose a random seed for you.")

seed = _select_seed_randomly(min_seed_value, max_seed_value)

@@ -162,7 +162,7 @@ def _build_fp16_env(dummy=False):
if not torch.cuda.is_available():
raise RuntimeError("No cuda")
if torch.cuda.get_device_capability(0)[0] < 7:
logger.warning(
logger.rank_zero_warning(
"NOTE: your device does NOT support faster training with fp16, "
"please switch to FP32 which is likely to be faster"
)


+ 6
- 2
fastNLP/core/log/logger.py View File

@@ -124,18 +124,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)

def rank_zero_warning(self, msg, *args, **kwargs):
def rank_zero_warning(self, msg, *args, once=False, **kwargs):
"""
只在 rank 0 上 warning 。

:param msg:
:param args:
:param once: 是否只 warning 一次
:param kwargs:
:return:
"""
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0':
if once and msg in self._warning_msgs:
return
if self.isEnabledFor(WARNING):
# kwargs = self._add_rank_info(kwargs)
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)

def warn(self, msg, *args, **kwargs):
@@ -302,6 +305,7 @@ def _set_stdout_handler(_logger, stdout='raw', level='INFO'):
break
if stream_handler is not None:
_logger.removeHandler(stream_handler)
del stream_handler

# Stream Handler
if stdout == 'raw':


+ 2
- 1
fastNLP/core/utils/cache_results.py View File

@@ -15,6 +15,7 @@ __all__ = [

from fastNLP.core.log.logger import logger
from fastNLP.core.log.highlighter import ColorHighlighter
from .utils import _get_fun_msg


class FuncCallVisitor(ast.NodeVisitor):
@@ -306,7 +307,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec
if verbose == 1:
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time))
if check_hash and old_hash_code != new_hash_code:
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The "
logger.warning(f"The function {_get_fun_msg(func)} is different from its last cache (Save on {save_time}). The "
f"difference may caused by the sourcecode change.",
extra={'highlighter': ColorHighlighter('red')})
refresh_flag = False


+ 1
- 1
fastNLP/core/utils/utils.py View File

@@ -239,7 +239,7 @@ def check_user_specific_params(user_params: Dict, fn: Callable):
fn_arg_names = get_fn_arg_names(fn)
for arg_name, arg_value in user_params.items():
if arg_name not in fn_arg_names:
logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.")
logger.rank_zero_warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.")
return user_params




Loading…
Cancel
Save