|
- import pytest
- from typing import Mapping
-
- from fastNLP.core.dataloaders import MixDataLoader
- from fastNLP import DataSet
- from fastNLP.core.collators import Collator
- from fastNLP.envs.imports import _NEED_IMPORT_TORCH
-
- if _NEED_IMPORT_TORCH:
- import torch
- from torch.utils.data import default_collate, SequentialSampler, RandomSampler
-
- d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
-
- d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10})
-
- d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})
-
-
- def test_pad_val(tensor, val=0):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.tolist()
- for item in tensor:
- if item[-1] > 0:
- continue
- elif item[-1] != val:
- return False
- return True
-
-
- class TestMixDataLoader:
-
- def test_sequential_init(self):
- datasets = {'d1': d1, 'd2': d2, 'd3': d3}
- # drop_last = True, collate_fn = 'auto
- dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True)
- for idx, batch in enumerate(dl):
- if idx == 0:
- # d1
- assert batch['x'].shape == torch.Size([16, 4])
- if idx == 1:
- # d2
- assert batch['x'].shape == torch.Size([16, 3])
- if idx > 1:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
-
- # collate_fn = Callable
- def collate_batch(batch):
- new_batch = {'x': [], 'y': []}
- for ins in batch:
- new_batch['x'].append(ins['x'])
- new_batch['y'].append(ins['y'])
- return new_batch
-
- dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True)
- for idx, batch in enumerate(dl1):
- if idx == 0:
- # d1
- assert [1, 2] in batch['x']
- if idx == 1:
- # d2
- assert [101, 201] in batch['x']
- if idx > 1:
- # d3
- assert [1000, 2000] in batch['x']
- assert 'x' in batch and 'y' in batch
-
- collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
- 'd2': Collator(backend='auto').set_pad("x", -2),
- 'd3': Collator(backend='auto').set_pad("x", -3)}
- dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
- for idx, batch in enumerate(dl2):
- if idx == 0:
- assert test_pad_val(batch['x'], val=-1)
- assert batch['x'].shape == torch.Size([16, 4])
- if idx == 1:
- assert test_pad_val(batch['x'], val=-2)
- assert batch['x'].shape == torch.Size([16, 3])
- if idx > 1:
- assert test_pad_val(batch['x'], val=-3)
- assert batch['x'].shape == torch.Size([16, 4])
-
- # sampler 为 str
- dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True)
- dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True)
- for idx, batch in enumerate(dl3):
- if idx == 0:
- # d1
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape == torch.Size([16, 4])
- if idx == 1:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape == torch.Size([16, 3])
- if idx == 2:
- # d3
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- if idx > 1:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
-
- for idx, batch in enumerate(dl4):
- if idx == 0:
- # d1
- assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape == torch.Size([16, 4])
- if idx == 1:
- # d2
- assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape == torch.Size([16, 3])
- if idx == 2:
- # d3
- assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- if idx > 1:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
-
- # sampler 为 Dict
- samplers = {'d1': SequentialSampler(d1),
- 'd2': SequentialSampler(d2),
- 'd3': RandomSampler(d3)}
- dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True)
- for idx, batch in enumerate(dl5):
- if idx == 0:
- # d1
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape == torch.Size([16, 4])
- if idx == 1:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape == torch.Size([16, 3])
- if idx > 1:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
-
- # ds_ratio 为 'truncate_to_least'
- dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
- for idx, batch in enumerate(dl6):
- if idx == 0:
- # d1
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape == torch.Size([16, 4])
- if idx == 1:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape == torch.Size([16, 3])
- if idx == 2:
- # d3
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
- if idx > 2:
- raise ValueError(f"ds_ratio: 'truncate_to_least' error")
-
- # ds_ratio 为 'pad_to_most'
- dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True)
- for idx, batch in enumerate(dl7):
- if idx < 18:
- # d1
- assert batch['x'].shape == torch.Size([16, 4])
- if 18 <= idx < 36:
- # d2
- assert batch['x'].shape == torch.Size([16, 3])
- if 36 <= idx < 54:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
- if idx >= 54:
- raise ValueError(f"ds_ratio: 'pad_to_most' error")
-
- # ds_ratio 为 Dict[str, float]
- ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
- dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
- for idx, batch in enumerate(dl8):
- if idx < 1:
- # d1
- assert batch['x'].shape == torch.Size([16, 4])
- if 1 <= idx < 4:
- # d2
- assert batch['x'].shape == torch.Size([16, 3])
- if 4 <= idx < 41:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
- assert test_pad_val(batch['x'], val=0)
- if idx >= 41:
- raise ValueError(f"ds_ratio: 'pad_to_most' error")
-
- ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
- dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
- for idx, batch in enumerate(dl9):
- if idx < 1:
- # d2
- assert batch['x'].shape == torch.Size([16, 3])
- if 1 <= idx < 19:
- # d3
- assert batch['x'].shape == torch.Size([16, 4])
-
- assert test_pad_val(batch['x'], val=0)
- if idx >= 19:
- raise ValueError(f"ds_ratio: 'pad_to_most' error")
-
- def test_mix(self):
- datasets = {'d1': d1, 'd2': d2, 'd3': d3}
- dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
- for idx, batch in enumerate(dl):
- assert test_pad_val(batch['x'], val=0)
- if idx >= 22:
- raise ValueError(f"out of range")
-
- # collate_fn = Callable
- def collate_batch(batch):
- new_batch = {'x': [], 'y': []}
- for ins in batch:
- new_batch['x'].append(ins['x'])
- new_batch['y'].append(ins['y'])
- return new_batch
-
- dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
- for idx, batch in enumerate(dl1):
- assert isinstance(batch['x'], list)
- assert test_pad_val(batch['x'], val=0)
- if idx >= 22:
- raise ValueError(f"out of range")
-
- collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
- 'd2': Collator(backend='auto').set_pad("x", -2),
- 'd3': Collator(backend='auto').set_pad("x", -3)}
- with pytest.raises(ValueError):
- MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns)
-
- # sampler 为 str
- dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
- for idx, batch in enumerate(dl3):
- assert test_pad_val(batch['x'], val=0)
- if idx >= 22:
- raise ValueError(f"out of range")
- dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
- for idx, batch in enumerate(dl4):
- assert test_pad_val(batch['x'], val=0)
- if idx >= 22:
- raise ValueError(f"out of range")
- # sampler 为 Dict
- samplers = {'d1': SequentialSampler(d1),
- 'd2': SequentialSampler(d2),
- 'd3': RandomSampler(d3)}
- dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
- for idx, batch in enumerate(dl5):
- assert test_pad_val(batch['x'], val=0)
- if idx >= 22:
- raise ValueError(f"out of range")
- # ds_ratio 为 'truncate_to_least'
- dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least')
- d1_len, d2_len, d3_len = 0, 0, 0
- for idx, batch in enumerate(dl6):
- for item in batch['y'].tolist():
- if item in [1, 0, 1]:
- d1_len += 1
- elif item in [20, 10, 10]:
- d2_len += 1
- elif item in [100, 100, 200]:
- d3_len += 1
- if idx >= 6:
- raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了")
- assert d1_len == d2_len == d3_len == 30
-
- # ds_ratio 为 'pad_to_most'
- dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most')
- d1_len, d2_len, d3_len = 0, 0, 0
- for idx, batch in enumerate(dl7):
- for item in batch['y'].tolist():
- if item in [1, 0, 1]:
- d1_len += 1
- elif item in [20, 10, 10]:
- d2_len += 1
- elif item in [100, 100, 200]:
- d3_len += 1
-
- if idx >= 57:
- raise ValueError(f"ds_ratio 为 'pad_to_most'出错了")
- assert d1_len == d2_len == d3_len == 300
-
- # ds_ratio 为 Dict[str, float]
- ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
- dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
- d1_len, d2_len, d3_len = 0, 0, 0
- for idx, batch in enumerate(dl8):
- for item in batch['y'].tolist():
- if item in [1, 0, 1]:
- d1_len += 1
- elif item in [20, 10, 10]:
- d2_len += 1
- elif item in [100, 100, 200]:
- d3_len += 1
- if idx >= 44:
- raise ValueError(f"ds_ratio 为 'Dict'出错了")
- assert d1_len == 30
- assert d2_len == 60
- assert d3_len == 600
-
- ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
- dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
- d1_len, d2_len, d3_len = 0, 0, 0
- for idx, batch in enumerate(dl9):
- for item in batch['y'].tolist():
- if item in [1, 0, 1]:
- d1_len += 1
- elif item in [20, 10, 10]:
- d2_len += 1
- elif item in [100, 100, 200]:
- d3_len += 1
- if idx >= 21:
- raise ValueError(f"ds_ratio 为 'Dict'出错了")
-
- def test_polling(self):
- datasets = {'d1': d1, 'd2': d2, 'd3': d3}
- dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18)
- for idx, batch in enumerate(dl):
- if idx == 0 or idx == 3:
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- elif idx == 2 or 4 < idx <= 20:
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
- if idx > 20:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
-
- # collate_fn = Callable
- def collate_batch(batch):
- new_batch = {'x': [], 'y': []}
- for ins in batch:
- new_batch['x'].append(ins['x'])
- new_batch['y'].append(ins['y'])
- return new_batch
-
- dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18)
- for idx, batch in enumerate(dl1):
- if idx == 0 or idx == 3:
- assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]]
- elif idx == 1 or idx == 4:
- # d2
- assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]]
- elif idx == 2 or 4 < idx <= 20:
- assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]]
- if idx > 20:
- raise ValueError(f"out of range")
-
- collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
- 'd2': Collator(backend='auto').set_pad("x", -2),
- 'd3': Collator(backend='auto').set_pad("x", -3)}
- dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
- for idx, batch in enumerate(dl1):
- if idx == 0 or idx == 3:
- assert test_pad_val(batch['x'], val=-1)
- assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4:
- # d2
- assert test_pad_val(batch['x'], val=-2)
- assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
- assert batch['x'].shape[1] == 3
- elif idx == 2 or 4 < idx <= 20:
- assert test_pad_val(batch['x'], val=-3)
- assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
- if idx > 20:
- raise ValueError(f"out of range")
-
- # sampler 为 str
- dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18)
- dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18)
- for idx, batch in enumerate(dl2):
- if idx == 0 or idx == 3:
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- elif idx == 2 or 4 < idx <= 20:
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
- if idx > 20:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
- for idx, batch in enumerate(dl3):
- if idx == 0 or idx == 3:
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4:
- # d2
- assert batch['x'].shape[1] == 3
- elif idx == 2 or 4 < idx <= 20:
- assert batch['x'].shape[1] == 4
- if idx > 20:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
- # sampler 为 Dict
- samplers = {'d1': SequentialSampler(d1),
- 'd2': SequentialSampler(d2),
- 'd3': RandomSampler(d3)}
- dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18)
- for idx, batch in enumerate(dl4):
- if idx == 0 or idx == 3:
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- elif idx == 2 or 4 < idx <= 20:
- assert batch['x'].shape[1] == 4
- if idx > 20:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
-
- # ds_ratio 为 'truncate_to_least'
- dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
- for idx, batch in enumerate(dl5):
- if idx == 0 or idx == 3:
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- elif idx == 2 or idx == 5:
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
- if idx > 5:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
-
- # ds_ratio 为 'pad_to_most'
- dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
- for idx, batch in enumerate(dl6):
- if idx % 3 == 0:
- # d1
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- if idx % 3 == 1:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- if idx % 3 == 2:
- # d3
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
- if idx >= 51:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
-
- # ds_ratio 为 Dict[str, float]
- ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
- dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
- for idx, batch in enumerate(dl7):
- if idx == 0 or idx == 3:
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1 or idx == 4 or idx == 6 or idx == 8:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- elif idx == 2 or idx == 5 or idx == 7 or idx > 8:
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
- if idx > 39:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
-
- ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
- dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
- for idx, batch in enumerate(dl8):
- if idx == 0:
- assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
- assert batch['x'].shape[1] == 4
- elif idx == 1:
- # d2
- assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
- assert batch['x'].shape[1] == 3
- elif idx > 1:
- assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
- assert batch['x'].shape[1] == 4
-
- if idx > 18:
- raise ValueError(f"out of range")
- test_pad_val(batch['x'], val=0)
|