@@ -21,30 +21,24 @@ from fastNLP.core.dataset import DataSet as FDataSet | |||||
class _JittorDataset(Dataset): | class _JittorDataset(Dataset): | ||||
""" | """ | ||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset使用jittor的dataset | |||||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset | |||||
""" | """ | ||||
def __init__(self, dataset) -> None: | def __init__(self, dataset) -> None: | ||||
super(_JittorDataset, self).__init__() | super(_JittorDataset, self).__init__() | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.total_len = len(dataset) | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return (item, self.dataset[item]) | return (item, self.dataset[item]) | ||||
def __len__(self) -> int: | |||||
return len(self.dataset) | |||||
# def __getattr__(self, item): | |||||
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 | |||||
# try: | |||||
# self.dataset.__getattribute__(item) | |||||
# except Exception as e: | |||||
# raise e | |||||
class JittorDataLoader: | class JittorDataLoader: | ||||
""" | """ | ||||
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | |||||
提供给使用jittor框架的DataLoader函数,其能够自动检测数据的类型并判断是否能够pad,若能会自动pad数据,默认pad_val=0; | |||||
用户可以调用set_pad方法来更改pad_val的值,也可以自定义针对某个field的callate_fn传入到set_field;若用户不想自动pad某个field, | |||||
则可以调用set_ignore来忽略对某个field的检测和pad。值得注意的是JittorDataLoader输入dataset只要是实现了__getitem__和__len__方法即可。 | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, | def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, | ||||
@@ -53,23 +47,36 @@ class JittorDataLoader: | |||||
collate_fn: Union[None, str, Callable] = "auto") -> None: | collate_fn: Union[None, str, Callable] = "auto") -> None: | ||||
""" | """ | ||||
:param dataset: 实现__getitem__和__len__的dataset | |||||
:param dataset: 实现``__getitem__``和``__len__``的dataset | |||||
:param batch_size: 批次大小 | :param batch_size: 批次大小 | ||||
:param shuffle: 是否打乱数据集 | :param shuffle: 是否打乱数据集 | ||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||||
:param buffer_size: | |||||
:param drop_last: 是否去掉最后一个不符合``batch_size``的数据 | |||||
:param num_workers: 进程的数量,当``num_workers=0``时不开启多进程 | |||||
:param buffer_size: 每个进程占用的内存空间,默认为512M。主要是配合num_workers使用,用户可以自定义每个进程的内存大小。 | |||||
:param stop_grad: | :param stop_grad: | ||||
:param keep_numpy_array: | |||||
:param endless: | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||||
:param keep_numpy_array: 返回的数据是``np.array`类`型而不是``jittor.array``类型,默认为``False`` | |||||
:param endless: 是否让``JittorDataLoader``无限返回数据,也就是将dataset循环使用使得返回数据是没有限制的。默认为``False``. | |||||
:param collate_fn: 用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. | |||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; | |||||
第二点注意的是此时``JittorDataLoader``会调用默认的`callate_batch`函数对sampler到的数据进行简单打包,组成一个batch返回。` | |||||
* ``callate_fn="auto"``时,``JittorDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, | |||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, | |||||
可以调用``set_ignore``方法忽略某个field。 | |||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``JittorDataLoader``会调用传进来的callable函数对 | |||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 | |||||
""" | """ | ||||
# TODO 验证支持replacesampler (以后完成) | |||||
# TODO 验证支持replacesampler (以后完成) 增加Sampler | |||||
# 将内部dataset批次设置为1 | |||||
if isinstance(dataset, Dataset): | |||||
dataset.set_attrs(batch_size=1) | |||||
# FastNLP Datset, collate_fn not None | # FastNLP Datset, collate_fn not None | ||||
if isinstance(dataset, FDataSet) and collate_fn is None: | if isinstance(dataset, FDataSet) and collate_fn is None: | ||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | ||||
# 将所有dataset转为jittor类型的dataset | |||||
if not isinstance(dataset, _JittorDataset): | if not isinstance(dataset, _JittorDataset): | ||||
self.dataset = _JittorDataset(dataset) | self.dataset = _JittorDataset(dataset) | ||||
@@ -85,17 +92,13 @@ class JittorDataLoader: | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
elif isinstance(collate_fn, Callable): | elif isinstance(collate_fn, Callable): | ||||
if collate_fn is not collate_batch: | |||||
self.collate_fn = collate_fn | |||||
self.collate_fn = collate_fn | |||||
else: | else: | ||||
self.collate_fn = collate_batch | self.collate_fn = collate_batch | ||||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | ||||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | ||||
keep_numpy_array=keep_numpy_array, endless=endless) | keep_numpy_array=keep_numpy_array, endless=endless) | ||||
# 将内部dataset批次设置为1 | |||||
if isinstance(self.dataset.dataset, Dataset): | |||||
self.dataset.dataset.set_attrs(batch_size=1) | |||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
@@ -108,12 +111,10 @@ class JittorDataLoader: | |||||
yield data | yield data | ||||
def __len__(self): | def __len__(self): | ||||
if self.dataset.drop_last: | |||||
return len(self.dataset) // self.dataset.batch_size | |||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | |||||
return len(self.dataset) | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | ||||
pad_fn: Callable = None) -> "JittorDataLoader": | |||||
pad_fn: Callable = None) -> Collator: | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -132,14 +133,27 @@ class JittorDataLoader: | |||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self.collate_fn, Collator): | |||||
self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||||
backend=backend) | |||||
return self | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return collator | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | ||||
def set_ignore(self, *field_names) -> "JittorDataLoader": | |||||
def _get_collator(self): | |||||
""" | |||||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||||
:return: | |||||
""" | |||||
collator = None | |||||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||||
collator = self.collate_fn.__wrapped__ | |||||
elif isinstance(self.collate_fn, Collator): | |||||
collator = self.collate_fn | |||||
return collator | |||||
def set_ignore(self, *field_names) -> Collator: | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Example:: | Example:: | ||||
@@ -151,9 +165,10 @@ class JittorDataLoader: | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self.collate_fn, Collator): | |||||
self.collate_fn.set_ignore(*field_names) | |||||
return self | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_ignore(*field_names) | |||||
return collator | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | ||||
@@ -23,7 +23,6 @@ class _MixDataset: | |||||
""" | """ | ||||
def __init__(self, datasets: list = None) -> None: | def __init__(self, datasets: list = None) -> None: | ||||
""" | """ | ||||
:param datasets: 数据集的列表 | :param datasets: 数据集的列表 | ||||
""" | """ | ||||
self.datasets = datasets | self.datasets = datasets | ||||
@@ -36,8 +35,9 @@ class _MixDataset: | |||||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | ||||
""" | """ | ||||
根据index索引获取数据 | |||||
:param idx: | |||||
:param idx: 整数类型的index或者列表 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(idx, int): | if isinstance(idx, int): | ||||
@@ -1,3 +1,31 @@ | |||||
""" | |||||
``PaddleDataLoader``是专门提供给``paddle``框架的``DataLoader``,其集成了``fastNLP``的``Collator``并对``paddle``的``DataLoader``进行了 | |||||
封装,使得其具备以下功能:1.``PaddleDataLoader``支持输入的dataset是无框架的,只要实现了``__getitem__``和``__len__``方法即可,当不使用``fastNLP``的 | |||||
``DataSet``时候也能够自动检测数据的类型并进行padding,只需要将``collate_fn="auto"``即可,例如:: | |||||
from fastNLP import PaddleDataLoader | |||||
class MyDataset: | |||||
def __init(self, data_lens=100): | |||||
self.data_lens = 100 | |||||
def __getitem__(self, item): | |||||
if item % 2 == 0: | |||||
return {'x':[101, 256, 453], 'y': 0} | |||||
else: | |||||
return {'x': [101, 200], 'y': 1} | |||||
def __len__(self): | |||||
return self.data_lens | |||||
dataset = MyDataset() | |||||
paddle_dl = PaddleDataLoader(dataset, collate_fn="auto") | |||||
for batch in paddle_dl: | |||||
... | |||||
2.当设置``collate_fn="auto"``时,``PaddleDataLoader``会调用fastNLP的Collator对数据进行自动pad处理,此时可以调用``set_pad``和``set_ignore``方法 | |||||
来设置field的pad_val或者忽略某个field的pad操作。 | |||||
.. note:: | |||||
当传入的dataset为fastNLP的DataSet时,collate_fn不能为None。默认可以是"auto"或者自定义callable函数。 | |||||
""" | |||||
__all__ = [ | __all__ = [ | ||||
'PaddleDataLoader', | 'PaddleDataLoader', | ||||
'prepare_paddle_dataloader' | 'prepare_paddle_dataloader' | ||||
@@ -23,7 +51,7 @@ from fastNLP.core.samplers import ReproducibleBatchSampler, RandomBatchSampler | |||||
class _PaddleDataset(Dataset): | class _PaddleDataset(Dataset): | ||||
""" | """ | ||||
对用户传的dataset进行封装,以便Fdataloader能够支持使用自定义的dataset使用paddle的dataloader | |||||
对用户传的dataset进行封装,以便PaddleDataLoader能够支持使用自定义的dataset | |||||
""" | """ | ||||
def __init__(self, dataset) -> None: | def __init__(self, dataset) -> None: | ||||
@@ -44,6 +72,10 @@ class _PaddleDataset(Dataset): | |||||
class PaddleDataLoader(DataLoader): | class PaddleDataLoader(DataLoader): | ||||
""" | |||||
提供给``paddle``框架使用的``DataLoader``函数,``PaddleDataLoader``提供了``Collator``的功能,用户可以通过设置``collate_fn="auto"``来 | |||||
使用,并可以配套使用``set_pad``和``set_ignore``方法设置p``ad_val``和忽略某个field的pad操作。 | |||||
""" | |||||
def __init__(self, dataset, feed_list=None, places=None, | def __init__(self, dataset, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | return_list: bool = True, batch_sampler=None, | ||||
@@ -52,6 +84,51 @@ class PaddleDataLoader(DataLoader): | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | worker_init_fn: Callable = None, persistent_workers=False) -> None: | ||||
""" | |||||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||||
:param feed_list: (list(Tensor)|tuple(Tensor)): feed Tensor list. | |||||
The Tensors should be created by :code:`paddle.static.data()`. | |||||
:attr:`feed_list` must be set if :attr:`return_list` is | |||||
False. Default None. | |||||
:param places: (list(Place)|tuple(Place)|list(str)|optional): a list of Place, | |||||
to put data onto, :attr:`places` can be None, if | |||||
:attr:`places` is None, default place(CPUPlace or CUDAPlace(0)) | |||||
will be used. Default None. If ``places`` is list of string, | |||||
the string in the list can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, | |||||
where ``x`` is the index of the GPUs. | |||||
:param return_list: whether the return value on each device is | |||||
presented as a list. If :attr:`return_list=False`, the return | |||||
value on each device would be a dict of str -> Tensor, where | |||||
the key of the dict is the name of each fed Tensors. If | |||||
:attr:`return_list=True`, the return value on each device would | |||||
be a list(Tensor). :attr:`return_list` can only be True | |||||
in dynamic graph mode. Default True. | |||||
:param batch_sampler: 实现了``__iter__``和``__len__``方法的实例化对象,它的功能是根据dataset生成数据indices并组成一个batch数据。 | |||||
:param batch_size: dataloader每次获得数据的批次大小 | |||||
:param shuffle: 是否将数据打乱,若``shuffle=True``则会将dataset打乱;若否则什么也不做。 | |||||
:param drop_last: 当``drop_last=True``时,``PaddleDataLoader``会扔掉最后一个不能组成``batch_size``大小的batch数据; | |||||
若``drop_last=False``, 则什么也不做。 | |||||
:param collate_fn:用来对从dataset取到的数据进行打包处理成batch的callable函数,其值应该为一下三个:``[None, "auto", callable]``. | |||||
* ``callate_fn=None``时,第一点值得注意的是此时传进来的datset不能为``fastNLP``的dataset,采用fastNLP的dataset时,``collate_fn``不能为``None``; | |||||
第二点注意的是此时``PaddleDataLoader``会调用默认的`default_collate_fn`函数对sampler到的数据进行简单打包,组成一个batch返回。` | |||||
* ``callate_fn="auto"``时,``PaddleDataLoader``会自动调用``fastNLP``自带的``Collator``,其会自动检测dataset的每个``field``, | |||||
并判断是否能够pad处理,若能则会自动进行pad操作,默认``pad_val=0``。若想要更改其值,可调用``set_pad``方法;若不想自动pad某个field, | |||||
可以调用``set_ignore``方法忽略某个field。 | |||||
* ``callate_fn=callable``时,callable函数是用户自定义的callate_fn函数,此时``PaddleDataLoader``会调用传进来的callable函数对 | |||||
数据进行打包处理并返回。值得注意的是用户自定义的callable函数的输入为batch,batch为list类型数据,其中batch的每一条数据都为dataset的一条数据。 | |||||
:param num_workers: 开启多进程的数量,当``num_workers=0``时不开启多进程 | |||||
:param use_buffer_reader: 是否开启buffer_reader。如果``use_buffer_reader=True``,那么``PaddleDataLoader``将会异步的预取下一个batch的 | |||||
数据,因此它将会加快数据传输的速度,但是将会占用更多的内存或者显存。默认值是``True``。如果``use_buffer_reader=False``,那么什么也不错 | |||||
:param use_shared_memory: 是否使用共享内存。当``use_shared_memory=True``时,将采用共享内存来加快将数据放进进程队列。建议仅当计算机上的 | |||||
共享空间足够大时。(例如Linux上的/dev/shm/空间足够大)共享内存仅在多进程模式(num_workers>0)下生效。 | |||||
:param timeout: 从子进程的输出队列获取数据的超时值 | |||||
:param worker_init_fn: init函数,如果不设置为None,则将会在每个子进程初始化时调用该函数。 | |||||
:param persistent_workers: | |||||
""" | |||||
# FastNLP Datset, collate_fn not None | # FastNLP Datset, collate_fn not None | ||||
if isinstance(dataset, FDataSet) and collate_fn is None: | if isinstance(dataset, FDataSet) and collate_fn is None: | ||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | ||||
@@ -173,7 +250,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
return_list: bool = True, | return_list: bool = True, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
train_batch_size: int = 1, shuffle: bool = False, | train_batch_size: int = 1, shuffle: bool = False, | ||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = None, | |||||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False, | worker_init_fn: Callable = None, persistent_workers=False, | ||||
@@ -177,12 +177,13 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | |||||
batch_size: int = 1, | |||||
shuffle: bool = True, | |||||
def prepare_torch_dataloader(ds_or_db, | |||||
train_batch_size: int = 16, | |||||
shuffle: bool = False, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | |||||
num_workers: int = 0, collate_fn: Union[Callable, str, None] = 'auto', | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
@@ -214,7 +215,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
from fastNLP.io import DataBundle | from fastNLP.io import DataBundle | ||||
if isinstance(ds_or_db, DataSet): | if isinstance(ds_or_db, DataSet): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -227,7 +228,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -236,7 +237,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | ||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
@@ -250,8 +251,11 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
elif isinstance(ds_or_db, Sequence): | elif isinstance(ds_or_db, Sequence): | ||||
dl_bundle = [] | dl_bundle = [] | ||||
for idx, ds in enumerate(ds_or_db): | for idx, ds in enumerate(ds_or_db): | ||||
if idx > 0: | |||||
train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size | |||||
sampler = non_train_sampler if non_train_sampler else sampler | |||||
dl_bundle.append( | dl_bundle.append( | ||||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -265,7 +269,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.items(): | for name, ds in ds_or_db.items(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -274,8 +278,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -32,7 +32,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
if driver not in {"torch", "fairscale"}: | if driver not in {"torch", "fairscale"}: | ||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale'].") | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") | |||||
_could_use_device_num = torch.cuda.device_count() | _could_use_device_num = torch.cuda.device_count() | ||||
if isinstance(device, str): | if isinstance(device, str): | ||||
@@ -43,6 +43,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | ||||
device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] | ||||
elif device >= _could_use_device_num: | elif device >= _could_use_device_num: | ||||
print(device, _could_use_device_num) | |||||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | raise ValueError("The gpu device that parameter `device` specifies is not existed.") | ||||
else: | else: | ||||
device = torch.device(f"cuda:{device}") | device = torch.device(f"cuda:{device}") | ||||
@@ -11,6 +11,7 @@ __all__ = [ | |||||
from collections import Counter | from collections import Counter | ||||
from functools import partial | from functools import partial | ||||
from functools import wraps | from functools import wraps | ||||
from typing import List, Callable, Union | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils.utils import Option | from fastNLP.core.utils.utils import Option | ||||
@@ -20,6 +21,9 @@ import io | |||||
class VocabularyOption(Option): | class VocabularyOption(Option): | ||||
""" | |||||
""" | |||||
def __init__(self, | def __init__(self, | ||||
max_size=None, | max_size=None, | ||||
min_freq=None, | min_freq=None, | ||||
@@ -33,8 +37,11 @@ class VocabularyOption(Option): | |||||
) | ) | ||||
def _check_build_vocab(func): | |||||
r"""A decorator to make sure the indexing is built before used. | |||||
def _check_build_vocab(func: Callable): | |||||
r""" | |||||
A decorator to make sure the indexing is built before used. | |||||
:param func: 传入的callable函数 | |||||
""" | """ | ||||
@@ -48,7 +55,10 @@ def _check_build_vocab(func): | |||||
def _check_build_status(func): | def _check_build_status(func): | ||||
r"""A decorator to check whether the vocabulary updates after the last build. | |||||
r""" | |||||
A decorator to check whether the vocabulary updates after the last build. | |||||
:param func: 用户传入要修饰的callable函数 | |||||
""" | """ | ||||
@@ -69,27 +79,30 @@ class Vocabulary(object): | |||||
r""" | r""" | ||||
用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | 用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | ||||
from fastNLP.core import Vocabulary | |||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
word_list = "this is a word list".split() | word_list = "this is a word list".split() | ||||
# vocab更新自己的字典,输入为list列表 | |||||
vocab.update(word_list) | vocab.update(word_list) | ||||
vocab["word"] # str to int | vocab["word"] # str to int | ||||
vocab.to_word(5) # int to str | vocab.to_word(5) # int to str | ||||
""" | """ | ||||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||||
def __init__(self, max_size:int=None, min_freq:int=None, padding:str='<pad>', unknown:str='<unk>'): | |||||
r""" | r""" | ||||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||||
:param max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||||
若为 ``None`` , 则不限制大小. Default: ``None`` | 若为 ``None`` , 则不限制大小. Default: ``None`` | ||||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||||
:param min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | ||||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||||
:param padding: padding的字符. 如果设置为 ``None`` , | |||||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | ||||
Default: '<pad>' | Default: '<pad>' | ||||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||||
:param unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | 如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | ||||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | 为 ``None`` 的情况多在为label建立Vocabulary的情况. | ||||
Default: '<unk>' | Default: '<unk>' | ||||
""" | """ | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
@@ -121,45 +134,50 @@ class Vocabulary(object): | |||||
self._word2idx = value | self._word2idx = value | ||||
@_check_build_status | @_check_build_status | ||||
def update(self, word_lst, no_create_entry=False): | |||||
r"""依次增加序列中词在词典中的出现频率 | |||||
def update(self, word_lst: list, no_create_entry:bool=False): | |||||
r""" | |||||
依次增加序列中词在词典中的出现频率 | |||||
:param list word_lst: a list of strings | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
:param word_lst: 列表形式的词语,如word_list=['I', 'am', 'a', 'Chinese'],列表中的每个词会计算出现频率并加入到词典中。 | |||||
:param no_create_entry: 如果词语来自于非训练集建议设置为True。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | ||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | ||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | ||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | ||||
则这个词将认为是需要创建单独的vector的。 | 则这个词将认为是需要创建单独的vector的。 | ||||
""" | """ | ||||
self._add_no_create_entry(word_lst, no_create_entry) | self._add_no_create_entry(word_lst, no_create_entry) | ||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
return self | return self | ||||
@_check_build_status | @_check_build_status | ||||
def add(self, word, no_create_entry=False): | |||||
def add(self, word:str, no_create_entry:bool=False): | |||||
r""" | r""" | ||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
:param word: 要添加进字典的新词, word为一个字符串 | |||||
:param no_create_entry: 如果词语来自于非训练集建议设置为True。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | 如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | ||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | 的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | ||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | 加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | ||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | 个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | ||||
则这个词将认为是需要创建单独的vector的。 | 则这个词将认为是需要创建单独的vector的。 | ||||
""" | """ | ||||
self._add_no_create_entry(word, no_create_entry) | self._add_no_create_entry(word, no_create_entry) | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
return self | return self | ||||
def _add_no_create_entry(self, word, no_create_entry): | |||||
def _add_no_create_entry(self, word:Union[str, List[str]], no_create_entry:bool): | |||||
r""" | r""" | ||||
在新加入word时,检查_no_create_word的设置。 | 在新加入word时,检查_no_create_word的设置。 | ||||
:param str List[str] word: | |||||
:param bool no_create_entry: | |||||
:param word: 要添加的新词或者是List类型的新词,如word='I'或者word=['I', 'am', 'a', 'Chinese']均可 | |||||
:param no_create_entry: 如果词语来自于非训练集建议设置为True。如果为True,则不会有这个词语创建一个单独的entry, | |||||
它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独的entry | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(word, str) or not _is_iterable(word): | if isinstance(word, str) or not _is_iterable(word): | ||||
word = [word] | word = [word] | ||||
@@ -170,32 +188,32 @@ class Vocabulary(object): | |||||
self._no_create_word.pop(w) | self._no_create_word.pop(w) | ||||
@_check_build_status | @_check_build_status | ||||
def add_word(self, word, no_create_entry=False): | |||||
def add_word(self, word:str, no_create_entry:bool=False): | |||||
r""" | r""" | ||||
增加一个新词在词典中的出现频率 | 增加一个新词在词典中的出现频率 | ||||
:param str word: 新词 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||||
则这个词将认为是需要创建单独的vector的。 | |||||
:param word: 要添加进字典的新词, word为一个字符串 | |||||
:param no_create_entry: 如果词语来自于非训练集建议设置为True。如果为True,则不会有这个词语创建一个单独的entry, | |||||
它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独的entry。如果这个word来自于dev或者test,一般设置为True, | |||||
如果来自与train一般设置为False。以下两种情况: 如果新加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary | |||||
中且并不是no_create_entry的,则还是会为这词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary | |||||
中且并不是no_create_entry的,则这个词将认为是需要创建单独的vector的。 | |||||
""" | """ | ||||
self.add(word, no_create_entry=no_create_entry) | self.add(word, no_create_entry=no_create_entry) | ||||
@_check_build_status | @_check_build_status | ||||
def add_word_lst(self, word_lst, no_create_entry=False): | |||||
def add_word_lst(self, word_lst: List[str], no_create_entry:bool=False): | |||||
r""" | r""" | ||||
依次增加序列中词在词典中的出现频率 | 依次增加序列中词在词典中的出现频率 | ||||
:param list[str] word_lst: 词的序列 | |||||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||||
则这个词将认为是需要创建单独的vector的。 | |||||
:param word_lst: 需要添加的新词的list序列,如word_lst=['I', 'am', 'a', 'Chinese'] | |||||
:param no_create_entry: 如果词语来自于非训练集建议设置为True。如果为True,则不会有这个词语创建一个单独的entry, | |||||
它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独的entry。如果这个word来自于dev或者test,一般设置为True, | |||||
如果来自与train一般设置为False。以下两种情况: 如果新加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary | |||||
中且并不是no_create_entry的,则还是会为这词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary | |||||
中且并不是no_create_entry的,则这个词将认为是需要创建单独的vector的。 | |||||
""" | """ | ||||
self.update(word_lst, no_create_entry=no_create_entry) | self.update(word_lst, no_create_entry=no_create_entry) | ||||
return self | return self | ||||
@@ -238,7 +256,7 @@ class Vocabulary(object): | |||||
return len(self._word2idx) | return len(self._word2idx) | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def __contains__(self, item): | |||||
def __contains__(self, item:str): | |||||
r""" | r""" | ||||
检查词是否被记录 | 检查词是否被记录 | ||||
@@ -247,7 +265,7 @@ class Vocabulary(object): | |||||
""" | """ | ||||
return item in self._word2idx | return item in self._word2idx | ||||
def has_word(self, w): | |||||
def has_word(self, w:str): | |||||
r""" | r""" | ||||
检查词是否被记录:: | 检查词是否被记录:: | ||||
@@ -255,7 +273,7 @@ class Vocabulary(object): | |||||
# equals to | # equals to | ||||
has_abc = 'abc' in vocab | has_abc = 'abc' in vocab | ||||
:param item: the word | |||||
:param item: 输入的str类型的词 | |||||
:return: ``True`` or ``False`` | :return: ``True`` or ``False`` | ||||
""" | """ | ||||
return self.__contains__(w) | return self.__contains__(w) | ||||
@@ -263,7 +281,7 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
r""" | r""" | ||||
To support usage like:: | |||||
支持从字典中直接得到词语的index,例如:: | |||||
vocab[w] | vocab[w] | ||||
""" | """ | ||||
@@ -275,15 +293,15 @@ class Vocabulary(object): | |||||
raise ValueError("word `{}` not in vocabulary".format(w)) | raise ValueError("word `{}` not in vocabulary".format(w)) | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||||
def index_dataset(self, *datasets, field_name:Union[List, str], new_field_name:Union[List, str, None]=None): | |||||
r""" | r""" | ||||
将DataSet中对应field的词转为数字,Example:: | 将DataSet中对应field的词转为数字,Example:: | ||||
# remember to use `field_name` | # remember to use `field_name` | ||||
vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | ||||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||||
:param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||||
:param datasets: 其类型为:~fastNLP.core.Dataset或者List[~fastNLP.core.Dataset] 需要转index的一个或多个数据集 | |||||
:param field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||||
目前支持 ``str`` , ``List[str]`` | 目前支持 ``str`` , ``List[str]`` | ||||
:param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | :param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | ||||
Default: ``None``. | Default: ``None``. | ||||
@@ -334,17 +352,16 @@ class Vocabulary(object): | |||||
def _no_create_word_length(self): | def _no_create_word_length(self): | ||||
return len(self._no_create_word) | return len(self._no_create_word) | ||||
def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): | |||||
def from_dataset(self, *datasets, field_name:Union[str,List[str]], no_create_entry_dataset=None): | |||||
r""" | r""" | ||||
使用dataset的对应field中词构建词典:: | 使用dataset的对应field中词构建词典:: | ||||
# remember to use `field_name` | # remember to use `field_name` | ||||
vocab.from_dataset(train_data1, train_data2, field_name='words') | |||||
vocab.from_dataset(train_data1, train_data2, field_name='words', no_create_entry_dataset=[test_data1, test_data2]) | |||||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||||
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | |||||
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | |||||
: ``str`` , ``List[str]`` | |||||
:param 其类型为:~fastNLP.core.Dataset或者List[~fastNLP.core.Dataset] 需要转index的一个或多个数据集 | |||||
:param field_name: 构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. | |||||
目前支持的field结构: ``str`` , ``List[str]`` | |||||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认), 建议直接将非训练数据都传入到这个参数。该选项用在接下来的模型会使用pretrain | :param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认), 建议直接将非训练数据都传入到这个参数。该选项用在接下来的模型会使用pretrain | ||||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | 的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | ||||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | 中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | ||||
@@ -352,7 +369,8 @@ class Vocabulary(object): | |||||
finetune embedding的话,这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector, | finetune embedding的话,这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector, | ||||
而应该让它指向unk这个vector的值。所以只位于no_create_entry_dataset中的token,将首先从预训练的词表中寻找它的表示, | 而应该让它指向unk这个vector的值。所以只位于no_create_entry_dataset中的token,将首先从预训练的词表中寻找它的表示, | ||||
如果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。 | 如果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。 | ||||
:return self: | |||||
:return Vocabulary自身 | |||||
""" | """ | ||||
if isinstance(field_name, str): | if isinstance(field_name, str): | ||||
field_name = [field_name] | field_name = [field_name] | ||||
@@ -396,15 +414,16 @@ class Vocabulary(object): | |||||
dataset.apply(partial_construct_vocab, show_progress_bar=False) | dataset.apply(partial_construct_vocab, show_progress_bar=False) | ||||
return self | return self | ||||
def _is_word_no_create_entry(self, word): | |||||
def _is_word_no_create_entry(self, word:str): | |||||
r""" | r""" | ||||
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | 判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | ||||
:param word: str | |||||
:return: bool | |||||
:param word: 输入的str类型的词语 | |||||
:return: bool值的判断结果 | |||||
""" | """ | ||||
return word in self._no_create_word | return word in self._no_create_word | ||||
def to_index(self, w): | |||||
def to_index(self, w:str): | |||||
r""" | r""" | ||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | ||||
@@ -412,8 +431,8 @@ class Vocabulary(object): | |||||
# equals to | # equals to | ||||
index = vocab['abc'] | index = vocab['abc'] | ||||
:param str w: a word | |||||
:return int index: the number | |||||
:param w: 需要输入的词语 | |||||
:return 词语w对应的int类型的index | |||||
""" | """ | ||||
return self.__getitem__(w) | return self.__getitem__(w) | ||||
@@ -421,7 +440,7 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
r""" | r""" | ||||
unknown 对应的数字. | |||||
获得unknown 对应的数字. | |||||
""" | """ | ||||
if self.unknown is None: | if self.unknown is None: | ||||
return None | return None | ||||
@@ -431,14 +450,14 @@ class Vocabulary(object): | |||||
@_check_build_vocab | @_check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
r""" | r""" | ||||
padding 对应的数字 | |||||
获得padding 对应的数字 | |||||
""" | """ | ||||
if self.padding is None: | if self.padding is None: | ||||
return None | return None | ||||
return self._word2idx[self.padding] | return self._word2idx[self.padding] | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def to_word(self, idx): | |||||
def to_word(self, idx: int): | |||||
r""" | r""" | ||||
给定一个数字, 将其转为对应的词. | 给定一个数字, 将其转为对应的词. | ||||
@@ -461,7 +480,8 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def __getstate__(self): | def __getstate__(self): | ||||
r"""Use to prepare data for pickle. | |||||
r""" | |||||
用来从pickle中加载data | |||||
""" | """ | ||||
len(self) # make sure vocab has been built | len(self) # make sure vocab has been built | ||||
@@ -471,7 +491,8 @@ class Vocabulary(object): | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
r"""Use to restore state from pickle. | |||||
r""" | |||||
支持pickle的保存,保存到pickle的data state | |||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
@@ -486,11 +507,11 @@ class Vocabulary(object): | |||||
for index in range(len(self._word2idx)): | for index in range(len(self._word2idx)): | ||||
yield self.to_word(index), index | yield self.to_word(index), index | ||||
def save(self, filepath): | |||||
def save(self, filepath: [str, io.StringIO]): | |||||
r""" | r""" | ||||
:param str,io.StringIO filepath: Vocabulary的储存路径 | |||||
:param filepath: Vocabulary的储存路径 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(filepath, io.IOBase): | if isinstance(filepath, io.IOBase): | ||||
assert filepath.writable() | assert filepath.writable() | ||||
@@ -522,10 +543,11 @@ class Vocabulary(object): | |||||
f.close() | f.close() | ||||
@staticmethod | @staticmethod | ||||
def load(filepath): | |||||
def load(filepath: Union[str,io.StringIO]): | |||||
r""" | r""" | ||||
从文件路径中加载数据 | |||||
:param str,io.StringIO filepath: Vocabulary的读取路径 | |||||
:param filepath: Vocabulary的读取路径 | |||||
:return: Vocabulary | :return: Vocabulary | ||||
""" | """ | ||||
if isinstance(filepath, io.IOBase): | if isinstance(filepath, io.IOBase): | ||||
@@ -0,0 +1,5 @@ | |||||
__all__ = [ | |||||
"LSTM", | |||||
] | |||||
from .lstm import LSTM |
@@ -0,0 +1,82 @@ | |||||
r"""undocumented | |||||
轻量封装的 Pytorch LSTM 模块. | |||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||||
""" | |||||
__all__ = [ | |||||
"LSTM" | |||||
] | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.utils.rnn as rnn | |||||
class LSTM(nn.Module): | |||||
r""" | |||||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||||
为1; 且可以应对DataParallel中LSTM的使用问题。 | |||||
""" | |||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||||
bidirectional=False, bias=True): | |||||
r""" | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param dropout: 层间dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
""" | |||||
super(LSTM, self).__init__() | |||||
self.batch_first = batch_first | |||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||||
dropout=dropout, bidirectional=bidirectional) | |||||
self.init_param() | |||||
def init_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'bias' in name: | |||||
# based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 | |||||
param.data.fill_(0) | |||||
n = param.size(0) | |||||
start, end = n // 4, n // 2 | |||||
param.data[start:end].fill_(1) | |||||
else: | |||||
nn.init.xavier_uniform_(param) | |||||
def forward(self, x, seq_len=None, h0=None, c0=None): | |||||
r""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | |||||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||||
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||||
:return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. | |||||
""" | |||||
batch_size, max_len, _ = x.size() | |||||
if h0 is not None and c0 is not None: | |||||
hx = (h0, c0) | |||||
else: | |||||
hx = None | |||||
if seq_len is not None and not isinstance(x, rnn.PackedSequence): | |||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
if self.batch_first: | |||||
x = x[sort_idx] | |||||
else: | |||||
x = x[:, sort_idx] | |||||
x = rnn.pack_padded_sequence(x, sort_lens.cpu(), batch_first=self.batch_first) | |||||
output, hx = self.lstm(x, hx) # -> [N,L,C] | |||||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) | |||||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||||
if self.batch_first: | |||||
output = output[unsort_idx] | |||||
else: | |||||
output = output[:, unsort_idx] | |||||
hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] | |||||
else: | |||||
output, hx = self.lstm(x, hx) | |||||
return output, hx |
@@ -74,7 +74,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
@@ -121,7 +121,7 @@ def test_model_checkpoint_callback_1( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_10" in all_saved_model_paths | assert "model-epoch_10" in all_saved_model_paths | ||||
assert "model-epoch_4-batch_123" in all_saved_model_paths | assert "model-epoch_4-batch_123" in all_saved_model_paths | ||||
@@ -144,7 +144,7 @@ def test_model_checkpoint_callback_1( | |||||
pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_9" in all_saved_model_paths | assert "model-epoch_9" in all_saved_model_paths | ||||
assert "model-last" in all_saved_model_paths | assert "model-last" in all_saved_model_paths | ||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
@@ -206,7 +206,7 @@ def test_model_checkpoint_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("only_state_dict", [True]) | @pytest.mark.parametrize("only_state_dict", [True]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
def test_model_checkpoint_callback_2( | def test_model_checkpoint_callback_2( | ||||
@@ -259,7 +259,7 @@ def test_model_checkpoint_callback_2( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | ||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | ||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | ||||
@@ -299,7 +299,7 @@ def test_model_checkpoint_callback_2( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
@@ -347,7 +347,7 @@ def test_trainer_checkpoint_callback_1( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-epoch_7" in all_saved_model_paths | assert "trainer-epoch_7" in all_saved_model_paths | ||||
assert "trainer-epoch_4-batch_123" in all_saved_model_paths | assert "trainer-epoch_4-batch_123" in all_saved_model_paths | ||||
@@ -371,7 +371,7 @@ def test_trainer_checkpoint_callback_1( | |||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-last" in all_saved_model_paths | assert "trainer-last" in all_saved_model_paths | ||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -417,7 +417,7 @@ def test_trainer_checkpoint_callback_1( | |||||
n_epochs=13, | n_epochs=13, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.load(folder, only_state_dict=only_state_dict) | |||||
trainer.load_checkpoint(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -489,7 +489,7 @@ def test_load_state(model_and_optimizers): | |||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.load(folder=epoch_2_path) | |||||
trainer.load_checkpoint(folder=epoch_2_path) | |||||
with Capturing() as output: | with Capturing() as output: | ||||
trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | trainer.run(num_eval_sanity_batch=0, num_train_batch_per_epoch=2) | ||||
@@ -503,7 +503,7 @@ def test_load_state(model_and_optimizers): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@pytest.mark.skip("Skip transformers test for now.") | @pytest.mark.skip("Skip transformers test for now.") | ||||
@@ -675,7 +675,7 @@ def test_trainer_checkpoint_callback_2( | |||||
# 检查生成保存模型文件的数量是不是正确的; | # 检查生成保存模型文件的数量是不是正确的; | ||||
if version == 0: | if version == 0: | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-epoch_1-batch_200" in all_saved_model_paths | assert "trainer-epoch_1-batch_200" in all_saved_model_paths | ||||
epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] | ||||
@@ -695,7 +695,7 @@ def test_trainer_checkpoint_callback_2( | |||||
pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") | ||||
# all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | ||||
if driver == "torch": | |||||
if not isinstance(device, list): | |||||
assert "trainer-last" in all_saved_model_paths | assert "trainer-last" in all_saved_model_paths | ||||
aLL_topk_folders = [] | aLL_topk_folders = [] | ||||
for each_folder_name in all_saved_model_paths: | for each_folder_name in all_saved_model_paths: | ||||
@@ -740,7 +740,7 @@ def test_trainer_checkpoint_callback_2( | |||||
output_mapping=bert_output_mapping, | output_mapping=bert_output_mapping, | ||||
metrics={"acc": acc}, | metrics={"acc": acc}, | ||||
) | ) | ||||
trainer.load(folder, model_load_fn=model_load_fn) | |||||
trainer.load_checkpoint(folder, model_load_fn=model_load_fn) | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -72,7 +72,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("save_folder", ['save_models', None]) | @pytest.mark.parametrize("save_folder", ['save_models', None]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -98,7 +98,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -183,7 +183,7 @@ def test_model_more_evaluate_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -256,7 +256,7 @@ def test_trainer_checkpoint_callback_1( | |||||
evaluate_fn='train_step' | evaluate_fn='train_step' | ||||
) | ) | ||||
folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | folder = path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).joinpath(folder) | ||||
trainer.load(folder, only_state_dict=only_state_dict) | |||||
trainer.load_checkpoint(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | trainer.run() | ||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
@@ -85,7 +85,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model, | model=model, | ||||
driver="torch_ddp", | |||||
driver="torch", | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
@@ -73,7 +73,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
): | ): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model, | model=model, | ||||
driver="torch_ddp", | |||||
driver="torch", | |||||
device=None, | device=None, | ||||
optimizers=optimizers, | optimizers=optimizers, | ||||
train_dataloader=train_dataloader, | train_dataloader=train_dataloader, | ||||
@@ -318,7 +318,7 @@ def test_torch_distributed_launch_2(version): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch", [0, 1])]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_wo_auto_param_call( | def test_torch_wo_auto_param_call( | ||||
driver, | driver, | ||||
@@ -4,6 +4,7 @@ from datasets import Dataset as HfDataset | |||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | ||||
from fastNLP.core.dataset import DataSet as Fdataset | from fastNLP.core.dataset import DataSet as Fdataset | ||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | ||||
if _NEED_IMPORT_JITTOR: | if _NEED_IMPORT_JITTOR: | ||||
from jittor.dataset import Dataset | from jittor.dataset import Dataset | ||||
@@ -53,9 +54,9 @@ class TestJittor: | |||||
jtl.set_ignore("y") | jtl.set_ignore("y") | ||||
for batch in jtl: | for batch in jtl: | ||||
assert batch['x'].size() == (16, 4) | assert batch['x'].size() == (16, 4) | ||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2) | |||||
jtl1 = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2) | |||||
for batch in jtl1: | |||||
print(batch) | |||||
def test_huggingface_datasets(self): | def test_huggingface_datasets(self): | ||||
@@ -79,4 +80,11 @@ class TestJittor: | |||||
for idx, batch in enumerate(dataset): | for idx, batch in enumerate(dataset): | ||||
print(idx, batch.shape) | print(idx, batch.shape) | ||||
for idx, batch in enumerate(dataset): | for idx, batch in enumerate(dataset): | ||||
print(idx, batch.shape) | |||||
print(idx, batch.shape) | |||||
def test_jittor_get_backend(self): | |||||
collate_bacth = Collator(backend='auto') | |||||
dl = MyDataset() | |||||
dl = dl.set_attrs(collate_batch=collate_bacth, batch_size=256) | |||||
for batch in dl: | |||||
print(batch) |
@@ -4,11 +4,12 @@ import numpy as np | |||||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.collators import Collator | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import Dataset | |||||
from paddle.io import Dataset, DataLoader | |||||
import paddle | import paddle | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
@@ -61,3 +62,32 @@ class TestPaddle: | |||||
fdl1.set_ignore('label') | fdl1.set_ignore('label') | ||||
for batch in fdl1: | for batch in fdl1: | ||||
assert batch['image'].shape == [4, 10, 5] | assert batch['image'].shape == [4, 10, 5] | ||||
def test_get_backend(self): | |||||
ds = RandomDataset() | |||||
collate_fn = Collator(backend='auto') | |||||
paddle_dl = DataLoader(ds, collate_fn=collate_fn) | |||||
for batch in paddle_dl: | |||||
print(batch) | |||||
def test_v4(self): | |||||
from paddle.io import DataLoader | |||||
from fastNLP import Collator | |||||
from paddle.io import Dataset | |||||
import paddle | |||||
class PaddleRandomMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features): | |||||
self.x = paddle.randn((num_samples, num_features)) | |||||
self.y = self.x.argmax(axis=-1) | |||||
def __len__(self): | |||||
return len(self.x) | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
ds = PaddleRandomMaxDataset(100, 2) | |||||
dl = DataLoader(ds, places=None, collate_fn=Collator(), batch_size=4) | |||||
for batch in dl: | |||||
print(batch) |
@@ -112,3 +112,19 @@ class TestFdl: | |||||
seq_ds = prepare_torch_dataloader(sequence) | seq_ds = prepare_torch_dataloader(sequence) | ||||
assert isinstance(seq_ds[0], TorchDataLoader) | assert isinstance(seq_ds[0], TorchDataLoader) | ||||
assert isinstance(seq_ds[1], TorchDataLoader) | assert isinstance(seq_ds[1], TorchDataLoader) | ||||
def test_get_backend(self): | |||||
from fastNLP.core.collators import Collator | |||||
from torch.utils.data import DataLoader, Dataset | |||||
class MyDatset(DataSet): | |||||
def __len__(self): | |||||
return 1000 | |||||
def __getitem__(self, item): | |||||
return [[1, 0], [1], [1, 2, 4]], [1, 0] | |||||
collate_batch = Collator(backend='auto') | |||||
dl = DataLoader(MyDatset(), collate_fn=collate_batch) | |||||
for batch in dl: | |||||
print(batch) |
@@ -626,9 +626,9 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
@@ -644,7 +644,7 @@ class TestSaveLoad: | |||||
rank=self.driver2.global_rank, | rank=self.driver2.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -736,9 +736,9 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
self.driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | batch_sampler = BatchSampler(dataset=self.dataset, batch_size=2) | ||||
@@ -752,7 +752,7 @@ class TestSaveLoad: | |||||
self.dataset, | self.dataset, | ||||
batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
) | ) | ||||
load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = self.driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -615,16 +615,16 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = DataLoader( | dataloader = DataLoader( | ||||
dataset=dataset, | dataset=dataset, | ||||
batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | batch_sampler=ReproduceBatchSampler(BatchSampler(dataset, batch_size=2, shuffle=True), 2, False) | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -697,9 +697,9 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[paddle.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
@@ -709,7 +709,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
dataset, | dataset, | ||||
batch_sampler=batch_sampler | batch_sampler=batch_sampler | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -648,7 +648,7 @@ class TestSaveLoad: | |||||
# 保存状态 | # 保存状态 | ||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_bucketedbatchsampler( | dataloader = dataloader_with_bucketedbatchsampler( | ||||
@@ -663,7 +663,7 @@ class TestSaveLoad: | |||||
rank=driver2.global_rank, | rank=driver2.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -754,9 +754,9 @@ class TestSaveLoad: | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
if only_state_dict: | if only_state_dict: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
else: | else: | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) | ||||
@@ -765,7 +765,7 @@ class TestSaveLoad: | |||||
rank=driver2.global_rank, | rank=driver2.global_rank, | ||||
pad=True | pad=True | ||||
) | ) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -37,28 +37,6 @@ def test_get_single_device(driver, device): | |||||
driver = initialize_torch_driver(driver, device, model) | driver = initialize_torch_driver(driver, device, model) | ||||
assert isinstance(driver, TorchSingleDriver) | assert isinstance(driver, TorchSingleDriver) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize( | |||||
"device", | |||||
[0, 1] | |||||
) | |||||
@pytest.mark.parametrize( | |||||
"driver", | |||||
["torch_ddp"] | |||||
) | |||||
@magic_argv_env_context | |||||
def test_get_ddp_2(driver, device): | |||||
""" | |||||
测试 ddp 多卡的初始化情况,但传入了单个 gpu | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
assert isinstance(driver, TorchDDPDriver) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
@@ -66,7 +44,7 @@ def test_get_ddp_2(driver, device): | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
["torch", "torch_ddp"] | |||||
["torch"] | |||||
) | ) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_get_ddp(driver, device): | def test_get_ddp(driver, device): | ||||
@@ -79,21 +57,6 @@ def test_get_ddp(driver, device): | |||||
assert isinstance(driver, TorchDDPDriver) | assert isinstance(driver, TorchDDPDriver) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize( | |||||
("driver", "device"), | |||||
[("torch_ddp", "cpu")] | |||||
) | |||||
def test_get_ddp_cpu(driver, device): | |||||
""" | |||||
测试试图在 cpu 上初始化分布式训练的情况 | |||||
""" | |||||
model = TorchNormalModel_Classification_1(64, 10) | |||||
with pytest.raises(ValueError): | |||||
driver = initialize_torch_driver(driver, device, model) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
@@ -101,7 +64,7 @@ def test_get_ddp_cpu(driver, device): | |||||
) | ) | ||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"driver", | "driver", | ||||
["torch", "torch_ddp"] | |||||
["torch"] | |||||
) | ) | ||||
def test_device_out_of_range(driver, device): | def test_device_out_of_range(driver, device): | ||||
""" | """ | ||||
@@ -595,12 +595,12 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.state_dict() | sampler_states = dataloader.batch_sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
# TODO optimizer 的 state_dict 总是为空 | # TODO optimizer 的 state_dict 总是为空 | ||||
@@ -664,12 +664,12 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): | |||||
sampler_states = dataloader.batch_sampler.sampler.state_dict() | sampler_states = dataloader.batch_sampler.sampler.state_dict() | ||||
save_states = {"num_consumed_batches": num_consumed_batches} | save_states = {"num_consumed_batches": num_consumed_batches} | ||||
driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
driver1.save_checkpoint(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) | |||||
# 加载 | # 加载 | ||||
# 更改 batch_size | # 更改 batch_size | ||||
dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | dataloader = dataloader_with_randomsampler(dataset, 2, True, False) | ||||
load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
load_states = driver2.load_checkpoint(Path(path), dataloader, only_state_dict, should_load_model=True) | |||||
replaced_loader = load_states.pop("dataloader") | replaced_loader = load_states.pop("dataloader") | ||||
# 1. 检查 optimizer 的状态 | # 1. 检查 optimizer 的状态 | ||||
@@ -7,8 +7,9 @@ from fastNLP import Vocabulary, DataSet, Instance | |||||
from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding | from fastNLP.embeddings.torch.char_embedding import LSTMCharEmbedding, CNNCharEmbedding | ||||
@pytest.mark.torch | |||||
class TestCharEmbed: | class TestCharEmbed: | ||||
@pytest.mark.test | |||||
# @pytest.mark.test | |||||
def test_case_1(self): | def test_case_1(self): | ||||
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ||||
vocab = Vocabulary().from_dataset(ds, field_name='words') | vocab = Vocabulary().from_dataset(ds, field_name='words') | ||||
@@ -18,7 +19,7 @@ class TestCharEmbed: | |||||
y = embed(x) | y = embed(x) | ||||
assert tuple(y.size()) == (2, 3, 3) | assert tuple(y.size()) == (2, 3, 3) | ||||
@pytest.mark.test | |||||
# @pytest.mark.test | |||||
def test_case_2(self): | def test_case_2(self): | ||||
ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])]) | ||||
vocab = Vocabulary().from_dataset(ds, field_name='words') | vocab = Vocabulary().from_dataset(ds, field_name='words') | ||||