Browse Source

update dataset

tags/v0.2.0
yunfan 5 years ago
parent
commit
0cbbfd5221
4 changed files with 92 additions and 208 deletions
  1. +44
    -82
      fastNLP/core/dataset.py
  2. +9
    -74
      fastNLP/core/field.py
  3. +39
    -0
      fastNLP/core/fieldarray.py
  4. +0
    -52
      fastNLP/core/instance.py

+ 44
- 82
fastNLP/core/dataset.py View File

@@ -2,10 +2,12 @@ import random
import sys
from collections import defaultdict
from copy import deepcopy
import numpy as np

from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.fieldarray import FieldArray

_READERS = {}

@@ -14,43 +16,29 @@ class DataSet(object):

"""

def __init__(self, fields=None):
"""

"""
pass

def index_all(self, vocab):
for ins in self:
ins.index_all(vocab)
return self

def index_field(self, field_name, vocab):
if isinstance(field_name, str):
field_list = [field_name]
vocab_list = [vocab]
def __init__(self, instance=None):
if instance is not None:
self._convert_ins(instance)
else:
classes = (list, tuple)
assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab)
field_list = field_name
vocab_list = vocab

for name, vocabs in zip(field_list, vocab_list):
for ins in self:
ins.index_field(name, vocabs)
return self

def to_tensor(self, idx: int, padding_length: dict):
"""Convert an instance in a dataset to tensor.
self.field_arrays = {}

:param idx: int, the index of the instance in the dataset.
:param padding_length: int
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
def _convert_ins(self, ins_list):
if isinstance(ins_list, list):
for ins in ins_list:
self.append(ins)
else:
self.append(ins)

"""
ins = self[idx]
return ins.to_tensor(padding_length, self.origin_len)
def append(self, ins):
# no field
if len(self.field_arrays) == 0:
for name, field in ins.field.items():
self.field_arrays[name] = FieldArray(name, [field])
else:
assert len(self.field_arrays) == len(ins.field)
for name, field in ins.field.items():
assert name in self.field_arrays
self.field_arrays[name].append(field)

def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.
@@ -59,15 +47,10 @@ class DataSet(object):
The list contains lengths of this field in all instances.

"""
lengths = defaultdict(list)
for ins in self:
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths
pass

def shuffle(self):
random.shuffle(self)
return self
pass

def split(self, ratio, shuffle=True):
"""Train/dev splitting
@@ -78,58 +61,37 @@ class DataSet(object):
dev_set: a DataSet object, representing the validation set

"""
assert 0 < ratio < 1
if shuffle:
self.shuffle()
split_idx = int(len(self) * ratio)
dev_set = deepcopy(self)
train_set = deepcopy(self)
del train_set[:split_idx]
del dev_set[split_idx:]
return train_set, dev_set
pass

def rename_field(self, old_name, new_name):
"""rename a field
"""
for ins in self:
ins.rename_field(old_name, new_name)
if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name)
else:
raise KeyError
return self

def set_target(self, **fields):
def set_is_target(self, **fields):
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged.

:param key-value pairs for field-name and `is_target` value(True, False or None).
:param key-value pairs for field-name and `is_target` value(True, False).
"""
for ins in self:
ins.set_target(**fields)
for name, val in fields.items():
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_target = val
else:
raise KeyError
return self

def update_vocab(self, **name_vocab):
"""using certain field data to update vocabulary.

e.g. ::

# update word vocab and label vocab seperately
dataset.update_vocab(word_seq=word_vocab, label_seq=label_vocab)
"""
for field_name, vocab in name_vocab.items():
for ins in self:
vocab.update(ins[field_name].contents())
return self

def set_origin_len(self, origin_field, origin_len_name=None):
"""make dataset tensor output contain origin_len field.

e.g. ::

# output "word_seq_origin_len", lengths based on "word_seq" field
dataset.set_origin_len("word_seq")
"""
if origin_field is None:
self.origin_len = None
else:
self.origin_len = (origin_field + "_origin_len", origin_field) \
if origin_len_name is None else (origin_len_name, origin_field)
def set_need_tensor(self, **kwargs):
for name, val in kwargs.items():
if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].need_tensor = val
else:
raise KeyError
return self

def __getattribute__(self, name):


+ 9
- 74
fastNLP/core/field.py View File

@@ -7,10 +7,9 @@ class Field(object):

"""

def __init__(self, name, is_target: bool):
self.name = name
def __init__(self, content, is_target: bool):
self.is_target = is_target
self.content = None
self.content = content

def index(self, vocab):
"""create index field
@@ -29,23 +28,15 @@ class Field(object):
raise NotImplementedError

def __repr__(self):
return self.contents().__repr__()

def new(self, *args, **kwargs):
return self.__class__(*args, **kwargs, is_target=self.is_target)
return self.content.__repr__()

class TextField(Field):
def __init__(self, name, text, is_target):
def __init__(self, text, is_target):
"""
:param text: list of strings
:param is_target: bool
"""
super(TextField, self).__init__(name, is_target)
self.content = text

def index(self, vocab):
idx_field = IndexField(self.name+'_idx', self.content, vocab, self.is_target)
return idx_field
super(TextField, self).__init__(text, is_target)


class IndexField(Field):
@@ -82,75 +73,19 @@ class LabelField(Field):

"""
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
self._index = None
super(LabelField, self).__init__(label, is_target)

def get_length(self):
"""Fetch the length of the label field.

:return length: int, the length of the label, always 1.
"""
return 1

def index(self, vocab):
if self._index is None:
if isinstance(self.label, str):
self._index = vocab[self.label]
return self._index

def to_tensor(self, padding_length):
if self._index is None:
if isinstance(self.label, int):
return torch.tensor(self.label)
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for LabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor([self._index])

def contents(self):
return [self.label]

class SeqLabelField(Field):
def __init__(self, label_seq, is_target=True):
super(SeqLabelField, self).__init__(is_target)
self.label_seq = label_seq
self._index = None

def get_length(self):
return len(self.label_seq)

def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.label_seq]
return self._index

def to_tensor(self, padding_length):
pads = [0] * (padding_length - self.get_length())
if self._index is None:
if self.get_length() == 0:
return torch.LongTensor(pads)
elif isinstance(self.label_seq[0], int):
return torch.LongTensor(self.label_seq + pads)
elif isinstance(self.label_seq[0], str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor(self._index + pads)

def contents(self):
return self.label_seq.copy()
super(SeqLabelField, self).__init__(label_seq, is_target)


class CharTextField(Field):
def __init__(self, text, max_word_len, is_target=False):
super(CharTextField, self).__init__(is_target)
self.text = text
# TODO
raise NotImplementedError
self.max_word_len = max_word_len
self._index = []



+ 39
- 0
fastNLP/core/fieldarray.py View File

@@ -0,0 +1,39 @@
import torch
import numpy as np

class FieldArray(object):
def __init__(self, name, content, padding_val=0, is_target=True, need_tensor=True):
self.name = name
self.data = [self._convert_np(val) for val in content]
self.padding_val = padding_val
self.is_target = is_target
self.need_tensor = need_tensor

def _convert_np(self, val):
if not isinstance(val, np.array):
return np.array(val)
else:
return val

def append(self, val):
self.data.append(self._convert_np(val))

def get(self, idxes):
if isinstance(idxes, int):
return self.data[idxes]
elif isinstance(idxes, list):
id_list = np.array(idxes)
batch_size = len(id_list)
len_list = [(i, self.data[i].shape[0]) for i in id_list]
_, max_len = max(len_list, key=lambda x: x[1])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int32)

for i, (idx, length) in enumerate(len_list):
if length == max_len:
array[i] = self.data[idx]
else:
array[i][:length] = self.data[idx]
return array

def __len__(self):
return len(self.data)

+ 0
- 52
fastNLP/core/instance.py View File

@@ -7,8 +7,6 @@ class Instance(object):

def __init__(self, **fields):
self.fields = fields
self.has_index = False
self.indexes = {}

def add_field(self, field_name, field):
self.fields[field_name] = field
@@ -17,8 +15,6 @@ class Instance(object):
def rename_field(self, old_name, new_name):
if old_name in self.fields:
self.fields[new_name] = self.fields.pop(old_name)
if old_name in self.indexes:
self.indexes[new_name] = self.indexes.pop(old_name)
else:
raise KeyError("error, no such field: {}".format(old_name))
return self
@@ -38,53 +34,5 @@ class Instance(object):
def __setitem__(self, name, field):
return self.add_field(name, field)

def get_length(self):
"""Fetch the length of all fields in the instance.

:return length: dict of (str: int), which means (field name: field length).

"""
length = {name: field.get_length() for name, field in self.fields.items()}
return length

def index_field(self, field_name, vocab):
"""use `vocab` to index certain field
"""
self.indexes[field_name] = self.fields[field_name].index(vocab)
return self

def index_all(self, vocab):
"""use `vocab` to index all fields
"""
if self.has_index:
print("error")
return self.indexes
indexes = {name: field.index(vocab) for name, field in self.fields.items()}
self.indexes = indexes
return indexes

def to_tensor(self, padding_length: dict, origin_len=None):
"""Convert instance to tensor.

:param padding_length: dict of (str: int), which means (field name: padding_length of this field)
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ])
If is_target is False for all fields, tensor_y would be an empty dict.
"""
tensor_x = {}
tensor_y = {}
for name, field in self.fields.items():
if field.is_target is True:
tensor_y[name] = field.to_tensor(padding_length[name])
elif field.is_target is False:
tensor_x[name] = field.to_tensor(padding_length[name])
else:
# is_target is None
continue
if origin_len is not None:
name, field_name = origin_len
tensor_x[name] = torch.LongTensor([self.fields[field_name].get_length()])
return tensor_x, tensor_y

def __repr__(self):
return self.fields.__repr__()

Loading…
Cancel
Save