Browse Source

添加了collators测试用例

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
65c621db78
1 changed files with 81 additions and 0 deletions
  1. +81
    -0
      tests/core/collators/test_collator.py

+ 81
- 0
tests/core/collators/test_collator.py View File

@@ -0,0 +1,81 @@
import pytest

from fastNLP.core.collators import AutoCollator
from fastNLP.core.collators.collator import _MultiCollator
from fastNLP.core.dataset import DataSet


class TestCollator:

@pytest.mark.parametrize('as_numpy', [True, False])
def test_auto_collator(self, as_numpy):
"""
测试auto_collator的auto_pad功能

:param as_numpy:
:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=as_numpy)
collator.set_input('x', 'y')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
results = []
for bucket in bucket_data:
res = collator(bucket)
assert res['x'].shape == (40, 5)
assert res['y'].shape == (40,)
results.append(res)

def test_auto_collator_v1(self):
"""
测试auto_collator的set_pad_val和set_pad_val功能

:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
collator.set_input('x')
collator.set_pad_val('x', val=-1)
collator.set_as_numpy(True)
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = collator(bucket)
print(res)

def test_multicollator(self):
"""
测试multicollator功能

:return:
"""
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100,
'y': [0, 1, 1, 0] * 100})
collator = AutoCollator(as_numpy=False)
multi_collator = _MultiCollator(collator)
multi_collator.set_as_numpy(as_numpy=True)
multi_collator.set_pad_val('x', val=-1)
multi_collator.set_input('x')
bucket_data = []
data = []
for i in range(len(dataset)):
data.append(dataset[i])
if len(data) == 40:
bucket_data.append(data)
data = []
for bucket in bucket_data:
res = multi_collator(bucket)
print(res)

Loading…
Cancel
Save