@@ -50,20 +50,20 @@ def prepare_callbacks(callbacks, progress_bar): | |||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | |||
_callbacks += callbacks | |||
has_no_progress = False | |||
has_no_progress = True | |||
for _callback in _callbacks: | |||
if isinstance(_callback, ProgressCallback): | |||
has_no_progress = True | |||
if not has_no_progress: | |||
has_no_progress = False | |||
if has_no_progress and progress_bar is not None: | |||
callback = choose_progress_callback(progress_bar) | |||
if callback is not None: | |||
_callbacks.append(callback) | |||
elif progress_bar is not None and progress_bar != 'auto': | |||
logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.") | |||
has_no_progress = False | |||
elif has_no_progress is False and progress_bar not in ('auto', None): | |||
logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") | |||
if has_no_progress and progress_bar is None: | |||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||
"during training.") | |||
if has_no_progress: | |||
logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.") | |||
return _callbacks | |||
@@ -87,17 +87,20 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
def on_train_end(self, trainer): | |||
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") | |||
if self.real_save_folder: | |||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
model_load_fn=self.model_load_fn) | |||
else: | |||
self.buffer.seek(0) | |||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
if self.delete_after_after: | |||
trainer.driver.barrier() | |||
self._delete_folder() | |||
trainer.driver.barrier() | |||
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | |||
if self.real_save_folder: | |||
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
model_load_fn=self.model_load_fn) | |||
else: | |||
logger.info( | |||
f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||
self.buffer.seek(0) | |||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
if self.delete_after_after: | |||
trainer.driver.barrier() | |||
self._delete_folder() | |||
trainer.driver.barrier() | |||
def _delete_folder(self): | |||
if self.real_save_folder: | |||
@@ -138,8 +138,6 @@ class PaddleTensorPadder(Padder): | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if isinstance(dtype, np.dtype): | |||
print(dtype) | |||
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
for i, field in enumerate(batch_field): | |||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
@@ -4,11 +4,11 @@ __all__ = [ | |||
] | |||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | |||
import inspect | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | |||
# from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | |||
@@ -79,35 +79,30 @@ class TorchDataLoader(DataLoader): | |||
if sampler is None and batch_sampler is None: | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
shuffle=False | |||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | |||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
if isinstance(collate_fn, str): | |||
if collate_fn == 'auto': | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
self._collate_fn = dataset.dataset.collator | |||
self._collate_fn.set_backend(backend="torch") | |||
collate_fn = dataset.dataset.collator | |||
collate_fn.set_backend(backend="torch") | |||
else: | |||
self._collate_fn = Collator(backend="torch") | |||
collate_fn = Collator(backend="torch") | |||
else: | |||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | |||
elif isinstance(collate_fn, Callable): | |||
if collate_fn is not default_collate: | |||
self._collate_fn = collate_fn | |||
else: | |||
self._collate_fn = default_collate | |||
self.cur_indices_batch = None | |||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
self.cur_batch_indices = None | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
# if len(self._collate_fn.get_collators()) == 0: | |||
# self._collate_fn.add_collator(self.collate_fn) | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
@@ -132,12 +127,26 @@ class TorchDataLoader(DataLoader): | |||
形式,输出将被直接作为结果输出。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
return self._collate_fn | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||
return collator | |||
else: | |||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | |||
def _get_collator(self): | |||
""" | |||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||
:return: | |||
""" | |||
collator = None | |||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||
collator = self.collate_fn.__wrapped__ | |||
elif isinstance(self.collate_fn, Collator): | |||
collator = self.collate_fn | |||
return collator | |||
def set_ignore(self, *field_names) -> Collator: | |||
""" | |||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
@@ -149,9 +158,10 @@ class TorchDataLoader(DataLoader): | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
if isinstance(self._collate_fn, Collator): | |||
self._collate_fn.set_ignore(*field_names) | |||
return self._collate_fn | |||
collator = self._get_collator() | |||
if isinstance(collator, Collator): | |||
collator.set_ignore(*field_names) | |||
return collator | |||
else: | |||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | |||
@@ -164,7 +174,8 @@ class TorchDataLoader(DataLoader): | |||
return self.cur_batch_indices | |||
def prepare_torch_dataloader(ds_or_db, | |||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | |||
batch_size: int = 16, | |||
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
@@ -197,7 +208,8 @@ def prepare_torch_dataloader(ds_or_db, | |||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
:param non_train_batch_size: | |||
""" | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.io import DataBundle | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
@@ -208,7 +220,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
elif type(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
@@ -10,12 +10,25 @@ def indice_collate_wrapper(func): | |||
:param func: 需要修饰的函数 | |||
:return: | |||
""" | |||
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||
return func | |||
def wrapper(tuple_data): | |||
def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 | |||
indice, ins_list = [], [] | |||
for idx, ins in tuple_data: | |||
indice.append(idx) | |||
ins_list.append(ins) | |||
return indice, func(ins_list) | |||
_indice_collate_wrapper.__wrapped__ = func # 记录对应的 | |||
return wrapper | |||
return _indice_collate_wrapper | |||
if __name__ == '__main__': | |||
def demo(*args, **kwargs): | |||
pass | |||
d = indice_collate_wrapper(demo) | |||
print(d.__name__) | |||
print(d.__wrapped__) |
@@ -8,6 +8,7 @@ __all__ = [ | |||
from collections import Counter | |||
from typing import Any, Union, List, Callable | |||
from ..log import logger | |||
import numpy as np | |||
@@ -21,7 +22,7 @@ class FieldArray: | |||
try: | |||
_content = list(_content) | |||
except BaseException as e: | |||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||
logger.error(f"Cannot convert content(of type:{type(content)}) into list.") | |||
raise e | |||
self.name = name | |||
self.content = _content | |||
@@ -87,7 +88,7 @@ class FieldArray: | |||
try: | |||
new_contents.append(cell.split(sep)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
@@ -111,7 +111,7 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
def warning_once(self, msg, *args, **kwargs): | |||
""" | |||
通过 warning 内容只会 warning 一次 | |||
相同的 warning 内容只会 warning 一次 | |||
:param msg: | |||
:param args: | |||
@@ -124,6 +124,22 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
self._log(WARNING, msg, args, **kwargs) | |||
self._warning_msgs.add(msg) | |||
def rank_zero_warning(self, msg, *args, **kwargs): | |||
""" | |||
只在 rank 0 上 warning 。 | |||
:param msg: | |||
:param args: | |||
:param kwargs: | |||
:return: | |||
""" | |||
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||
if msg not in self._warning_msgs: | |||
if self.isEnabledFor(WARNING): | |||
# kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
self._warning_msgs.add(msg) | |||
def warn(self, msg, *args, **kwargs): | |||
if self.isEnabledFor(WARNING): | |||
kwargs = self._add_rank_info(kwargs) | |||
@@ -156,8 +156,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
self.refresh() # 使得bar不残留 | |||
if len(self._tasks) == 0: | |||
super().stop() | |||
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 | |||
# if len(self._tasks) == 0: | |||
# self.live.stop() | |||
def start(self) -> None: | |||
super().start() | |||
@@ -15,6 +15,7 @@ from functools import wraps | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils.utils import Option | |||
from fastNLP.core.utils.utils import _is_iterable | |||
from .log import logger | |||
import io | |||
@@ -56,7 +57,7 @@ def _check_build_status(func): | |||
if self.rebuild is False: | |||
self.rebuild = True | |||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
logger.warning("Vocabulary has reached the max size {} when calling {} method. " | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
self.max_size, func.__name__)) | |||
return func(self, *args, **kwargs) | |||
@@ -322,7 +323,7 @@ class Vocabulary(object): | |||
for f_n, n_f_n in zip(field_name, new_field_name): | |||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
@@ -378,7 +379,7 @@ class Vocabulary(object): | |||
try: | |||
dataset.apply(construct_vocab) | |||
except BaseException as e: | |||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
else: | |||
raise TypeError("Only DataSet type is allowed.") | |||
@@ -10,7 +10,7 @@ from typing import Union, List, Callable | |||
from ..core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
# from ..core._logger import _logger | |||
from fastNLP.core import logger | |||
class DataBundle: | |||
@@ -72,7 +72,7 @@ class DataBundle: | |||
else: | |||
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ | |||
f'It should be one of {self.datasets.keys()}.' | |||
print(error_msg) | |||
logger.error(error_msg) | |||
raise KeyError(error_msg) | |||
def delete_dataset(self, name: str): | |||
@@ -97,7 +97,7 @@ class DataBundle: | |||
else: | |||
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ | |||
f'It should be one of {self.vocabs.keys()}.' | |||
print(error_msg) | |||
logger.error(error_msg) | |||
raise KeyError(error_msg) | |||
def delete_vocab(self, field_name: str): | |||
@@ -117,7 +117,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||
logger.error('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||
sample = [] | |||
continue | |||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | |||
@@ -132,5 +132,5 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
except Exception as e: | |||
if dropna: | |||
return | |||
print('invalid instance ends at line: {}'.format(line_idx)) | |||
logger.error('invalid instance ends at line: {}'.format(line_idx)) | |||
raise e |
@@ -29,6 +29,7 @@ import warnings | |||
from .loader import Loader | |||
from fastNLP.core.dataset import Instance, DataSet | |||
from ...core import logger | |||
# from ...core._logger import log | |||
@@ -86,7 +87,8 @@ class CLSBaseLoader(Loader): | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, target=target)) | |||
except Exception as e: | |||
print(f'Load file `{path}` failed for `{e}`') | |||
logger.error(f'Fail to load `{path}`.') | |||
raise e | |||
return ds | |||