Browse Source

修改torch fdl冲突

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
74a4b3c356
12 changed files with 111 additions and 64 deletions
  1. +8
    -8
      fastNLP/core/callbacks/callback_manager.py
  2. +14
    -11
      fastNLP/core/callbacks/load_best_model_callback.py
  3. +0
    -2
      fastNLP/core/collators/padders/paddle_padder.py
  4. +39
    -27
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  5. +15
    -2
      fastNLP/core/dataloaders/utils.py
  6. +3
    -2
      fastNLP/core/dataset/field.py
  7. +17
    -1
      fastNLP/core/log/logger.py
  8. +3
    -2
      fastNLP/core/utils/rich_progress.py
  9. +4
    -3
      fastNLP/core/vocabulary.py
  10. +3
    -3
      fastNLP/io/data_bundle.py
  11. +2
    -2
      fastNLP/io/file_reader.py
  12. +3
    -1
      fastNLP/io/loader/classification.py

+ 8
- 8
fastNLP/core/callbacks/callback_manager.py View File

@@ -50,20 +50,20 @@ def prepare_callbacks(callbacks, progress_bar):
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`")
_callbacks += callbacks

has_no_progress = False
has_no_progress = True
for _callback in _callbacks:
if isinstance(_callback, ProgressCallback):
has_no_progress = True
if not has_no_progress:
has_no_progress = False
if has_no_progress and progress_bar is not None:
callback = choose_progress_callback(progress_bar)
if callback is not None:
_callbacks.append(callback)
elif progress_bar is not None and progress_bar != 'auto':
logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.")
has_no_progress = False
elif has_no_progress is False and progress_bar not in ('auto', None):
logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.")

if has_no_progress and progress_bar is None:
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output "
"during training.")
if has_no_progress:
logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.")

return _callbacks



+ 14
- 11
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -87,17 +87,20 @@ class LoadBestModelCallback(HasMonitorCallback):
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)

def on_train_end(self, trainer):
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...")
if self.real_save_folder:
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:
trainer.driver.barrier()
self._delete_folder()
trainer.driver.barrier()
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(
f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:
trainer.driver.barrier()
self._delete_folder()
trainer.driver.barrier()

def _delete_folder(self):
if self.real_save_folder:


+ 0
- 2
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -138,8 +138,6 @@ class PaddleTensorPadder(Padder):

shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
print(dtype)
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])


+ 39
- 27
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -4,11 +4,11 @@ __all__ = [
]

from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List
import inspect

from fastNLP.core.dataset import DataSet
from fastNLP.core.collators import Collator
from fastNLP.core.dataloaders.utils import indice_collate_wrapper
# from fastNLP.io.data_bundle import DataBundle
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler

@@ -79,35 +79,30 @@ class TorchDataLoader(DataLoader):
if sampler is None and batch_sampler is None:
sampler = RandomSampler(dataset, shuffle=shuffle)
shuffle=False

super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)
if isinstance(collate_fn, str):
if collate_fn == 'auto':
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset
self._collate_fn = dataset.dataset.collator
self._collate_fn.set_backend(backend="torch")
collate_fn = dataset.dataset.collator
collate_fn.set_backend(backend="torch")
else:
self._collate_fn = Collator(backend="torch")
collate_fn = Collator(backend="torch")
else:
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'")
elif isinstance(collate_fn, Callable):
if collate_fn is not default_collate:
self._collate_fn = collate_fn
else:
self._collate_fn = default_collate

self.cur_indices_batch = None
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn,
multiprocessing_context=multiprocessing_context, generator=generator,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers)

self.cur_batch_indices = None

def __iter__(self):
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。
# if len(self._collate_fn.get_collators()) == 0:
# self._collate_fn.add_collator(self.collate_fn)
self.collate_fn = indice_collate_wrapper(self._collate_fn)
self.collate_fn = indice_collate_wrapper(self.collate_fn)
for indices, data in super().__iter__():
self.cur_batch_indices = indices
yield data
@@ -132,12 +127,26 @@ class TorchDataLoader(DataLoader):
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return self._collate_fn
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend)
return collator
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.")

def _get_collator(self):
"""
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None

:return:
"""
collator = None
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator):
collator = self.collate_fn.__wrapped__
elif isinstance(self.collate_fn, Collator):
collator = self.collate_fn
return collator

def set_ignore(self, *field_names) -> Collator:
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
@@ -149,9 +158,10 @@ class TorchDataLoader(DataLoader):
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
if isinstance(self._collate_fn, Collator):
self._collate_fn.set_ignore(*field_names)
return self._collate_fn
collator = self._get_collator()
if isinstance(collator, Collator):
collator.set_ignore(*field_names)
return collator
else:
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.")

@@ -164,7 +174,8 @@ class TorchDataLoader(DataLoader):
return self.cur_batch_indices


def prepare_torch_dataloader(ds_or_db,

def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
@@ -197,7 +208,8 @@ def prepare_torch_dataloader(ds_or_db,
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler
:param non_train_batch_size:
"""
from fastNLP.io.data_bundle import DataBundle

from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataSet):
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size,
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler,
@@ -208,7 +220,7 @@ def prepare_torch_dataloader(ds_or_db,
)
return dl

elif isinstance(ds_or_db, DataBundle):
elif type(ds_or_db, DataBundle):
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:


+ 15
- 2
fastNLP/core/dataloaders/utils.py View File

@@ -10,12 +10,25 @@ def indice_collate_wrapper(func):
:param func: 需要修饰的函数
:return:
"""
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了
return func

def wrapper(tuple_data):
def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到
indice, ins_list = [], []
for idx, ins in tuple_data:
indice.append(idx)
ins_list.append(ins)
return indice, func(ins_list)
_indice_collate_wrapper.__wrapped__ = func # 记录对应的

return wrapper
return _indice_collate_wrapper


if __name__ == '__main__':
def demo(*args, **kwargs):
pass

d = indice_collate_wrapper(demo)

print(d.__name__)
print(d.__wrapped__)

+ 3
- 2
fastNLP/core/dataset/field.py View File

@@ -8,6 +8,7 @@ __all__ = [

from collections import Counter
from typing import Any, Union, List, Callable
from ..log import logger

import numpy as np

@@ -21,7 +22,7 @@ class FieldArray:
try:
_content = list(_content)
except BaseException as e:
print(f"Cannot convert content(of type:{type(content)}) into list.")
logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
raise e
self.name = name
self.content = _content
@@ -87,7 +88,7 @@ class FieldArray:
try:
new_contents.append(cell.split(sep))
except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)



+ 17
- 1
fastNLP/core/log/logger.py View File

@@ -111,7 +111,7 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):

def warning_once(self, msg, *args, **kwargs):
"""
通过 warning 内容只会 warning 一次
相同的 warning 内容只会 warning 一次

:param msg:
:param args:
@@ -124,6 +124,22 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)

def rank_zero_warning(self, msg, *args, **kwargs):
"""
只在 rank 0 上 warning 。

:param msg:
:param args:
:param kwargs:
:return:
"""
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0':
if msg not in self._warning_msgs:
if self.isEnabledFor(WARNING):
# kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)

def warn(self, msg, *args, **kwargs):
if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs)


+ 3
- 2
fastNLP/core/utils/rich_progress.py View File

@@ -156,8 +156,9 @@ class FRichProgress(Progress, metaclass=Singleton):
super().stop_task(task_id)
super().remove_task(task_id)
self.refresh() # 使得bar不残留
if len(self._tasks) == 0:
super().stop()
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。
# if len(self._tasks) == 0:
# self.live.stop()

def start(self) -> None:
super().start()


+ 4
- 3
fastNLP/core/vocabulary.py View File

@@ -15,6 +15,7 @@ from functools import wraps
from fastNLP.core.dataset import DataSet
from fastNLP.core.utils.utils import Option
from fastNLP.core.utils.utils import _is_iterable
from .log import logger
import io


@@ -56,7 +57,7 @@ def _check_build_status(func):
if self.rebuild is False:
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
logger.warning("Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)
@@ -322,7 +323,7 @@ class Vocabulary(object):
for f_n, n_f_n in zip(field_name, new_field_name):
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else:
raise RuntimeError("Only DataSet type is allowed.")
@@ -378,7 +379,7 @@ class Vocabulary(object):
try:
dataset.apply(construct_vocab)
except BaseException as e:
print("When processing the `{}` dataset, the following error occurred:".format(idx))
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e
else:
raise TypeError("Only DataSet type is allowed.")


+ 3
- 3
fastNLP/io/data_bundle.py View File

@@ -10,7 +10,7 @@ from typing import Union, List, Callable

from ..core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
# from ..core._logger import _logger
from fastNLP.core import logger


class DataBundle:
@@ -72,7 +72,7 @@ class DataBundle:
else:
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \
f'It should be one of {self.datasets.keys()}.'
print(error_msg)
logger.error(error_msg)
raise KeyError(error_msg)

def delete_dataset(self, name: str):
@@ -97,7 +97,7 @@ class DataBundle:
else:
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \
f'It should be one of {self.vocabs.keys()}.'
print(error_msg)
logger.error(error_msg)
raise KeyError(error_msg)

def delete_vocab(self, field_name: str):


+ 2
- 2
fastNLP/io/file_reader.py View File

@@ -117,7 +117,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
yield line_idx, res
except Exception as e:
if dropna:
print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
logger.error('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
sample = []
continue
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
@@ -132,5 +132,5 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True):
except Exception as e:
if dropna:
return
print('invalid instance ends at line: {}'.format(line_idx))
logger.error('invalid instance ends at line: {}'.format(line_idx))
raise e

+ 3
- 1
fastNLP/io/loader/classification.py View File

@@ -29,6 +29,7 @@ import warnings

from .loader import Loader
from fastNLP.core.dataset import Instance, DataSet
from ...core import logger


# from ...core._logger import log
@@ -86,7 +87,8 @@ class CLSBaseLoader(Loader):
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
except Exception as e:
print(f'Load file `{path}` failed for `{e}`')
logger.error(f'Fail to load `{path}`.')
raise e
return ds




Loading…
Cancel
Save