diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index b557038b..9c20c2a6 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -1,5 +1,5 @@ -import _pickle +import torch class API: @@ -11,6 +11,6 @@ class API: pass def load(self, name): - _dict = _pickle.load(name) + _dict = torch.load(name) self.pipeline = _dict['pipeline'] self.model = _dict['model']