|
|
@@ -0,0 +1,81 @@ |
|
|
|
import pytest |
|
|
|
|
|
|
|
from fastNLP.core.collators import AutoCollator |
|
|
|
from fastNLP.core.collators.collator import _MultiCollator |
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
|
|
|
|
|
|
class TestCollator: |
|
|
|
|
|
|
|
@pytest.mark.parametrize('as_numpy', [True, False]) |
|
|
|
def test_auto_collator(self, as_numpy): |
|
|
|
""" |
|
|
|
测试auto_collator的auto_pad功能 |
|
|
|
|
|
|
|
:param as_numpy: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, |
|
|
|
'y': [0, 1, 1, 0] * 100}) |
|
|
|
collator = AutoCollator(as_numpy=as_numpy) |
|
|
|
collator.set_input('x', 'y') |
|
|
|
bucket_data = [] |
|
|
|
data = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
data.append(dataset[i]) |
|
|
|
if len(data) == 40: |
|
|
|
bucket_data.append(data) |
|
|
|
data = [] |
|
|
|
results = [] |
|
|
|
for bucket in bucket_data: |
|
|
|
res = collator(bucket) |
|
|
|
assert res['x'].shape == (40, 5) |
|
|
|
assert res['y'].shape == (40,) |
|
|
|
results.append(res) |
|
|
|
|
|
|
|
def test_auto_collator_v1(self): |
|
|
|
""" |
|
|
|
测试auto_collator的set_pad_val和set_pad_val功能 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, |
|
|
|
'y': [0, 1, 1, 0] * 100}) |
|
|
|
collator = AutoCollator(as_numpy=False) |
|
|
|
collator.set_input('x') |
|
|
|
collator.set_pad_val('x', val=-1) |
|
|
|
collator.set_as_numpy(True) |
|
|
|
bucket_data = [] |
|
|
|
data = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
data.append(dataset[i]) |
|
|
|
if len(data) == 40: |
|
|
|
bucket_data.append(data) |
|
|
|
data = [] |
|
|
|
for bucket in bucket_data: |
|
|
|
res = collator(bucket) |
|
|
|
print(res) |
|
|
|
|
|
|
|
def test_multicollator(self): |
|
|
|
""" |
|
|
|
测试multicollator功能 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, |
|
|
|
'y': [0, 1, 1, 0] * 100}) |
|
|
|
collator = AutoCollator(as_numpy=False) |
|
|
|
multi_collator = _MultiCollator(collator) |
|
|
|
multi_collator.set_as_numpy(as_numpy=True) |
|
|
|
multi_collator.set_pad_val('x', val=-1) |
|
|
|
multi_collator.set_input('x') |
|
|
|
bucket_data = [] |
|
|
|
data = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
data.append(dataset[i]) |
|
|
|
if len(data) == 40: |
|
|
|
bucket_data.append(data) |
|
|
|
data = [] |
|
|
|
for bucket in bucket_data: |
|
|
|
res = multi_collator(bucket) |
|
|
|
print(res) |