|
|
@@ -152,12 +152,12 @@ class PaddleFleetDriver(PaddleDriver): |
|
|
|
parallel_device: Optional[Union[List[str], str]], |
|
|
|
is_pull_by_paddle_run: bool = False, |
|
|
|
fp16: bool = False, |
|
|
|
paddle_kwrags: Dict = {}, |
|
|
|
paddle_kwargs: Dict = None, |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
if USER_CUDA_VISIBLE_DEVICES not in os.environ: |
|
|
|
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") |
|
|
|
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, paddle_kwrags=paddle_kwargs, **kwargs) |
|
|
|
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using fastNLP.") |
|
|
|
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, paddle_kwargs=paddle_kwargs, **kwargs) |
|
|
|
|
|
|
|
# 如果不是通过 launch 启动,要求用户必须传入 parallel_device |
|
|
|
if not is_pull_by_paddle_run: |
|
|
@@ -195,17 +195,14 @@ class PaddleFleetDriver(PaddleDriver): |
|
|
|
self.world_size = None |
|
|
|
self.global_rank = 0 |
|
|
|
self.gloo_rendezvous_dir = None |
|
|
|
|
|
|
|
# 分布式环境的其它参数设置 |
|
|
|
paddle_kwargs = kwargs.get("paddle_kwargs", {}) |
|
|
|
|
|
|
|
self._fleet_kwargs = paddle_kwargs.get("fleet_kwargs", {}) |
|
|
|
self._fleet_kwargs = self._paddle_kwargs.get("fleet_kwargs", {}) |
|
|
|
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__, DataParallel.__name__) |
|
|
|
# fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 |
|
|
|
self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) |
|
|
|
self.is_collective = self._fleet_kwargs.pop("is_collective", True) |
|
|
|
if not self.is_collective: |
|
|
|
raise NotImplementedError("FastNLP only support `collective` for distributed training now.") |
|
|
|
raise NotImplementedError("fastNLP only support `collective` for distributed training now.") |
|
|
|
self.role_maker = self._fleet_kwargs.pop("role_maker", None) |
|
|
|
|
|
|
|
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") |
|
|
|