@@ -59,6 +59,7 @@ __all__ = [ | |||||
# drivers | # drivers | ||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
"TorchDDPDriver", | "TorchDDPDriver", | ||||
"DeepSpeedDriver", | |||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
"PaddleFleetDriver", | "PaddleFleetDriver", | ||||
"JittorSingleDriver", | "JittorSingleDriver", | ||||
@@ -3,6 +3,7 @@ __all__ = [ | |||||
'TorchDriver', | 'TorchDriver', | ||||
"TorchSingleDriver", | "TorchSingleDriver", | ||||
"TorchDDPDriver", | "TorchDDPDriver", | ||||
"DeepSpeedDriver", | |||||
"PaddleDriver", | "PaddleDriver", | ||||
"PaddleSingleDriver", | "PaddleSingleDriver", | ||||
"PaddleFleetDriver", | "PaddleFleetDriver", | ||||
@@ -14,7 +15,7 @@ __all__ = [ | |||||
'optimizer_state_to_device' | 'optimizer_state_to_device' | ||||
] | ] | ||||
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, torch_seed_everything, optimizer_state_to_device | |||||
from .torch_driver import TorchDriver, TorchSingleDriver, TorchDDPDriver, DeepSpeedDriver, torch_seed_everything, optimizer_state_to_device | |||||
from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | from .jittor_driver import JittorDriver, JittorMPIDriver, JittorSingleDriver | ||||
from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | from .paddle_driver import PaddleDriver, PaddleFleetDriver, PaddleSingleDriver, paddle_seed_everything | ||||
from .driver import Driver | from .driver import Driver | ||||
@@ -1,6 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'TorchDDPDriver', | 'TorchDDPDriver', | ||||
'TorchSingleDriver', | 'TorchSingleDriver', | ||||
'DeepSpeedDriver', | |||||
'TorchDriver', | 'TorchDriver', | ||||
'torch_seed_everything', | 'torch_seed_everything', | ||||
'optimizer_state_to_device' | 'optimizer_state_to_device' | ||||
@@ -10,6 +11,7 @@ from .ddp import TorchDDPDriver | |||||
# todo 实现 fairscale 后再将 fairscale 导入到这里; | # todo 实现 fairscale 后再将 fairscale 导入到这里; | ||||
from .single_device import TorchSingleDriver | from .single_device import TorchSingleDriver | ||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from .deepspeed import DeepSpeedDriver | |||||
from .utils import torch_seed_everything, optimizer_state_to_device | from .utils import torch_seed_everything, optimizer_state_to_device | ||||