Browse Source

添加了对tensor_to_numeric reduce 参数的测试

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
28eb1a5836
3 changed files with 42 additions and 13 deletions
  1. +1
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  2. +22
    -7
      tests/core/drivers/paddle_driver/test_single_device.py
  3. +19
    -5
      tests/core/drivers/torch_driver/test_single_device.py

+ 1
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH:
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import RandomSampler as TorchRandomSampler from torch.utils.data import RandomSampler as TorchRandomSampler
_reduces = { _reduces = {
'sum': torch.max,
'sum': torch.sum,
'min': torch.min, 'min': torch.min,
'max': torch.max, 'max': torch.max,
'mean': torch.mean 'mean': torch.mean


+ 22
- 7
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -75,12 +75,12 @@ class TestPaddleDriverFunctions:
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
""" """
dataloader = DataLoader(PaddleNormalDataset()) dataloader = DataLoader(PaddleNormalDataset())
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# batch_size 和 batch_sampler 均为 None 的情形 # batch_size 和 batch_sampler 均为 None 的情形
dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) dataloader = DataLoader(PaddleNormalDataset(), batch_size=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# 创建torch的dataloader # 创建torch的dataloader
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
@@ -88,7 +88,7 @@ class TestPaddleDriverFunctions:
batch_size=32, shuffle=True batch_size=32, shuffle=True
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


@pytest.mark.torchpaddle @pytest.mark.torchpaddle
def test_check_dataloader_legality_in_test(self): def test_check_dataloader_legality_in_test(self):
@@ -100,7 +100,7 @@ class TestPaddleDriverFunctions:
"train": DataLoader(PaddleNormalDataset()), "train": DataLoader(PaddleNormalDataset()),
"test":DataLoader(PaddleNormalDataset()) "test":DataLoader(PaddleNormalDataset())
} }
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# batch_size 和 batch_sampler 均为 None 的情形 # batch_size 和 batch_sampler 均为 None 的情形
dataloader = { dataloader = {
@@ -108,12 +108,12 @@ class TestPaddleDriverFunctions:
"test":DataLoader(PaddleNormalDataset(), batch_size=None) "test":DataLoader(PaddleNormalDataset(), batch_size=None)
} }
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# 传入的不是 dict ,应该报错 # 传入的不是 dict ,应该报错
dataloader = DataLoader(PaddleNormalDataset()) dataloader = DataLoader(PaddleNormalDataset())
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# 创建 torch 的 dataloader # 创建 torch 的 dataloader
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
@@ -126,7 +126,7 @@ class TestPaddleDriverFunctions:
) )
dataloader = {"train": train_loader, "test": test_loader} dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError): with pytest.raises(ValueError):
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader")


@pytest.mark.paddle @pytest.mark.paddle
def test_tensor_to_numeric(self): def test_tensor_to_numeric(self):
@@ -182,6 +182,21 @@ class TestPaddleDriverFunctions:
assert r == d.tolist() assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()


@pytest.mark.paddle
def test_tensor_to_numeric_reduce(self):
tensor = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

res_max = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="max")
res_min = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="min")
res_sum = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="sum")
res_mean = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="mean")

assert res_max == 6
assert res_min == 1
assert res_sum == 21
assert res_mean == 3.5


@pytest.mark.paddle @pytest.mark.paddle
def test_set_model_mode(self): def test_set_model_mode(self):
""" """


+ 19
- 5
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -117,7 +117,7 @@ class TestTorchDriverFunctions:
测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现
""" """
dataloader = DataLoader(TorchNormalDataset()) dataloader = DataLoader(TorchNormalDataset())
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# 创建 paddle 的 dataloader # 创建 paddle 的 dataloader
dataloader = paddle.io.DataLoader( dataloader = paddle.io.DataLoader(
@@ -125,7 +125,7 @@ class TestTorchDriverFunctions:
batch_size=32, shuffle=True batch_size=32, shuffle=True
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")


@pytest.mark.torchpaddle @pytest.mark.torchpaddle
def test_check_dataloader_legality_in_test(self): def test_check_dataloader_legality_in_test(self):
@@ -137,12 +137,12 @@ class TestTorchDriverFunctions:
"train": DataLoader(TorchNormalDataset()), "train": DataLoader(TorchNormalDataset()),
"test": DataLoader(TorchNormalDataset()) "test": DataLoader(TorchNormalDataset())
} }
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# 传入的不是 dict,应该报错 # 传入的不是 dict,应该报错
dataloader = DataLoader(TorchNormalDataset()) dataloader = DataLoader(TorchNormalDataset())
with pytest.raises(ValueError): with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")


# 创建 paddle 的 dataloader # 创建 paddle 的 dataloader
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
@@ -155,7 +155,7 @@ class TestTorchDriverFunctions:
) )
dataloader = {"train": train_loader, "test": test_loader} dataloader = {"train": train_loader, "test": test_loader}
with pytest.raises(ValueError): with pytest.raises(ValueError):
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False)
TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader")


@pytest.mark.torch @pytest.mark.torch
def test_tensor_to_numeric(self): def test_tensor_to_numeric(self):
@@ -211,6 +211,20 @@ class TestTorchDriverFunctions:
assert r == d.tolist() assert r == d.tolist()
assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist()


@pytest.mark.torch
def test_tensor_to_numeric_reduce(self):
tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])

res_max = TorchSingleDriver.tensor_to_numeric(tensor, reduce="max")
res_min = TorchSingleDriver.tensor_to_numeric(tensor, reduce="min")
res_sum = TorchSingleDriver.tensor_to_numeric(tensor, reduce="sum")
res_mean = TorchSingleDriver.tensor_to_numeric(tensor, reduce="mean")

assert res_max == 6
assert res_min == 1
assert res_sum == 21
assert res_mean == 3.5

@pytest.mark.torch @pytest.mark.torch
def test_set_model_mode(self): def test_set_model_mode(self):
""" """


Loading…
Cancel
Save