diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 7b04d8ad..3966c6fc 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -50,20 +50,20 @@ def prepare_callbacks(callbacks, progress_bar): raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") _callbacks += callbacks - has_no_progress = False + has_no_progress = True for _callback in _callbacks: if isinstance(_callback, ProgressCallback): - has_no_progress = True - if not has_no_progress: + has_no_progress = False + if has_no_progress and progress_bar is not None: callback = choose_progress_callback(progress_bar) if callback is not None: _callbacks.append(callback) - elif progress_bar is not None and progress_bar != 'auto': - logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.") + has_no_progress = False + elif has_no_progress is False and progress_bar not in ('auto', None): + logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") - if has_no_progress and progress_bar is None: - rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " - "during training.") + if has_no_progress: + logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.") return _callbacks diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 32534d2a..227c80c4 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -87,17 +87,20 @@ class LoadBestModelCallback(HasMonitorCallback): trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) def on_train_end(self, trainer): - logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") - if self.real_save_folder: - trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, - model_load_fn=self.model_load_fn) - else: - self.buffer.seek(0) - trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) - if self.delete_after_after: - trainer.driver.barrier() - self._delete_folder() - trainer.driver.barrier() + if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 + if self.real_save_folder: + logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") + trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, + model_load_fn=self.model_load_fn) + else: + logger.info( + f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") + self.buffer.seek(0) + trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) + if self.delete_after_after: + trainer.driver.barrier() + self._delete_folder() + trainer.driver.barrier() def _delete_folder(self): if self.real_save_folder: diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index f4ae0300..5432b17a 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -138,8 +138,6 @@ class PaddleTensorPadder(Padder): shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] - if isinstance(dtype, np.dtype): - print(dtype) tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 1616fb85..923f6415 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -4,11 +4,11 @@ __all__ = [ ] from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List +import inspect from fastNLP.core.dataset import DataSet from fastNLP.core.collators import Collator from fastNLP.core.dataloaders.utils import indice_collate_wrapper -# from fastNLP.io.data_bundle import DataBundle from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler @@ -79,35 +79,30 @@ class TorchDataLoader(DataLoader): if sampler is None and batch_sampler is None: sampler = RandomSampler(dataset, shuffle=shuffle) shuffle=False - - super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, - batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, - pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, - multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers) if isinstance(collate_fn, str): if collate_fn == 'auto': if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset - self._collate_fn = dataset.dataset.collator - self._collate_fn.set_backend(backend="torch") + collate_fn = dataset.dataset.collator + collate_fn.set_backend(backend="torch") else: - self._collate_fn = Collator(backend="torch") + collate_fn = Collator(backend="torch") else: raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") - elif isinstance(collate_fn, Callable): - if collate_fn is not default_collate: - self._collate_fn = collate_fn - else: - self._collate_fn = default_collate - self.cur_indices_batch = None + super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, + multiprocessing_context=multiprocessing_context, generator=generator, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers) + + self.cur_batch_indices = None def __iter__(self): # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 # if len(self._collate_fn.get_collators()) == 0: # self._collate_fn.add_collator(self.collate_fn) - self.collate_fn = indice_collate_wrapper(self._collate_fn) + self.collate_fn = indice_collate_wrapper(self.collate_fn) for indices, data in super().__iter__(): self.cur_batch_indices = indices yield data @@ -132,12 +127,26 @@ class TorchDataLoader(DataLoader): 形式,输出将被直接作为结果输出。 :return: 返回 Collator 自身 """ - if isinstance(self._collate_fn, Collator): - self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) - return self._collate_fn + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") + def _get_collator(self): + """ + 如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None + + :return: + """ + collator = None + if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): + collator = self.collate_fn.__wrapped__ + elif isinstance(self.collate_fn, Collator): + collator = self.collate_fn + return collator + def set_ignore(self, *field_names) -> Collator: """ 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 @@ -149,9 +158,10 @@ class TorchDataLoader(DataLoader): __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 :return: 返回 Collator 自身 """ - if isinstance(self._collate_fn, Collator): - self._collate_fn.set_ignore(*field_names) - return self._collate_fn + collator = self._get_collator() + if isinstance(collator, Collator): + collator.set_ignore(*field_names) + return collator else: raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") @@ -164,7 +174,8 @@ class TorchDataLoader(DataLoader): return self.cur_batch_indices -def prepare_torch_dataloader(ds_or_db, + +def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 16, shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, @@ -197,7 +208,8 @@ def prepare_torch_dataloader(ds_or_db, :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler :param non_train_batch_size: """ - from fastNLP.io.data_bundle import DataBundle + + from fastNLP.io import DataBundle if isinstance(ds_or_db, DataSet): dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, @@ -208,7 +220,7 @@ def prepare_torch_dataloader(ds_or_db, ) return dl - elif isinstance(ds_or_db, DataBundle): + elif type(ds_or_db, DataBundle): dl_bundle = {} for name, ds in ds_or_db.iter_datasets(): if 'train' in name: diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 2305cebe..0ee496c5 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -10,12 +10,25 @@ def indice_collate_wrapper(func): :param func: 需要修饰的函数 :return: """ + if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 + return func - def wrapper(tuple_data): + def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 indice, ins_list = [], [] for idx, ins in tuple_data: indice.append(idx) ins_list.append(ins) return indice, func(ins_list) + _indice_collate_wrapper.__wrapped__ = func # 记录对应的 - return wrapper \ No newline at end of file + return _indice_collate_wrapper + + +if __name__ == '__main__': + def demo(*args, **kwargs): + pass + + d = indice_collate_wrapper(demo) + + print(d.__name__) + print(d.__wrapped__) \ No newline at end of file diff --git a/fastNLP/core/dataset/field.py b/fastNLP/core/dataset/field.py index 42ba700e..cbe064e6 100644 --- a/fastNLP/core/dataset/field.py +++ b/fastNLP/core/dataset/field.py @@ -8,6 +8,7 @@ __all__ = [ from collections import Counter from typing import Any, Union, List, Callable +from ..log import logger import numpy as np @@ -21,7 +22,7 @@ class FieldArray: try: _content = list(_content) except BaseException as e: - print(f"Cannot convert content(of type:{type(content)}) into list.") + logger.error(f"Cannot convert content(of type:{type(content)}) into list.") raise e self.name = name self.content = _content @@ -87,7 +88,7 @@ class FieldArray: try: new_contents.append(cell.split(sep)) except Exception as e: - print(f"Exception happens when process value in index {index}.") + logger.error(f"Exception happens when process value in index {index}.") raise e return self._after_process(new_contents, inplace=inplace) diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index bdfc299f..bbc1e8e1 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -111,7 +111,7 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): def warning_once(self, msg, *args, **kwargs): """ - 通过 warning 内容只会 warning 一次 + 相同的 warning 内容只会 warning 一次 :param msg: :param args: @@ -124,6 +124,22 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): self._log(WARNING, msg, args, **kwargs) self._warning_msgs.add(msg) + def rank_zero_warning(self, msg, *args, **kwargs): + """ + 只在 rank 0 上 warning 。 + + :param msg: + :param args: + :param kwargs: + :return: + """ + if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': + if msg not in self._warning_msgs: + if self.isEnabledFor(WARNING): + # kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + self._warning_msgs.add(msg) + def warn(self, msg, *args, **kwargs): if self.isEnabledFor(WARNING): kwargs = self._add_rank_info(kwargs) diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index e7b95d9c..4799765f 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -156,8 +156,9 @@ class FRichProgress(Progress, metaclass=Singleton): super().stop_task(task_id) super().remove_task(task_id) self.refresh() # 使得bar不残留 - if len(self._tasks) == 0: - super().stop() + # 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 + # if len(self._tasks) == 0: + # self.live.stop() def start(self) -> None: super().start() diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 4591a959..9fe8d3c8 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -15,6 +15,7 @@ from functools import wraps from fastNLP.core.dataset import DataSet from fastNLP.core.utils.utils import Option from fastNLP.core.utils.utils import _is_iterable +from .log import logger import io @@ -56,7 +57,7 @@ def _check_build_status(func): if self.rebuild is False: self.rebuild = True if self.max_size is not None and len(self.word_count) >= self.max_size: - print("[Warning] Vocabulary has reached the max size {} when calling {} method. " + logger.warning("Vocabulary has reached the max size {} when calling {} method. " "Adding more words may cause unexpected behaviour of Vocabulary. ".format( self.max_size, func.__name__)) return func(self, *args, **kwargs) @@ -322,7 +323,7 @@ class Vocabulary(object): for f_n, n_f_n in zip(field_name, new_field_name): dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) except Exception as e: - print("When processing the `{}` dataset, the following error occurred.".format(idx)) + logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) raise e else: raise RuntimeError("Only DataSet type is allowed.") @@ -378,7 +379,7 @@ class Vocabulary(object): try: dataset.apply(construct_vocab) except BaseException as e: - print("When processing the `{}` dataset, the following error occurred:".format(idx)) + logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) raise e else: raise TypeError("Only DataSet type is allowed.") diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index a14439ce..5a0dc78d 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -10,7 +10,7 @@ from typing import Union, List, Callable from ..core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary -# from ..core._logger import _logger +from fastNLP.core import logger class DataBundle: @@ -72,7 +72,7 @@ class DataBundle: else: error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ f'It should be one of {self.datasets.keys()}.' - print(error_msg) + logger.error(error_msg) raise KeyError(error_msg) def delete_dataset(self, name: str): @@ -97,7 +97,7 @@ class DataBundle: else: error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ f'It should be one of {self.vocabs.keys()}.' - print(error_msg) + logger.error(error_msg) raise KeyError(error_msg) def delete_vocab(self, field_name: str): diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 43460d19..9181bc06 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -117,7 +117,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): yield line_idx, res except Exception as e: if dropna: - print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) + logger.error('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) sample = [] continue raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) @@ -132,5 +132,5 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): except Exception as e: if dropna: return - print('invalid instance ends at line: {}'.format(line_idx)) + logger.error('invalid instance ends at line: {}'.format(line_idx)) raise e diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 0b1a670b..be04c2cc 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -29,6 +29,7 @@ import warnings from .loader import Loader from fastNLP.core.dataset import Instance, DataSet +from ...core import logger # from ...core._logger import log @@ -86,7 +87,8 @@ class CLSBaseLoader(Loader): if raw_words: ds.append(Instance(raw_words=raw_words, target=target)) except Exception as e: - print(f'Load file `{path}` failed for `{e}`') + logger.error(f'Fail to load `{path}`.') + raise e return ds