diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index dea27174..4925ac36 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,3 +1,4 @@ +import _pickle as pickle import numpy as np from fastNLP.core.fieldarray import FieldArray @@ -317,3 +318,12 @@ class DataSet(object): for header, content in zip(headers, contents): _dict[header].append(content) return cls(_dict) + + def save(self, path): + with open(path, 'wb') as f: + pickle.dump(self, f) + + @staticmethod + def load(self, path): + with open(path, 'rb') as f: + return pickle.load(f)