@@ -34,7 +34,7 @@ class EvaluateBatchLoop(Loop): | |||||
except BaseException as e: | except BaseException as e: | ||||
if callable(getattr(dataloader, 'get_batch_indices', None)): | if callable(getattr(dataloader, 'get_batch_indices', None)): | ||||
indices = dataloader.get_batch_indices() | indices = dataloader.get_batch_indices() | ||||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||||
logger.error(f"Exception happens when evaluating on samples: {indices}") | |||||
raise e | raise e | ||||
self.batch_step_fn(evaluator, batch) | self.batch_step_fn(evaluator, batch) | ||||
@@ -32,7 +32,7 @@ class TrainBatchLoop(Loop): | |||||
break | break | ||||
except BaseException as e: | except BaseException as e: | ||||
if indices and not isinstance(e, EarlyStopException): | if indices and not isinstance(e, EarlyStopException): | ||||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||||
logger.error(f"Exception happens when running on samples: {indices}") | |||||
raise e | raise e | ||||
trainer.on_train_batch_begin(batch, indices) | trainer.on_train_batch_begin(batch, indices) | ||||
@@ -514,7 +514,7 @@ class Trainer(TrainerEventTrigger): | |||||
else: | else: | ||||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | ||||
if self.evaluator is not None and num_eval_sanity_batch > 0: | |||||
if self.evaluator is not None and num_eval_sanity_batch != 0: | |||||
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") | logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") | ||||
self.on_sanity_check_begin() | self.on_sanity_check_begin() | ||||
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | ||||
@@ -1,7 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'print' | 'print' | ||||
] | ] | ||||
from logging import INFO | |||||
from .logger import logger | from .logger import logger | ||||
@@ -22,4 +22,6 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): | |||||
:return: | :return: | ||||
""" | """ | ||||
line = sep.join(map(str, args)) | line = sep.join(map(str, args)) | ||||
logger.info(line) | |||||
if logger.isEnabledFor(INFO): | |||||
kwargs = logger._add_rank_info({}) | |||||
logger._log(INFO, line, args, **kwargs) |
@@ -84,7 +84,7 @@ class Metric: | |||||
def _sync_get_metric(self, get_metric): | def _sync_get_metric(self, get_metric): | ||||
@functools.wraps(get_metric) | @functools.wraps(get_metric) | ||||
def _wrap_get_metric(*args, **kwargs): | def _wrap_get_metric(*args, **kwargs): | ||||
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \ | |||||
assert self._updated, f"You have to call `{self.__class__.__name__}'s update() function before calling " \ | |||||
f"get_metric()." | f"get_metric()." | ||||
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): | with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): | ||||
results = get_metric(*args, **kwargs) | results = get_metric(*args, **kwargs) | ||||
@@ -366,17 +366,22 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
首先按照 sample 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,sample 只会在这个桶内进行组合,这样 | |||||
每个 batch 中的 padding 数量会比较少 (因为桶内的数据的长度都接近)。 | |||||
首先按照 ``sample`` 的长度排序,然后按照 batch_size*num_batch_per_bucket 为一个桶的大小,``sample`` 只会在这个桶内进行组 | |||||
合,这样每个 ``batch`` 中的 ``padding`` 数量会比较少 (因为桶内的数据的长度都接近)。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||||
如果否则使用 len() 函数得到每个 sample 中这个 field 的长度。 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param batch_size: 每个 batch 的大小 | :param batch_size: 每个 batch 的大小 | ||||
:param num_batch_per_bucket: 多少个 batch 组成一个桶,数据只会在一个桶内进行 shuffle 。 | |||||
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 batch 的 sample 数量无法凑齐 batch_size 这么多,是否需要丢掉。 | |||||
:param num_batch_per_bucket: 多少个 ``batch`` 组成一个桶,数据只会在一个桶内进行 ``shuffle`` 。 | |||||
:param shuffle: 如果为 True,将不进行 ``shuffle``,实际上数据会以从长到短的方式输出。 | |||||
:param drop_last: 如果最后一个 `batch` 的 ``sample`` 数量无法凑齐 ``batch_size`` 这么多,是否需要丢掉。 | |||||
:param seed: 设置的随机数种子 | :param seed: 设置的随机数种子 | ||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
@@ -386,10 +391,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
types = set(map(type, length)) | |||||
assert isinstance(length, list) and len(types)==1 and types.pop()==int, \ | |||||
"When the dataset is not fastNLP.DataSet, the length parameter can only be List[int]" | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||||
assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ | |||||
f"`length`({len(length)}) should be equal." | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
@@ -55,6 +55,7 @@ class ReproducibleSampler: | |||||
class RandomSampler(ReproducibleSampler): | class RandomSampler(ReproducibleSampler): | ||||
def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
随机顺序的 Sampler 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器 | :param dataset: 实现了 __len__ 方法的数据容器 | ||||
:param shuffle: 是否在每次 iterate 的时候打乱顺序。 | :param shuffle: 是否在每次 iterate 的时候打乱顺序。 | ||||
@@ -169,9 +170,8 @@ class RandomSampler(ReproducibleSampler): | |||||
def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
self.epoch = epoch | self.epoch = epoch | ||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
def set_distributed(self, num_replicas:int, rank:int, pad:bool=True): | |||||
""" | """ | ||||
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; | |||||
:param num_replicas: | :param num_replicas: | ||||
:param rank: | :param rank: | ||||
@@ -215,7 +215,7 @@ class RandomSampler(ReproducibleSampler): | |||||
class SequentialSampler(RandomSampler): | class SequentialSampler(RandomSampler): | ||||
def __init__(self, dataset, **kwargs): | def __init__(self, dataset, **kwargs): | ||||
""" | """ | ||||
按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 | |||||
按照顺序读取 ``dataset`` 。在多卡情况下,间隔读取,例如,在两卡情况下,卡 0 取 ``[0,2,4,..]``, 卡1取 ``[1,3,5...]`` 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param kwargs: | :param kwargs: | ||||
@@ -285,13 +285,20 @@ class SequentialSampler(RandomSampler): | |||||
class SortedSampler(SequentialSampler): | class SortedSampler(SequentialSampler): | ||||
def __init__(self, dataset, length:Union[str, List], **kwargs): | def __init__(self, dataset, length:Union[str, List], **kwargs): | ||||
""" | """ | ||||
将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 | |||||
将 ``dataset`` 中的数据根据 ``length`` 从长到短进行迭代。在多卡情况下,由于 ``padding`` , 最后一个 ``sample`` 可能是最长 | |||||
的那个 ``sample`` 。 | |||||
:param dataset: 实现了 __len__ 方法的数据容器。 | :param dataset: 实现了 __len__ 方法的数据容器。 | ||||
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 | |||||
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 | |||||
:param seed: 设置的随机数种子 | |||||
:param kwargs: fastNLP 保留使用 | |||||
:param length: 每条数据的长度。 | |||||
* 为 ``List[int]`` 时 | |||||
应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量; | |||||
* 为 ``str`` 时 | |||||
仅当传入的 ``dataset`` 是 :class:`fastNLP.DataSet` 时,允许传入 `str` ,该 `str` 将被认为是 ``dataset`` 中的 | |||||
``field`` 。若 field 中的元素为 ``int``,则认为该值是 sample 的长度;若不为 ``int`` ,则尝试使用 ``len`` 方法 | |||||
获取该 ``field`` 中每个元素的长度。 | |||||
:param seed: 设置的随机数种子。 | |||||
:param kwargs: fastNLP 保留使用。 | |||||
""" | """ | ||||
super().__init__(dataset=dataset, **kwargs) | super().__init__(dataset=dataset, **kwargs) | ||||
if isinstance(dataset, DataSet) and isinstance(length, str): | if isinstance(dataset, DataSet) and isinstance(length, str): | ||||
@@ -299,8 +306,9 @@ class SortedSampler(SequentialSampler): | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
types = set(map(type, length)) | |||||
assert isinstance(length, list) and len(types)==1 and types.pop()==int, \ | |||||
"When the dataset is not fastNLP.DataSet, the length parameter can only be List[int]" | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | ||||