Browse Source

add magic iter in dataset

tags/v0.2.0^2
yunfan 5 years ago
parent
commit
ad3c5b6ef0
4 changed files with 41 additions and 23 deletions
  1. +23
    -21
      fastNLP/core/dataset.py
  2. +16
    -0
      fastNLP/core/utils.py
  3. +1
    -1
      fastNLP/modules/encoder/char_embedding.py
  4. +1
    -1
      test/core/test_dataset.py

+ 23
- 21
fastNLP/core/dataset.py View File

@@ -26,24 +26,6 @@ class DataSet(object):
However, it stores data in a different way: Field-first, Instance-second.

"""

class DataSetIter(object):
def __init__(self, data_set, idx=-1, **fields):
self.data_set = data_set
self.idx = idx
self.fields = fields

def __next__(self):
self.idx += 1
if self.idx >= len(self.data_set):
raise StopIteration
# this returns a copy
return self.data_set[self.idx]

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.data_set[name][self.idx])) for name
in self.data_set.get_fields().keys()])

def __init__(self, data=None):
"""

@@ -72,7 +54,27 @@ class DataSet(object):
return item in self.field_arrays

def __iter__(self):
return self.DataSetIter(self)
def iter_func():
for idx in range(len(self)):
yield self[idx]
return iter_func()

def _inner_iter(self):
class Iter_ptr:
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx
def __getitem__(self, item):
assert self.idx < len(self.dataset), "index:{} out of range".format(self.idx)
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx])
return self.dataset.field_arrays[item][self.idx]
def __repr__(self):
return self.dataset[self.idx].__repr__()

def inner_iter_func():
for idx in range(len(self)):
yield Iter_ptr(self, idx)
return inner_iter_func()

def __getitem__(self, idx):
"""Fetch Instance(s) at the `idx` position(s) in the dataset.
@@ -232,7 +234,7 @@ class DataSet(object):
:param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: if new_field_name is not passed, returned values of the function over all instances.
"""
results = [func(ins) for ins in self]
results = [func(ins) for ins in self._inner_iter()]
if new_field_name is not None:
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
@@ -248,7 +250,7 @@ class DataSet(object):
return results

def drop(self, func):
results = [ins for ins in self if not func(ins)]
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]



+ 16
- 0
fastNLP/core/utils.py View File

@@ -382,3 +382,19 @@ def seq_lens_to_masks(seq_lens, float=True):
raise NotImplemented
else:
raise NotImplemented


def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]


+ 1
- 1
fastNLP/modules/encoder/char_embedding.py View File

@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module):
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2)
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = F.tanh(y)
y = torch.tanh(y)
y, __ = torch.max(y, 2)
# [batch_size*sent_length, feature_maps[i]]
feats.append(y)


+ 1
- 1
test/core/test_dataset.py View File

@@ -130,4 +130,4 @@ class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4], 'y': [5, 6]}")
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")

Loading…
Cancel
Save