diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index a6c6cde6..464a6446 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -406,7 +406,18 @@ class DataSet(object): else: raise ValueError("data only be dict or list type.") - self.collater = Collater() + self._collater = Collater() + + @property + def collater(self): + if self._collater is None: + self._collater = Collater() + return self._collater + + @collater.setter + def collater(self, value): + assert isinstance(value, Collater) + self._collater = value def __contains__(self, item): return item in self.field_arrays