diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 21325b5c..17d65d54 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: from torch.optim import Optimizer from torch.utils.data import RandomSampler as TorchRandomSampler _reduces = { - 'sum': torch.max, + 'sum': torch.sum, 'min': torch.min, 'max': torch.max, 'mean': torch.mean diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 9b7a8560..3c2d7e27 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -75,12 +75,12 @@ class TestPaddleDriverFunctions: 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 """ dataloader = DataLoader(PaddleNormalDataset()) - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # batch_size 和 batch_sampler 均为 None 的情形 dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建torch的dataloader dataloader = torch.utils.data.DataLoader( @@ -88,7 +88,7 @@ class TestPaddleDriverFunctions: batch_size=32, shuffle=True ) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.torchpaddle def test_check_dataloader_legality_in_test(self): @@ -100,7 +100,7 @@ class TestPaddleDriverFunctions: "train": DataLoader(PaddleNormalDataset()), "test":DataLoader(PaddleNormalDataset()) } - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # batch_size 和 batch_sampler 均为 None 的情形 dataloader = { @@ -108,12 +108,12 @@ class TestPaddleDriverFunctions: "test":DataLoader(PaddleNormalDataset(), batch_size=None) } with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 传入的不是 dict ,应该报错 dataloader = DataLoader(PaddleNormalDataset()) with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建 torch 的 dataloader train_loader = torch.utils.data.DataLoader( @@ -126,7 +126,7 @@ class TestPaddleDriverFunctions: ) dataloader = {"train": train_loader, "test": test_loader} with pytest.raises(ValueError): - PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.paddle def test_tensor_to_numeric(self): @@ -182,6 +182,21 @@ class TestPaddleDriverFunctions: assert r == d.tolist() assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() + @pytest.mark.paddle + def test_tensor_to_numeric_reduce(self): + tensor = paddle.to_tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + res_max = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="max") + res_min = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="min") + res_sum = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="sum") + res_mean = PaddleSingleDriver.tensor_to_numeric(tensor, reduce="mean") + + assert res_max == 6 + assert res_min == 1 + assert res_sum == 21 + assert res_mean == 3.5 + + @pytest.mark.paddle def test_set_model_mode(self): """ diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 7839e1c9..4d92b05a 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -117,7 +117,7 @@ class TestTorchDriverFunctions: 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 """ dataloader = DataLoader(TorchNormalDataset()) - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建 paddle 的 dataloader dataloader = paddle.io.DataLoader( @@ -125,7 +125,7 @@ class TestTorchDriverFunctions: batch_size=32, shuffle=True ) with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.torchpaddle def test_check_dataloader_legality_in_test(self): @@ -137,12 +137,12 @@ class TestTorchDriverFunctions: "train": DataLoader(TorchNormalDataset()), "test": DataLoader(TorchNormalDataset()) } - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 传入的不是 dict,应该报错 dataloader = DataLoader(TorchNormalDataset()) with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") # 创建 paddle 的 dataloader train_loader = paddle.io.DataLoader( @@ -155,7 +155,7 @@ class TestTorchDriverFunctions: ) dataloader = {"train": train_loader, "test": test_loader} with pytest.raises(ValueError): - TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader") @pytest.mark.torch def test_tensor_to_numeric(self): @@ -211,6 +211,20 @@ class TestTorchDriverFunctions: assert r == d.tolist() assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() + @pytest.mark.torch + def test_tensor_to_numeric_reduce(self): + tensor = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + res_max = TorchSingleDriver.tensor_to_numeric(tensor, reduce="max") + res_min = TorchSingleDriver.tensor_to_numeric(tensor, reduce="min") + res_sum = TorchSingleDriver.tensor_to_numeric(tensor, reduce="sum") + res_mean = TorchSingleDriver.tensor_to_numeric(tensor, reduce="mean") + + assert res_max == 6 + assert res_min == 1 + assert res_sum == 21 + assert res_mean == 3.5 + @pytest.mark.torch def test_set_model_mode(self): """