|
|
@@ -384,12 +384,17 @@ class DistributedPipeline(Pipeline): |
|
|
|
preprocessor: Union[Preprocessor, List[Preprocessor]] = None, |
|
|
|
auto_collate=True, |
|
|
|
**kwargs): |
|
|
|
super().__init__(model=model, preprocessor=preprocessor, kwargs=kwargs) |
|
|
|
# DistributedPipeline uses classmethod to initialize model |
|
|
|
# without calling super().__init__ method |
|
|
|
self.preprocessor = preprocessor |
|
|
|
self._model_prepare = False |
|
|
|
self._model_prepare_lock = Lock() |
|
|
|
self._auto_collate = auto_collate |
|
|
|
|
|
|
|
self.model_dir = self.model.model_dir |
|
|
|
if os.path.exists(model): |
|
|
|
self.model_dir = model |
|
|
|
else: |
|
|
|
self.model_dir = snapshot_download(model) |
|
|
|
self.cfg = read_config(self.model_dir) |
|
|
|
self.world_size = self.cfg.model.world_size |
|
|
|
self.model_pool = None |
|
|
|