@@ -81,7 +81,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
**kwargs): | **kwargs): | ||||
super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | super(MoreEvaluateCallback, self).__init__(watch_monitor, watch_monitor_larger_better, | ||||
must_have_monitor=False) | must_have_monitor=False) | ||||
if watch_monitor is not None and evaluate_every == -1: # 将evaluate_every 弄掉。 | |||||
evaluate_every = None | |||||
if watch_monitor is None and evaluate_every is None: | if watch_monitor is None and evaluate_every is None: | ||||
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | ||||
if watch_monitor is not None and evaluate_every is not None: | if watch_monitor is not None and evaluate_every is not None: | ||||
@@ -176,8 +176,8 @@ class Collator: | |||||
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | ||||
try: | try: | ||||
for key, padder in self.padders.items(): | for key, padder in self.padders.items(): | ||||
batch = unpack_batch.get(key) | |||||
pad_batch[key] = padder(batch) | |||||
batch = unpack_batch.get(key) | |||||
pad_batch[key] = padder(batch) | |||||
except BaseException as e: | except BaseException as e: | ||||
try: | try: | ||||
logger.error(f"The following exception happens when try to pad the `{key}` field with padder:{padder}:") | logger.error(f"The following exception happens when try to pad the `{key}` field with padder:{padder}:") | ||||
@@ -3,7 +3,8 @@ __all__ = [ | |||||
'EleDtypeUnsupportedError', | 'EleDtypeUnsupportedError', | ||||
'EleDtypeDtypeConversionError', | 'EleDtypeDtypeConversionError', | ||||
'DtypeUnsupportedError', | 'DtypeUnsupportedError', | ||||
"DtypeError" | |||||
"DtypeError", | |||||
"NoProperPadderError" | |||||
] | ] | ||||
@@ -22,6 +23,12 @@ class DtypeError(BaseException): | |||||
self.msg = msg | self.msg = msg | ||||
class NoProperPadderError(BaseException): | |||||
def __init__(self, msg, *args): | |||||
super(NoProperPadderError, self).__init__(msg, *args) | |||||
self.msg = msg | |||||
class EleDtypeUnsupportedError(DtypeError): | class EleDtypeUnsupportedError(DtypeError): | ||||
""" | """ | ||||
当 batch 中的 element 的类别本身无法 pad 的时候报错。 | 当 batch 中的 element 的类别本身无法 pad 的时候报错。 | ||||
@@ -49,8 +49,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
f"information please set logger's level to DEBUG." | f"information please set logger's level to DEBUG." | ||||
if must_pad: | if must_pad: | ||||
raise InconsistencyError(msg) | raise InconsistencyError(msg) | ||||
logger.debug(msg) | |||||
return NullPadder() | |||||
raise NoProperPadderError(msg) | |||||
# 再检查所有的元素 shape 是否一致? | # 再检查所有的元素 shape 是否一致? | ||||
shape_lens = set([len(v[0]) for v in catalog.values()]) | shape_lens = set([len(v[0]) for v in catalog.values()]) | ||||
@@ -60,8 +59,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
f"information please set logger's level to DEBUG." | f"information please set logger's level to DEBUG." | ||||
if must_pad: | if must_pad: | ||||
raise InconsistencyError(msg) | raise InconsistencyError(msg) | ||||
logger.debug(msg) | |||||
return NullPadder() | |||||
raise NoProperPadderError(msg) | |||||
# 再检查所有的元素 type 是否一致 | # 再检查所有的元素 type 是否一致 | ||||
try: | try: | ||||
@@ -74,8 +72,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
f"information please set logger's level to DEBUG." | f"information please set logger's level to DEBUG." | ||||
if must_pad: | if must_pad: | ||||
raise InconsistencyError(msg) | raise InconsistencyError(msg) | ||||
logger.debug(msg) | |||||
return NullPadder() | |||||
raise NoProperPadderError(msg) | |||||
depth = depths.pop() | depth = depths.pop() | ||||
shape_len = shape_lens.pop() | shape_len = shape_lens.pop() | ||||
@@ -131,8 +128,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | msg = "Does not support pad tensor under nested list. If you need this, please report." | ||||
if must_pad: | if must_pad: | ||||
raise RuntimeError(msg) | raise RuntimeError(msg) | ||||
logger.debug(msg) | |||||
return NullPadder() | |||||
raise NoProperPadderError(msg) | |||||
except DtypeError as e: | except DtypeError as e: | ||||
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | ||||
@@ -140,8 +136,9 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
if must_pad: | if must_pad: | ||||
logger.error(msg) | logger.error(msg) | ||||
raise type(e)(msg=msg) | raise type(e)(msg=msg) | ||||
logger.debug(msg) | |||||
return NullPadder() | |||||
except NoProperPadderError as e: | |||||
logger.debug(f"{e.msg}") | |||||
except BaseException as e: | except BaseException as e: | ||||
raise e | raise e | ||||
@@ -188,7 +188,7 @@ def fill_tensor(batch_field, padded_batch, dtype): | |||||
padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype)) | padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype)) | ||||
elif padded_batch.ndim == 4: | elif padded_batch.ndim == 4: | ||||
try: # 应该是图像,所以直接应该就 ok 了。 | try: # 应该是图像,所以直接应该就 ok 了。 | ||||
padded_batch = np.array(batch_field) | |||||
padded_batch = jittor.Var(batch_field) | |||||
except: | except: | ||||
for i, content_i in enumerate(batch_field): | for i, content_i in enumerate(batch_field): | ||||
for j, content_ii in enumerate(content_i): | for j, content_ii in enumerate(content_i): | ||||
@@ -175,7 +175,7 @@ def fill_tensor(batch_field, padded_batch, dtype): | |||||
padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) | padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) | ||||
elif padded_batch.ndim == 4: | elif padded_batch.ndim == 4: | ||||
try: # 应该是图像,所以直接应该就 ok 了。 | try: # 应该是图像,所以直接应该就 ok 了。 | ||||
padded_batch = np.array(batch_field) | |||||
padded_batch = torch.tensor(batch_field) | |||||
except: | except: | ||||
for i, content_i in enumerate(batch_field): | for i, content_i in enumerate(batch_field): | ||||
for j, content_ii in enumerate(content_i): | for j, content_ii in enumerate(content_i): | ||||
@@ -203,7 +203,7 @@ def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = Fa | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto", | collate_fn: Union[None, str, Callable] = "auto", | ||||
non_train_batch_size: int = 16) \ | |||||
non_train_batch_size: int = None) \ | |||||
-> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: | -> Union[Sequence[JittorDataLoader], Dict[str, JittorDataLoader], JittorDataLoader]: | ||||
""" | """ | ||||
``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``JittorDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.JittorDataLoader`。 | ``prepare_jittor_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``JittorDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.JittorDataLoader`。 | ||||
@@ -254,7 +254,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False, | worker_init_fn: Callable = None, persistent_workers=False, | ||||
non_train_batch_size: int = 16) \ | |||||
non_train_batch_size: int = None) \ | |||||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | -> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | ||||
""" | """ | ||||
``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.PaddleDataLoader`。 | ``prepare_paddle_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``PaddleDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.PaddleDataLoader`。 | ||||
@@ -6,7 +6,6 @@ from typing import Union, Callable | |||||
import os | import os | ||||
import sys | import sys | ||||
from ..samplers import RandomBatchSampler, RandomSampler | |||||
from .torch_dataloader import prepare_torch_dataloader | from .torch_dataloader import prepare_torch_dataloader | ||||
from .paddle_dataloader import prepare_paddle_dataloader | from .paddle_dataloader import prepare_paddle_dataloader | ||||
from .jittor_dataloader import prepare_jittor_dataloader | from .jittor_dataloader import prepare_jittor_dataloader | ||||
@@ -16,7 +15,7 @@ from ..log import logger | |||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | ||||
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | ||||
seed: int = 0, backend: str = 'auto'): | |||||
backend: str = 'auto'): | |||||
""" | """ | ||||
自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 | 自动创建合适的 ``DataLoader`` 对象。例如,检测当当前环境是 ``torch`` 的,则返回 ``TorchDataLoader`` , 是 ``paddle`` 的则 | ||||
返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 | 返回 ``PaddleDataLoader`` 。如果有更多需要定制的参数,请直接使用对应的 ``prepare`` 函数,例如 | ||||
@@ -43,7 +42,6 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro | |||||
* 为 ``None`` 时 | * 为 ``None`` 时 | ||||
使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | 使用各个框架的 DataLoader 的默认 ``collate_fn`` 。 | ||||
:param num_workers: 使用多少进程进行数据的 fetch 。 | :param num_workers: 使用多少进程进行数据的 fetch 。 | ||||
:param seed: 使用的随机数种子。 | |||||
:param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 | :param backend: 当前支持 ``["auto", "torch", "paddle", "jittor"]`` 四种类型。 | ||||
* 为 ``auto`` 时 | * 为 ``auto`` 时 | ||||
@@ -61,18 +59,14 @@ def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, dro | |||||
if backend == 'auto': | if backend == 'auto': | ||||
backend = _get_backend() | backend = _get_backend() | ||||
if backend == 'torch': | if backend == 'torch': | ||||
batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, seed=seed) | |||||
return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, | |||||
num_workers=num_workers, shuffle=False, sampler=None) | |||||
return prepare_torch_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, | |||||
num_workers=num_workers, shuffle=shuffle, sampler=None, | |||||
batch_size=batch_size) | |||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
batch_sampler = RandomBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle, | |||||
drop_last=drop_last, seed=seed) | |||||
return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, | |||||
num_workers=num_workers) | |||||
return prepare_paddle_dataloader(ds_or_db=dataset, batch_sampler=None, collate_fn=collate_fn, | |||||
num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) | |||||
elif backend == 'jittor': | elif backend == 'jittor': | ||||
sampler = RandomSampler(dataset=dataset, shuffle=shuffle, seed=seed) | |||||
prepare_jittor_dataloader(ds_or_db=dataset, sampler=sampler, collate_fn=collate_fn, | |||||
prepare_jittor_dataloader(ds_or_db=dataset, sampler=None, collate_fn=collate_fn, | |||||
num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, | ||||
drop_last=drop_last) | drop_last=drop_last) | ||||
else: | else: | ||||
@@ -222,7 +222,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, | persistent_workers: bool = False, | ||||
non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | non_train_sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
non_train_batch_size: int = 16) \ | |||||
non_train_batch_size: int = None) \ | |||||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | ||||
""" | """ | ||||
``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.TorchDataLoader`。 | ``prepare_torch_dataloader`` 的功能是将输入的单个或多个 dataset 同时转为 ``TorchDataloader``对象, 详见 :class: `~fastNLP.core.dataloaders.TorchDataLoader`。 | ||||
@@ -254,13 +254,13 @@ def prepare_torch_dataloader(ds_or_db, | |||||
:param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | :param non_train_batch_size: 非 'train' 数据集的 ``TorchDataLoader`` 批次大小,默认为 ``16`` 且当 batch_sampler 为 None 有效。 | ||||
:param shuffle: 是否打乱数据集, 默认为 ``False``。 | :param shuffle: 是否打乱数据集, 默认为 ``False``。 | ||||
:param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | :param sampler: 实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | ||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | :param non_train_sampler: 非 'train' 数据集的的实现了 __len__() 和 __iter__() 的实例化对象,其 __iter__() 方法每次都会返回 dataset 的一个下标 index , | ||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
默认为None, 当其不为 None 时, shuffle 参数无效。 | |||||
:param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | :param batch_sampler: 实现了 __len__() 和 __iter__() 的实例化对象,,其__iter__() 方法每次都会返回一个 List 对象, List中的值为 | ||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
dataset 的下标 index ;默认为 None,当其不为 None 时,bacth_size, sampler, shuffle 参数均失效。 | |||||
:param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | :param num_workers: 当 ``num_workers > 0`` 时, ``TorchDataLoader`` 会开启 num_workers 个子进程来处理数据, 可以加快 | ||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
数据处理速度,但同时也消耗大量内存。 当 ``num_workers=0`` 时, 不开启子进程。 默认为 ``0``。 | |||||
:param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | :param collate_fn: 用于从 dataset 取到的一个 batch 数据进行打包处理的 Callable 函数,其值应该为以下三个: ``[None, "auto", Callable]``. | ||||
* callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | * callate_fn 为 'None' 时,需要注意的是此时传进来的 datset 类型不能为 :class:`~fastNLP.core.dataset.DataSet` , 当 collate_fn 为 ``None`` 时, | ||||
@@ -273,7 +273,7 @@ def prepare_torch_dataloader(ds_or_db, | |||||
:param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | :param pin_memory: 如果其为 ``True``, 那么 ``TorchDataLoader`` 会在返回数据张量之前将其 copy 到 cud a的 pin memory 中。 | ||||
:param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | :param drop_last: 当 ``drop_last=True`` 时,``TorchDataLoader`` 会扔掉最后一个长度小于 ``batch_size`` 的 batch 数据; | ||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
若 ``drop_last=False`` , 则会返回该 batch 数据。 默认为 ``False`` 。 | |||||
:param timeout: 子进程的输出队列获取数据的超时值 | :param timeout: 子进程的输出队列获取数据的超时值 | ||||
:param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | :param worker_init_fn: init 函数,如果不设置为 None ,则将会在每个子进程初始化时调用该函数。 | ||||
:param multiprocessing_context: 多进程的上下文环境 | :param multiprocessing_context: 多进程的上下文环境 | ||||
@@ -46,7 +46,6 @@ class DummyFRichProgress: | |||||
class FRichProgress(Progress, metaclass=Singleton): | class FRichProgress(Progress, metaclass=Singleton): | ||||
def new_progess(self, *columns: Union[str, ProgressColumn], | def new_progess(self, *columns: Union[str, ProgressColumn], | ||||
console: Optional[Console] = None, | |||||
# 这里将 auto_refresh 关掉是想要避免单独开启线程,同时也是为了避免pdb的时候会持续刷新 | # 这里将 auto_refresh 关掉是想要避免单独开启线程,同时也是为了避免pdb的时候会持续刷新 | ||||
auto_refresh: bool = False, | auto_refresh: bool = False, | ||||
refresh_per_second: float = 10, | refresh_per_second: float = 10, | ||||
@@ -81,7 +80,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
self.expand = expand | self.expand = expand | ||||
self.live = Live( | self.live = Live( | ||||
console=console or get_console(), | |||||
console=get_console(), | |||||
auto_refresh=auto_refresh, | auto_refresh=auto_refresh, | ||||
refresh_per_second=refresh_per_second, | refresh_per_second=refresh_per_second, | ||||
transient=transient, | transient=transient, | ||||
@@ -92,6 +91,12 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
self.get_time = get_time or self.console.get_time | self.get_time = get_time or self.console.get_time | ||||
self.print = self.console.print | self.print = self.console.print | ||||
self.log = self.console.log | self.log = self.console.log | ||||
self.auto_refresh = auto_refresh | |||||
self.transient = transient | |||||
self.redirect_stdout = redirect_stdout | |||||
self.redirect_stderr = redirect_stderr | |||||
self.refresh_per_second = refresh_per_second | |||||
self._need_renew_live = False | |||||
return self | return self | ||||
@@ -125,7 +130,19 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
from .tqdm_progress import f_tqdm_progress | from .tqdm_progress import f_tqdm_progress | ||||
assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." | assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." | ||||
if self.live._started is False: | |||||
# 如果需要替换,应该是由于destroy的时候给换掉了 | |||||
if self._need_renew_live: | |||||
self.live = Live( | |||||
console=get_console(), | |||||
auto_refresh=self.auto_refresh, | |||||
refresh_per_second=self.refresh_per_second, | |||||
transient=self.transient, | |||||
redirect_stdout=self.redirect_stdout, | |||||
redirect_stderr=self.redirect_stderr, | |||||
get_renderable=self.get_renderable, | |||||
) | |||||
self._need_renew_live = False | |||||
if not self.live.is_started: | |||||
self.start() | self.start() | ||||
post_desc = fields.pop('post_desc', '') | post_desc = fields.pop('post_desc', '') | ||||
return super().add_task(description=description, | return super().add_task(description=description, | ||||
@@ -155,6 +172,8 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
setattr(self.live.console, 'line', lambda *args,**kwargs:...) | setattr(self.live.console, 'line', lambda *args,**kwargs:...) | ||||
self.live.stop() | self.live.stop() | ||||
setattr(self.live.console, 'line', old_line) | setattr(self.live.console, 'line', old_line) | ||||
# 在 jupyter 的情况下需要替换一下,不然会出不打印的问题。 | |||||
self._need_renew_live = True if is_notebook() else False | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||
@@ -65,7 +65,7 @@ class SequenceGeneratorModel(nn.Module): | |||||
if tgt_seq_len is not None: | if tgt_seq_len is not None: | ||||
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | ||||
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | ||||
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||||
loss = F.cross_entropy(pred[:, :-1].transpose(1, 2), tgt_tokens[:, 1:]) | |||||
return {'loss': loss} | return {'loss': loss} | ||||
def evaluate_step(self, src_tokens, src_seq_len=None): | def evaluate_step(self, src_tokens, src_seq_len=None): | ||||
@@ -59,7 +59,7 @@ class Seq2SeqModel(nn.Module): | |||||
if tgt_seq_len is not None: | if tgt_seq_len is not None: | ||||
mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | mask = seq_len_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1)) | ||||
tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | tgt_tokens = tgt_tokens.masked_fill(mask.eq(0), -100) | ||||
loss = F.cross_entropy(pred.transpose(1, 2), tgt_tokens) | |||||
loss = F.cross_entropy(pred[:, :-1].transpose(1, 2), tgt_tokens[:, 1:]) | |||||
return {'loss': loss} | return {'loss': loss} | ||||
def prepare_state(self, src_tokens, src_seq_len=None): | def prepare_state(self, src_tokens, src_seq_len=None): | ||||
@@ -368,13 +368,13 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||||
next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) | next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) | ||||
_tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) | _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) | ||||
next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) | next_tokens = _tokens.gather(dim=1, index=ids) # (batch_size, 2*num_beams) | ||||
from_which_beam = ids // (num_beams + 1) # (batch_size, 2*num_beams) | |||||
from_which_beam = torch.floor(ids / (num_beams + 1)).long() # (batch_size, 2*num_beams) | |||||
else: | else: | ||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) | scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) | ||||
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) | _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) | ||||
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) | _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) | ||||
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) | next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) | ||||
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) | |||||
from_which_beam = torch.floor(ids / vocab_size).long() # (batch_size, 2*num_beams) | |||||
next_tokens = ids % vocab_size # (batch_size, 2*num_beams) | next_tokens = ids % vocab_size # (batch_size, 2*num_beams) | ||||
# 接下来需要组装下一个batch的结果。 | # 接下来需要组装下一个batch的结果。 | ||||
@@ -318,6 +318,13 @@ class TestCollator: | |||||
with pytest.raises(KeyError): | with pytest.raises(KeyError): | ||||
collator.set_pad((1, 2)) | collator.set_pad((1, 2)) | ||||
@pytest.mark.torch | |||||
def test_torch_4d(self): | |||||
collator = Collator(backend='torch') | |||||
data = [{'x': [[[0,1], [2,3]]]}, {'x': [[[0,1]]]}] | |||||
output = collator(data) | |||||
assert output['x'].size() == (2, 1, 2, 2) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
def test_torch_dl(): | def test_torch_dl(): | ||||
@@ -2,6 +2,7 @@ import pytest | |||||
from fastNLP import prepare_dataloader | from fastNLP import prepare_dataloader | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP.io import DataBundle | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@@ -10,4 +11,18 @@ def test_torch(): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
dl = prepare_dataloader(ds, batch_size=2, shuffle=True) | dl = prepare_dataloader(ds, batch_size=2, shuffle=True) | ||||
for batch in dl: | for batch in dl: | ||||
assert isinstance(batch['x'], torch.Tensor) | |||||
assert isinstance(batch['x'], torch.Tensor) | |||||
@pytest.mark.torch | |||||
def test_torch_data_bundle(): | |||||
import torch | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||||
dl = DataBundle() | |||||
dl.set_dataset(dataset=ds, name='train') | |||||
dl.set_dataset(dataset=ds, name='test') | |||||
dls = prepare_dataloader(dl, batch_size=2, shuffle=True) | |||||
for dl in dls.values(): | |||||
for batch in dl: | |||||
assert isinstance(batch['x'], torch.Tensor) | |||||
assert batch['x'].size(0) == 2 |