| @@ -69,17 +69,17 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
| if not isinstance(device, List): | if not isinstance(device, List): | ||||
| return PaddleSingleDriver(model, device, **kwargs) | return PaddleSingleDriver(model, device, **kwargs) | ||||
| else: | else: | ||||
| logger.rank_zero_warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||||
| "`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||||
| "`driver` as `PaddleFleetDriver`.") | |||||
| logger.info("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||||
| "`PaddleFleetDriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||||
| "`driver` as `fleet`.") | |||||
| return PaddleFleetDriver(model, device, **kwargs) | return PaddleFleetDriver(model, device, **kwargs) | ||||
| elif driver == "fleet": | elif driver == "fleet": | ||||
| if not isinstance(device, List): | if not isinstance(device, List): | ||||
| if device == "cpu": | if device == "cpu": | ||||
| raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | ||||
| logger.rank_zero_warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||||
| logger.info("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||||
| "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | ||||
| "choose `paddle` driver.") | |||||
| "`driver` as `paddle`.") | |||||
| return PaddleFleetDriver(model, [device], **kwargs) | return PaddleFleetDriver(model, [device], **kwargs) | ||||
| else: | else: | ||||
| return PaddleFleetDriver(model, device, **kwargs) | return PaddleFleetDriver(model, device, **kwargs) | ||||