|
|
@@ -56,7 +56,7 @@ def model_and_optimizers(request): |
|
|
|
num_labels=NormalClassificationTrainTorchConfig.num_labels, |
|
|
|
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension |
|
|
|
) |
|
|
|
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) |
|
|
|
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.01) |
|
|
|
dataset = TorchNormalDataset_Classification( |
|
|
|
num_labels=NormalClassificationTrainTorchConfig.num_labels, |
|
|
|
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, |
|
|
@@ -112,7 +112,7 @@ def test_trainer_torch_without_evaluator( |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", 4), |
|
|
|
@pytest.mark.parametrize("fp16", [False, True]) |
|
|
|
@pytest.mark.parametrize("accumulation_steps", [1, 3]) |
|
|
|
@magic_argv_env_context |
|
|
@@ -152,7 +152,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( |
|
|
|
|
|
|
|
# 测试 accumulation_steps; |
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) |
|
|
|
@pytest.mark.parametrize("accumulation_steps", [1, 3]) |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_torch_without_evaluator_accumulation_steps( |
|
|
@@ -186,7 +186,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.torch |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) |
|
|
|
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) |
|
|
|
@pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) |
|
|
|
@magic_argv_env_context |
|
|
|
def test_trainer_output_from_new_proc( |
|
|
@@ -196,8 +196,8 @@ def test_trainer_output_from_new_proc( |
|
|
|
output_from_new_proc, |
|
|
|
n_epochs=2, |
|
|
|
): |
|
|
|
std_msg = "test std msg trainer, std std std" |
|
|
|
err_msg = "test err msg trainer, err err, err" |
|
|
|
std_msg = "std_msg" |
|
|
|
err_msg = "err_msg" |
|
|
|
|
|
|
|
from fastNLP.core.log.logger import logger |
|
|
|
|
|
|
|