From ad3c5b6ef02947bb718382538d22c3407625acf5 Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 3 Dec 2018 21:54:22 +0800 Subject: [PATCH] add magic iter in dataset --- fastNLP/core/dataset.py | 44 ++++++++++++----------- fastNLP/core/utils.py | 16 +++++++++ fastNLP/modules/encoder/char_embedding.py | 2 +- test/core/test_dataset.py | 2 +- 4 files changed, 41 insertions(+), 23 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 40ea0aab..dea27174 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -26,24 +26,6 @@ class DataSet(object): However, it stores data in a different way: Field-first, Instance-second. """ - - class DataSetIter(object): - def __init__(self, data_set, idx=-1, **fields): - self.data_set = data_set - self.idx = idx - self.fields = fields - - def __next__(self): - self.idx += 1 - if self.idx >= len(self.data_set): - raise StopIteration - # this returns a copy - return self.data_set[self.idx] - - def __repr__(self): - return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name - in self.data_set.get_fields().keys()]) - def __init__(self, data=None): """ @@ -72,7 +54,27 @@ class DataSet(object): return item in self.field_arrays def __iter__(self): - return self.DataSetIter(self) + def iter_func(): + for idx in range(len(self)): + yield self[idx] + return iter_func() + + def _inner_iter(self): + class Iter_ptr: + def __init__(self, dataset, idx): + self.dataset = dataset + self.idx = idx + def __getitem__(self, item): + assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx) + assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) + return self.dataset.field_arrays[item][self.idx] + def __repr__(self): + return self.dataset[self.idx].__repr__() + + def inner_iter_func(): + for idx in range(len(self)): + yield Iter_ptr(self, idx) + return inner_iter_func() def __getitem__(self, idx): """Fetch Instance(s) at the `idx` position(s) in the dataset. @@ -232,7 +234,7 @@ class DataSet(object): :param str new_field_name: If not None, results of the function will be stored as a new field. :return results: if new_field_name is not passed, returned values of the function over all instances. """ - results = [func(ins) for ins in self] + results = [func(ins) for ins in self._inner_iter()] if new_field_name is not None: if new_field_name in self.field_arrays: # overwrite the field, keep same attributes @@ -248,7 +250,7 @@ class DataSet(object): return results def drop(self, func): - results = [ins for ins in self if not func(ins)] + results = [ins for ins in self._inner_iter() if not func(ins)] for name, old_field in self.field_arrays.items(): self.field_arrays[name].content = [ins[name] for ins in results] diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 6c101890..abe7889c 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -382,3 +382,19 @@ def seq_lens_to_masks(seq_lens, float=True): raise NotImplemented else: raise NotImplemented + + +def seq_mask(seq_len, max_len): + """Create sequence mask. + + :param seq_len: list or torch.Tensor, the lengths of sequences in a batch. + :param max_len: int, the maximum sequence length in a batch. + :return mask: torch.LongTensor, [batch_size, max_len] + + """ + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.LongTensor(seq_len) + seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] + seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] + return torch.gt(seq_len, seq_range) # [batch_size, max_len] + diff --git a/fastNLP/modules/encoder/char_embedding.py b/fastNLP/modules/encoder/char_embedding.py index 1ca3b5ba..249a73ad 100644 --- a/fastNLP/modules/encoder/char_embedding.py +++ b/fastNLP/modules/encoder/char_embedding.py @@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module): # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] y = torch.squeeze(y, 2) # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] - y = F.tanh(y) + y = torch.tanh(y) y, __ = torch.max(y, 2) # [batch_size*sent_length, feature_maps[i]] feats.append(y) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index fa3e1ea3..8ca2ed86 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -130,4 +130,4 @@ class TestDataSetIter(unittest.TestCase): def test__repr__(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: - self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}") + self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")