| @@ -0,0 +1,86 @@ | |||
| from collections import defaultdict | |||
| import torch | |||
| class Batch(object): | |||
| def __init__(self, dataset, sampler, batch_size): | |||
| self.dataset = dataset | |||
| self.sampler = sampler | |||
| self.batch_size = batch_size | |||
| self.idx_list = None | |||
| self.curidx = 0 | |||
| def __iter__(self): | |||
| self.idx_list = self.sampler(self.dataset) | |||
| self.curidx = 0 | |||
| self.lengths = self.dataset.get_length() | |||
| return self | |||
| def __next__(self): | |||
| if self.curidx >= len(self.idx_list): | |||
| raise StopIteration | |||
| else: | |||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
| padding_length = {field_name : max(field_length[self.curidx: endidx]) | |||
| for field_name, field_length in self.lengths.items()} | |||
| batch_x, batch_y = defaultdict(list), defaultdict(list) | |||
| for idx in range(self.curidx, endidx): | |||
| x, y = self.dataset.to_tensor(idx, padding_length) | |||
| for name, tensor in x.items(): | |||
| batch_x[name].append(tensor) | |||
| for name, tensor in y.items(): | |||
| batch_y[name].append(tensor) | |||
| for batch in (batch_x, batch_y): | |||
| for name, tensor_list in batch.items(): | |||
| print(name, " ", tensor_list) | |||
| batch[name] = torch.stack(tensor_list, dim=0) | |||
| self.curidx += endidx | |||
| return batch_x, batch_y | |||
| if __name__ == "__main__": | |||
| """simple running example | |||
| """ | |||
| from field import TextField, LabelField | |||
| from instance import Instance | |||
| from dataset import DataSet | |||
| texts = ["i am a cat", | |||
| "this is a test of new batch", | |||
| "haha" | |||
| ] | |||
| labels = [0, 1, 0] | |||
| # prepare vocabulary | |||
| vocab = {} | |||
| for text in texts: | |||
| for tokens in text.split(): | |||
| if tokens not in vocab: | |||
| vocab[tokens] = len(vocab) | |||
| # prepare input dataset | |||
| data = DataSet() | |||
| for text, label in zip(texts, labels): | |||
| x = TextField(text.split(), False) | |||
| y = LabelField(label, is_target=True) | |||
| ins = Instance(text=x, label=y) | |||
| data.append(ins) | |||
| # use vocabulary to index data | |||
| data.index_field("text", vocab) | |||
| # define naive sampler for batch class | |||
| class SeqSampler: | |||
| def __call__(self, dataset): | |||
| return list(range(len(dataset))) | |||
| # use bacth to iterate dataset | |||
| batcher = Batch(data, SeqSampler(), 2) | |||
| for epoch in range(3): | |||
| for batch_x, batch_y in batcher: | |||
| print(batch_x) | |||
| print(batch_y) | |||
| # do stuff | |||
| @@ -0,0 +1,29 @@ | |||
| from collections import defaultdict | |||
| class DataSet(list): | |||
| def __init__(self, name="", instances=None): | |||
| list.__init__([]) | |||
| self.name = name | |||
| if instances is not None: | |||
| self.extend(instances) | |||
| def index_all(self, vocab): | |||
| for ins in self: | |||
| ins.index_all(vocab) | |||
| def index_field(self, field_name, vocab): | |||
| for ins in self: | |||
| ins.index_field(field_name, vocab) | |||
| def to_tensor(self, idx: int, padding_length: dict): | |||
| ins = self[idx] | |||
| return ins.to_tensor(padding_length) | |||
| def get_length(self): | |||
| lengths = defaultdict(list) | |||
| for ins in self: | |||
| for field_name, field_length in ins.get_length().items(): | |||
| lengths[field_name].append(field_length) | |||
| return lengths | |||
| @@ -0,0 +1,70 @@ | |||
| import torch | |||
| class Field(object): | |||
| def __init__(self, is_target: bool): | |||
| self.is_target = is_target | |||
| def index(self, vocab): | |||
| pass | |||
| def get_length(self): | |||
| pass | |||
| def to_tensor(self, padding_length): | |||
| pass | |||
| class TextField(Field): | |||
| def __init__(self, text: list, is_target): | |||
| """ | |||
| :param list text: | |||
| """ | |||
| super(TextField, self).__init__(is_target) | |||
| self.text = text | |||
| self._index = None | |||
| def index(self, vocab): | |||
| if self._index is None: | |||
| self._index = [vocab[c] for c in self.text] | |||
| else: | |||
| print('error') | |||
| return self._index | |||
| def get_length(self): | |||
| return len(self.text) | |||
| def to_tensor(self, padding_length: int): | |||
| pads = [] | |||
| if self._index is None: | |||
| print('error') | |||
| if padding_length > self.get_length(): | |||
| pads = [0 for i in range(padding_length - self.get_length())] | |||
| # (length, ) | |||
| return torch.LongTensor(self._index + pads) | |||
| class LabelField(Field): | |||
| def __init__(self, label, is_target=True): | |||
| super(LabelField, self).__init__(is_target) | |||
| self.label = label | |||
| self._index = None | |||
| def get_length(self): | |||
| return 1 | |||
| def index(self, vocab): | |||
| if self._index is None: | |||
| self._index = vocab[self.label] | |||
| else: | |||
| pass | |||
| return self._index | |||
| def to_tensor(self, padding_length): | |||
| if self._index is None: | |||
| return torch.LongTensor([self.label]) | |||
| else: | |||
| return torch.LongTensor([self._index]) | |||
| if __name__ == "__main__": | |||
| tf = TextField("test the code".split()) | |||
| @@ -0,0 +1,38 @@ | |||
| class Instance(object): | |||
| def __init__(self, **fields): | |||
| self.fields = fields | |||
| self.has_index = False | |||
| self.indexes = {} | |||
| def add_field(self, field_name, field): | |||
| self.fields[field_name] = field | |||
| def get_length(self): | |||
| length = {name : field.get_length() for name, field in self.fields.items()} | |||
| return length | |||
| def index_field(self, field_name, vocab): | |||
| """use `vocab` to index certain field | |||
| """ | |||
| self.indexes[field_name] = self.fields[field_name].index(vocab) | |||
| def index_all(self, vocab): | |||
| """use `vocab` to index all fields | |||
| """ | |||
| if self.has_index: | |||
| print("error") | |||
| return self.indexes | |||
| indexes = {name : field.index(vocab) for name, field in self.fields.items()} | |||
| self.indexes = indexes | |||
| return indexes | |||
| def to_tensor(self, padding_length: dict): | |||
| tensorX = {} | |||
| tensorY = {} | |||
| for name, field in self.fields.items(): | |||
| if field.is_target: | |||
| tensorY[name] = field.to_tensor(padding_length[name]) | |||
| else: | |||
| tensorX[name] = field.to_tensor(padding_length[name]) | |||
| return tensorX, tensorY | |||