|
|
@@ -69,17 +69,17 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ |
|
|
|
if not isinstance(device, List): |
|
|
|
return PaddleSingleDriver(model, device, **kwargs) |
|
|
|
else: |
|
|
|
logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" |
|
|
|
"`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" |
|
|
|
"`driver` as `PaddleFleetDriver`.") |
|
|
|
logger.info("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" |
|
|
|
"`PaddleFleetDriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" |
|
|
|
"`driver` as `fleet`.") |
|
|
|
return PaddleFleetDriver(model, device, **kwargs) |
|
|
|
elif driver == "fleet": |
|
|
|
if not isinstance(device, List): |
|
|
|
if device == "cpu": |
|
|
|
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") |
|
|
|
logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" |
|
|
|
logger.info("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" |
|
|
|
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " |
|
|
|
"choose `paddle` driver.") |
|
|
|
"`driver` as `paddle`.") |
|
|
|
return PaddleFleetDriver(model, [device], **kwargs) |
|
|
|
else: |
|
|
|
return PaddleFleetDriver(model, device, **kwargs) |