|
|
@@ -23,9 +23,9 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, |
|
|
|
elif driver in {"jittor"}: |
|
|
|
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver |
|
|
|
return initialize_jittor_driver(driver, device, model, **kwargs) |
|
|
|
elif driver in {"paddle", "fleet"}: |
|
|
|
elif driver in {"paddle"}: |
|
|
|
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver |
|
|
|
return initialize_paddle_driver(driver, device, model, **kwargs) |
|
|
|
else: |
|
|
|
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale', " |
|
|
|
"'jittor', 'paddle', 'fleet'].") |
|
|
|
"'jittor', 'paddle'].") |