@@ -0,0 +1,86 @@ | |||
from collections import defaultdict | |||
import torch | |||
class Batch(object): | |||
def __init__(self, dataset, sampler, batch_size): | |||
self.dataset = dataset | |||
self.sampler = sampler | |||
self.batch_size = batch_size | |||
self.idx_list = None | |||
self.curidx = 0 | |||
def __iter__(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
self.curidx = 0 | |||
self.lengths = self.dataset.get_length() | |||
return self | |||
def __next__(self): | |||
if self.curidx >= len(self.idx_list): | |||
raise StopIteration | |||
else: | |||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
padding_length = {field_name : max(field_length[self.curidx: endidx]) | |||
for field_name, field_length in self.lengths.items()} | |||
batch_x, batch_y = defaultdict(list), defaultdict(list) | |||
for idx in range(self.curidx, endidx): | |||
x, y = self.dataset.to_tensor(idx, padding_length) | |||
for name, tensor in x.items(): | |||
batch_x[name].append(tensor) | |||
for name, tensor in y.items(): | |||
batch_y[name].append(tensor) | |||
for batch in (batch_x, batch_y): | |||
for name, tensor_list in batch.items(): | |||
print(name, " ", tensor_list) | |||
batch[name] = torch.stack(tensor_list, dim=0) | |||
self.curidx += endidx | |||
return batch_x, batch_y | |||
if __name__ == "__main__": | |||
"""simple running example | |||
""" | |||
from field import TextField, LabelField | |||
from instance import Instance | |||
from dataset import DataSet | |||
texts = ["i am a cat", | |||
"this is a test of new batch", | |||
"haha" | |||
] | |||
labels = [0, 1, 0] | |||
# prepare vocabulary | |||
vocab = {} | |||
for text in texts: | |||
for tokens in text.split(): | |||
if tokens not in vocab: | |||
vocab[tokens] = len(vocab) | |||
# prepare input dataset | |||
data = DataSet() | |||
for text, label in zip(texts, labels): | |||
x = TextField(text.split(), False) | |||
y = LabelField(label, is_target=True) | |||
ins = Instance(text=x, label=y) | |||
data.append(ins) | |||
# use vocabulary to index data | |||
data.index_field("text", vocab) | |||
# define naive sampler for batch class | |||
class SeqSampler: | |||
def __call__(self, dataset): | |||
return list(range(len(dataset))) | |||
# use bacth to iterate dataset | |||
batcher = Batch(data, SeqSampler(), 2) | |||
for epoch in range(3): | |||
for batch_x, batch_y in batcher: | |||
print(batch_x) | |||
print(batch_y) | |||
# do stuff | |||
@@ -0,0 +1,29 @@ | |||
from collections import defaultdict | |||
class DataSet(list): | |||
def __init__(self, name="", instances=None): | |||
list.__init__([]) | |||
self.name = name | |||
if instances is not None: | |||
self.extend(instances) | |||
def index_all(self, vocab): | |||
for ins in self: | |||
ins.index_all(vocab) | |||
def index_field(self, field_name, vocab): | |||
for ins in self: | |||
ins.index_field(field_name, vocab) | |||
def to_tensor(self, idx: int, padding_length: dict): | |||
ins = self[idx] | |||
return ins.to_tensor(padding_length) | |||
def get_length(self): | |||
lengths = defaultdict(list) | |||
for ins in self: | |||
for field_name, field_length in ins.get_length().items(): | |||
lengths[field_name].append(field_length) | |||
return lengths | |||
@@ -0,0 +1,70 @@ | |||
import torch | |||
class Field(object): | |||
def __init__(self, is_target: bool): | |||
self.is_target = is_target | |||
def index(self, vocab): | |||
pass | |||
def get_length(self): | |||
pass | |||
def to_tensor(self, padding_length): | |||
pass | |||
class TextField(Field): | |||
def __init__(self, text: list, is_target): | |||
""" | |||
:param list text: | |||
""" | |||
super(TextField, self).__init__(is_target) | |||
self.text = text | |||
self._index = None | |||
def index(self, vocab): | |||
if self._index is None: | |||
self._index = [vocab[c] for c in self.text] | |||
else: | |||
print('error') | |||
return self._index | |||
def get_length(self): | |||
return len(self.text) | |||
def to_tensor(self, padding_length: int): | |||
pads = [] | |||
if self._index is None: | |||
print('error') | |||
if padding_length > self.get_length(): | |||
pads = [0 for i in range(padding_length - self.get_length())] | |||
# (length, ) | |||
return torch.LongTensor(self._index + pads) | |||
class LabelField(Field): | |||
def __init__(self, label, is_target=True): | |||
super(LabelField, self).__init__(is_target) | |||
self.label = label | |||
self._index = None | |||
def get_length(self): | |||
return 1 | |||
def index(self, vocab): | |||
if self._index is None: | |||
self._index = vocab[self.label] | |||
else: | |||
pass | |||
return self._index | |||
def to_tensor(self, padding_length): | |||
if self._index is None: | |||
return torch.LongTensor([self.label]) | |||
else: | |||
return torch.LongTensor([self._index]) | |||
if __name__ == "__main__": | |||
tf = TextField("test the code".split()) | |||
@@ -0,0 +1,38 @@ | |||
class Instance(object): | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
self.has_index = False | |||
self.indexes = {} | |||
def add_field(self, field_name, field): | |||
self.fields[field_name] = field | |||
def get_length(self): | |||
length = {name : field.get_length() for name, field in self.fields.items()} | |||
return length | |||
def index_field(self, field_name, vocab): | |||
"""use `vocab` to index certain field | |||
""" | |||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||
def index_all(self, vocab): | |||
"""use `vocab` to index all fields | |||
""" | |||
if self.has_index: | |||
print("error") | |||
return self.indexes | |||
indexes = {name : field.index(vocab) for name, field in self.fields.items()} | |||
self.indexes = indexes | |||
return indexes | |||
def to_tensor(self, padding_length: dict): | |||
tensorX = {} | |||
tensorY = {} | |||
for name, field in self.fields.items(): | |||
if field.is_target: | |||
tensorY[name] = field.to_tensor(padding_length[name]) | |||
else: | |||
tensorX[name] = field.to_tensor(padding_length[name]) | |||
return tensorX, tensorY |