@@ -2,10 +2,12 @@ import random | |||||
import sys | import sys | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from copy import deepcopy | from copy import deepcopy | ||||
import numpy as np | |||||
from fastNLP.core.field import TextField, LabelField | from fastNLP.core.field import TextField, LabelField | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.fieldarray import FieldArray | |||||
_READERS = {} | _READERS = {} | ||||
@@ -14,43 +16,29 @@ class DataSet(object): | |||||
""" | """ | ||||
def __init__(self, fields=None): | |||||
""" | |||||
""" | |||||
pass | |||||
def index_all(self, vocab): | |||||
for ins in self: | |||||
ins.index_all(vocab) | |||||
return self | |||||
def index_field(self, field_name, vocab): | |||||
if isinstance(field_name, str): | |||||
field_list = [field_name] | |||||
vocab_list = [vocab] | |||||
def __init__(self, instance=None): | |||||
if instance is not None: | |||||
self._convert_ins(instance) | |||||
else: | else: | ||||
classes = (list, tuple) | |||||
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) | |||||
field_list = field_name | |||||
vocab_list = vocab | |||||
for name, vocabs in zip(field_list, vocab_list): | |||||
for ins in self: | |||||
ins.index_field(name, vocabs) | |||||
return self | |||||
def to_tensor(self, idx: int, padding_length: dict): | |||||
"""Convert an instance in a dataset to tensor. | |||||
self.field_arrays = {} | |||||
:param idx: int, the index of the instance in the dataset. | |||||
:param padding_length: int | |||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
def _convert_ins(self, ins_list): | |||||
if isinstance(ins_list, list): | |||||
for ins in ins_list: | |||||
self.append(ins) | |||||
else: | |||||
self.append(ins) | |||||
""" | |||||
ins = self[idx] | |||||
return ins.to_tensor(padding_length, self.origin_len) | |||||
def append(self, ins): | |||||
# no field | |||||
if len(self.field_arrays) == 0: | |||||
for name, field in ins.field.items(): | |||||
self.field_arrays[name] = FieldArray(name, [field]) | |||||
else: | |||||
assert len(self.field_arrays) == len(ins.field) | |||||
for name, field in ins.field.items(): | |||||
assert name in self.field_arrays | |||||
self.field_arrays[name].append(field) | |||||
def get_length(self): | def get_length(self): | ||||
"""Fetch lengths of all fields in all instances in a dataset. | """Fetch lengths of all fields in all instances in a dataset. | ||||
@@ -59,15 +47,10 @@ class DataSet(object): | |||||
The list contains lengths of this field in all instances. | The list contains lengths of this field in all instances. | ||||
""" | """ | ||||
lengths = defaultdict(list) | |||||
for ins in self: | |||||
for field_name, field_length in ins.get_length().items(): | |||||
lengths[field_name].append(field_length) | |||||
return lengths | |||||
pass | |||||
def shuffle(self): | def shuffle(self): | ||||
random.shuffle(self) | |||||
return self | |||||
pass | |||||
def split(self, ratio, shuffle=True): | def split(self, ratio, shuffle=True): | ||||
"""Train/dev splitting | """Train/dev splitting | ||||
@@ -78,58 +61,37 @@ class DataSet(object): | |||||
dev_set: a DataSet object, representing the validation set | dev_set: a DataSet object, representing the validation set | ||||
""" | """ | ||||
assert 0 < ratio < 1 | |||||
if shuffle: | |||||
self.shuffle() | |||||
split_idx = int(len(self) * ratio) | |||||
dev_set = deepcopy(self) | |||||
train_set = deepcopy(self) | |||||
del train_set[:split_idx] | |||||
del dev_set[split_idx:] | |||||
return train_set, dev_set | |||||
pass | |||||
def rename_field(self, old_name, new_name): | def rename_field(self, old_name, new_name): | ||||
"""rename a field | """rename a field | ||||
""" | """ | ||||
for ins in self: | |||||
ins.rename_field(old_name, new_name) | |||||
if old_name in self.field_arrays: | |||||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||||
else: | |||||
raise KeyError | |||||
return self | return self | ||||
def set_target(self, **fields): | |||||
def set_is_target(self, **fields): | |||||
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | """Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged. | ||||
:param key-value pairs for field-name and `is_target` value(True, False or None). | |||||
:param key-value pairs for field-name and `is_target` value(True, False). | |||||
""" | """ | ||||
for ins in self: | |||||
ins.set_target(**fields) | |||||
for name, val in fields.items(): | |||||
if name in self.field_arrays: | |||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].is_target = val | |||||
else: | |||||
raise KeyError | |||||
return self | return self | ||||
def update_vocab(self, **name_vocab): | |||||
"""using certain field data to update vocabulary. | |||||
e.g. :: | |||||
# update word vocab and label vocab seperately | |||||
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab) | |||||
""" | |||||
for field_name, vocab in name_vocab.items(): | |||||
for ins in self: | |||||
vocab.update(ins[field_name].contents()) | |||||
return self | |||||
def set_origin_len(self, origin_field, origin_len_name=None): | |||||
"""make dataset tensor output contain origin_len field. | |||||
e.g. :: | |||||
# output "word_seq_origin_len", lengths based on "word_seq" field | |||||
dataset.set_origin_len("word_seq") | |||||
""" | |||||
if origin_field is None: | |||||
self.origin_len = None | |||||
else: | |||||
self.origin_len = (origin_field + "_origin_len", origin_field) \ | |||||
if origin_len_name is None else (origin_len_name, origin_field) | |||||
def set_need_tensor(self, **kwargs): | |||||
for name, val in kwargs.items(): | |||||
if name in self.field_arrays: | |||||
assert isinstance(val, bool) | |||||
self.field_arrays[name].need_tensor = val | |||||
else: | |||||
raise KeyError | |||||
return self | return self | ||||
def __getattribute__(self, name): | def __getattribute__(self, name): | ||||
@@ -7,10 +7,9 @@ class Field(object): | |||||
""" | """ | ||||
def __init__(self, name, is_target: bool): | |||||
self.name = name | |||||
def __init__(self, content, is_target: bool): | |||||
self.is_target = is_target | self.is_target = is_target | ||||
self.content = None | |||||
self.content = content | |||||
def index(self, vocab): | def index(self, vocab): | ||||
"""create index field | """create index field | ||||
@@ -29,23 +28,15 @@ class Field(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def __repr__(self): | def __repr__(self): | ||||
return self.contents().__repr__() | |||||
def new(self, *args, **kwargs): | |||||
return self.__class__(*args, **kwargs, is_target=self.is_target) | |||||
return self.content.__repr__() | |||||
class TextField(Field): | class TextField(Field): | ||||
def __init__(self, name, text, is_target): | |||||
def __init__(self, text, is_target): | |||||
""" | """ | ||||
:param text: list of strings | :param text: list of strings | ||||
:param is_target: bool | :param is_target: bool | ||||
""" | """ | ||||
super(TextField, self).__init__(name, is_target) | |||||
self.content = text | |||||
def index(self, vocab): | |||||
idx_field = IndexField(self.name+'_idx', self.content, vocab, self.is_target) | |||||
return idx_field | |||||
super(TextField, self).__init__(text, is_target) | |||||
class IndexField(Field): | class IndexField(Field): | ||||
@@ -82,75 +73,19 @@ class LabelField(Field): | |||||
""" | """ | ||||
def __init__(self, label, is_target=True): | def __init__(self, label, is_target=True): | ||||
super(LabelField, self).__init__(is_target) | |||||
self.label = label | |||||
self._index = None | |||||
super(LabelField, self).__init__(label, is_target) | |||||
def get_length(self): | |||||
"""Fetch the length of the label field. | |||||
:return length: int, the length of the label, always 1. | |||||
""" | |||||
return 1 | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
if isinstance(self.label, str): | |||||
self._index = vocab[self.label] | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
if self._index is None: | |||||
if isinstance(self.label, int): | |||||
return torch.tensor(self.label) | |||||
elif isinstance(self.label, str): | |||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
else: | |||||
raise RuntimeError( | |||||
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
else: | |||||
return torch.LongTensor([self._index]) | |||||
def contents(self): | |||||
return [self.label] | |||||
class SeqLabelField(Field): | class SeqLabelField(Field): | ||||
def __init__(self, label_seq, is_target=True): | def __init__(self, label_seq, is_target=True): | ||||
super(SeqLabelField, self).__init__(is_target) | |||||
self.label_seq = label_seq | |||||
self._index = None | |||||
def get_length(self): | |||||
return len(self.label_seq) | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = [vocab[c] for c in self.label_seq] | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
pads = [0] * (padding_length - self.get_length()) | |||||
if self._index is None: | |||||
if self.get_length() == 0: | |||||
return torch.LongTensor(pads) | |||||
elif isinstance(self.label_seq[0], int): | |||||
return torch.LongTensor(self.label_seq + pads) | |||||
elif isinstance(self.label_seq[0], str): | |||||
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
else: | |||||
raise RuntimeError( | |||||
"Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
else: | |||||
return torch.LongTensor(self._index + pads) | |||||
def contents(self): | |||||
return self.label_seq.copy() | |||||
super(SeqLabelField, self).__init__(label_seq, is_target) | |||||
class CharTextField(Field): | class CharTextField(Field): | ||||
def __init__(self, text, max_word_len, is_target=False): | def __init__(self, text, max_word_len, is_target=False): | ||||
super(CharTextField, self).__init__(is_target) | super(CharTextField, self).__init__(is_target) | ||||
self.text = text | |||||
# TODO | |||||
raise NotImplementedError | |||||
self.max_word_len = max_word_len | self.max_word_len = max_word_len | ||||
self._index = [] | self._index = [] | ||||
@@ -0,0 +1,39 @@ | |||||
import torch | |||||
import numpy as np | |||||
class FieldArray(object): | |||||
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True): | |||||
self.name = name | |||||
self.data = [self._convert_np(val) for val in content] | |||||
self.padding_val = padding_val | |||||
self.is_target = is_target | |||||
self.need_tensor = need_tensor | |||||
def _convert_np(self, val): | |||||
if not isinstance(val, np.array): | |||||
return np.array(val) | |||||
else: | |||||
return val | |||||
def append(self, val): | |||||
self.data.append(self._convert_np(val)) | |||||
def get(self, idxes): | |||||
if isinstance(idxes, int): | |||||
return self.data[idxes] | |||||
elif isinstance(idxes, list): | |||||
id_list = np.array(idxes) | |||||
batch_size = len(id_list) | |||||
len_list = [(i, self.data[i].shape[0]) for i in id_list] | |||||
_, max_len = max(len_list, key=lambda x: x[1]) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32) | |||||
for i, (idx, length) in enumerate(len_list): | |||||
if length == max_len: | |||||
array[i] = self.data[idx] | |||||
else: | |||||
array[i][:length] = self.data[idx] | |||||
return array | |||||
def __len__(self): | |||||
return len(self.data) |
@@ -7,8 +7,6 @@ class Instance(object): | |||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
self.fields = fields | self.fields = fields | ||||
self.has_index = False | |||||
self.indexes = {} | |||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
self.fields[field_name] = field | self.fields[field_name] = field | ||||
@@ -17,8 +15,6 @@ class Instance(object): | |||||
def rename_field(self, old_name, new_name): | def rename_field(self, old_name, new_name): | ||||
if old_name in self.fields: | if old_name in self.fields: | ||||
self.fields[new_name] = self.fields.pop(old_name) | self.fields[new_name] = self.fields.pop(old_name) | ||||
if old_name in self.indexes: | |||||
self.indexes[new_name] = self.indexes.pop(old_name) | |||||
else: | else: | ||||
raise KeyError("error, no such field: {}".format(old_name)) | raise KeyError("error, no such field: {}".format(old_name)) | ||||
return self | return self | ||||
@@ -38,53 +34,5 @@ class Instance(object): | |||||
def __setitem__(self, name, field): | def __setitem__(self, name, field): | ||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def get_length(self): | |||||
"""Fetch the length of all fields in the instance. | |||||
:return length: dict of (str: int), which means (field name: field length). | |||||
""" | |||||
length = {name: field.get_length() for name, field in self.fields.items()} | |||||
return length | |||||
def index_field(self, field_name, vocab): | |||||
"""use `vocab` to index certain field | |||||
""" | |||||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||||
return self | |||||
def index_all(self, vocab): | |||||
"""use `vocab` to index all fields | |||||
""" | |||||
if self.has_index: | |||||
print("error") | |||||
return self.indexes | |||||
indexes = {name: field.index(vocab) for name, field in self.fields.items()} | |||||
self.indexes = indexes | |||||
return indexes | |||||
def to_tensor(self, padding_length: dict, origin_len=None): | |||||
"""Convert instance to tensor. | |||||
:param padding_length: dict of (str: int), which means (field name: padding_length of this field) | |||||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
If is_target is False for all fields, tensor_y would be an empty dict. | |||||
""" | |||||
tensor_x = {} | |||||
tensor_y = {} | |||||
for name, field in self.fields.items(): | |||||
if field.is_target is True: | |||||
tensor_y[name] = field.to_tensor(padding_length[name]) | |||||
elif field.is_target is False: | |||||
tensor_x[name] = field.to_tensor(padding_length[name]) | |||||
else: | |||||
# is_target is None | |||||
continue | |||||
if origin_len is not None: | |||||
name, field_name = origin_len | |||||
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()]) | |||||
return tensor_x, tensor_y | |||||
def __repr__(self): | def __repr__(self): | ||||
return self.fields.__repr__() | return self.fields.__repr__() |