|
|
@@ -33,7 +33,8 @@ class ActionRecognitionPipeline(Pipeline): |
|
|
|
'cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
self.infer_model = BaseVideoModel(cfg=self.cfg).to(self.device) |
|
|
|
self.infer_model.eval() |
|
|
|
self.infer_model.load_state_dict(torch.load(model_path)['model_state']) |
|
|
|
self.infer_model.load_state_dict( |
|
|
|
torch.load(model_path, map_location=self.device)['model_state']) |
|
|
|
self.label_mapping = self.cfg.label_mapping |
|
|
|
logger.info('load model done') |
|
|
|
|
|
|
|