Browse Source

修改torch fdl冲突

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
74a4b3c356
12 changed files with 111 additions and 64 deletions
  1. +8
    -8
      fastNLP/core/callbacks/callback_manager.py
  2. +14
    -11
      fastNLP/core/callbacks/load_best_model_callback.py
  3. +0
    -2
      fastNLP/core/collators/padders/paddle_padder.py
  4. +39
    -27
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  5. +15
    -2
      fastNLP/core/dataloaders/utils.py
  6. +3
    -2
      fastNLP/core/dataset/field.py
  7. +17
    -1
      fastNLP/core/log/logger.py
  8. +3
    -2
      fastNLP/core/utils/rich_progress.py
  9. +4
    -3
      fastNLP/core/vocabulary.py
  10. +3
    -3
      fastNLP/io/data_bundle.py
  11. +2
    -2
      fastNLP/io/file_reader.py
  12. +3
    -1
      fastNLP/io/loader/classification.py

+ 8
- 8
fastNLP/core/callbacks/callback_manager.py View File

@@ -50,20 +50,20 @@ def prepare_callbacks(callbacks, progress_bar):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
_callbacks += callbacks _callbacks += callbacks


has_no_progress = False
has_no_progress = True
for _callback in _callbacks: for _callback in _callbacks:
if isinstance(_callback, ProgressCallback): if isinstance(_callback, ProgressCallback):
has_no_progress = True
if not has_no_progress:
has_no_progress = False
if has_no_progress and progress_bar is not None:
callback = choose_progress_callback(progress_bar) callback = choose_progress_callback(progress_bar)
if callback is not None: if callback is not None:
_callbacks.append(callback) _callbacks.append(callback)
elif progress_bar is not None and progress_bar != 'auto':
logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.")
has_no_progress = False
elif has_no_progress is False and progress_bar not in ('auto', None):
logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.")


if has_no_progress and progress_bar is None:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")
if has_no_progress:
logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.")


return _callbacks return _callbacks




+ 14
- 11
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -87,17 +87,20 @@ class LoadBestModelCallback(HasMonitorCallback):
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)


def on_train_end(self, trainer): def on_train_end(self, trainer):
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...")
if self.real_save_folder:
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:
trainer.driver.barrier()
self._delete_folder()
trainer.driver.barrier()
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(
f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:
trainer.driver.barrier()
self._delete_folder()
trainer.driver.barrier()


def _delete_folder(self): def _delete_folder(self):
if self.real_save_folder: if self.real_save_folder:


+ 0
- 2
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -138,8 +138,6 @@ class PaddleTensorPadder(Padder):


shapes = [field.shape for field in batch_field] shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
print(dtype)
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field): for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) slices = (i, ) + tuple(slice(0, s) for s in shapes[i])


+ 39
- 27
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -4,11 +4,11 @@ __all__ = [
] ]


from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List
import inspect


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper from fastNLP.core.dataloaders.utils import indice_collate_wrapper
# from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler


@@ -79,35 +79,30 @@ class TorchDataLoader(DataLoader):
if sampler is None and batch_sampler is None: if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle) sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False shuffle=False

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
if isinstance(collate_fn, str): if isinstance(collate_fn, str):
if collate_fn == 'auto': if collate_fn == 'auto':
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch")
collate_fn = dataset.dataset.collator
collate_fn.set_backend(backend="torch")
else: else:
self._collate_fn = Collator(backend="torch")
collate_fn = Collator(backend="torch")
else: else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable):
if collate_fn is not default_collate:
self._collate_fn = collate_fn
else:
self._collate_fn = default_collate


self.cur_indices_batch = None
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)

self.cur_batch_indices = None


def __iter__(self): def __iter__(self):
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
# if len(self._collate_fn.get_collators()) == 0: # if len(self._collate_fn.get_collators()) == 0:
# self._collate_fn.add_collator(self.collate_fn) # self._collate_fn.add_collator(self.collate_fn)
self.collate_fn = indice_collate_wrapper(self._collate_fn)
self.collate_fn = indice_collate_wrapper(self.collate_fn)
for indices, data in super().__iter__(): for indices, data in super().__iter__():
self.cur_batch_indices = indices self.cur_batch_indices = indices
yield data yield data
@@ -132,12 +127,26 @@ class TorchDataLoader(DataLoader):
形式,输出将被直接作为结果输出。 形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return collator
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")


def _get_collator(self):
"""
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None

:return:
"""
collator = None
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
collator = self.collate_fn.__wrapped__
elif isinstance(self.collate_fn, Collator):
collator = self.collate_fn
return collator

def set_ignore(self, *field_names) -> Collator: def set_ignore(self, *field_names) -> Collator:
""" """
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
@@ -149,9 +158,10 @@ class TorchDataLoader(DataLoader):
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身 :return: 返回 Collator 自身
""" """
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self._collate_fn
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_ignore(*field_names)
return collator
else: else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")


@@ -164,7 +174,8 @@ class TorchDataLoader(DataLoader):
return self.cur_batch_indices return self.cur_batch_indices




def prepare_torch_dataloader(ds_or_db,

def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 16, batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
@@ -197,7 +208,8 @@ def prepare_torch_dataloader(ds_or_db,
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler
:param non_train_batch_size: :param non_train_batch_size:
""" """
from fastNLP.io.data_bundle import DataBundle

from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataSet): if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
@@ -208,7 +220,7 @@ def prepare_torch_dataloader(ds_or_db,
) )
return dl return dl


elif isinstance(ds_or_db, DataBundle):
elif type(ds_or_db, DataBundle):
dl_bundle = {} dl_bundle = {}
for name, ds in ds_or_db.iter_datasets(): for name, ds in ds_or_db.iter_datasets():
if 'train' in name: if 'train' in name:


+ 15
- 2
fastNLP/core/dataloaders/utils.py View File

@@ -10,12 +10,25 @@ def indice_collate_wrapper(func):
:param func: 需要修饰的函数 :param func: 需要修饰的函数
:return: :return:
""" """
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了
return func


def wrapper(tuple_data):
def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到
indice, ins_list = [], [] indice, ins_list = [], []
for idx, ins in tuple_data: for idx, ins in tuple_data:
indice.append(idx) indice.append(idx)
ins_list.append(ins) ins_list.append(ins)
return indice, func(ins_list) return indice, func(ins_list)
_indice_collate_wrapper.__wrapped__ = func # 记录对应的


return wrapper
return _indice_collate_wrapper


if __name__ == '__main__':
def demo(*args, **kwargs):
pass

d = indice_collate_wrapper(demo)

print(d.__name__)
print(d.__wrapped__)

+ 3
- 2
fastNLP/core/dataset/field.py View File

@@ -8,6 +8,7 @@ __all__ = [


from collections import Counter from collections import Counter
from typing import Any, Union, List, Callable from typing import Any, Union, List, Callable
from ..log import logger


import numpy as np import numpy as np


@@ -21,7 +22,7 @@ class FieldArray:
try: try:
_content = list(_content) _content = list(_content)
except BaseException as e: except BaseException as e:
print(f"Cannot convert content(of type:{type(content)}) into list.")
logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
raise e raise e
self.name = name self.name = name
self.content = _content self.content = _content
@@ -87,7 +88,7 @@ class FieldArray:
try: try:
new_contents.append(cell.split(sep)) new_contents.append(cell.split(sep))
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)




+ 17
- 1
fastNLP/core/log/logger.py View File

@@ -111,7 +111,7 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):


def warning_once(self, msg, *args, **kwargs): def warning_once(self, msg, *args, **kwargs):
""" """
通过 warning 内容只会 warning 一次
相同的 warning 内容只会 warning 一次


:param msg: :param msg:
:param args: :param args:
@@ -124,6 +124,22 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
self._log(WARNING, msg, args, **kwargs) self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg) self._warning_msgs.add(msg)


def rank_zero_warning(self, msg, *args, **kwargs):
"""
只在 rank 0 上 warning 。

:param msg:
:param args:
:param kwargs:
:return:
"""
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0':
if msg not in self._warning_msgs:
if self.isEnabledFor(WARNING):
# kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)

def warn(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs):
if self.isEnabledFor(WARNING): if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs) kwargs = self._add_rank_info(kwargs)


+ 3
- 2
fastNLP/core/utils/rich_progress.py View File

@@ -156,8 +156,9 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id) super().stop_task(task_id)
super().remove_task(task_id) super().remove_task(task_id)
self.refresh() # 使得bar不残留 self.refresh() # 使得bar不残留
if len(self._tasks) == 0:
super().stop()
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。
# if len(self._tasks) == 0:
# self.live.stop()


def start(self) -> None: def start(self) -> None:
super().start() super().start()


+ 4
- 3
fastNLP/core/vocabulary.py View File

@@ -15,6 +15,7 @@ from functools import wraps
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.utils.utils import Option from fastNLP.core.utils.utils import Option
from fastNLP.core.utils.utils import _is_iterable from fastNLP.core.utils.utils import _is_iterable
from .log import logger
import io import io




@@ -56,7 +57,7 @@ def _check_build_status(func):
if self.rebuild is False: if self.rebuild is False:
self.rebuild = True self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size: if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
logger.warning("Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( "Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__)) self.max_size, func.__name__))
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
@@ -322,7 +323,7 @@ class Vocabulary(object):
for f_n, n_f_n in zip(field_name, new_field_name): for f_n, n_f_n in zip(field_name, new_field_name):
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
except Exception as e: except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e raise e
else: else:
raise RuntimeError("Only DataSet type is allowed.") raise RuntimeError("Only DataSet type is allowed.")
@@ -378,7 +379,7 @@ class Vocabulary(object):
try: try:
dataset.apply(construct_vocab) dataset.apply(construct_vocab)
except BaseException as e: except BaseException as e:
print("When processing the `{}` dataset, the following error occurred:".format(idx))
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e raise e
else: else:
raise TypeError("Only DataSet type is allowed.") raise TypeError("Only DataSet type is allowed.")


+ 3
- 3
fastNLP/io/data_bundle.py View File

@@ -10,7 +10,7 @@ from typing import Union, List, Callable


from ..core.dataset import DataSet from ..core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
# from ..core._logger import _logger
from fastNLP.core import logger




class DataBundle: class DataBundle:
@@ -72,7 +72,7 @@ class DataBundle:
else: else:
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ error_msg = f'DataBundle do NOT have DataSet named {name}. ' \
f'It should be one of {self.datasets.keys()}.' f'It should be one of {self.datasets.keys()}.'
print(error_msg)
logger.error(error_msg)
raise KeyError(error_msg) raise KeyError(error_msg)


def delete_dataset(self, name: str): def delete_dataset(self, name: str):
@@ -97,7 +97,7 @@ class DataBundle:
else: else:
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \
f'It should be one of {self.vocabs.keys()}.' f'It should be one of {self.vocabs.keys()}.'
print(error_msg)
logger.error(error_msg)
raise KeyError(error_msg) raise KeyError(error_msg)


def delete_vocab(self, field_name: str): def delete_vocab(self, field_name: str):


+ 2
- 2
fastNLP/io/file_reader.py View File

@@ -117,7 +117,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
yield line_idx, res yield line_idx, res
except Exception as e: except Exception as e:
if dropna: if dropna:
print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
logger.error('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
sample = [] sample = []
continue continue
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
@@ -132,5 +132,5 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
except Exception as e: except Exception as e:
if dropna: if dropna:
return return
print('invalid instance ends at line: {}'.format(line_idx))
logger.error('invalid instance ends at line: {}'.format(line_idx))
raise e raise e

+ 3
- 1
fastNLP/io/loader/classification.py View File

@@ -29,6 +29,7 @@ import warnings


from .loader import Loader from .loader import Loader
from fastNLP.core.dataset import Instance, DataSet from fastNLP.core.dataset import Instance, DataSet
from ...core import logger




# from ...core._logger import log # from ...core._logger import log
@@ -86,7 +87,8 @@ class CLSBaseLoader(Loader):
if raw_words: if raw_words:
ds.append(Instance(raw_words=raw_words, target=target)) ds.append(Instance(raw_words=raw_words, target=target))
except Exception as e: except Exception as e:
print(f'Load file `{path}` failed for `{e}`')
logger.error(f'Fail to load `{path}`.')
raise e
return ds return ds






Loading…
Cancel
Save