@@ -205,6 +205,8 @@ class TopkSaver(ResultsMonitor, Saver): | |||
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | |||
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | |||
**kwargs): | |||
if topk is None: | |||
topk = 0 | |||
ResultsMonitor.__init__(self, monitor, larger_better) | |||
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | |||
@@ -134,7 +134,11 @@ class JittorTensorPadder(Padder): | |||
f"it must have tolist() method.") | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if len(batch_field) < 2: | |||
max_shape = [len(batch_field)] + list(shapes[0]) | |||
else: | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
# if dtype is not None: | |||
# tensor = jittor.full(max_shape, pad_val, dtype=dtype) | |||
# else: | |||
@@ -97,7 +97,11 @@ class NumpyTensorPadder(Padder): | |||
f"it must have tolist() method.") | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if len(batch_field) < 2: | |||
max_shape = [len(batch_field)] + list(shapes[0]) | |||
else: | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
array = np.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
for i, field in enumerate(batch_field): | |||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
@@ -140,7 +140,11 @@ class PaddleTensorPadder(Padder): | |||
f"it must have tolist() method.") | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if len(batch_field) < 2: | |||
max_shape = [len(batch_field)] + list(shapes[0]) | |||
else: | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if isinstance(batch_field[0], paddle.Tensor): | |||
array = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
else: | |||
@@ -132,7 +132,11 @@ class TorchTensorPadder(Padder): | |||
f"it must have tolist() method.") | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if len(batch_field) < 2: | |||
max_shape = [len(batch_field)] + list(shapes[0]) | |||
else: | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype, device=device) | |||
for i, field in enumerate(batch_field): | |||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
@@ -32,7 +32,11 @@ def get_shape(batch_field:List, shape=None): | |||
if isinstance(batch_field[0], Sequence): | |||
for _field in batch_field: | |||
shapes.append(get_shape(_field, _shape)) | |||
max_shape = [max(_) for _ in zip(*shapes)] | |||
if len(shapes) == 1: | |||
max_shape = shapes[0] | |||
else: | |||
max_shape = [max(_) for _ in zip(*shapes)] | |||
return max_shape | |||
except IndexError: # 空的shape | |||
pass | |||
@@ -618,9 +618,9 @@ class Trainer(TrainerEventTrigger): | |||
if not catch_KeyboardInterrupt: | |||
raise e | |||
except RuntimeError as e: | |||
if 'torch' in self.driver_name.lower(): # 如果是 torch ,需要检测一下 find_unused_parameters | |||
if 'torch' in self.driver_name.lower() and len(e.args) > 0: # 如果是 torch ,需要检测一下 find_unused_parameters | |||
if 'find_unused_parameters' in e.args[0]: | |||
logger.error("You may need to pass `torch_ddp_kwargs={'find_unused_parameters': True}` in the " | |||
logger.error("You may need to pass `torch_kwargs={'ddp_kwargs':{'find_unused_parameters': True}}` in the " | |||
"Trainer initialization to avoid this error.") | |||
self.driver.on_exception() | |||
self.on_exception(e) | |||
@@ -249,7 +249,7 @@ class PaddleDataLoader(DataLoader): | |||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
return_list: bool = True, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
train_batch_size: int = 1, shuffle: bool = False, | |||
batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
@@ -259,7 +259,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
from fastNLP.io.data_bundle import DataBundle | |||
if isinstance(ds_or_db, Dataset): | |||
dl = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
@@ -270,7 +270,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
if 'train' in name: | |||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | |||
return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, | |||
batch_sampler=batch_sampler, batch_size=batch_size, | |||
shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, | |||
@@ -292,7 +292,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
ds_seq = [] | |||
for ds in ds_or_db: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
@@ -304,7 +304,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
for name, ds in ds_or_db.items(): | |||
if 'train' in name: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
batch_sampler=batch_sampler, batch_size=batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
@@ -109,6 +109,9 @@ def _get_backend(): | |||
if len(available_backends) == 1: | |||
backend = available_backends.pop() | |||
logger.debug(f"Get Dataloader backend:{backend} from sys.modules.") | |||
elif len(available_backends) > 1: | |||
raise RuntimeError("Fail to detect dataloader backend automatically, because multiple backends:" | |||
f"{available_backends} has been imported.") | |||
else: | |||
raise RuntimeError("Fail to detect dataloader backend automatically, please set it manually.") | |||
return backend |
@@ -78,9 +78,14 @@ class TorchDataLoader(DataLoader): | |||
if not isinstance(dataset, _FDataSet): | |||
dataset = _FDataSet(dataset) | |||
if sampler is None and batch_sampler is None: | |||
if batch_sampler is not None: | |||
batch_size = 1 | |||
shuffle = False | |||
sampler = None | |||
elif sampler is None: | |||
sampler = RandomSampler(dataset, shuffle=shuffle) | |||
shuffle = False | |||
if isinstance(collate_fn, str): | |||
if collate_fn == 'auto': | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
@@ -179,7 +184,7 @@ class TorchDataLoader(DataLoader): | |||
def prepare_torch_dataloader(ds_or_db, | |||
train_batch_size: int = 16, | |||
batch_size: int = 16, | |||
shuffle: bool = False, | |||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | |||
@@ -215,7 +220,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
from fastNLP.io import DataBundle | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=train_batch_size, | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
@@ -228,7 +233,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
@@ -237,7 +242,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
persistent_workers=persistent_workers, | |||
) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||
batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
@@ -252,10 +257,10 @@ def prepare_torch_dataloader(ds_or_db, | |||
dl_bundle = [] | |||
for idx, ds in enumerate(ds_or_db): | |||
if idx > 0: | |||
train_batch_size = non_train_batch_size if non_train_batch_size else train_batch_size | |||
batch_size = non_train_batch_size if non_train_batch_size else batch_size | |||
sampler = non_train_sampler if non_train_sampler else sampler | |||
dl_bundle.append( | |||
TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
@@ -269,7 +274,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.items(): | |||
if 'train' in name: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=train_batch_size, | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
@@ -278,7 +283,7 @@ def prepare_torch_dataloader(ds_or_db, | |||
persistent_workers=persistent_workers, | |||
) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else train_batch_size, | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||
batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
@@ -497,7 +497,7 @@ class DataSet: | |||
:param progress_desc: 进度条的描述字符,默认为'Main | |||
""" | |||
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": | |||
raise ("Lambda function does not support multiple processes, please set `num_proc=0`.") | |||
raise TypeError("Lambda function does not support multiple processes, please set `num_proc=0`.") | |||
if num_proc>1 and sys.platform in ('win32', 'msys', 'cygwin'): | |||
raise RuntimeError("Your platform does not support multiprocessing with fork, please set `num_proc=0`") | |||
@@ -6,7 +6,7 @@ __all__ = [ | |||
import math | |||
from copy import deepcopy | |||
from typing import Dict, Union, List | |||
from typing import Dict, Union, List, Sequence | |||
from itertools import chain | |||
import numpy as np | |||
@@ -390,17 +390,20 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
self.length = np.array(length, dtype=int) | |||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
else: | |||
types = set(map(type, length)) | |||
assert isinstance(length, list) and len(types)==1 and types.pop()==int, \ | |||
"When the dataset is not fastNLP.DataSet, the length parameter can only be List[int]" | |||
try: | |||
self.length = np.array(length, dtype=int) | |||
self.sorted_indices = np.argsort(length)[::-1] | |||
except BaseException as e: | |||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | |||
assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ | |||
f"`length`({len(length)}) should be equal." | |||
assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length." | |||
self.dataset = dataset | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
self.batch_size = batch_size | |||
self.num_batch_per_bucket = num_batch_per_bucket | |||
@@ -5,7 +5,7 @@ __all__ = [ | |||
"SequentialSampler" | |||
] | |||
from typing import Dict, List, Union | |||
from typing import Dict, List, Union, Sequence | |||
import math | |||
import numpy as np | |||
@@ -305,12 +305,18 @@ class SortedSampler(SequentialSampler): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
self.length = np.array(length, dtype=int) | |||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||
else: | |||
types = set(map(type, length)) | |||
assert isinstance(length, list) and len(types)==1 and types.pop()==int, \ | |||
"When the dataset is not fastNLP.DataSet, the length parameter can only be List[int]" | |||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||
try: | |||
self.length = np.array(length, dtype=int) | |||
self.sorted_indices = np.argsort(length)[::-1] | |||
except BaseException as e: | |||
logger.error(f"Cannot use {self.__class__.__name__} as length, since it is not sortable.") | |||
assert len(length) == len(dataset), f"The length of `dataset`({len(dataset)}) and " \ | |||
f"`length`({len(length)}) should be equal." | |||
assert len(self.sorted_indices) == len(dataset), "The indices and dataset should have equal length." | |||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | |||
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 | |||
@@ -184,7 +184,7 @@ class TestDataSetMethods: | |||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", progress_bar=None) | |||
assert ds.field_arrays["y"].content[0] == 2 | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") | |||
assert (isinstance(res, list) and len(res) > 0) == True | |||
assert res[0] == 4 | |||
@@ -377,7 +377,7 @@ class TestDataSetMethods: | |||
def test_apply_proc(self): | |||
data = DataSet({'x': ['xxxxas1w xw zxw xz', 'xxxxas1w xw zxw xz'] * 100, 'y': [0, 1] * 100}) | |||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=2) | |||
data.apply_field(lambda x: len(x), field_name='x', new_field_name='len_x', num_proc=0) | |||
class TestFieldArrayInit: | |||
@@ -87,16 +87,10 @@ class TestPaddleMoveDataToDevice: | |||
""" | |||
paddle_tensor = paddle.rand((3, 4, 5)).cpu() | |||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device=None) | |||
res = paddle_move_data_to_device(paddle_tensor, device=None) | |||
self.check_cpu(res) | |||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device=None) | |||
self.check_gpu(res, 0) | |||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0", data_device="cpu") | |||
self.check_gpu(res, 0) | |||
res = paddle_move_data_to_device(paddle_tensor, device=None, data_device="gpu:0") | |||
res = paddle_move_data_to_device(paddle_tensor, device="gpu:0") | |||
self.check_gpu(res, 0) | |||
def test_list_transfer(self): | |||
@@ -106,12 +100,12 @@ class TestPaddleMoveDataToDevice: | |||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
res = paddle_move_data_to_device(paddle_list, device="cpu", data_device="gpu:1") | |||
res = paddle_move_data_to_device(paddle_list, device="cpu") | |||
assert isinstance(res, list) | |||
for r in res: | |||
self.check_cpu(r) | |||
res = paddle_move_data_to_device(paddle_list, device="gpu:0", data_device=None) | |||
res = paddle_move_data_to_device(paddle_list, device="gpu:0") | |||
assert isinstance(res, list) | |||
for r in res: | |||
self.check_gpu(r, 0) | |||
@@ -124,12 +118,12 @@ class TestPaddleMoveDataToDevice: | |||
paddle_list = [paddle.rand((6, 4, 2)) for i in range(10)] | |||
paddle_tuple = tuple(paddle_list) | |||
res = paddle_move_data_to_device(paddle_tuple, device="cpu", data_device="gpu:1") | |||
res = paddle_move_data_to_device(paddle_tuple, device="cpu") | |||
assert isinstance(res, tuple) | |||
for r in res: | |||
self.check_cpu(r) | |||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:0", data_device=None) | |||
res = paddle_move_data_to_device(paddle_tuple, device="gpu:0") | |||
assert isinstance(res, tuple) | |||
for r in res: | |||
self.check_gpu(r, 0) | |||
@@ -150,7 +144,7 @@ class TestPaddleMoveDataToDevice: | |||
"string": "test string" | |||
} | |||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device=None) | |||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0") | |||
assert isinstance(res, dict) | |||
self.check_gpu(res["tensor"], 0) | |||
assert isinstance(res["list"], list) | |||
@@ -164,7 +158,7 @@ class TestPaddleMoveDataToDevice: | |||
self.check_gpu(t, 0) | |||
self.check_gpu(res["dict"]["tensor"], 0) | |||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0", data_device="cpu") | |||
res = paddle_move_data_to_device(paddle_dict, device="gpu:0") | |||
assert isinstance(res, dict) | |||
self.check_gpu(res["tensor"], 0) | |||
assert isinstance(res["list"], list) | |||
@@ -178,7 +172,7 @@ class TestPaddleMoveDataToDevice: | |||
self.check_gpu(t, 0) | |||
self.check_gpu(res["dict"]["tensor"], 0) | |||
res = paddle_move_data_to_device(paddle_dict, device="cpu", data_device="gpu:0") | |||
res = paddle_move_data_to_device(paddle_dict, device="cpu") | |||
assert isinstance(res, dict) | |||
self.check_cpu(res["tensor"]) | |||
assert isinstance(res["list"], list) | |||
@@ -56,13 +56,13 @@ class TestPaddle2Torch: | |||
res = paddle2torch(paddle_tensor) | |||
self.check_torch_tensor(res, "cpu", not paddle_tensor.stop_gradient) | |||
res = paddle2torch(paddle_tensor, target_device="cuda:2", no_gradient=None) | |||
res = paddle2torch(paddle_tensor, device="cuda:2", no_gradient=None) | |||
self.check_torch_tensor(res, "cuda:2", not paddle_tensor.stop_gradient) | |||
res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=True) | |||
res = paddle2torch(paddle_tensor, device="cuda:1", no_gradient=True) | |||
self.check_torch_tensor(res, "cuda:1", False) | |||
res = paddle2torch(paddle_tensor, target_device="cuda:1", no_gradient=False) | |||
res = paddle2torch(paddle_tensor, device="cuda:1", no_gradient=False) | |||
self.check_torch_tensor(res, "cuda:1", True) | |||
def test_list_transfer(self): | |||
@@ -76,7 +76,7 @@ class TestPaddle2Torch: | |||
for t in res: | |||
self.check_torch_tensor(t, "cuda:1", False) | |||
res = paddle2torch(paddle_list, target_device="cpu", no_gradient=False) | |||
res = paddle2torch(paddle_list, device="cpu", no_gradient=False) | |||
assert isinstance(res, list) | |||
for t in res: | |||
self.check_torch_tensor(t, "cpu", True) | |||
@@ -176,13 +176,13 @@ class TestTorch2Paddle: | |||
res = torch2paddle(torch_tensor) | |||
self.check_paddle_tensor(res, "cpu", True) | |||
res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=None) | |||
res = torch2paddle(torch_tensor, device="gpu:2", no_gradient=None) | |||
self.check_paddle_tensor(res, "gpu:2", True) | |||
res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=True) | |||
res = torch2paddle(torch_tensor, device="gpu:2", no_gradient=True) | |||
self.check_paddle_tensor(res, "gpu:2", True) | |||
res = torch2paddle(torch_tensor, target_device="gpu:2", no_gradient=False) | |||
res = torch2paddle(torch_tensor, device="gpu:2", no_gradient=False) | |||
self.check_paddle_tensor(res, "gpu:2", False) | |||
def test_tensor_list_transfer(self): | |||
@@ -196,7 +196,7 @@ class TestTorch2Paddle: | |||
for t in res: | |||
self.check_paddle_tensor(t, "cpu", True) | |||
res = torch2paddle(torch_list, target_device="gpu:1", no_gradient=False) | |||
res = torch2paddle(torch_list, device="gpu:1", no_gradient=False) | |||
assert isinstance(res, list) | |||
for t in res: | |||
self.check_paddle_tensor(t, "gpu:1", False) | |||
@@ -208,7 +208,7 @@ class TestTorch2Paddle: | |||
torch_list = [torch.rand(6, 4, 2) for i in range(10)] | |||
torch_tuple = tuple(torch_list) | |||
res = torch2paddle(torch_tuple, target_device="cpu") | |||
res = torch2paddle(torch_tuple, device="cpu") | |||
assert isinstance(res, tuple) | |||
for t in res: | |||
self.check_paddle_tensor(t, "cpu", True) | |||
@@ -249,6 +249,7 @@ class TestTorch2Paddle: | |||
# | |||
############################################################################ | |||
@pytest.mark.torchjittor | |||
class TestJittor2Torch: | |||
def check_torch_tensor(self, tensor, device, requires_grad): | |||
@@ -272,13 +273,13 @@ class TestJittor2Torch: | |||
res = jittor2torch(jittor_var) | |||
self.check_torch_tensor(res, "cpu", True) | |||
res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=None) | |||
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=None) | |||
self.check_torch_tensor(res, "cuda:2", True) | |||
res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=True) | |||
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=True) | |||
self.check_torch_tensor(res, "cuda:2", False) | |||
res = jittor2torch(jittor_var, target_device="cuda:2", no_gradient=False) | |||
res = jittor2torch(jittor_var, device="cuda:2", no_gradient=False) | |||
self.check_torch_tensor(res, "cuda:2", True) | |||
def test_var_list_transfer(self): | |||
@@ -292,7 +293,7 @@ class TestJittor2Torch: | |||
for t in res: | |||
self.check_torch_tensor(t, "cpu", True) | |||
res = jittor2torch(jittor_list, target_device="cuda:1", no_gradient=False) | |||
res = jittor2torch(jittor_list, device="cuda:1", no_gradient=False) | |||
assert isinstance(res, list) | |||
for t in res: | |||
self.check_torch_tensor(t, "cuda:1", True) | |||
@@ -304,7 +305,7 @@ class TestJittor2Torch: | |||
jittor_list = [jittor.rand((6, 4, 2)) for i in range(10)] | |||
jittor_tuple = tuple(jittor_list) | |||
res = jittor2torch(jittor_tuple, target_device="cpu") | |||
res = jittor2torch(jittor_tuple, device="cpu") | |||
assert isinstance(res, tuple) | |||
for t in res: | |||
self.check_torch_tensor(t, "cpu", True) | |||
@@ -345,6 +346,7 @@ class TestJittor2Torch: | |||
# | |||
############################################################################ | |||
@pytest.mark.torchjittor | |||
class TestTorch2Jittor: | |||
def check_jittor_var(self, var, requires_grad): | |||
@@ -4,4 +4,5 @@ markers = | |||
paddle | |||
paddledist | |||
jittor | |||
torchpaddle | |||
torchpaddle | |||
torchjittor |