From 0cbbfd522155d1de4b5292ddad109377d162997b Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 9 Nov 2018 20:06:06 +0800 Subject: [PATCH] update dataset --- fastNLP/core/dataset.py | 126 +++++++++++++------------------------ fastNLP/core/field.py | 83 +++--------------------- fastNLP/core/fieldarray.py | 39 ++++++++++++ fastNLP/core/instance.py | 52 --------------- 4 files changed, 92 insertions(+), 208 deletions(-) create mode 100644 fastNLP/core/fieldarray.py diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index c2a10210..a08a429c 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -2,10 +2,12 @@ import random import sys from collections import defaultdict from copy import deepcopy +import numpy as np from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance from fastNLP.core.vocabulary import Vocabulary +from fastNLP.core.fieldarray import FieldArray _READERS = {} @@ -14,43 +16,29 @@ class DataSet(object): """ - def __init__(self, fields=None): - """ - - """ - pass - - def index_all(self, vocab): - for ins in self: - ins.index_all(vocab) - return self - - def index_field(self, field_name, vocab): - if isinstance(field_name, str): - field_list = [field_name] - vocab_list = [vocab] + def __init__(self, instance=None): + if instance is not None: + self._convert_ins(instance) else: - classes = (list, tuple) - assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) - field_list = field_name - vocab_list = vocab - - for name, vocabs in zip(field_list, vocab_list): - for ins in self: - ins.index_field(name, vocabs) - return self - - def to_tensor(self, idx: int, padding_length: dict): - """Convert an instance in a dataset to tensor. + self.field_arrays = {} - :param idx: int, the index of the instance in the dataset. - :param padding_length: int - :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) - tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) + def _convert_ins(self, ins_list): + if isinstance(ins_list, list): + for ins in ins_list: + self.append(ins) + else: + self.append(ins) - """ - ins = self[idx] - return ins.to_tensor(padding_length, self.origin_len) + def append(self, ins): + # no field + if len(self.field_arrays) == 0: + for name, field in ins.field.items(): + self.field_arrays[name] = FieldArray(name, [field]) + else: + assert len(self.field_arrays) == len(ins.field) + for name, field in ins.field.items(): + assert name in self.field_arrays + self.field_arrays[name].append(field) def get_length(self): """Fetch lengths of all fields in all instances in a dataset. @@ -59,15 +47,10 @@ class DataSet(object): The list contains lengths of this field in all instances. """ - lengths = defaultdict(list) - for ins in self: - for field_name, field_length in ins.get_length().items(): - lengths[field_name].append(field_length) - return lengths + pass def shuffle(self): - random.shuffle(self) - return self + pass def split(self, ratio, shuffle=True): """Train/dev splitting @@ -78,58 +61,37 @@ class DataSet(object): dev_set: a DataSet object, representing the validation set """ - assert 0 < ratio < 1 - if shuffle: - self.shuffle() - split_idx = int(len(self) * ratio) - dev_set = deepcopy(self) - train_set = deepcopy(self) - del train_set[:split_idx] - del dev_set[split_idx:] - return train_set, dev_set + pass def rename_field(self, old_name, new_name): """rename a field """ - for ins in self: - ins.rename_field(old_name, new_name) + if old_name in self.field_arrays: + self.field_arrays[new_name] = self.field_arrays.pop(old_name) + else: + raise KeyError return self - def set_target(self, **fields): + def set_is_target(self, **fields): """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. - :param key-value pairs for field-name and `is_target` value(True, False or None). + :param key-value pairs for field-name and `is_target` value(True, False). """ - for ins in self: - ins.set_target(**fields) + for name, val in fields.items(): + if name in self.field_arrays: + assert isinstance(val, bool) + self.field_arrays[name].is_target = val + else: + raise KeyError return self - def update_vocab(self, **name_vocab): - """using certain field data to update vocabulary. - - e.g. :: - - # update word vocab and label vocab seperately - dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab) - """ - for field_name, vocab in name_vocab.items(): - for ins in self: - vocab.update(ins[field_name].contents()) - return self - - def set_origin_len(self, origin_field, origin_len_name=None): - """make dataset tensor output contain origin_len field. - - e.g. :: - - # output "word_seq_origin_len", lengths based on "word_seq" field - dataset.set_origin_len("word_seq") - """ - if origin_field is None: - self.origin_len = None - else: - self.origin_len = (origin_field + "_origin_len", origin_field) \ - if origin_len_name is None else (origin_len_name, origin_field) + def set_need_tensor(self, **kwargs): + for name, val in kwargs.items(): + if name in self.field_arrays: + assert isinstance(val, bool) + self.field_arrays[name].need_tensor = val + else: + raise KeyError return self def __getattribute__(self, name): diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 8720bf1b..5b9c1b63 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -7,10 +7,9 @@ class Field(object): """ - def __init__(self, name, is_target: bool): - self.name = name + def __init__(self, content, is_target: bool): self.is_target = is_target - self.content = None + self.content = content def index(self, vocab): """create index field @@ -29,23 +28,15 @@ class Field(object): raise NotImplementedError def __repr__(self): - return self.contents().__repr__() - - def new(self, *args, **kwargs): - return self.__class__(*args, **kwargs, is_target=self.is_target) + return self.content.__repr__() class TextField(Field): - def __init__(self, name, text, is_target): + def __init__(self, text, is_target): """ :param text: list of strings :param is_target: bool """ - super(TextField, self).__init__(name, is_target) - self.content = text - - def index(self, vocab): - idx_field = IndexField(self.name+'_idx', self.content, vocab, self.is_target) - return idx_field + super(TextField, self).__init__(text, is_target) class IndexField(Field): @@ -82,75 +73,19 @@ class LabelField(Field): """ def __init__(self, label, is_target=True): - super(LabelField, self).__init__(is_target) - self.label = label - self._index = None + super(LabelField, self).__init__(label, is_target) - def get_length(self): - """Fetch the length of the label field. - - :return length: int, the length of the label, always 1. - """ - return 1 - - def index(self, vocab): - if self._index is None: - if isinstance(self.label, str): - self._index = vocab[self.label] - return self._index - - def to_tensor(self, padding_length): - if self._index is None: - if isinstance(self.label, int): - return torch.tensor(self.label) - elif isinstance(self.label, str): - raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) - else: - raise RuntimeError( - "Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) - else: - return torch.LongTensor([self._index]) - - def contents(self): - return [self.label] class SeqLabelField(Field): def __init__(self, label_seq, is_target=True): - super(SeqLabelField, self).__init__(is_target) - self.label_seq = label_seq - self._index = None - - def get_length(self): - return len(self.label_seq) - - def index(self, vocab): - if self._index is None: - self._index = [vocab[c] for c in self.label_seq] - return self._index - - def to_tensor(self, padding_length): - pads = [0] * (padding_length - self.get_length()) - if self._index is None: - if self.get_length() == 0: - return torch.LongTensor(pads) - elif isinstance(self.label_seq[0], int): - return torch.LongTensor(self.label_seq + pads) - elif isinstance(self.label_seq[0], str): - raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) - else: - raise RuntimeError( - "Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label))) - else: - return torch.LongTensor(self._index + pads) - - def contents(self): - return self.label_seq.copy() + super(SeqLabelField, self).__init__(label_seq, is_target) class CharTextField(Field): def __init__(self, text, max_word_len, is_target=False): super(CharTextField, self).__init__(is_target) - self.text = text + # TODO + raise NotImplementedError self.max_word_len = max_word_len self._index = [] diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py new file mode 100644 index 00000000..9710f991 --- /dev/null +++ b/fastNLP/core/fieldarray.py @@ -0,0 +1,39 @@ +import torch +import numpy as np + +class FieldArray(object): + def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): + self.name = name + self.data = [self._convert_np(val) for val in content] + self.padding_val = padding_val + self.is_target = is_target + self.need_tensor = need_tensor + + def _convert_np(self, val): + if not isinstance(val, np.array): + return np.array(val) + else: + return val + + def append(self, val): + self.data.append(self._convert_np(val)) + + def get(self, idxes): + if isinstance(idxes, int): + return self.data[idxes] + elif isinstance(idxes, list): + id_list = np.array(idxes) + batch_size = len(id_list) + len_list = [(i, self.data[i].shape[0]) for i in id_list] + _, max_len = max(len_list, key=lambda x: x[1]) + array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) + + for i, (idx, length) in enumerate(len_list): + if length == max_len: + array[i] = self.data[idx] + else: + array[i][:length] = self.data[idx] + return array + + def __len__(self): + return len(self.data) diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 50787fd1..a2686da8 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -7,8 +7,6 @@ class Instance(object): def __init__(self, **fields): self.fields = fields - self.has_index = False - self.indexes = {} def add_field(self, field_name, field): self.fields[field_name] = field @@ -17,8 +15,6 @@ class Instance(object): def rename_field(self, old_name, new_name): if old_name in self.fields: self.fields[new_name] = self.fields.pop(old_name) - if old_name in self.indexes: - self.indexes[new_name] = self.indexes.pop(old_name) else: raise KeyError("error, no such field: {}".format(old_name)) return self @@ -38,53 +34,5 @@ class Instance(object): def __setitem__(self, name, field): return self.add_field(name, field) - def get_length(self): - """Fetch the length of all fields in the instance. - - :return length: dict of (str: int), which means (field name: field length). - - """ - length = {name: field.get_length() for name, field in self.fields.items()} - return length - - def index_field(self, field_name, vocab): - """use `vocab` to index certain field - """ - self.indexes[field_name] = self.fields[field_name].index(vocab) - return self - - def index_all(self, vocab): - """use `vocab` to index all fields - """ - if self.has_index: - print("error") - return self.indexes - indexes = {name: field.index(vocab) for name, field in self.fields.items()} - self.indexes = indexes - return indexes - - def to_tensor(self, padding_length: dict, origin_len=None): - """Convert instance to tensor. - - :param padding_length: dict of (str: int), which means (field name: padding_length of this field) - :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) - tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) - If is_target is False for all fields, tensor_y would be an empty dict. - """ - tensor_x = {} - tensor_y = {} - for name, field in self.fields.items(): - if field.is_target is True: - tensor_y[name] = field.to_tensor(padding_length[name]) - elif field.is_target is False: - tensor_x[name] = field.to_tensor(padding_length[name]) - else: - # is_target is None - continue - if origin_len is not None: - name, field_name = origin_len - tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) - return tensor_x, tensor_y - def __repr__(self): return self.fields.__repr__() \ No newline at end of file