diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index c73e3fef..e1964d99 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -14,17 +14,11 @@ class DataSet(list): """ - def __init__(self, name="", instances=None): + def __init__(self, fields=None): """ - :param name: str, the name of the dataset. (default: "") - :param instances: list of Instance objects. (default: None) """ - list.__init__([]) - self.name = name - self.origin_len = None - if instances is not None: - self.extend(instances) + pass def index_all(self, vocab): for ins in self: diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 1c5e7425..48e451f6 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -1,4 +1,5 @@ import torch +import numpy as np class Field(object): @@ -6,61 +7,69 @@ class Field(object): """ - def __init__(self, is_target: bool): + def __init__(self, name, is_target: bool): + self.name = name self.is_target = is_target + self.content = None def index(self, vocab): + """create index field + """ raise NotImplementedError - def get_length(self): - raise NotImplementedError - - def to_tensor(self, padding_length): - raise NotImplementedError + def __len__(self): + """number of samples + """ + assert self.content is not None + return len(self.content) - def contents(self): + def to_tensor(self, id_list): + """convert batch of index to tensor + """ raise NotImplementedError class TextField(Field): - def __init__(self, text, is_target): + def __init__(self, name, text, is_target): """ :param text: list of strings :param is_target: bool """ - super(TextField, self).__init__(is_target) - self.text = text - self._index = None + super(TextField, self).__init__(name, is_target) + self.content = text def index(self, vocab): - if self._index is None: - self._index = [vocab[c] for c in self.text] - else: - raise RuntimeError("Replicate indexing of this field.") - return self._index - - def get_length(self): - """Fetch the length of the text field. - - :return length: int, the length of the text. - - """ - return len(self.text) - - def to_tensor(self, padding_length: int): - """Convert text field to tensor. - - :param padding_length: int - :return tensor: torch.LongTensor, of shape [padding_length, ] - """ - pads = [] - if self._index is None: - raise RuntimeError("Indexing not done before to_tensor in TextField.") - if padding_length > self.get_length(): - pads = [0] * (padding_length - self.get_length()) - return torch.LongTensor(self._index + pads) - - def contents(self): - return self.text.copy() + idx_field = IndexField(self.name+'_idx', self.content, vocab, self.is_target) + return idx_field + + +class IndexField(Field): + def __init__(self, name, content, vocab, is_target): + super(IndexField, self).__init__(name, is_target) + self.content = [] + self.padding_idx = vocab.padding_idx + for sent in content: + idx = vocab.index_sent(sent) + if isinstance(idx, list): + idx = torch.Tensor(idx) + elif isinstance(idx, np.array): + idx = torch.from_numpy(idx) + elif not isinstance(idx, torch.Tensor): + raise ValueError + self.content.append(idx) + + def to_tensor(self, id_list, sort_within_batch=False): + max_len = max(id_list) + batch_size = len(id_list) + tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) + len_list = [(i, self.content[i].size(0)) for i in id_list] + if sort_within_batch: + len_list = sorted(len_list, key=lambda x: x[1], reverse=True) + for i, (idx, length) in enumerate(len_list): + if length == max_len: + tensor[i] = self.content[idx] + else: + tensor[i][:length] = self.content[idx] + return tensor class LabelField(Field): """The Field representing a single label. Can be a string or integer.