@@ -0,0 +1,86 @@ | |||||
from collections import defaultdict | |||||
import torch | |||||
class Batch(object): | |||||
def __init__(self, dataset, sampler, batch_size): | |||||
self.dataset = dataset | |||||
self.sampler = sampler | |||||
self.batch_size = batch_size | |||||
self.idx_list = None | |||||
self.curidx = 0 | |||||
def __iter__(self): | |||||
self.idx_list = self.sampler(self.dataset) | |||||
self.curidx = 0 | |||||
self.lengths = self.dataset.get_length() | |||||
return self | |||||
def __next__(self): | |||||
if self.curidx >= len(self.idx_list): | |||||
raise StopIteration | |||||
else: | |||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||||
padding_length = {field_name : max(field_length[self.curidx: endidx]) | |||||
for field_name, field_length in self.lengths.items()} | |||||
batch_x, batch_y = defaultdict(list), defaultdict(list) | |||||
for idx in range(self.curidx, endidx): | |||||
x, y = self.dataset.to_tensor(idx, padding_length) | |||||
for name, tensor in x.items(): | |||||
batch_x[name].append(tensor) | |||||
for name, tensor in y.items(): | |||||
batch_y[name].append(tensor) | |||||
for batch in (batch_x, batch_y): | |||||
for name, tensor_list in batch.items(): | |||||
print(name, " ", tensor_list) | |||||
batch[name] = torch.stack(tensor_list, dim=0) | |||||
self.curidx += endidx | |||||
return batch_x, batch_y | |||||
if __name__ == "__main__": | |||||
"""simple running example | |||||
""" | |||||
from field import TextField, LabelField | |||||
from instance import Instance | |||||
from dataset import DataSet | |||||
texts = ["i am a cat", | |||||
"this is a test of new batch", | |||||
"haha" | |||||
] | |||||
labels = [0, 1, 0] | |||||
# prepare vocabulary | |||||
vocab = {} | |||||
for text in texts: | |||||
for tokens in text.split(): | |||||
if tokens not in vocab: | |||||
vocab[tokens] = len(vocab) | |||||
# prepare input dataset | |||||
data = DataSet() | |||||
for text, label in zip(texts, labels): | |||||
x = TextField(text.split(), False) | |||||
y = LabelField(label, is_target=True) | |||||
ins = Instance(text=x, label=y) | |||||
data.append(ins) | |||||
# use vocabulary to index data | |||||
data.index_field("text", vocab) | |||||
# define naive sampler for batch class | |||||
class SeqSampler: | |||||
def __call__(self, dataset): | |||||
return list(range(len(dataset))) | |||||
# use bacth to iterate dataset | |||||
batcher = Batch(data, SeqSampler(), 2) | |||||
for epoch in range(3): | |||||
for batch_x, batch_y in batcher: | |||||
print(batch_x) | |||||
print(batch_y) | |||||
# do stuff | |||||
@@ -0,0 +1,29 @@ | |||||
from collections import defaultdict | |||||
class DataSet(list): | |||||
def __init__(self, name="", instances=None): | |||||
list.__init__([]) | |||||
self.name = name | |||||
if instances is not None: | |||||
self.extend(instances) | |||||
def index_all(self, vocab): | |||||
for ins in self: | |||||
ins.index_all(vocab) | |||||
def index_field(self, field_name, vocab): | |||||
for ins in self: | |||||
ins.index_field(field_name, vocab) | |||||
def to_tensor(self, idx: int, padding_length: dict): | |||||
ins = self[idx] | |||||
return ins.to_tensor(padding_length) | |||||
def get_length(self): | |||||
lengths = defaultdict(list) | |||||
for ins in self: | |||||
for field_name, field_length in ins.get_length().items(): | |||||
lengths[field_name].append(field_length) | |||||
return lengths | |||||
@@ -0,0 +1,70 @@ | |||||
import torch | |||||
class Field(object): | |||||
def __init__(self, is_target: bool): | |||||
self.is_target = is_target | |||||
def index(self, vocab): | |||||
pass | |||||
def get_length(self): | |||||
pass | |||||
def to_tensor(self, padding_length): | |||||
pass | |||||
class TextField(Field): | |||||
def __init__(self, text: list, is_target): | |||||
""" | |||||
:param list text: | |||||
""" | |||||
super(TextField, self).__init__(is_target) | |||||
self.text = text | |||||
self._index = None | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = [vocab[c] for c in self.text] | |||||
else: | |||||
print('error') | |||||
return self._index | |||||
def get_length(self): | |||||
return len(self.text) | |||||
def to_tensor(self, padding_length: int): | |||||
pads = [] | |||||
if self._index is None: | |||||
print('error') | |||||
if padding_length > self.get_length(): | |||||
pads = [0 for i in range(padding_length - self.get_length())] | |||||
# (length, ) | |||||
return torch.LongTensor(self._index + pads) | |||||
class LabelField(Field): | |||||
def __init__(self, label, is_target=True): | |||||
super(LabelField, self).__init__(is_target) | |||||
self.label = label | |||||
self._index = None | |||||
def get_length(self): | |||||
return 1 | |||||
def index(self, vocab): | |||||
if self._index is None: | |||||
self._index = vocab[self.label] | |||||
else: | |||||
pass | |||||
return self._index | |||||
def to_tensor(self, padding_length): | |||||
if self._index is None: | |||||
return torch.LongTensor([self.label]) | |||||
else: | |||||
return torch.LongTensor([self._index]) | |||||
if __name__ == "__main__": | |||||
tf = TextField("test the code".split()) | |||||
@@ -0,0 +1,38 @@ | |||||
class Instance(object): | |||||
def __init__(self, **fields): | |||||
self.fields = fields | |||||
self.has_index = False | |||||
self.indexes = {} | |||||
def add_field(self, field_name, field): | |||||
self.fields[field_name] = field | |||||
def get_length(self): | |||||
length = {name : field.get_length() for name, field in self.fields.items()} | |||||
return length | |||||
def index_field(self, field_name, vocab): | |||||
"""use `vocab` to index certain field | |||||
""" | |||||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||||
def index_all(self, vocab): | |||||
"""use `vocab` to index all fields | |||||
""" | |||||
if self.has_index: | |||||
print("error") | |||||
return self.indexes | |||||
indexes = {name : field.index(vocab) for name, field in self.fields.items()} | |||||
self.indexes = indexes | |||||
return indexes | |||||
def to_tensor(self, padding_length: dict): | |||||
tensorX = {} | |||||
tensorY = {} | |||||
for name, field in self.fields.items(): | |||||
if field.is_target: | |||||
tensorY[name] = field.to_tensor(padding_length[name]) | |||||
else: | |||||
tensorX[name] = field.to_tensor(padding_length[name]) | |||||
return tensorX, tensorY |