|
|
@@ -908,12 +908,17 @@ class DataSet(object): |
|
|
|
:param bool shuffle: 在split前是否shuffle一下 |
|
|
|
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] |
|
|
|
""" |
|
|
|
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' |
|
|
|
assert isinstance(ratio, float) |
|
|
|
assert 0 < ratio < 1 |
|
|
|
all_indices = [_ for _ in range(len(self))] |
|
|
|
if shuffle: |
|
|
|
np.random.shuffle(all_indices) |
|
|
|
split = int(ratio * len(self)) |
|
|
|
if split == 0: |
|
|
|
error_msg = f'Dev DataSet has {split} instance after split.' |
|
|
|
logger.error(error_msg) |
|
|
|
raise IndexError(error_msg) |
|
|
|
dev_indices = all_indices[:split] |
|
|
|
train_indices = all_indices[split:] |
|
|
|
dev_set = DataSet() |
|
|
|