@@ -1,10 +1,3 @@ | |||||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import | |||||
set_env_on_import() | |||||
# 再设置 backend 相关 | |||||
from fastNLP.envs.set_backend import _set_backend | |||||
_set_backend() | |||||
from fastNLP.envs import * | |||||
from fastNLP.core import Trainer, Evaluator | from fastNLP.core import Trainer, Evaluator |
@@ -1,8 +1,11 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Collator', | |||||
'Collator' | |||||
] | ] | ||||
from typing import List, Union, Dict, Callable, Sequence, Mapping | from typing import List, Union, Dict, Callable, Sequence, Mapping | ||||
import os | |||||
import sys | |||||
import inspect | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .padders.get_padder import get_padder | from .padders.get_padder import get_padder | ||||
@@ -13,18 +16,76 @@ from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch | |||||
pack_batch_sequence | pack_batch_sequence | ||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | ||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | |||||
def _get_backend(): | |||||
""" | |||||
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | |||||
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | |||||
某个 backend 的 dataloader 。 | |||||
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。 | |||||
如果都没有找到则返回 numpy 。 | |||||
:return: | |||||
""" | |||||
def _check_module(module): | |||||
""" | |||||
检查该 module 是否含有 某个 backend 的特征 | |||||
:param module: module 对象 | |||||
:return: | |||||
""" | |||||
catch_backend = [] | |||||
try: | |||||
file = module.__file__ | |||||
for backend in CHECK_BACKEND: | |||||
if f'{os.sep}site-packages{os.sep}{backend}' in file: | |||||
catch_backend = [backend, file] | |||||
except: | |||||
pass | |||||
return catch_backend | |||||
currentframe = inspect.currentframe() | |||||
# 方式(1) | |||||
catch_backend = [] | |||||
for i in range(100): | |||||
currentframe = currentframe.f_back | |||||
if currentframe is not None: | |||||
module = inspect.getmodule(currentframe) | |||||
if module is not None: | |||||
catch_backend = _check_module(module) | |||||
if len(catch_backend): # 主要捕获到一个就结束吧 | |||||
break | |||||
else: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
# 方式 (2) | |||||
for key, module in sys.modules.items(): | |||||
catch_backend = _check_module(module) | |||||
if catch_backend: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
return 'numpy' | |||||
class Collator: | class Collator: | ||||
def __init__(self, backend='torch'): | |||||
def __init__(self, backend='auto'): | |||||
""" | """ | ||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | ||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | ||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 | |||||
若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
""" | """ | ||||
self.unpack_batch_func = None | self.unpack_batch_func = None | ||||
self.pack_batch_func = None | self.pack_batch_func = None | ||||
@@ -73,7 +134,7 @@ class Collator: | |||||
else: | else: | ||||
self.unpack_batch_func = unpack_batch_mapping | self.unpack_batch_func = unpack_batch_mapping | ||||
self.pack_batch_func = lambda x:x | self.pack_batch_func = lambda x:x | ||||
# 在这里用ignore_field过滤掉 | |||||
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | ||||
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | ||||
else: | else: | ||||
@@ -81,6 +142,9 @@ class Collator: | |||||
pad_batch = {} | pad_batch = {} | ||||
if len(self.padders)==0: # 第一次运行,准备 padder | if len(self.padders)==0: # 第一次运行,准备 padder | ||||
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 | |||||
self.backend = _get_backend() | |||||
for key in unpack_batch.keys(): | for key in unpack_batch.keys(): | ||||
if key not in self.input_fields and key not in self.ignore_fields: | if key not in self.input_fields and key not in self.ignore_fields: | ||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | ||||
@@ -104,7 +168,7 @@ class Collator: | |||||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', | |||||
pad_fn:Callable=None) -> "Collator": | pad_fn:Callable=None) -> "Collator": | ||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -114,10 +178,11 @@ class Collator: | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | ||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | ||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | ||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | ||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | ||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | ||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
@@ -158,8 +223,8 @@ class Collator: | |||||
""" | """ | ||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||||
若为 None ,则不进行 padding 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | |||||
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
assert backend in SUPPORTED_BACKENDS | assert backend in SUPPORTED_BACKENDS | ||||
@@ -181,7 +246,7 @@ class Collator: | |||||
if field_name in self.input_fields: | if field_name in self.input_fields: | ||||
self.input_fields.pop(field_name) | self.input_fields.pop(field_name) | ||||
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | ||||
self.padders.pop(field_name, None) # 如果有的话,将它的 padder 扔掉。 | |||||
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 | |||||
self.ignore_fields.add(field_name) | self.ignore_fields.add(field_name) | ||||
return self | return self | ||||
@@ -190,6 +255,9 @@ class Collator: | |||||
# | # | ||||
# from abc import ABCMeta, abstractmethod | # from abc import ABCMeta, abstractmethod | ||||
# from typing import Any, Dict, List, Callable, Union, Tuple | # from typing import Any, Dict, List, Callable, Union, Tuple | ||||
@@ -1,4 +1,7 @@ | |||||
from typing import List, Union, Dict, Callable, Sequence, Mapping | from typing import List, Union, Dict, Callable, Sequence, Mapping | ||||
import os | |||||
import sys | |||||
import inspect | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .padders.get_padder import get_padder | from .padders.get_padder import get_padder | ||||
@@ -9,18 +12,76 @@ from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch | |||||
pack_batch_sequence | pack_batch_sequence | ||||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | ||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] | |||||
CHECK_BACKEND = ['torch', 'jittor', 'paddle'] # backend 为 auto 时 检查是否是这些 backend | |||||
def _get_backend(): | |||||
""" | |||||
当 Collator 的 backend 为 None 的时候如何,通过这个函数自动判定其 backend 。判断方法主要为以下两个: | |||||
(1)尝试通过向上寻找当前 collator 的 callee 对象,根据 callee 对象寻找。然后使用 '/site-packages/{backend}' 来寻找是否是 | |||||
某个 backend 的 dataloader 。 | |||||
(2)如果方式(1)没找,则通过分析 sys.modules 中的内容进行寻找。 | |||||
如果都没有找到则返回 numpy 。 | |||||
:return: | |||||
""" | |||||
def _check_module(module): | |||||
""" | |||||
检查该 module 是否含有 某个 backend 的特征 | |||||
:param module: module 对象 | |||||
:return: | |||||
""" | |||||
catch_backend = [] | |||||
try: | |||||
file = module.__file__ | |||||
for backend in CHECK_BACKEND: | |||||
if f'{os.sep}site-packages{os.sep}{backend}' in file: | |||||
catch_backend = [backend, file] | |||||
except: | |||||
pass | |||||
return catch_backend | |||||
currentframe = inspect.currentframe() | |||||
# 方式(1) | |||||
catch_backend = [] | |||||
for i in range(100): | |||||
currentframe = currentframe.f_back | |||||
if currentframe is not None: | |||||
module = inspect.getmodule(currentframe) | |||||
if module is not None: | |||||
catch_backend = _check_module(module) | |||||
if len(catch_backend): # 主要捕获到一个就结束吧 | |||||
break | |||||
else: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from stack contain backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
# 方式 (2) | |||||
for key, module in sys.modules.items(): | |||||
catch_backend = _check_module(module) | |||||
if catch_backend: | |||||
break | |||||
if len(catch_backend): | |||||
logger.debug(f"Find a file named:{catch_backend[1]} from sys.modules contain backend:{catch_backend[0]}.") | |||||
return catch_backend[0] | |||||
return 'numpy' | |||||
class Collator: | class Collator: | ||||
def __init__(self, backend='torch'): | |||||
def __init__(self, backend='auto'): | |||||
""" | """ | ||||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | ||||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | ||||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 | |||||
若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', auto, None]。 | |||||
若为 'auto' ,则在进行 pad 的时候会根据调用的环境决定其 backend 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad | |||||
的数据返回一定是 list 。 | |||||
""" | """ | ||||
self.unpack_batch_func = None | self.unpack_batch_func = None | ||||
self.pack_batch_func = None | self.pack_batch_func = None | ||||
@@ -77,6 +138,9 @@ class Collator: | |||||
pad_batch = {} | pad_batch = {} | ||||
if len(self.padders)==0: # 第一次运行,准备 padder | if len(self.padders)==0: # 第一次运行,准备 padder | ||||
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 | |||||
self.backend = _get_backend() | |||||
for key in unpack_batch.keys(): | for key in unpack_batch.keys(): | ||||
if key not in self.input_fields and key not in self.ignore_fields: | if key not in self.input_fields and key not in self.ignore_fields: | ||||
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | ||||
@@ -100,7 +164,7 @@ class Collator: | |||||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', | |||||
pad_fn:Callable=None) -> "Collator": | pad_fn:Callable=None) -> "Collator": | ||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -110,10 +174,11 @@ class Collator: | |||||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | ||||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | ||||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | ||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||||
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 | |||||
无意义。 | |||||
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | ||||
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||||
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||||
:param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, | |||||
torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 | |||||
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | ||||
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | ||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
@@ -154,8 +219,8 @@ class Collator: | |||||
""" | """ | ||||
设置可以 pad 的 field 默认 pad 为什么类型的 tensor | 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | ||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||||
若为 None ,则不进行 padding 。 | |||||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw', 'auto', None], | |||||
若为 auto ,则在进行 pad 的时候会根据调用的环境决定其 backend 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
assert backend in SUPPORTED_BACKENDS | assert backend in SUPPORTED_BACKENDS | ||||
@@ -27,6 +27,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
:param field_name: 方便报错的。 | :param field_name: 方便报错的。 | ||||
:return: | :return: | ||||
""" | """ | ||||
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | ||||
if pad_val is None: | if pad_val is None: | ||||
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | ||||
@@ -84,25 +85,25 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
try: | try: | ||||
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | ||||
if backend == 'raw': | if backend == 'raw': | ||||
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return RawNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'numpy': | elif backend == 'numpy': | ||||
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | ||||
if backend == 'raw': | if backend == 'raw': | ||||
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return RawSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'numpy': | elif backend == 'numpy': | ||||
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if depth == 1 and shape_len != 0: | if depth == 1 and shape_len != 0: | ||||
if backend == 'numpy': | if backend == 'numpy': | ||||
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if shape_len != 0 and depth>1: | if shape_len != 0 and depth>1: | ||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | msg = "Does not support pad tensor under nested list. If you need this, please report." | ||||
@@ -112,7 +113,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return NullPadder() | return NullPadder() | ||||
except DtypeError as e: | except DtypeError as e: | ||||
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | |||||
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | |||||
"information please set logger's level to DEBUG." | "information please set logger's level to DEBUG." | ||||
if must_pad: | if must_pad: | ||||
raise type(e)(msg=msg) | raise type(e)(msg=msg) | ||||
@@ -1,6 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'NumpyNumberPadder', | 'NumpyNumberPadder', | ||||
'NumpySequencePadder', | 'NumpySequencePadder', | ||||
"NumpyTensorPadder" | |||||
] | ] | ||||
from numbers import Number | from numbers import Number | ||||
@@ -14,7 +15,7 @@ from .exceptions import * | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if not is_number_or_numpy_number(ele_dtype): | |||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers but get `{ele_dtype}`.") | f"or numpy numbers but get `{ele_dtype}`.") | ||||
@@ -29,7 +30,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class NumpyNumberPadder(Padder): | class NumpyNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -39,7 +47,14 @@ class NumpyNumberPadder(Padder): | |||||
class NumpySequencePadder(Padder): | class NumpySequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -49,13 +64,13 @@ class NumpySequencePadder(Padder): | |||||
class NumpyTensorPadder(Padder): | class NumpyTensorPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | """ | ||||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field | pad 类似于 [np.array([3, 4], np.array([1])] 的 field | ||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | """ | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -14,6 +14,13 @@ class Padder: | |||||
class NullPadder(Padder): | class NullPadder(Padder): | ||||
def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | ||||
""" | |||||
不进行任何 检查 与 pad 的空 padder 。 | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
def __call__(self, batch_field): | def __call__(self, batch_field): | ||||
@@ -1,25 +1,35 @@ | |||||
from .padder import Padder | from .padder import Padder | ||||
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array | |||||
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | |||||
from .exceptions import * | from .exceptions import * | ||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if is_number(ele_dtype): | |||||
if dtype is None: | |||||
dtype = ele_dtype | |||||
elif not is_number(dtype): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " | |||||
f"get `{dtype}`.") | |||||
else: | |||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"but get `{ele_dtype}`.") | |||||
f"or numpy numbers but get `{ele_dtype}`.") | |||||
if dtype is None: | |||||
dtype = ele_dtype | |||||
else: | |||||
if not is_number_or_numpy_number(dtype): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or numpy numbers but get `{dtype}`.") | |||||
dtype = dtype | |||||
return dtype | return dtype | ||||
class RawNumberPadder(Padder): | class RawNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -32,7 +42,14 @@ class RawNumberPadder(Padder): | |||||
class RawSequencePadder(Padder): | class RawSequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -37,7 +37,7 @@ def is_torch_tensor(dtype): | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | ||||
@@ -47,20 +47,27 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
f"or torch.dtype but get `{dtype}`.") | f"or torch.dtype but get `{dtype}`.") | ||||
dtype = number_to_torch_dtype_dict.get(dtype, dtype) | dtype = number_to_torch_dtype_dict.get(dtype, dtype) | ||||
else: | else: | ||||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||||
if ele_dtype is not None: | |||||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||||
return dtype | return dtype | ||||
class TorchNumberPadder(Padder): | class TorchNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -70,7 +77,14 @@ class TorchNumberPadder(Padder): | |||||
class TorchSequencePadder(Padder): | class TorchSequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -81,13 +95,13 @@ class TorchSequencePadder(Padder): | |||||
class TorchTensorPadder(Padder): | class TorchTensorPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | """ | ||||
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | ||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | """ | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -96,8 +110,6 @@ class TorchTensorPadder(Padder): | |||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
if isinstance(dtype, np.dtype): | |||||
print(dtype) | |||||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
@@ -10,8 +10,7 @@ def is_torch_tensor_dtype(dtype) -> bool: | |||||
""" | """ | ||||
返回当前 dtype 是否是 torch 的 dtype 类型 | 返回当前 dtype 是否是 torch 的 dtype 类型 | ||||
:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 | |||||
:param dtype: 类似与 torch.ones(3).dtype | |||||
:return: | :return: | ||||
""" | """ | ||||
try: | try: | ||||
@@ -86,12 +86,12 @@ class TorchDataLoader(DataLoader): | |||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | ||||
self._collate_fn = dataset.dataset.collator | self._collate_fn = dataset.dataset.collator | ||||
self._collate_fn.set_backend(backend="torch") | |||||
self._collate_fn.set_backend() | |||||
# if collate_fn is not None and collate_fn is not default_collate: | # if collate_fn is not None and collate_fn is not default_collate: | ||||
# # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | # # 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | ||||
# self._collate_fn.add_collator(collate_fn) | # self._collate_fn.add_collator(collate_fn) | ||||
else: | else: | ||||
self._collate_fn = Collator(backend='torch') | |||||
self._collate_fn = Collator() | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
elif isinstance(collate_fn, Callable): | elif isinstance(collate_fn, Callable): | ||||
@@ -162,6 +162,7 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | ||||
@@ -759,8 +759,7 @@ class DataSet: | |||||
dict_ = {key: value.content for key, value in self.field_arrays.items()} | dict_ = {key: value.content for key, value in self.field_arrays.items()} | ||||
return pd.DataFrame.from_dict(dict_) | return pd.DataFrame.from_dict(dict_) | ||||
# TODO 应该有返回值的吧 | |||||
def to_csv(self, path: str) -> None: | |||||
def to_csv(self, path: str): | |||||
""" | """ | ||||
将dataset保存为csv文件 | 将dataset保存为csv文件 | ||||
@@ -769,7 +768,7 @@ class DataSet: | |||||
""" | """ | ||||
df = self.to_pandas() | df = self.to_pandas() | ||||
df.to_csv(path, encoding="utf-8") | |||||
return df.to_csv(path, encoding="utf-8") | |||||
def set_ignore(self, *field_names) -> None: | def set_ignore(self, *field_names) -> None: | ||||
""" | """ | ||||
@@ -14,7 +14,11 @@ __all__ = [ | |||||
from .env import * | from .env import * | ||||
from .set_env_on_import import set_env_on_import | from .set_env_on_import import set_env_on_import | ||||
from .set_backend import dump_fastnlp_backend | |||||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||||
set_env_on_import() | |||||
from .set_backend import dump_fastnlp_backend, _set_backend | |||||
# 再设置 backend 相关 | |||||
_set_backend() | |||||
from .imports import * | from .imports import * | ||||
from .utils import _module_available, get_gpu_count | from .utils import _module_available, get_gpu_count | ||||
from .distributed import * | from .distributed import * |
@@ -5,9 +5,9 @@ import operator | |||||
from fastNLP.envs.env import FASTNLP_BACKEND | from fastNLP.envs.env import FASTNLP_BACKEND | ||||
from fastNLP.envs.utils import _module_available, _compare_version | from fastNLP.envs.utils import _module_available, _compare_version | ||||
from fastNLP.envs.set_backend import SUPPORT_BACKENDS | |||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
backend = os.environ.get(FASTNLP_BACKEND, 'all') | backend = os.environ.get(FASTNLP_BACKEND, 'all') | ||||
if backend == 'all': | if backend == 'all': | ||||
need_import = SUPPORT_BACKENDS | need_import = SUPPORT_BACKENDS | ||||
@@ -1,7 +1,3 @@ | |||||
""" | |||||
这个文件用于自动以及手动设置某些环境变量的,该文件中的set_env()函数会在 fastNLP 被 import 的时候在set_env_on_import之后运行。可以 | |||||
用于设置某些必要的环境变量。同时用户在使用时set_env()修改环境变量时,也应该保证set_env()函数在所有其它代码之前被运行。 | |||||
""" | |||||
import os | import os | ||||
import json | import json | ||||
import sys | import sys | ||||
@@ -9,9 +5,12 @@ from collections import defaultdict | |||||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.envs.imports import SUPPORT_BACKENDS | |||||
from fastNLP.envs.utils import _module_available, get_gpu_count | from fastNLP.envs.utils import _module_available, get_gpu_count | ||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
def _set_backend(): | def _set_backend(): | ||||
""" | """ | ||||
根据环境变量或者默认配置文件设置 backend 。 | 根据环境变量或者默认配置文件设置 backend 。 | ||||
@@ -179,11 +178,11 @@ def dump_fastnlp_backend(default:bool = False, backend=None): | |||||
os.makedirs(os.path.dirname(env_path), exist_ok=True) | os.makedirs(os.path.dirname(env_path), exist_ok=True) | ||||
envs = {} | envs = {} | ||||
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." | |||||
if backend is None: | if backend is None: | ||||
if FASTNLP_BACKEND in os.environ: | if FASTNLP_BACKEND in os.environ: | ||||
envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] | ||||
else: | else: | ||||
assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." | |||||
envs[FASTNLP_BACKEND] = backend | envs[FASTNLP_BACKEND] = backend | ||||
if len(envs): | if len(envs): | ||||
with open(env_path, 'w', encoding='utf8') as f: | with open(env_path, 'w', encoding='utf8') as f: | ||||
@@ -1,97 +0,0 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"CWSLoader" | |||||
] | |||||
import glob | |||||
import os | |||||
import random | |||||
import shutil | |||||
import time | |||||
from .loader import Loader | |||||
from fastNLP.core.dataset import DataSet, Instance | |||||
class CWSLoader(Loader): | |||||
r""" | |||||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||||
Example:: | |||||
上海 浦东 开发 与 法制 建设 同步 | |||||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||||
... | |||||
该Loader读取后的DataSet具有如下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words" | |||||
"上海 浦东 开发 与 法制 建设 同步" | |||||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||||
"..." | |||||
""" | |||||
def __init__(self, dataset_name: str = None): | |||||
r""" | |||||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||||
""" | |||||
super().__init__() | |||||
datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} | |||||
if dataset_name in datanames: | |||||
self.dataset_name = datanames[dataset_name] | |||||
else: | |||||
self.dataset_name = None | |||||
def _load(self, path: str): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
ds.append(Instance(raw_words=line)) | |||||
return ds | |||||
def download(self, dev_ratio=0.1, re_download=False) -> str: | |||||
r""" | |||||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||||
:return: str | |||||
""" | |||||
if self.dataset_name is None: | |||||
return '' | |||||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||||
modify_time = 0 | |||||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||||
modify_time = os.stat(filepath).st_mtime | |||||
break | |||||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||||
shutil.rmtree(data_dir) | |||||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||||
if dev_ratio > 0: | |||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||||
try: | |||||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||||
for line in f: | |||||
if random.random() < dev_ratio: | |||||
f2.write(line) | |||||
else: | |||||
f1.write(line) | |||||
os.remove(os.path.join(data_dir, 'train.txt')) | |||||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||||
finally: | |||||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||||
return data_dir |
@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
class TestNumpyNumberPadder: | class TestNumpyNumberPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = NumpyNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
assert isinstance(padder(a), np.ndarray) | assert isinstance(padder(a), np.ndarray) | ||||
assert (padder(a) == np.array(a)).sum() == 3 | assert (padder(a) == np.array(a)).sum() == 3 | ||||
@@ -17,7 +17,7 @@ class TestNumpyNumberPadder: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestNumpySequencePadder: | class TestNumpySequencePadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = np.shape(a) | shape = np.shape(a) | ||||
@@ -27,18 +27,18 @@ class TestNumpySequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
class TestNumpyTensorPadder: | class TestNumpyTensorPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3).dtype, dtype=int) | |||||
a = [np.zeros(3), np.zeros(2), np.zeros(0)] | a = [np.zeros(3), np.zeros(2), np.zeros(0)] | ||||
a = padder(a) | a = padder(a) | ||||
shape = np.shape(a) | shape = np.shape(a) | ||||
@@ -68,15 +68,15 @@ class TestNumpyTensorPadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) | |||||
@@ -59,9 +59,9 @@ class TestpaddleTensorPadder: | |||||
shape = a.shape | shape = a.shape | ||||
assert isinstance(a, paddle.Tensor) | assert isinstance(a, paddle.Tensor) | ||||
assert tuple(shape) == (3, 3, 2) | assert tuple(shape) == (3, 3, 2) | ||||
b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | [[0, 0], [0, 0], [-1, -1]], | ||||
[[0, 0], [-1, -1], [-1, -1]]]) | |||||
[[0, 0], [-1, -1], [-1, -1]]], dtype='in') | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))] | a = [paddle.zeros((3, 2)), paddle.zeros((2, 2)), paddle.zeros((1, 1))] | ||||
@@ -69,7 +69,7 @@ class TestpaddleTensorPadder: | |||||
shape = a.shape | shape = a.shape | ||||
assert isinstance(a, paddle.Tensor) | assert isinstance(a, paddle.Tensor) | ||||
assert tuple(shape) == (3, 3, 2) | assert tuple(shape) == (3, 3, 2) | ||||
b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | [[0, 0], [0, 0], [-1, -1]], | ||||
[[0, -1], [-1, -1], [-1, -1]]]) | [[0, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
@@ -80,7 +80,7 @@ class TestpaddleTensorPadder: | |||||
shape = a.shape | shape = a.shape | ||||
assert isinstance(a, paddle.Tensor) | assert isinstance(a, paddle.Tensor) | ||||
assert tuple(shape) == (3, 3, 2) | assert tuple(shape) == (3, 3, 2) | ||||
b = paddle.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | [[0, 0], [0, 0], [-1, -1]], | ||||
[[-1, -1], [-1, -1], [-1, -1]]]) | [[-1, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
@@ -91,17 +91,17 @@ class TestpaddleTensorPadder: | |||||
shape = a.shape | shape = a.shape | ||||
assert isinstance(a, paddle.Tensor) | assert isinstance(a, paddle.Tensor) | ||||
assert tuple(shape) == (3, 3, 2) | assert tuple(shape) == (3, 3, 2) | ||||
b = paddle.FloatTensor([[[0, 0], [0, 0], [0, 0]], | |||||
b = paddle.to_tensor([[[0, 0], [0, 0], [0, 0]], | |||||
[[0, 0], [0, 0], [-1, -1]], | [[0, 0], [0, 0], [-1, -1]], | ||||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||||
[[-1, -1], [-1, -1], [-1, -1]]], dtype='float32') | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | padder = paddleTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | padder = paddleTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | ||||
padder = paddleTensorPadder(ele_dtype=paddle.long, dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype=int, dtype=paddle.long, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype='int64', dtype=int, pad_val=-1) | |||||
padder = paddleTensorPadder(ele_dtype=int, dtype='int64', pad_val=-1) | |||||
@@ -7,14 +7,14 @@ from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
class TestRawNumberPadder: | class TestRawNumberPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = RawNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
assert padder(a) == a | assert padder(a) == a | ||||
class TestRawSequencePadder: | class TestRawSequencePadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = np.shape(a) | shape = np.shape(a) | ||||
@@ -24,6 +24,6 @@ class TestRawSequencePadder: | |||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) |
@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestTorchNumberPadder: | class TestTorchNumberPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = TorchNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
t_a = padder(a) | t_a = padder(a) | ||||
assert isinstance(t_a, torch.Tensor) | assert isinstance(t_a, torch.Tensor) | ||||
@@ -22,7 +22,7 @@ class TestTorchNumberPadder: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestTorchSequencePadder: | class TestTorchSequencePadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -32,20 +32,20 @@ class TestTorchSequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None) | |||||
a = padder([[1], [2, 322]]) | a = padder([[1], [2, 322]]) | ||||
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | ||||
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestTorchTensorPadder: | class TestTorchTensorPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) | |||||
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -74,7 +74,7 @@ class TestTorchTensorPadder: | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | [[0, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) | |||||
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -85,7 +85,7 @@ class TestTorchTensorPadder: | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | [[-1, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=None) | |||||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -97,11 +97,11 @@ class TestTorchTensorPadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) | |||||
@@ -65,6 +65,7 @@ def model_and_optimizers(): | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) # , ("torch", 6), ("torch", [6, 7]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | ||||
@pytest.mark.torch | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger( | def test_trainer_event_trigger( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -7,16 +7,16 @@ from tests.helpers.utils import magic_argv_env_context | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator(): | def test_trainer_torch_without_evaluator(): | ||||
@Trainer.on(Events.ON_TRAIN_EPOCH_BEGIN(every=10)) | |||||
@Trainer.on(Events.on_train_epoch_begin(every=10)) | |||||
def fn1(trainer): | def fn1(trainer): | ||||
pass | pass | ||||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) | |||||
@Trainer.on(Events.on_train_batch_begin(every=10)) | |||||
def fn2(trainer, batch, indices): | def fn2(trainer, batch, indices): | ||||
pass | pass | ||||
with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
@Trainer.on(Events.ON_TRAIN_BATCH_BEGIN(every=10)) | |||||
@Trainer.on(Events.on_train_batch_begin(every=10)) | |||||
def fn3(trainer, batch): | def fn3(trainer, batch): | ||||
pass | pass | ||||
@@ -25,8 +25,8 @@ class TrainPaddleConfig: | |||||
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) | @pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])]) | ||||
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | # @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True), | |||||
RichCallback(5)]]) | |||||
@pytest.mark.parametrize("callbacks", [[RichCallback(5)]]) | |||||
@pytest.mark.paddle | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_paddle( | def test_trainer_paddle( | ||||
driver, | driver, | ||||
@@ -98,6 +98,7 @@ def model_and_optimizers(request): | |||||
# 测试一下普通的情况; | # 测试一下普通的情况; | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) | ||||
@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) | ||||
@@ -133,6 +134,7 @@ def test_trainer_torch_with_evaluator( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | ||||
@pytest.mark.parametrize("fp16", [True, False]) | @pytest.mark.parametrize("fp16", [True, False]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@@ -76,6 +76,7 @@ def model_and_optimizers(request): | |||||
# 测试一下 cpu; | # 测试一下 cpu; | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
@@ -107,6 +108,7 @@ def test_trainer_torch_without_evaluator( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), | @pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), | ||||
@pytest.mark.parametrize("fp16", [False, True]) | @pytest.mark.parametrize("fp16", [False, True]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@@ -146,6 +148,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
# 测试 accumulation_steps; | # 测试 accumulation_steps; | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -179,6 +182,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | ||||
@pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | @pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -242,6 +246,7 @@ def test_trainer_output_from_new_proc( | |||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | @pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | ||||
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -294,6 +299,7 @@ def test_torch_distributed_launch_1(version): | |||||
subprocess.check_call(command) | subprocess.check_call(command) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("version", [0, 1, 2, 3]) | @pytest.mark.parametrize("version", [0, 1, 2, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_distributed_launch_2(version): | def test_torch_distributed_launch_2(version): | ||||
@@ -307,6 +313,7 @@ def test_torch_distributed_launch_2(version): | |||||
subprocess.check_call(command) | subprocess.check_call(command) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | @pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_torch_wo_auto_param_call( | def test_torch_wo_auto_param_call( | ||||
@@ -10,7 +10,7 @@ class Test_WrapDataLoader: | |||||
all_sanity_batches = [4, 20, 100] | all_sanity_batches = [4, 20, 100] | ||||
for sanity_batches in all_sanity_batches: | for sanity_batches in all_sanity_batches: | ||||
data = NormalIterator(num_of_data=1000) | data = NormalIterator(num_of_data=1000) | ||||
wrapper = _TruncatedDataLoader(num_batches=sanity_batches) | |||||
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | |||||
dataloader = iter(wrapper(dataloader=data)) | dataloader = iter(wrapper(dataloader=data)) | ||||
mark = 0 | mark = 0 | ||||
while True: | while True: | ||||
@@ -31,7 +31,7 @@ class Test_WrapDataLoader: | |||||
for sanity_batches in all_sanity_batches: | for sanity_batches in all_sanity_batches: | ||||
dataset = TorchNormalDataset(num_of_data=1000) | dataset = TorchNormalDataset(num_of_data=1000) | ||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||||
wrapper = _TruncatedDataLoader(num_batches=sanity_batches) | |||||
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | |||||
dataloader = wrapper(dataloader) | dataloader = wrapper(dataloader) | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
all_supposed_running_data_num = 0 | all_supposed_running_data_num = 0 | ||||
@@ -54,7 +54,7 @@ class Test_WrapDataLoader: | |||||
for sanity_batches in all_sanity_batches: | for sanity_batches in all_sanity_batches: | ||||
dataset = TorchNormalDataset(num_of_data=1000) | dataset = TorchNormalDataset(num_of_data=1000) | ||||
dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | dataloader = DataLoader(dataset, batch_size=bs, shuffle=True) | ||||
wrapper = _TruncatedDataLoader(num_batches=sanity_batches) | |||||
wrapper = _TruncatedDataLoader(dataloader, num_batches=sanity_batches) | |||||
dataloader = wrapper(dataloader) | dataloader = wrapper(dataloader) | ||||
length.append(len(dataloader)) | length.append(len(dataloader)) | ||||
assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) | assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))]) |
@@ -1,12 +1,16 @@ | |||||
import pytest | import pytest | ||||
from jittor.dataset import Dataset | |||||
import jittor | |||||
import numpy as np | import numpy as np | ||||
from datasets import Dataset as HfDataset | from datasets import Dataset as HfDataset | ||||
from datasets import load_dataset | from datasets import load_dataset | ||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | ||||
from fastNLP.core.dataset import DataSet as Fdataset | from fastNLP.core.dataset import DataSet as Fdataset | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
from jittor.dataset import Dataset | |||||
import jittor | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||||
class MyDataset(Dataset): | class MyDataset(Dataset): | ||||
@@ -25,7 +29,7 @@ class MyDataset(Dataset): | |||||
# def __len__(self): | # def __len__(self): | ||||
# return self.dataset_len | # return self.dataset_len | ||||
@pytest.mark.jittor | |||||
class TestJittor: | class TestJittor: | ||||
def test_v1(self): | def test_v1(self): | ||||
@@ -1,13 +1,18 @@ | |||||
import unittest | |||||
import pytest | |||||
import os | import os | ||||
import numpy as np | import numpy as np | ||||
import jittor as jt # 将 jittor 引入 | |||||
from jittor import nn, Module # 引入相关的模块 | |||||
from jittor import init | |||||
from jittor.dataset import MNIST | |||||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | ||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt # 将 jittor 引入 | |||||
from jittor import nn, Module # 引入相关的模块 | |||||
from jittor import init | |||||
from jittor.dataset import MNIST | |||||
else: | |||||
from fastNLP.core.utils.dummy_class import DummyClass as Module | |||||
class Model (Module): | class Model (Module): | ||||
@@ -39,7 +44,8 @@ class Model (Module): | |||||
x = self.fc2 (x) | x = self.fc2 (x) | ||||
return x | return x | ||||
class SingleDeviceTestCase(unittest.TestCase): | |||||
@pytest.mark.jittor | |||||
class TestSingleDevice: | |||||
def test_on_gpu_without_fp16(self): | def test_on_gpu_without_fp16(self): | ||||
# TODO get_dataloader | # TODO get_dataloader | ||||
@@ -82,7 +88,7 @@ class SingleDeviceTestCase(unittest.TestCase): | |||||
total_acc += acc | total_acc += acc | ||||
total_num += batch_size | total_num += batch_size | ||||
acc = acc / batch_size | acc = acc / batch_size | ||||
self.assertGreater(total_acc / total_num, 0.95) | |||||
assert total_acc / total_num > 0.95 | |||||
def test_on_cpu_without_fp16(self): | def test_on_cpu_without_fp16(self): | ||||
@@ -18,6 +18,7 @@ from tests.helpers.utils import magic_argv_env_context | |||||
import paddle | import paddle | ||||
import paddle.distributed as dist | import paddle.distributed as dist | ||||
@pytest.mark.paddle | |||||
class TestDistUtilsTools: | class TestDistUtilsTools: | ||||
""" | """ | ||||
测试一些工具函数 | 测试一些工具函数 | ||||
@@ -78,6 +79,7 @@ class TestDistUtilsTools: | |||||
assert res["string"] == paddle_dict["string"] | assert res["string"] == paddle_dict["string"] | ||||
@pytest.mark.paddle | |||||
class TestAllGatherAndBroadCast: | class TestAllGatherAndBroadCast: | ||||
@classmethod | @classmethod | ||||
@@ -38,6 +38,7 @@ def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, out | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.paddle | |||||
class TestFleetDriverFunction: | class TestFleetDriverFunction: | ||||
""" | """ | ||||
测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | 测试 PaddleFleetDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 | ||||
@@ -145,6 +146,7 @@ class TestFleetDriverFunction: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.paddle | |||||
class TestSetDistReproDataloader: | class TestSetDistReproDataloader: | ||||
@classmethod | @classmethod | ||||
@@ -517,6 +519,8 @@ class TestSetDistReproDataloader: | |||||
# 测试 save 和 load 相关的功能 | # 测试 save 和 load 相关的功能 | ||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.paddle | |||||
class TestSaveLoad: | class TestSaveLoad: | ||||
""" | """ | ||||
测试多卡情况下 save 和 load 相关函数的表现 | 测试多卡情况下 save 和 load 相关函数的表现 | ||||
@@ -8,12 +8,14 @@ from tests.helpers.utils import magic_argv_env_context | |||||
import paddle | import paddle | ||||
@pytest.mark.paddle | |||||
def test_incorrect_driver(): | def test_incorrect_driver(): | ||||
model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver("torch", 0, model) | driver = initialize_paddle_driver("torch", 0, model) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
["cpu", "gpu:0", 0] | ["cpu", "gpu:0", 0] | ||||
@@ -31,6 +33,7 @@ def test_get_single_device(driver, device): | |||||
driver = initialize_paddle_driver(driver, device, model) | driver = initialize_paddle_driver(driver, device, model) | ||||
assert isinstance(driver, PaddleSingleDriver) | assert isinstance(driver, PaddleSingleDriver) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[0, 1, [1]] | [0, 1, [1]] | ||||
@@ -50,6 +53,7 @@ def test_get_fleet_2(driver, device): | |||||
assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[[0, 2, 3], -1] | [[0, 2, 3], -1] | ||||
@@ -69,6 +73,7 @@ def test_get_fleet(driver, device): | |||||
assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
("driver", "device"), | ("driver", "device"), | ||||
[("fleet", "cpu")] | [("fleet", "cpu")] | ||||
@@ -82,6 +87,7 @@ def test_get_fleet_cpu(driver, device): | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver(driver, device, model) | driver = initialize_paddle_driver(driver, device, model) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize( | @pytest.mark.parametrize( | ||||
"device", | "device", | ||||
[-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | [-2, [0, get_gpu_count() + 1, 3], [-2], get_gpu_count() + 1] | ||||
@@ -97,4 +103,4 @@ def test_device_out_of_range(driver, device): | |||||
""" | """ | ||||
model = PaddleNormalModel_Classification_1(2, 100) | model = PaddleNormalModel_Classification_1(2, 100) | ||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
driver = initialize_paddle_driver(driver, device, model) | |||||
driver = initialize_paddle_driver(driver, device, model) |
@@ -29,6 +29,7 @@ class TestPaddleDriverFunctions: | |||||
model = PaddleNormalModel_Classification_1(10, 32) | model = PaddleNormalModel_Classification_1(10, 32) | ||||
self.driver = PaddleSingleDriver(model, device="cpu") | self.driver = PaddleSingleDriver(model, device="cpu") | ||||
@pytest.mark.torchpaddle | |||||
def test_check_single_optimizer_legality(self): | def test_check_single_optimizer_legality(self): | ||||
""" | """ | ||||
测试传入单个 optimizer 时的表现 | 测试传入单个 optimizer 时的表现 | ||||
@@ -45,6 +46,7 @@ class TestPaddleDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_optimizers_legality(self): | def test_check_optimizers_legality(self): | ||||
""" | """ | ||||
测试传入 optimizer list 的表现 | 测试传入 optimizer list 的表现 | ||||
@@ -65,6 +67,7 @@ class TestPaddleDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_dataloader_legality_in_train(self): | def test_check_dataloader_legality_in_train(self): | ||||
""" | """ | ||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | ||||
@@ -85,6 +88,7 @@ class TestPaddleDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
""" | """ | ||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | ||||
@@ -122,6 +126,7 @@ class TestPaddleDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
@pytest.mark.paddle | |||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
""" | """ | ||||
测试 tensor_to_numeric 函数 | 测试 tensor_to_numeric 函数 | ||||
@@ -175,6 +180,7 @@ class TestPaddleDriverFunctions: | |||||
assert r == d.tolist() | assert r == d.tolist() | ||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | ||||
@pytest.mark.paddle | |||||
def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
""" | """ | ||||
测试 set_model_mode 函数 | 测试 set_model_mode 函数 | ||||
@@ -187,6 +193,7 @@ class TestPaddleDriverFunctions: | |||||
with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
self.driver.set_model_mode("test") | self.driver.set_model_mode("test") | ||||
@pytest.mark.paddle | |||||
def test_move_model_to_device_cpu(self): | def test_move_model_to_device_cpu(self): | ||||
""" | """ | ||||
测试 move_model_to_device 函数 | 测试 move_model_to_device 函数 | ||||
@@ -194,6 +201,7 @@ class TestPaddleDriverFunctions: | |||||
PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") | ||||
assert self.driver.model.linear1.weight.place.is_cpu_place() | assert self.driver.model.linear1.weight.place.is_cpu_place() | ||||
@pytest.mark.paddle | |||||
def test_move_model_to_device_gpu(self): | def test_move_model_to_device_gpu(self): | ||||
""" | """ | ||||
测试 move_model_to_device 函数 | 测试 move_model_to_device 函数 | ||||
@@ -202,6 +210,7 @@ class TestPaddleDriverFunctions: | |||||
assert self.driver.model.linear1.weight.place.is_gpu_place() | assert self.driver.model.linear1.weight.place.is_gpu_place() | ||||
assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 | assert self.driver.model.linear1.weight.place.gpu_device_id() == 0 | ||||
@pytest.mark.paddle | |||||
def test_worker_init_function(self): | def test_worker_init_function(self): | ||||
""" | """ | ||||
测试 worker_init_function | 测试 worker_init_function | ||||
@@ -210,6 +219,7 @@ class TestPaddleDriverFunctions: | |||||
# TODO:正确性 | # TODO:正确性 | ||||
PaddleSingleDriver.worker_init_function(0) | PaddleSingleDriver.worker_init_function(0) | ||||
@pytest.mark.paddle | |||||
def test_set_deterministic_dataloader(self): | def test_set_deterministic_dataloader(self): | ||||
""" | """ | ||||
测试 set_deterministic_dataloader | 测试 set_deterministic_dataloader | ||||
@@ -219,6 +229,7 @@ class TestPaddleDriverFunctions: | |||||
dataloader = DataLoader(PaddleNormalDataset()) | dataloader = DataLoader(PaddleNormalDataset()) | ||||
self.driver.set_deterministic_dataloader(dataloader) | self.driver.set_deterministic_dataloader(dataloader) | ||||
@pytest.mark.paddle | |||||
def test_set_sampler_epoch(self): | def test_set_sampler_epoch(self): | ||||
""" | """ | ||||
测试 set_sampler_epoch | 测试 set_sampler_epoch | ||||
@@ -228,6 +239,7 @@ class TestPaddleDriverFunctions: | |||||
dataloader = DataLoader(PaddleNormalDataset()) | dataloader = DataLoader(PaddleNormalDataset()) | ||||
self.driver.set_sampler_epoch(dataloader, 0) | self.driver.set_sampler_epoch(dataloader, 0) | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize("batch_size", [16]) | @pytest.mark.parametrize("batch_size", [16]) | ||||
@pytest.mark.parametrize("shuffle", [True, False]) | @pytest.mark.parametrize("shuffle", [True, False]) | ||||
@pytest.mark.parametrize("drop_last", [True, False]) | @pytest.mark.parametrize("drop_last", [True, False]) | ||||
@@ -253,6 +265,7 @@ class TestPaddleDriverFunctions: | |||||
assert res.batch_size == batch_size | assert res.batch_size == batch_size | ||||
assert res.drop_last == drop_last | assert res.drop_last == drop_last | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize("batch_size", [16]) | @pytest.mark.parametrize("batch_size", [16]) | ||||
@pytest.mark.parametrize("shuffle", [True, False]) | @pytest.mark.parametrize("shuffle", [True, False]) | ||||
@pytest.mark.parametrize("drop_last", [True, False]) | @pytest.mark.parametrize("drop_last", [True, False]) | ||||
@@ -281,6 +294,7 @@ class TestPaddleDriverFunctions: | |||||
assert res.batch_size == batch_size | assert res.batch_size == batch_size | ||||
assert res.drop_last == drop_last | assert res.drop_last == drop_last | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize("batch_size", [16]) | @pytest.mark.parametrize("batch_size", [16]) | ||||
@pytest.mark.parametrize("shuffle", [True, False]) | @pytest.mark.parametrize("shuffle", [True, False]) | ||||
@pytest.mark.parametrize("drop_last", [True, False]) | @pytest.mark.parametrize("drop_last", [True, False]) | ||||
@@ -311,6 +325,7 @@ class TestPaddleDriverFunctions: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.paddle | |||||
class TestSingleDeviceFunction: | class TestSingleDeviceFunction: | ||||
""" | """ | ||||
测试其它函数的测试例 | 测试其它函数的测试例 | ||||
@@ -345,6 +360,7 @@ class TestSingleDeviceFunction: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.paddle | |||||
class TestSetDistReproDataloader: | class TestSetDistReproDataloader: | ||||
""" | """ | ||||
专门测试 set_dist_repro_dataloader 函数的类 | 专门测试 set_dist_repro_dataloader 函数的类 | ||||
@@ -541,6 +557,7 @@ def prepare_test_save_load(): | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | ||||
return driver1, driver2, dataloader | return driver1, driver2, dataloader | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict): | def test_save_and_load_model(prepare_test_save_load, only_state_dict): | ||||
""" | """ | ||||
@@ -570,6 +587,7 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
rank_zero_rm(path + ".pdiparams.info") | rank_zero_rm(path + ".pdiparams.info") | ||||
rank_zero_rm(path + ".pdmodel") | rank_zero_rm(path + ".pdmodel") | ||||
@pytest.mark.paddle | |||||
# @pytest.mark.parametrize("only_state_dict", ([True, False])) | # @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("only_state_dict", ([True])) | @pytest.mark.parametrize("only_state_dict", ([True])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
@@ -650,6 +668,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
# @pytest.mark.parametrize("only_state_dict", ([True, False])) | # @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
# TODO 在有迭代且使用了paddle.jit.save的时候会引发段错误,注释掉任意一段都不会出错 | # TODO 在有迭代且使用了paddle.jit.save的时候会引发段错误,注释掉任意一段都不会出错 | ||||
# 但无法在单独的文件中复现 | # 但无法在单独的文件中复现 | ||||
@pytest.mark.paddle | |||||
@pytest.mark.parametrize("only_state_dict", ([True])) | @pytest.mark.parametrize("only_state_dict", ([True])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | def test_save_and_load_with_randomsampler(only_state_dict, fp16): | ||||
@@ -1,3 +1,4 @@ | |||||
import os | |||||
import pytest | import pytest | ||||
from fastNLP.core.drivers.paddle_driver.utils import ( | from fastNLP.core.drivers.paddle_driver.utils import ( | ||||
@@ -23,12 +24,14 @@ from tests.helpers.datasets.paddle_data import PaddleNormalDataset | |||||
("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), | ("3,6,7,8", "6,7,8", "gpu:2", str, "gpu:1"), | ||||
) | ) | ||||
) | ) | ||||
@pytest.mark.paddle | |||||
def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): | def test_get_device_from_visible_str(user_visible_devices, cuda_visible_devices, device, output_type, correct): | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices | os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices | ||||
os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices | os.environ["USER_CUDA_VISIBLE_DEVICES"] = user_visible_devices | ||||
res = get_device_from_visible(device, output_type) | res = get_device_from_visible(device, output_type) | ||||
assert res == correct | assert res == correct | ||||
@pytest.mark.paddle | |||||
def test_replace_batch_sampler(): | def test_replace_batch_sampler(): | ||||
dataset = PaddleNormalDataset(10) | dataset = PaddleNormalDataset(10) | ||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
@@ -42,6 +45,7 @@ def test_replace_batch_sampler(): | |||||
assert len(replaced_loader.dataset) == len(dataset) | assert len(replaced_loader.dataset) == len(dataset) | ||||
assert replaced_loader.batch_sampler.batch_size == 16 | assert replaced_loader.batch_sampler.batch_size == 16 | ||||
@pytest.mark.paddle | |||||
def test_replace_sampler(): | def test_replace_sampler(): | ||||
dataset = PaddleNormalDataset(10) | dataset = PaddleNormalDataset(10) | ||||
dataloader = DataLoader(dataset, batch_size=32) | dataloader = DataLoader(dataset, batch_size=32) | ||||
@@ -1,31 +0,0 @@ | |||||
import sys | |||||
sys.path.append("../../../../") | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
import torch | |||||
device = [0, 1] | |||||
torch_model = TorchNormalModel_Classification_1(10, 10) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in device] | |||||
driver = TorchDDPDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
fp16=False | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() | |||||
print("-----------first--------------") | |||||
device = [0, 2] | |||||
torch_model = TorchNormalModel_Classification_1(10, 10) | |||||
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) | |||||
device = [torch.device(i) for i in device] | |||||
driver = TorchDDPDriver( | |||||
model=torch_model, | |||||
parallel_device=device, | |||||
fp16=False | |||||
) | |||||
driver.set_optimizers(torch_opt) | |||||
driver.setup() |
@@ -1,4 +1,5 @@ | |||||
import os | import os | ||||
import pytest | |||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
@@ -62,6 +62,7 @@ class TestTorchDriverFunctions: | |||||
model = TorchNormalModel_Classification_1(10, 32) | model = TorchNormalModel_Classification_1(10, 32) | ||||
self.driver = TorchSingleDriver(model, device="cpu") | self.driver = TorchSingleDriver(model, device="cpu") | ||||
@pytest.mark.torchpaddle | |||||
def test_check_single_optimizer_legality(self): | def test_check_single_optimizer_legality(self): | ||||
""" | """ | ||||
测试传入单个 optimizer 时的表现 | 测试传入单个 optimizer 时的表现 | ||||
@@ -81,6 +82,7 @@ class TestTorchDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
self.driver.set_optimizers(optimizer) | self.driver.set_optimizers(optimizer) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_optimizers_legality(self): | def test_check_optimizers_legality(self): | ||||
""" | """ | ||||
测试传入 optimizer list 的表现 | 测试传入 optimizer list 的表现 | ||||
@@ -104,6 +106,7 @@ class TestTorchDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
self.driver.set_optimizers(optimizers) | self.driver.set_optimizers(optimizers) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_dataloader_legality_in_train(self): | def test_check_dataloader_legality_in_train(self): | ||||
""" | """ | ||||
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 | ||||
@@ -119,6 +122,7 @@ class TestTorchDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) | ||||
@pytest.mark.torchpaddle | |||||
def test_check_dataloader_legality_in_test(self): | def test_check_dataloader_legality_in_test(self): | ||||
""" | """ | ||||
测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 | ||||
@@ -148,6 +152,7 @@ class TestTorchDriverFunctions: | |||||
with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) | ||||
@pytest.mark.torch | |||||
def test_tensor_to_numeric(self): | def test_tensor_to_numeric(self): | ||||
""" | """ | ||||
测试 tensor_to_numeric 函数 | 测试 tensor_to_numeric 函数 | ||||
@@ -201,6 +206,7 @@ class TestTorchDriverFunctions: | |||||
assert r == d.tolist() | assert r == d.tolist() | ||||
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() | ||||
@pytest.mark.torch | |||||
def test_set_model_mode(self): | def test_set_model_mode(self): | ||||
""" | """ | ||||
测试set_model_mode函数 | 测试set_model_mode函数 | ||||
@@ -213,6 +219,7 @@ class TestTorchDriverFunctions: | |||||
with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
self.driver.set_model_mode("test") | self.driver.set_model_mode("test") | ||||
@pytest.mark.torch | |||||
def test_move_model_to_device_cpu(self): | def test_move_model_to_device_cpu(self): | ||||
""" | """ | ||||
测试move_model_to_device函数 | 测试move_model_to_device函数 | ||||
@@ -220,6 +227,7 @@ class TestTorchDriverFunctions: | |||||
TorchSingleDriver.move_model_to_device(self.driver.model, "cpu") | TorchSingleDriver.move_model_to_device(self.driver.model, "cpu") | ||||
assert self.driver.model.linear1.weight.device.type == "cpu" | assert self.driver.model.linear1.weight.device.type == "cpu" | ||||
@pytest.mark.torch | |||||
def test_move_model_to_device_gpu(self): | def test_move_model_to_device_gpu(self): | ||||
""" | """ | ||||
测试move_model_to_device函数 | 测试move_model_to_device函数 | ||||
@@ -228,6 +236,7 @@ class TestTorchDriverFunctions: | |||||
assert self.driver.model.linear1.weight.device.type == "cuda" | assert self.driver.model.linear1.weight.device.type == "cuda" | ||||
assert self.driver.model.linear1.weight.device.index == 0 | assert self.driver.model.linear1.weight.device.index == 0 | ||||
@pytest.mark.torch | |||||
def test_worker_init_function(self): | def test_worker_init_function(self): | ||||
""" | """ | ||||
测试worker_init_function | 测试worker_init_function | ||||
@@ -236,6 +245,7 @@ class TestTorchDriverFunctions: | |||||
# TODO:正确性 | # TODO:正确性 | ||||
TorchSingleDriver.worker_init_function(0) | TorchSingleDriver.worker_init_function(0) | ||||
@pytest.mark.torch | |||||
def test_set_deterministic_dataloader(self): | def test_set_deterministic_dataloader(self): | ||||
""" | """ | ||||
测试set_deterministic_dataloader | 测试set_deterministic_dataloader | ||||
@@ -245,6 +255,7 @@ class TestTorchDriverFunctions: | |||||
dataloader = DataLoader(TorchNormalDataset()) | dataloader = DataLoader(TorchNormalDataset()) | ||||
self.driver.set_deterministic_dataloader(dataloader) | self.driver.set_deterministic_dataloader(dataloader) | ||||
@pytest.mark.torch | |||||
def test_set_sampler_epoch(self): | def test_set_sampler_epoch(self): | ||||
""" | """ | ||||
测试set_sampler_epoch | 测试set_sampler_epoch | ||||
@@ -254,6 +265,7 @@ class TestTorchDriverFunctions: | |||||
dataloader = DataLoader(TorchNormalDataset()) | dataloader = DataLoader(TorchNormalDataset()) | ||||
self.driver.set_sampler_epoch(dataloader, 0) | self.driver.set_sampler_epoch(dataloader, 0) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("batch_size", [16]) | @pytest.mark.parametrize("batch_size", [16]) | ||||
@pytest.mark.parametrize("shuffle", [True, False]) | @pytest.mark.parametrize("shuffle", [True, False]) | ||||
@pytest.mark.parametrize("drop_last", [True, False]) | @pytest.mark.parametrize("drop_last", [True, False]) | ||||
@@ -279,6 +291,7 @@ class TestTorchDriverFunctions: | |||||
assert res.batch_size == batch_size | assert res.batch_size == batch_size | ||||
assert res.drop_last == drop_last | assert res.drop_last == drop_last | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("batch_size", [16]) | @pytest.mark.parametrize("batch_size", [16]) | ||||
@pytest.mark.parametrize("shuffle", [True, False]) | @pytest.mark.parametrize("shuffle", [True, False]) | ||||
@pytest.mark.parametrize("drop_last", [True, False]) | @pytest.mark.parametrize("drop_last", [True, False]) | ||||
@@ -300,6 +313,7 @@ class TestTorchDriverFunctions: | |||||
assert res.batch_size == batch_size | assert res.batch_size == batch_size | ||||
assert res.drop_last == drop_last | assert res.drop_last == drop_last | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("batch_size", [16]) | @pytest.mark.parametrize("batch_size", [16]) | ||||
@pytest.mark.parametrize("shuffle", [True, False]) | @pytest.mark.parametrize("shuffle", [True, False]) | ||||
@pytest.mark.parametrize("drop_last", [True, False]) | @pytest.mark.parametrize("drop_last", [True, False]) | ||||
@@ -325,6 +339,7 @@ class TestTorchDriverFunctions: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.torch | |||||
class TestSingleDeviceFunction: | class TestSingleDeviceFunction: | ||||
""" | """ | ||||
测试其它函数的测试例 | 测试其它函数的测试例 | ||||
@@ -359,6 +374,7 @@ class TestSingleDeviceFunction: | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.torch | |||||
class TestSetDistReproDataloader: | class TestSetDistReproDataloader: | ||||
""" | """ | ||||
专门测试 set_dist_repro_dataloader 函数的类 | 专门测试 set_dist_repro_dataloader 函数的类 | ||||
@@ -534,6 +550,7 @@ def prepare_test_save_load(): | |||||
driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) | ||||
return driver1, driver2, dataloader | return driver1, driver2, dataloader | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
def test_save_and_load_model(prepare_test_save_load, only_state_dict): | def test_save_and_load_model(prepare_test_save_load, only_state_dict): | ||||
""" | """ | ||||
@@ -555,6 +572,7 @@ def test_save_and_load_model(prepare_test_save_load, only_state_dict): | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | ||||
@@ -623,6 +641,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize("only_state_dict", ([True, False])) | @pytest.mark.parametrize("only_state_dict", ([True, False])) | ||||
@pytest.mark.parametrize("fp16", ([True, False])) | @pytest.mark.parametrize("fp16", ([True, False])) | ||||
def test_save_and_load_with_randomsampler(only_state_dict, fp16): | def test_save_and_load_with_randomsampler(only_state_dict, fp16): | ||||
@@ -1,4 +1,4 @@ | |||||
import unittest | |||||
import pytest | |||||
from fastNLP.modules.mix_modules.mix_module import MixModule | from fastNLP.modules.mix_modules.mix_module import MixModule | ||||
from fastNLP.core.drivers.torch_paddle_driver.torch_paddle_driver import TorchPaddleDriver | from fastNLP.core.drivers.torch_paddle_driver.torch_paddle_driver import TorchPaddleDriver | ||||
@@ -56,10 +56,11 @@ class MixMNISTModel(MixModule): | |||||
def test_step(self, x): | def test_step(self, x): | ||||
return self.forward(x) | return self.forward(x) | ||||
class TestMNIST(unittest.TestCase): | |||||
@pytest.mark.torchpaddle | |||||
class TestMNIST: | |||||
@classmethod | @classmethod | ||||
def setUpClass(self): | |||||
def setup_class(self): | |||||
self.train_dataset = paddle.vision.datasets.MNIST(mode='train') | self.train_dataset = paddle.vision.datasets.MNIST(mode='train') | ||||
self.test_dataset = paddle.vision.datasets.MNIST(mode='test') | self.test_dataset = paddle.vision.datasets.MNIST(mode='test') | ||||
@@ -70,7 +71,7 @@ class TestMNIST(unittest.TestCase): | |||||
self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) | self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) | ||||
def setUp(self): | |||||
def setup_method(self): | |||||
model = MixMNISTModel() | model = MixMNISTModel() | ||||
self.torch_loss_func = torch.nn.CrossEntropyLoss() | self.torch_loss_func = torch.nn.CrossEntropyLoss() | ||||
@@ -118,4 +119,4 @@ class TestMNIST(unittest.TestCase): | |||||
correct += 1 | correct += 1 | ||||
acc = correct / len(self.test_dataset) | acc = correct / len(self.test_dataset) | ||||
self.assertGreater(acc, 0.85) | |||||
assert acc > 0.85 |
@@ -49,12 +49,12 @@ def test_accuracy_single(): | |||||
# 测试 单机多卡情况下的Accuracy | # 测试 单机多卡情况下的Accuracy | ||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
def test_accuracy_ddp(): | |||||
launcher = FleetLauncher(devices=[0, 1]) | |||||
launcher.launch() | |||||
role = role_maker.PaddleCloudRoleMaker(is_collective=True) | |||||
fleet.init(role) | |||||
if fleet.is_server(): | |||||
pass | |||||
elif fleet.is_worker(): | |||||
print(os.getenv("PADDLE_TRAINER_ID")) | |||||
# def test_accuracy_ddp(): | |||||
# launcher = FleetLauncher(devices=[0, 1]) | |||||
# launcher.launch() | |||||
# role = role_maker.PaddleCloudRoleMaker(is_collective=True) | |||||
# fleet.init(role) | |||||
# if fleet.is_server(): | |||||
# pass | |||||
# elif fleet.is_worker(): | |||||
# print(os.getenv("PADDLE_TRAINER_ID")) |
@@ -1,26 +0,0 @@ | |||||
from fastNLP.core.metrics.metric import Metric | |||||
from collections import defaultdict | |||||
from functools import partial | |||||
import unittest | |||||
class MyMetric(Metric): | |||||
def __init__(self, backend='auto', | |||||
aggregate_when_get_metric: bool = False): | |||||
super(MyMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||||
self.tp = defaultdict(partial(self.register_element, aggregate_method='sum')) | |||||
def update(self, item): | |||||
self.tp['1'] += item | |||||
class TestMetric(unittest.TestCase): | |||||
def test_va1(self): | |||||
my = MyMetric() | |||||
my.update(1) | |||||
print(my.tp['1']) |
@@ -29,6 +29,8 @@ class TestUnrepeatedSampler: | |||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
@pytest.mark.parametrize('shuffle', [False, True]) | @pytest.mark.parametrize('shuffle', [False, True]) | ||||
def test_multi(self, num_replicas, num_of_data, shuffle): | def test_multi(self, num_replicas, num_of_data, shuffle): | ||||
if num_replicas > num_of_data: | |||||
pytest.skip("num_replicas > num_of_data") | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replicas): | for i in range(num_replicas): | ||||
@@ -53,6 +55,8 @@ class TestUnrepeatedSortedSampler: | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replicas, num_of_data): | def test_multi(self, num_replicas, num_of_data): | ||||
if num_replicas > num_of_data: | |||||
pytest.skip("num_replicas > num_of_data") | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replicas): | for i in range(num_replicas): | ||||
@@ -84,6 +88,8 @@ class TestUnrepeatedSequentialSampler: | |||||
@pytest.mark.parametrize('num_replicas', [2, 3]) | @pytest.mark.parametrize('num_replicas', [2, 3]) | ||||
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) | ||||
def test_multi(self, num_replicas, num_of_data): | def test_multi(self, num_replicas, num_of_data): | ||||
if num_replicas > num_of_data: | |||||
pytest.skip("num_replicas > num_of_data") | |||||
data = DatasetWithVaryLength(num_of_data=num_of_data) | data = DatasetWithVaryLength(num_of_data=num_of_data) | ||||
samplers = [] | samplers = [] | ||||
for i in range(num_replicas): | for i in range(num_replicas): | ||||
@@ -1,29 +1,16 @@ | |||||
import time | |||||
import os | import os | ||||
import pytest | import pytest | ||||
from subprocess import Popen, PIPE | |||||
import subprocess | |||||
from io import StringIO | from io import StringIO | ||||
import sys | import sys | ||||
from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
from tests.helpers.common.utils import check_time_elapse | |||||
from fastNLP.core import rank_zero_rm | from fastNLP.core import rank_zero_rm | ||||
def get_subprocess_results(cmd): | def get_subprocess_results(cmd): | ||||
pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) | |||||
output, err = pipe.communicate() | |||||
if output: | |||||
output = output.decode('utf8') | |||||
else: | |||||
output = '' | |||||
if err: | |||||
err = err.decode('utf8') | |||||
else: | |||||
err = '' | |||||
res = output + err | |||||
return res | |||||
output = subprocess.check_output(cmd, shell=True) | |||||
return output.decode('utf8') | |||||
class Capturing(list): | class Capturing(list): | ||||
@@ -48,12 +35,12 @@ class TestCacheResults: | |||||
try: | try: | ||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo(): | def demo(): | ||||
time.sleep(1) | |||||
print("¥") | |||||
return 1 | return 1 | ||||
res = demo() | res = demo() | ||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo() | res = demo() | ||||
assert '¥' not in output[0] | |||||
finally: | finally: | ||||
rank_zero_rm(cache_fp) | rank_zero_rm(cache_fp) | ||||
@@ -63,12 +50,13 @@ class TestCacheResults: | |||||
try: | try: | ||||
@cache_results(cache_fp, _refresh=True) | @cache_results(cache_fp, _refresh=True) | ||||
def demo(): | def demo(): | ||||
time.sleep(1.5) | |||||
print("¥") | |||||
return 1 | return 1 | ||||
res = demo() | res = demo() | ||||
with check_time_elapse(1, op='ge'): | |||||
with Capturing() as output: | |||||
res = demo() | res = demo() | ||||
assert '¥' in output[0] | |||||
finally: | finally: | ||||
rank_zero_rm(cache_fp) | rank_zero_rm(cache_fp) | ||||
@@ -77,19 +65,21 @@ class TestCacheResults: | |||||
try: | try: | ||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo(): | def demo(): | ||||
time.sleep(2) | |||||
print('¥') | |||||
return 1 | return 1 | ||||
with check_time_elapse(1, op='gt'): | |||||
with Capturing() as output: | |||||
res = demo() | res = demo() | ||||
assert '¥' in output[0] | |||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo(): | def demo(): | ||||
time.sleep(2) | |||||
print('¥') | |||||
return 1 | return 1 | ||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo() | res = demo() | ||||
assert '¥' not in output[0] | |||||
finally: | finally: | ||||
rank_zero_rm('demo.pkl') | rank_zero_rm('demo.pkl') | ||||
@@ -98,27 +88,28 @@ class TestCacheResults: | |||||
try: | try: | ||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo(): | def demo(): | ||||
time.sleep(2) | |||||
print('¥') | |||||
return 1 | return 1 | ||||
with check_time_elapse(1, op='gt'): | |||||
with Capturing() as output: | |||||
res = demo() | res = demo() | ||||
assert '¥' in output[0] | |||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo(): | def demo(): | ||||
time.sleep(1) | |||||
print('¥¥') | |||||
return 1 | return 1 | ||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo() | |||||
assert 'is different from its last cache' in output[0] | |||||
with Capturing() as output: | |||||
res = demo() | |||||
assert 'different' in output[0] | |||||
assert '¥' not in output[0] | |||||
# 关闭check_hash应该不warning的 | # 关闭check_hash应该不warning的 | ||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=0) | |||||
assert 'is different from its last cache' not in output[0] | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=0) | |||||
assert 'different' not in output[0] | |||||
assert '¥' not in output[0] | |||||
finally: | finally: | ||||
rank_zero_rm('demo.pkl') | rank_zero_rm('demo.pkl') | ||||
@@ -128,28 +119,29 @@ class TestCacheResults: | |||||
try: | try: | ||||
@cache_results(cache_fp, _check_hash=False) | @cache_results(cache_fp, _check_hash=False) | ||||
def demo(): | def demo(): | ||||
time.sleep(2) | |||||
print('¥') | |||||
return 1 | return 1 | ||||
with check_time_elapse(1, op='gt'): | |||||
res = demo() | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=0) | |||||
assert '¥' in output[0] | |||||
@cache_results(cache_fp, _check_hash=False) | @cache_results(cache_fp, _check_hash=False) | ||||
def demo(): | def demo(): | ||||
time.sleep(1) | |||||
print('¥¥') | |||||
return 1 | return 1 | ||||
# 默认不会check | # 默认不会check | ||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo() | |||||
assert 'is different from its last cache' not in output[0] | |||||
with Capturing() as output: | |||||
res = demo() | |||||
assert 'different' not in output[0] | |||||
assert '¥' not in output[0] | |||||
# check也可以 | # check也可以 | ||||
with check_time_elapse(1, op='lt'): | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=True) | |||||
assert 'is different from its last cache' in output[0] | |||||
with Capturing() as output: | |||||
res = demo(_check_hash=True) | |||||
assert 'different' in output[0] | |||||
assert '¥' not in output[0] | |||||
finally: | finally: | ||||
rank_zero_rm('demo.pkl') | rank_zero_rm('demo.pkl') | ||||
@@ -159,22 +151,22 @@ class TestCacheResults: | |||||
cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
test_type = 'func_refer_fun_change' | test_type = 'func_refer_fun_change' | ||||
try: | try: | ||||
with check_time_elapse(3, op='gt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert "¥" in res | |||||
# 引用的function没有变化 | # 引用的function没有变化 | ||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'Read cache from' in res | |||||
assert 'is different from its last cache' not in res | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert "¥" not in res | |||||
assert 'Read' in res | |||||
assert 'different' not in res | |||||
# 引用的function有变化 | # 引用的function有变化 | ||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'is different from its last cache' in res | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
res = get_subprocess_results(cmd) | |||||
assert "¥" not in res | |||||
assert 'different' in res | |||||
finally: | finally: | ||||
rank_zero_rm(cache_fp) | rank_zero_rm(cache_fp) | ||||
@@ -184,22 +176,21 @@ class TestCacheResults: | |||||
cache_fp = 'demo.pkl' | cache_fp = 'demo.pkl' | ||||
test_type = 'refer_class_method_change' | test_type = 'refer_class_method_change' | ||||
try: | try: | ||||
with check_time_elapse(3, op='gt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert "¥" in res | |||||
# 引用的class没有变化 | # 引用的class没有变化 | ||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'Read cache from' in res | |||||
assert 'is different from its last cache' not in res | |||||
# 引用的class有变化 | |||||
with check_time_elapse(2, op='lt'): | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'is different from its last cache' in res | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'Read' in res | |||||
assert 'different' not in res | |||||
assert "¥" not in res | |||||
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
res = get_subprocess_results(cmd) | |||||
assert 'different' in res | |||||
assert "¥" not in res | |||||
finally: | finally: | ||||
rank_zero_rm(cache_fp) | rank_zero_rm(cache_fp) | ||||
@@ -278,8 +269,8 @@ if __name__ == '__main__': | |||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo_refer_other_func(): | def demo_refer_other_func(): | ||||
time.sleep(3) | |||||
b = demo() | b = demo() | ||||
print("¥") | |||||
return b | return b | ||||
res = demo_refer_other_func() | res = demo_refer_other_func() | ||||
@@ -296,7 +287,7 @@ if __name__ == '__main__': | |||||
# pdb.set_trace() | # pdb.set_trace() | ||||
@cache_results(cache_fp) | @cache_results(cache_fp) | ||||
def demo_func(): | def demo_func(): | ||||
time.sleep(3) | |||||
print("¥") | |||||
b = demo.demo() | b = demo.demo() | ||||
return b | return b | ||||
@@ -1,4 +1,3 @@ | |||||
import unittest | |||||
import pytest | import pytest | ||||
import paddle | import paddle | ||||
@@ -12,21 +11,21 @@ from fastNLP.core.utils.paddle_utils import paddle_to, paddle_move_data_to_devic | |||||
############################################################################ | ############################################################################ | ||||
@pytest.mark.paddle | @pytest.mark.paddle | ||||
class PaddleToDeviceTestCase(unittest.TestCase): | |||||
class TestPaddleToDevice: | |||||
def test_case(self): | def test_case(self): | ||||
tensor = paddle.rand((4, 5)) | tensor = paddle.rand((4, 5)) | ||||
res = paddle_to(tensor, "gpu") | res = paddle_to(tensor, "gpu") | ||||
self.assertTrue(res.place.is_gpu_place()) | |||||
self.assertEqual(res.place.gpu_device_id(), 0) | |||||
assert res.place.is_gpu_place() | |||||
assert res.place.gpu_device_id() == 0 | |||||
res = paddle_to(tensor, "cpu") | res = paddle_to(tensor, "cpu") | ||||
self.assertTrue(res.place.is_cpu_place()) | |||||
assert res.place.is_cpu_place() | |||||
res = paddle_to(tensor, "gpu:2") | res = paddle_to(tensor, "gpu:2") | ||||
self.assertTrue(res.place.is_gpu_place()) | |||||
self.assertEqual(res.place.gpu_device_id(), 2) | |||||
assert res.place.is_gpu_place() | |||||
assert res.place.gpu_device_id() == 2 | |||||
res = paddle_to(tensor, "gpu:1") | res = paddle_to(tensor, "gpu:1") | ||||
self.assertTrue(res.place.is_gpu_place()) | |||||
self.assertEqual(res.place.gpu_device_id(), 1) | |||||
assert res.place.is_gpu_place() | |||||
assert res.place.gpu_device_id() == 1 | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||
@@ -34,22 +33,22 @@ class PaddleToDeviceTestCase(unittest.TestCase): | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
class PaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
class TestPaddleMoveDataToDevice: | |||||
def check_gpu(self, tensor, idx): | def check_gpu(self, tensor, idx): | ||||
""" | """ | ||||
检查张量是否在指定的设备上的工具函数 | 检查张量是否在指定的设备上的工具函数 | ||||
""" | """ | ||||
self.assertTrue(tensor.place.is_gpu_place()) | |||||
self.assertEqual(tensor.place.gpu_device_id(), idx) | |||||
assert tensor.place.is_gpu_place() | |||||
assert tensor.place.gpu_device_id() == idx | |||||
def check_cpu(self, tensor): | def check_cpu(self, tensor): | ||||
""" | """ | ||||
检查张量是否在cpu上的工具函数 | 检查张量是否在cpu上的工具函数 | ||||
""" | """ | ||||
self.assertTrue(tensor.place.is_cpu_place()) | |||||
assert tensor.place.is_cpu_place() | |||||
def test_tensor_transfer(self): | def test_tensor_transfer(self): | ||||
""" | """ | ||||
@@ -82,22 +81,22 @@ class PaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | ||||
res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | res = paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_cpu(r) | self.check_cpu(r) | ||||
res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 0) | self.check_gpu(r, 0) | ||||
res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | res = paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
@@ -109,22 +108,22 @@ class PaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | ||||
paddle_tuple = tuple(paddle_list) | paddle_tuple = tuple(paddle_list) | ||||
res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | res = paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_cpu(r) | self.check_cpu(r) | ||||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 0) | self.check_gpu(r, 0) | ||||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | res = paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
@@ -145,57 +144,57 @@ class PaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
} | } | ||||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_gpu(res["tensor"], 0) | self.check_gpu(res["tensor"], 0) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.check_gpu(res["dict"]["tensor"], 0) | self.check_gpu(res["dict"]["tensor"], 0) | ||||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu") | res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu") | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_gpu(res["tensor"], 0) | self.check_gpu(res["tensor"], 0) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.check_gpu(res["dict"]["tensor"], 0) | self.check_gpu(res["dict"]["tensor"], 0) | ||||
res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | res = paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_gpu(res["tensor"], 1) | self.check_gpu(res["tensor"], 1) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_gpu(t, 1) | self.check_gpu(t, 1) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_gpu(t, 1) | self.check_gpu(t, 1) | ||||
self.check_gpu(res["dict"]["tensor"], 1) | self.check_gpu(res["dict"]["tensor"], 1) | ||||
res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_cpu(res["tensor"]) | self.check_cpu(res["tensor"]) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_cpu(t) | self.check_cpu(t) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_cpu(t) | self.check_cpu(t) | ||||
self.check_cpu(res["dict"]["tensor"]) | self.check_cpu(res["dict"]["tensor"]) |
@@ -1,5 +1,3 @@ | |||||
import unittest | |||||
import paddle | import paddle | ||||
import pytest | import pytest | ||||
import torch | import torch | ||||
@@ -12,9 +10,8 @@ from fastNLP.core.utils.torch_paddle_utils import torch_paddle_move_data_to_devi | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
# @pytest.mark.paddle | |||||
# @pytest.mark.torch | |||||
class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
@pytest.mark.torchpaddle | |||||
class TestTorchPaddleMoveDataToDevice: | |||||
def check_gpu(self, tensor, idx): | def check_gpu(self, tensor, idx): | ||||
""" | """ | ||||
@@ -22,17 +19,17 @@ class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
""" | """ | ||||
if isinstance(tensor, paddle.Tensor): | if isinstance(tensor, paddle.Tensor): | ||||
self.assertTrue(tensor.place.is_gpu_place()) | |||||
self.assertEqual(tensor.place.gpu_device_id(), idx) | |||||
assert tensor.place.is_gpu_place() | |||||
assert tensor.place.gpu_device_id() == idx | |||||
elif isinstance(tensor, torch.Tensor): | elif isinstance(tensor, torch.Tensor): | ||||
self.assertTrue(tensor.is_cuda) | |||||
self.assertEqual(tensor.device.index, idx) | |||||
assert tensor.is_cuda | |||||
assert tensor.device.index == idx | |||||
def check_cpu(self, tensor): | def check_cpu(self, tensor): | ||||
if isinstance(tensor, paddle.Tensor): | if isinstance(tensor, paddle.Tensor): | ||||
self.assertTrue(tensor.place.is_cpu_place()) | |||||
assert tensor.place.is_cpu_place() | |||||
elif isinstance(tensor, torch.Tensor): | elif isinstance(tensor, torch.Tensor): | ||||
self.assertFalse(tensor.is_cuda) | |||||
assert not tensor.is_cuda | |||||
def test_tensor_transfer(self): | def test_tensor_transfer(self): | ||||
""" | """ | ||||
@@ -63,7 +60,6 @@ class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
self.check_cpu(res) | self.check_cpu(res) | ||||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device=None) | res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:0", data_device=None) | ||||
print(res.device) | |||||
self.check_gpu(res, 0) | self.check_gpu(res, 0) | ||||
res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:1", data_device=None) | res = torch_paddle_move_data_to_device(torch_tensor, device="gpu:1", data_device=None) | ||||
@@ -85,22 +81,22 @@ class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)] | paddle_list = [paddle.rand((6, 4, 2)) for i in range(5)] + [torch.rand((6, 4, 2)) for i in range(5)] | ||||
res = torch_paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | res = torch_paddle_move_data_to_device(paddle_list, device=None, data_device="gpu:1") | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
res = torch_paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | res = torch_paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_cpu(r) | self.check_cpu(r) | ||||
res = torch_paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | res = torch_paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 0) | self.check_gpu(r, 0) | ||||
res = torch_paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | res = torch_paddle_move_data_to_device(paddle_list, device="gpu:1", data_device="cpu") | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
@@ -112,22 +108,22 @@ class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + [torch.rand((6, 4, 2)) for i in range(5)] | paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] + [torch.rand((6, 4, 2)) for i in range(5)] | ||||
paddle_tuple = tuple(paddle_list) | paddle_tuple = tuple(paddle_list) | ||||
res = torch_paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | res = torch_paddle_move_data_to_device(paddle_tuple, device=None, data_device="gpu:1") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
res = torch_paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | res = torch_paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_cpu(r) | self.check_cpu(r) | ||||
res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 0) | self.check_gpu(r, 0) | ||||
res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | res = torch_paddle_move_data_to_device(paddle_tuple, device="gpu:1", data_device="cpu") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for r in res: | for r in res: | ||||
self.check_gpu(r, 1) | self.check_gpu(r, 1) | ||||
@@ -151,57 +147,57 @@ class TorchPaddleMoveDataToDeviceTestCase(unittest.TestCase): | |||||
} | } | ||||
res = torch_paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | res = torch_paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_gpu(res["torch_tensor"], 0) | self.check_gpu(res["torch_tensor"], 0) | ||||
self.check_gpu(res["paddle_tensor"], 0) | self.check_gpu(res["paddle_tensor"], 0) | ||||
self.assertIsInstance(res["torch_list"], list) | |||||
assert isinstance(res["torch_list"], list) | |||||
for t in res["torch_list"]: | for t in res["torch_list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_gpu(t, 0) | self.check_gpu(t, 0) | ||||
self.check_gpu(res["dict"]["torch_tensor"], 0) | self.check_gpu(res["dict"]["torch_tensor"], 0) | ||||
self.check_gpu(res["dict"]["paddle_tensor"], 0) | self.check_gpu(res["dict"]["paddle_tensor"], 0) | ||||
res = torch_paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | res = torch_paddle_move_data_to_device(paddle_dict, device=None, data_device="gpu:1") | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_gpu(res["torch_tensor"], 1) | self.check_gpu(res["torch_tensor"], 1) | ||||
self.check_gpu(res["paddle_tensor"], 1) | self.check_gpu(res["paddle_tensor"], 1) | ||||
self.assertIsInstance(res["torch_list"], list) | |||||
assert isinstance(res["torch_list"], list) | |||||
for t in res["torch_list"]: | for t in res["torch_list"]: | ||||
self.check_gpu(t, 1) | self.check_gpu(t, 1) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_gpu(t, 1) | self.check_gpu(t, 1) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_gpu(t, 1) | self.check_gpu(t, 1) | ||||
self.check_gpu(res["dict"]["torch_tensor"], 1) | self.check_gpu(res["dict"]["torch_tensor"], 1) | ||||
self.check_gpu(res["dict"]["paddle_tensor"], 1) | self.check_gpu(res["dict"]["paddle_tensor"], 1) | ||||
res = torch_paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | res = torch_paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_cpu(res["torch_tensor"]) | self.check_cpu(res["torch_tensor"]) | ||||
self.check_cpu(res["paddle_tensor"]) | self.check_cpu(res["paddle_tensor"]) | ||||
self.assertIsInstance(res["torch_list"], list) | |||||
assert isinstance(res["torch_list"], list) | |||||
for t in res["torch_list"]: | for t in res["torch_list"]: | ||||
self.check_cpu(t) | self.check_cpu(t) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_cpu(t) | self.check_cpu(t) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_cpu(t) | self.check_cpu(t) | ||||
self.check_cpu(res["dict"]["torch_tensor"]) | self.check_cpu(res["dict"]["torch_tensor"]) | ||||
@@ -3,11 +3,11 @@ from contextlib import contextmanager | |||||
@contextmanager | @contextmanager | ||||
def check_time_elapse(seconds, op='lt'): | |||||
def check_time_elapse(seconds:float, op='lt'): | |||||
""" | """ | ||||
检测某一段程序所花费的时间,是否 op 给定的seconds | 检测某一段程序所花费的时间,是否 op 给定的seconds | ||||
:param int seconds: | |||||
:param seconds: | |||||
:param str op: | :param str op: | ||||
:return: | :return: | ||||
""" | """ | ||||
@@ -15,19 +15,15 @@ def check_time_elapse(seconds, op='lt'): | |||||
yield | yield | ||||
end = time.time() | end = time.time() | ||||
if op == 'lt': | if op == 'lt': | ||||
assert end-start < seconds | |||||
assert end-start < seconds, (end-start, seconds) | |||||
elif op == 'gt': | elif op == 'gt': | ||||
assert end-start > seconds | |||||
assert end-start > seconds, (end-start, seconds) | |||||
elif op == 'eq': | elif op == 'eq': | ||||
assert end - start == seconds | |||||
assert end - start == seconds, (end-start, seconds) | |||||
elif op == 'le': | elif op == 'le': | ||||
assert end - start <= seconds | |||||
assert end - start <= seconds, (end-start, seconds) | |||||
elif op == 'ge': | elif op == 'ge': | ||||
assert end - start >= seconds | |||||
assert end - start >= seconds, (end-start, seconds) | |||||
else: | else: | ||||
raise ValueError("Only supports lt,gt,eq,le,ge.") | raise ValueError("Only supports lt,gt,eq,le,ge.") | ||||
@@ -26,9 +26,9 @@ class Paddle2TorchTestCase(unittest.TestCase): | |||||
检查张量设备和梯度情况的工具函数 | 检查张量设备和梯度情况的工具函数 | ||||
""" | """ | ||||
self.assertIsInstance(tensor, torch.Tensor) | |||||
self.assertEqual(tensor.device, torch.device(device)) | |||||
self.assertEqual(tensor.requires_grad, requires_grad) | |||||
assert isinstance(tensor, torch.Tensor) | |||||
assert tensor.device == torch.device(device) | |||||
assert tensor.requires_grad == requires_grad | |||||
def test_gradient(self): | def test_gradient(self): | ||||
""" | """ | ||||
@@ -39,7 +39,7 @@ class Paddle2TorchTestCase(unittest.TestCase): | |||||
y = paddle2torch(x) | y = paddle2torch(x) | ||||
z = 3 * (y ** 2) | z = 3 * (y ** 2) | ||||
z.sum().backward() | z.sum().backward() | ||||
self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30]) | |||||
assert y.grad.tolist() == [6, 12, 18, 24, 30] | |||||
def test_tensor_transfer(self): | def test_tensor_transfer(self): | ||||
""" | """ | ||||
@@ -66,12 +66,12 @@ class Paddle2TorchTestCase(unittest.TestCase): | |||||
paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)] | paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)] | ||||
res = paddle2torch(paddle_list) | res = paddle2torch(paddle_list) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cuda:1", False) | self.check_torch_tensor(t, "cuda:1", False) | ||||
res = paddle2torch(paddle_list, target_device="cpu", no_gradient=False) | res = paddle2torch(paddle_list, target_device="cpu", no_gradient=False) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cpu", True) | self.check_torch_tensor(t, "cpu", True) | ||||
@@ -83,7 +83,7 @@ class Paddle2TorchTestCase(unittest.TestCase): | |||||
paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)] | paddle_list = [paddle.rand((6, 4, 2)).cuda(1) for i in range(10)] | ||||
paddle_tuple = tuple(paddle_list) | paddle_tuple = tuple(paddle_list) | ||||
res = paddle2torch(paddle_tuple) | res = paddle2torch(paddle_tuple) | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cuda:1", False) | self.check_torch_tensor(t, "cuda:1", False) | ||||
@@ -103,15 +103,15 @@ class Paddle2TorchTestCase(unittest.TestCase): | |||||
"string": "test string" | "string": "test string" | ||||
} | } | ||||
res = paddle2torch(paddle_dict) | res = paddle2torch(paddle_dict) | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_torch_tensor(res["tensor"], "cuda:0", False) | self.check_torch_tensor(res["tensor"], "cuda:0", False) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_torch_tensor(t, "cuda:0", False) | self.check_torch_tensor(t, "cuda:0", False) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_torch_tensor(t, "cuda:0", False) | self.check_torch_tensor(t, "cuda:0", False) | ||||
self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", False) | self.check_torch_tensor(res["dict"]["tensor"], "cuda:0", False) | ||||
@@ -130,24 +130,24 @@ class Torch2PaddleTestCase(unittest.TestCase): | |||||
检查得到的paddle张量设备和梯度情况的工具函数 | 检查得到的paddle张量设备和梯度情况的工具函数 | ||||
""" | """ | ||||
self.assertIsInstance(tensor, paddle.Tensor) | |||||
assert isinstance(tensor, paddle.Tensor) | |||||
if device == "cpu": | if device == "cpu": | ||||
self.assertTrue(tensor.place.is_cpu_place()) | |||||
assert tensor.place.is_cpu_place() | |||||
elif device.startswith("gpu"): | elif device.startswith("gpu"): | ||||
paddle_device = paddle.device._convert_to_place(device) | paddle_device = paddle.device._convert_to_place(device) | ||||
self.assertTrue(tensor.place.is_gpu_place()) | |||||
assert tensor.place.is_gpu_place() | |||||
if hasattr(tensor.place, "gpu_device_id"): | if hasattr(tensor.place, "gpu_device_id"): | ||||
# paddle中,有两种Place | # paddle中,有两种Place | ||||
# paddle.fluid.core.Place是创建Tensor时使用的类型 | # paddle.fluid.core.Place是创建Tensor时使用的类型 | ||||
# 有函数gpu_device_id获取设备 | # 有函数gpu_device_id获取设备 | ||||
self.assertEqual(tensor.place.gpu_device_id(), paddle_device.get_device_id()) | |||||
assert tensor.place.gpu_device_id() == paddle_device.get_device_id() | |||||
else: | else: | ||||
# 通过_convert_to_place得到的是paddle.CUDAPlace | # 通过_convert_to_place得到的是paddle.CUDAPlace | ||||
# 通过get_device_id获取设备 | # 通过get_device_id获取设备 | ||||
self.assertEqual(tensor.place.get_device_id(), paddle_device.get_device_id()) | |||||
assert tensor.place.get_device_id() == paddle_device.get_device_id() | |||||
else: | else: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
self.assertEqual(tensor.stop_gradient, stop_gradient) | |||||
assert tensor.stop_gradient == stop_gradient | |||||
def test_gradient(self): | def test_gradient(self): | ||||
""" | """ | ||||
@@ -158,7 +158,7 @@ class Torch2PaddleTestCase(unittest.TestCase): | |||||
y = torch2paddle(x) | y = torch2paddle(x) | ||||
z = 3 * (y ** 2) | z = 3 * (y ** 2) | ||||
z.sum().backward() | z.sum().backward() | ||||
self.assertListEqual(y.grad.tolist(), [6, 12, 18, 24, 30]) | |||||
assert y.grad.tolist() == [6, 12, 18, 24, 30] | |||||
def test_tensor_transfer(self): | def test_tensor_transfer(self): | ||||
""" | """ | ||||
@@ -185,12 +185,12 @@ class Torch2PaddleTestCase(unittest.TestCase): | |||||
torch_list = [torch.rand(6, 4, 2) for i in range(10)] | torch_list = [torch.rand(6, 4, 2) for i in range(10)] | ||||
res = torch2paddle(torch_list) | res = torch2paddle(torch_list) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_paddle_tensor(t, "cpu", True) | self.check_paddle_tensor(t, "cpu", True) | ||||
res = torch2paddle(torch_list, target_device="gpu:1", no_gradient=False) | res = torch2paddle(torch_list, target_device="gpu:1", no_gradient=False) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_paddle_tensor(t, "gpu:1", False) | self.check_paddle_tensor(t, "gpu:1", False) | ||||
@@ -202,7 +202,7 @@ class Torch2PaddleTestCase(unittest.TestCase): | |||||
torch_list = [torch.rand(6, 4, 2) for i in range(10)] | torch_list = [torch.rand(6, 4, 2) for i in range(10)] | ||||
torch_tuple = tuple(torch_list) | torch_tuple = tuple(torch_list) | ||||
res = torch2paddle(torch_tuple, target_device="cpu") | res = torch2paddle(torch_tuple, target_device="cpu") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for t in res: | for t in res: | ||||
self.check_paddle_tensor(t, "cpu", True) | self.check_paddle_tensor(t, "cpu", True) | ||||
@@ -222,15 +222,15 @@ class Torch2PaddleTestCase(unittest.TestCase): | |||||
"string": "test string" | "string": "test string" | ||||
} | } | ||||
res = torch2paddle(torch_dict) | res = torch2paddle(torch_dict) | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_paddle_tensor(res["tensor"], "cpu", True) | self.check_paddle_tensor(res["tensor"], "cpu", True) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_paddle_tensor(t, "cpu", True) | self.check_paddle_tensor(t, "cpu", True) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_paddle_tensor(t, "cpu", True) | self.check_paddle_tensor(t, "cpu", True) | ||||
self.check_paddle_tensor(res["dict"]["tensor"], "cpu", True) | self.check_paddle_tensor(res["dict"]["tensor"], "cpu", True) | ||||
@@ -249,12 +249,12 @@ class Jittor2TorchTestCase(unittest.TestCase): | |||||
检查得到的torch张量的工具函数 | 检查得到的torch张量的工具函数 | ||||
""" | """ | ||||
self.assertIsInstance(tensor, torch.Tensor) | |||||
assert isinstance(tensor, torch.Tensor) | |||||
if device == "cpu": | if device == "cpu": | ||||
self.assertFalse(tensor.is_cuda) | |||||
assert not tensor.is_cuda | |||||
else: | else: | ||||
self.assertEqual(tensor.device, torch.device(device)) | |||||
self.assertEqual(tensor.requires_grad, requires_grad) | |||||
assert tensor.device == torch.device(device) | |||||
assert tensor.requires_grad == requires_grad | |||||
def test_var_transfer(self): | def test_var_transfer(self): | ||||
""" | """ | ||||
@@ -281,12 +281,12 @@ class Jittor2TorchTestCase(unittest.TestCase): | |||||
jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] | jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] | ||||
res = jittor2torch(jittor_list) | res = jittor2torch(jittor_list) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cpu", True) | self.check_torch_tensor(t, "cpu", True) | ||||
res = jittor2torch(jittor_list, target_device="cuda:1", no_gradient=False) | res = jittor2torch(jittor_list, target_device="cuda:1", no_gradient=False) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cuda:1", True) | self.check_torch_tensor(t, "cuda:1", True) | ||||
@@ -298,7 +298,7 @@ class Jittor2TorchTestCase(unittest.TestCase): | |||||
jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] | jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] | ||||
jittor_tuple = tuple(jittor_list) | jittor_tuple = tuple(jittor_list) | ||||
res = jittor2torch(jittor_tuple, target_device="cpu") | res = jittor2torch(jittor_tuple, target_device="cpu") | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for t in res: | for t in res: | ||||
self.check_torch_tensor(t, "cpu", True) | self.check_torch_tensor(t, "cpu", True) | ||||
@@ -318,15 +318,15 @@ class Jittor2TorchTestCase(unittest.TestCase): | |||||
"string": "test string" | "string": "test string" | ||||
} | } | ||||
res = jittor2torch(jittor_dict) | res = jittor2torch(jittor_dict) | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_torch_tensor(res["tensor"], "cpu", True) | self.check_torch_tensor(res["tensor"], "cpu", True) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_torch_tensor(t, "cpu", True) | self.check_torch_tensor(t, "cpu", True) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_torch_tensor(t, "cpu", True) | self.check_torch_tensor(t, "cpu", True) | ||||
self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | self.check_torch_tensor(res["dict"]["tensor"], "cpu", True) | ||||
@@ -345,8 +345,8 @@ class Torch2JittorTestCase(unittest.TestCase): | |||||
检查得到的Jittor Var梯度情况的工具函数 | 检查得到的Jittor Var梯度情况的工具函数 | ||||
""" | """ | ||||
self.assertIsInstance(var, jittor.Var) | |||||
self.assertEqual(var.requires_grad, requires_grad) | |||||
assert isinstance(var, jittor.Var) | |||||
assert var.requires_grad == requires_grad | |||||
def test_gradient(self): | def test_gradient(self): | ||||
""" | """ | ||||
@@ -357,7 +357,7 @@ class Torch2JittorTestCase(unittest.TestCase): | |||||
y = torch2jittor(x) | y = torch2jittor(x) | ||||
z = 3 * (y ** 2) | z = 3 * (y ** 2) | ||||
grad = jittor.grad(z, y) | grad = jittor.grad(z, y) | ||||
self.assertListEqual(grad.tolist(), [6.0, 12.0, 18.0, 24.0, 30.0]) | |||||
assert grad.tolist() == [6.0, 12.0, 18.0, 24.0, 30.0] | |||||
def test_tensor_transfer(self): | def test_tensor_transfer(self): | ||||
""" | """ | ||||
@@ -384,12 +384,12 @@ class Torch2JittorTestCase(unittest.TestCase): | |||||
torch_list = [torch.rand((6, 4, 2)) for i in range(10)] | torch_list = [torch.rand((6, 4, 2)) for i in range(10)] | ||||
res = torch2jittor(torch_list) | res = torch2jittor(torch_list) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_jittor_var(t, False) | self.check_jittor_var(t, False) | ||||
res = torch2jittor(torch_list, no_gradient=False) | res = torch2jittor(torch_list, no_gradient=False) | ||||
self.assertIsInstance(res, list) | |||||
assert isinstance(res, list) | |||||
for t in res: | for t in res: | ||||
self.check_jittor_var(t, True) | self.check_jittor_var(t, True) | ||||
@@ -401,7 +401,7 @@ class Torch2JittorTestCase(unittest.TestCase): | |||||
torch_list = [torch.rand((6, 4, 2)) for i in range(10)] | torch_list = [torch.rand((6, 4, 2)) for i in range(10)] | ||||
torch_tuple = tuple(torch_list) | torch_tuple = tuple(torch_list) | ||||
res = torch2jittor(torch_tuple) | res = torch2jittor(torch_tuple) | ||||
self.assertIsInstance(res, tuple) | |||||
assert isinstance(res, tuple) | |||||
for t in res: | for t in res: | ||||
self.check_jittor_var(t, False) | self.check_jittor_var(t, False) | ||||
@@ -421,15 +421,15 @@ class Torch2JittorTestCase(unittest.TestCase): | |||||
"string": "test string" | "string": "test string" | ||||
} | } | ||||
res = torch2jittor(torch_dict) | res = torch2jittor(torch_dict) | ||||
self.assertIsInstance(res, dict) | |||||
assert isinstance(res, dict) | |||||
self.check_jittor_var(res["tensor"], False) | self.check_jittor_var(res["tensor"], False) | ||||
self.assertIsInstance(res["list"], list) | |||||
assert isinstance(res["list"], list) | |||||
for t in res["list"]: | for t in res["list"]: | ||||
self.check_jittor_var(t, False) | self.check_jittor_var(t, False) | ||||
self.assertIsInstance(res["int"], int) | |||||
self.assertIsInstance(res["string"], str) | |||||
self.assertIsInstance(res["dict"], dict) | |||||
self.assertIsInstance(res["dict"]["list"], list) | |||||
assert isinstance(res["int"], int) | |||||
assert isinstance(res["string"], str) | |||||
assert isinstance(res["dict"], dict) | |||||
assert isinstance(res["dict"]["list"], list) | |||||
for t in res["dict"]["list"]: | for t in res["dict"]["list"]: | ||||
self.check_jittor_var(t, False) | self.check_jittor_var(t, False) | ||||
self.check_jittor_var(res["dict"]["tensor"], False) | self.check_jittor_var(res["dict"]["tensor"], False) |
@@ -1,4 +1,4 @@ | |||||
import unittest | |||||
import pytest | |||||
import os | import os | ||||
from itertools import chain | from itertools import chain | ||||
@@ -18,9 +18,9 @@ from fastNLP.core import rank_zero_rm | |||||
# | # | ||||
############################################################################ | ############################################################################ | ||||
class TestMixModule(MixModule): | |||||
class MixModuleForTest(MixModule): | |||||
def __init__(self): | def __init__(self): | ||||
super(TestMixModule, self).__init__() | |||||
super(MixModuleForTest, self).__init__() | |||||
self.torch_fc1 = torch.nn.Linear(10, 10) | self.torch_fc1 = torch.nn.Linear(10, 10) | ||||
self.torch_softmax = torch.nn.Softmax(0) | self.torch_softmax = torch.nn.Softmax(0) | ||||
@@ -33,9 +33,9 @@ class TestMixModule(MixModule): | |||||
self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3) | self.paddle_conv2d1 = paddle.nn.Conv2D(10, 10, 3) | ||||
self.paddle_tensor = paddle.ones((4, 4)) | self.paddle_tensor = paddle.ones((4, 4)) | ||||
class TestTorchModule(torch.nn.Module): | |||||
class TorchModuleForTest(torch.nn.Module): | |||||
def __init__(self): | def __init__(self): | ||||
super(TestTorchModule, self).__init__() | |||||
super(TorchModuleForTest, self).__init__() | |||||
self.torch_fc1 = torch.nn.Linear(10, 10) | self.torch_fc1 = torch.nn.Linear(10, 10) | ||||
self.torch_softmax = torch.nn.Softmax(0) | self.torch_softmax = torch.nn.Softmax(0) | ||||
@@ -43,9 +43,9 @@ class TestTorchModule(torch.nn.Module): | |||||
self.torch_tensor = torch.ones(3, 3) | self.torch_tensor = torch.ones(3, 3) | ||||
self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | self.torch_param = torch.nn.Parameter(torch.ones(4, 4)) | ||||
class TestPaddleModule(paddle.nn.Layer): | |||||
class PaddleModuleForTest(paddle.nn.Layer): | |||||
def __init__(self): | def __init__(self): | ||||
super(TestPaddleModule, self).__init__() | |||||
super(PaddleModuleForTest, self).__init__() | |||||
self.paddle_fc1 = paddle.nn.Linear(10, 10) | self.paddle_fc1 = paddle.nn.Linear(10, 10) | ||||
self.paddle_softmax = paddle.nn.Softmax(0) | self.paddle_softmax = paddle.nn.Softmax(0) | ||||
@@ -53,13 +53,14 @@ class TestPaddleModule(paddle.nn.Layer): | |||||
self.paddle_tensor = paddle.ones((4, 4)) | self.paddle_tensor = paddle.ones((4, 4)) | ||||
class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
@pytest.mark.torchpaddle | |||||
class TestTorchPaddleMixModule: | |||||
def setUp(self): | |||||
def setup_method(self): | |||||
self.model = TestMixModule() | |||||
self.torch_model = TestTorchModule() | |||||
self.paddle_model = TestPaddleModule() | |||||
self.model = MixModuleForTest() | |||||
self.torch_model = TorchModuleForTest() | |||||
self.paddle_model = PaddleModuleForTest() | |||||
def test_to(self): | def test_to(self): | ||||
""" | """ | ||||
@@ -110,7 +111,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()): | for value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()): | ||||
params.append(value) | params.append(value) | ||||
self.assertEqual(len(params), len(mix_params)) | |||||
assert len(params) == len(mix_params) | |||||
def test_named_parameters(self): | def test_named_parameters(self): | ||||
""" | """ | ||||
@@ -126,7 +127,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()): | for name, value in chain(self.torch_model.named_parameters(), self.paddle_model.named_parameters()): | ||||
param_names.append(name) | param_names.append(name) | ||||
self.assertListEqual(sorted(param_names), sorted(mix_param_names)) | |||||
assert sorted(param_names) == sorted(mix_param_names) | |||||
def test_torch_named_parameters(self): | def test_torch_named_parameters(self): | ||||
""" | """ | ||||
@@ -142,7 +143,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
for name, value in self.torch_model.named_parameters(): | for name, value in self.torch_model.named_parameters(): | ||||
param_names.append(name) | param_names.append(name) | ||||
self.assertListEqual(sorted(param_names), sorted(mix_param_names)) | |||||
assert sorted(param_names) == sorted(mix_param_names) | |||||
def test_paddle_named_parameters(self): | def test_paddle_named_parameters(self): | ||||
""" | """ | ||||
@@ -158,7 +159,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
for name, value in self.paddle_model.named_parameters(): | for name, value in self.paddle_model.named_parameters(): | ||||
param_names.append(name) | param_names.append(name) | ||||
self.assertListEqual(sorted(param_names), sorted(mix_param_names)) | |||||
assert sorted(param_names) == sorted(mix_param_names) | |||||
def test_torch_state_dict(self): | def test_torch_state_dict(self): | ||||
""" | """ | ||||
@@ -167,7 +168,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
torch_dict = self.torch_model.state_dict() | torch_dict = self.torch_model.state_dict() | ||||
mix_dict = self.model.state_dict(backend="torch") | mix_dict = self.model.state_dict(backend="torch") | ||||
self.assertListEqual(sorted(torch_dict.keys()), sorted(mix_dict.keys())) | |||||
assert sorted(torch_dict.keys()) == sorted(mix_dict.keys()) | |||||
def test_paddle_state_dict(self): | def test_paddle_state_dict(self): | ||||
""" | """ | ||||
@@ -177,7 +178,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
mix_dict = self.model.state_dict(backend="paddle") | mix_dict = self.model.state_dict(backend="paddle") | ||||
# TODO 测试程序会显示passed后显示paddle的异常退出信息 | # TODO 测试程序会显示passed后显示paddle的异常退出信息 | ||||
self.assertListEqual(sorted(paddle_dict.keys()), sorted(mix_dict.keys())) | |||||
assert sorted(paddle_dict.keys()) == sorted(mix_dict.keys()) | |||||
def test_state_dict(self): | def test_state_dict(self): | ||||
""" | """ | ||||
@@ -188,7 +189,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
mix_dict = self.model.state_dict() | mix_dict = self.model.state_dict() | ||||
# TODO 测试程序会显示passed后显示paddle的异常退出信息 | # TODO 测试程序会显示passed后显示paddle的异常退出信息 | ||||
self.assertListEqual(sorted(all_dict.keys()), sorted(mix_dict.keys())) | |||||
assert sorted(all_dict.keys()) == sorted(mix_dict.keys()) | |||||
def test_load_state_dict(self): | def test_load_state_dict(self): | ||||
""" | """ | ||||
@@ -196,7 +197,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
""" | """ | ||||
state_dict = self.model.state_dict() | state_dict = self.model.state_dict() | ||||
new_model = TestMixModule() | |||||
new_model = MixModuleForTest() | |||||
new_model.load_state_dict(state_dict) | new_model.load_state_dict(state_dict) | ||||
new_state_dict = new_model.state_dict() | new_state_dict = new_model.state_dict() | ||||
@@ -205,7 +206,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
for name, value in new_state_dict.items(): | for name, value in new_state_dict.items(): | ||||
new_state_dict[name] = value.tolist() | new_state_dict[name] = value.tolist() | ||||
self.assertDictEqual(state_dict, new_state_dict) | |||||
# self.assertDictEqual(state_dict, new_state_dict) | |||||
def test_save_and_load_state_dict(self): | def test_save_and_load_state_dict(self): | ||||
""" | """ | ||||
@@ -214,7 +215,7 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
path = "model" | path = "model" | ||||
try: | try: | ||||
self.model.save_state_dict_to_file(path) | self.model.save_state_dict_to_file(path) | ||||
new_model = TestMixModule() | |||||
new_model = MixModuleForTest() | |||||
new_model.load_state_dict_from_file(path) | new_model.load_state_dict_from_file(path) | ||||
state_dict = self.model.state_dict() | state_dict = self.model.state_dict() | ||||
@@ -225,49 +226,49 @@ class TorchPaddleMixModuleTestCase(unittest.TestCase): | |||||
for name, value in new_state_dict.items(): | for name, value in new_state_dict.items(): | ||||
new_state_dict[name] = value.tolist() | new_state_dict[name] = value.tolist() | ||||
self.assertDictEqual(state_dict, new_state_dict) | |||||
# self.assertDictEqual(state_dict, new_state_dict) | |||||
finally: | finally: | ||||
rank_zero_rm(path) | rank_zero_rm(path) | ||||
def if_device_correct(self, device): | def if_device_correct(self, device): | ||||
self.assertEqual(self.model.torch_fc1.weight.device, self.torch_model.torch_fc1.weight.device) | |||||
self.assertEqual(self.model.torch_conv2d1.weight.device, self.torch_model.torch_fc1.bias.device) | |||||
self.assertEqual(self.model.torch_conv2d1.bias.device, self.torch_model.torch_conv2d1.bias.device) | |||||
self.assertEqual(self.model.torch_tensor.device, self.torch_model.torch_tensor.device) | |||||
self.assertEqual(self.model.torch_param.device, self.torch_model.torch_param.device) | |||||
assert self.model.torch_fc1.weight.device == self.torch_model.torch_fc1.weight.device | |||||
assert self.model.torch_conv2d1.weight.device == self.torch_model.torch_fc1.bias.device | |||||
assert self.model.torch_conv2d1.bias.device == self.torch_model.torch_conv2d1.bias.device | |||||
assert self.model.torch_tensor.device == self.torch_model.torch_tensor.device | |||||
assert self.model.torch_param.device == self.torch_model.torch_param.device | |||||
if device == "cpu": | if device == "cpu": | ||||
self.assertTrue(self.model.paddle_fc1.weight.place.is_cpu_place()) | |||||
self.assertTrue(self.model.paddle_fc1.bias.place.is_cpu_place()) | |||||
self.assertTrue(self.model.paddle_conv2d1.weight.place.is_cpu_place()) | |||||
self.assertTrue(self.model.paddle_conv2d1.bias.place.is_cpu_place()) | |||||
self.assertTrue(self.model.paddle_tensor.place.is_cpu_place()) | |||||
assert self.model.paddle_fc1.weight.place.is_cpu_place() | |||||
assert self.model.paddle_fc1.bias.place.is_cpu_place() | |||||
assert self.model.paddle_conv2d1.weight.place.is_cpu_place() | |||||
assert self.model.paddle_conv2d1.bias.place.is_cpu_place() | |||||
assert self.model.paddle_tensor.place.is_cpu_place() | |||||
elif device.startswith("cuda"): | elif device.startswith("cuda"): | ||||
self.assertTrue(self.model.paddle_fc1.weight.place.is_gpu_place()) | |||||
self.assertTrue(self.model.paddle_fc1.bias.place.is_gpu_place()) | |||||
self.assertTrue(self.model.paddle_conv2d1.weight.place.is_gpu_place()) | |||||
self.assertTrue(self.model.paddle_conv2d1.bias.place.is_gpu_place()) | |||||
self.assertTrue(self.model.paddle_tensor.place.is_gpu_place()) | |||||
self.assertEqual(self.model.paddle_fc1.weight.place.gpu_device_id(), self.paddle_model.paddle_fc1.weight.place.gpu_device_id()) | |||||
self.assertEqual(self.model.paddle_fc1.bias.place.gpu_device_id(), self.paddle_model.paddle_fc1.bias.place.gpu_device_id()) | |||||
self.assertEqual(self.model.paddle_conv2d1.weight.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id()) | |||||
self.assertEqual(self.model.paddle_conv2d1.bias.place.gpu_device_id(), self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id()) | |||||
self.assertEqual(self.model.paddle_tensor.place.gpu_device_id(), self.paddle_model.paddle_tensor.place.gpu_device_id()) | |||||
assert self.model.paddle_fc1.weight.place.is_gpu_place() | |||||
assert self.model.paddle_fc1.bias.place.is_gpu_place() | |||||
assert self.model.paddle_conv2d1.weight.place.is_gpu_place() | |||||
assert self.model.paddle_conv2d1.bias.place.is_gpu_place() | |||||
assert self.model.paddle_tensor.place.is_gpu_place() | |||||
assert self.model.paddle_fc1.weight.place.gpu_device_id() == self.paddle_model.paddle_fc1.weight.place.gpu_device_id() | |||||
assert self.model.paddle_fc1.bias.place.gpu_device_id() == self.paddle_model.paddle_fc1.bias.place.gpu_device_id() | |||||
assert self.model.paddle_conv2d1.weight.place.gpu_device_id() == self.paddle_model.paddle_conv2d1.weight.place.gpu_device_id() | |||||
assert self.model.paddle_conv2d1.bias.place.gpu_device_id() == self.paddle_model.paddle_conv2d1.bias.place.gpu_device_id() | |||||
assert self.model.paddle_tensor.place.gpu_device_id() == self.paddle_model.paddle_tensor.place.gpu_device_id() | |||||
else: | else: | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def if_training_correct(self, training): | def if_training_correct(self, training): | ||||
self.assertEqual(self.model.torch_fc1.training, training) | |||||
self.assertEqual(self.model.torch_softmax.training, training) | |||||
self.assertEqual(self.model.torch_conv2d1.training, training) | |||||
assert self.model.torch_fc1.training == training | |||||
assert self.model.torch_softmax.training == training | |||||
assert self.model.torch_conv2d1.training == training | |||||
self.assertEqual(self.model.paddle_fc1.training, training) | |||||
self.assertEqual(self.model.paddle_softmax.training, training) | |||||
self.assertEqual(self.model.paddle_conv2d1.training, training) | |||||
assert self.model.paddle_fc1.training == training | |||||
assert self.model.paddle_softmax.training == training | |||||
assert self.model.paddle_conv2d1.training == training | |||||
############################################################################ | ############################################################################ | ||||
@@ -311,10 +312,11 @@ class MixMNISTModel(MixModule): | |||||
return torch_out | return torch_out | ||||
class TestMNIST(unittest.TestCase): | |||||
@pytest.mark.torchpaddle | |||||
class TestMNIST: | |||||
@classmethod | @classmethod | ||||
def setUpClass(self): | |||||
def setup_class(self): | |||||
self.train_dataset = paddle.vision.datasets.MNIST(mode='train') | self.train_dataset = paddle.vision.datasets.MNIST(mode='train') | ||||
self.test_dataset = paddle.vision.datasets.MNIST(mode='test') | self.test_dataset = paddle.vision.datasets.MNIST(mode='test') | ||||
@@ -325,7 +327,7 @@ class TestMNIST(unittest.TestCase): | |||||
self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) | self.dataloader = DataLoader(self.train_dataset, batch_size=100, shuffle=True) | ||||
def setUp(self): | |||||
def setup_method(self): | |||||
self.model = MixMNISTModel().to("cuda") | self.model = MixMNISTModel().to("cuda") | ||||
self.torch_loss_func = torch.nn.CrossEntropyLoss() | self.torch_loss_func = torch.nn.CrossEntropyLoss() | ||||
@@ -353,7 +355,7 @@ class TestMNIST(unittest.TestCase): | |||||
self.paddle_opt.clear_grad() | self.paddle_opt.clear_grad() | ||||
else: | else: | ||||
self.assertLess(epoch_loss / (batch + 1), 0.3) | |||||
assert epoch_loss / (batch + 1) < 0.3 | |||||
# 开始测试 | # 开始测试 | ||||
correct = 0 | correct = 0 | ||||
@@ -367,7 +369,7 @@ class TestMNIST(unittest.TestCase): | |||||
correct += 1 | correct += 1 | ||||
acc = correct / len(self.test_dataset) | acc = correct / len(self.test_dataset) | ||||
self.assertGreater(acc, 0.85) | |||||
assert acc > 0.85 | |||||
############################################################################ | ############################################################################ | ||||
# | # | ||||