diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 109d4fe9..9aab146c 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -12,8 +12,10 @@ from queue import Empty, Full import numpy as np import torch import torch.multiprocessing as mp +import torch.utils.data from .sampler import RandomSampler +from .dataset import DataSet _python_is_exit = False @@ -25,8 +27,78 @@ def _set_python_is_exit(): atexit.register(_set_python_is_exit) +class DataSetGetter: + def __init__(self, dataset: DataSet, as_numpy=False): + self.dataset = dataset + self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} + self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} + self.as_numpy = as_numpy + + def __getitem__(self, idx: int): + inputs = {n:f.get(idx) for n, f in self.inputs.items()} + targets = {n:f.get(idx) for n, f in self.targets.items()} + return idx, inputs, targets + + def __len__(self): + return len(self.dataset) + + def collate_fn(self, batch: list): + batch_x = {n:[] for n in self.inputs.keys()} + batch_y = {n:[] for n in self.targets.keys()} + indices = [] + for idx, x, y in batch: + indices.append(idx) + for n, v in x.items(): + batch_x[n].append(v) + for n, v in y.items(): + batch_y[n].append(v) + + def pad_batch(batch_dict, field_array): + for n, vlist in batch_dict.items(): + f = field_array[n] + if f.padder is None: + batch_dict[n] = np.array(vlist) + else: + data = f.padder(vlist, field_name=n, field_ele_dtype=f.dtype) + if not self.as_numpy: + data = _to_tensor(data, f.dtype) + batch_dict[n] = data + return batch_dict + + return (indices, + pad_batch(batch_x, self.inputs), + pad_batch(batch_y, self.targets)) + + +class Batch: + def __init__(self, dataset, batch_size, sampler=None, buffer_size=0, as_numpy=False, + num_workers=0, pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None, **kwargs): + + dataset_getter = DataSetGetter(dataset, as_numpy) + self.buffer_size = buffer_size + self.cur_batch_indices = None + self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) + shuffle = isinstance(sampler, RandomSampler) + self.dataiter = torch.utils.data.DataLoader( + dataset=dataset_getter, batch_size=batch_size, shuffle=shuffle, + collate_fn=dataset_getter.collate_fn, + num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, + timeout=timeout, worker_init_fn=worker_init_fn) + + def __iter__(self): + for indices, batch_x, batch_y in self.dataiter: + self.cur_batch_indices = indices + yield batch_x, batch_y + + def get_batch_indices(self): + return self.cur_batch_indices + + def __len__(self): + return self.num_batches + -class Batch(object): +class Batch1(object): """ 别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch`