|
|
@@ -12,8 +12,10 @@ from queue import Empty, Full |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch.multiprocessing as mp |
|
|
|
import torch.utils.data |
|
|
|
|
|
|
|
from .sampler import RandomSampler |
|
|
|
from .dataset import DataSet |
|
|
|
|
|
|
|
_python_is_exit = False |
|
|
|
|
|
|
@@ -25,8 +27,78 @@ def _set_python_is_exit(): |
|
|
|
|
|
|
|
atexit.register(_set_python_is_exit) |
|
|
|
|
|
|
|
class DataSetGetter: |
|
|
|
def __init__(self, dataset: DataSet, as_numpy=False): |
|
|
|
self.dataset = dataset |
|
|
|
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} |
|
|
|
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} |
|
|
|
self.as_numpy = as_numpy |
|
|
|
|
|
|
|
def __getitem__(self, idx: int): |
|
|
|
inputs = {n:f.get(idx) for n, f in self.inputs.items()} |
|
|
|
targets = {n:f.get(idx) for n, f in self.targets.items()} |
|
|
|
return idx, inputs, targets |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.dataset) |
|
|
|
|
|
|
|
def collate_fn(self, batch: list): |
|
|
|
batch_x = {n:[] for n in self.inputs.keys()} |
|
|
|
batch_y = {n:[] for n in self.targets.keys()} |
|
|
|
indices = [] |
|
|
|
for idx, x, y in batch: |
|
|
|
indices.append(idx) |
|
|
|
for n, v in x.items(): |
|
|
|
batch_x[n].append(v) |
|
|
|
for n, v in y.items(): |
|
|
|
batch_y[n].append(v) |
|
|
|
|
|
|
|
def pad_batch(batch_dict, field_array): |
|
|
|
for n, vlist in batch_dict.items(): |
|
|
|
f = field_array[n] |
|
|
|
if f.padder is None: |
|
|
|
batch_dict[n] = np.array(vlist) |
|
|
|
else: |
|
|
|
data = f.padder(vlist, field_name=n, field_ele_dtype=f.dtype) |
|
|
|
if not self.as_numpy: |
|
|
|
data = _to_tensor(data, f.dtype) |
|
|
|
batch_dict[n] = data |
|
|
|
return batch_dict |
|
|
|
|
|
|
|
return (indices, |
|
|
|
pad_batch(batch_x, self.inputs), |
|
|
|
pad_batch(batch_y, self.targets)) |
|
|
|
|
|
|
|
|
|
|
|
class Batch: |
|
|
|
def __init__(self, dataset, batch_size, sampler=None, buffer_size=0, as_numpy=False, |
|
|
|
num_workers=0, pin_memory=False, drop_last=False, |
|
|
|
timeout=0, worker_init_fn=None, **kwargs): |
|
|
|
|
|
|
|
dataset_getter = DataSetGetter(dataset, as_numpy) |
|
|
|
self.buffer_size = buffer_size |
|
|
|
self.cur_batch_indices = None |
|
|
|
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) |
|
|
|
shuffle = isinstance(sampler, RandomSampler) |
|
|
|
self.dataiter = torch.utils.data.DataLoader( |
|
|
|
dataset=dataset_getter, batch_size=batch_size, shuffle=shuffle, |
|
|
|
collate_fn=dataset_getter.collate_fn, |
|
|
|
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, |
|
|
|
timeout=timeout, worker_init_fn=worker_init_fn) |
|
|
|
|
|
|
|
def __iter__(self): |
|
|
|
for indices, batch_x, batch_y in self.dataiter: |
|
|
|
self.cur_batch_indices = indices |
|
|
|
yield batch_x, batch_y |
|
|
|
|
|
|
|
def get_batch_indices(self): |
|
|
|
return self.cur_batch_indices |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return self.num_batches |
|
|
|
|
|
|
|
|
|
|
|
class Batch(object): |
|
|
|
class Batch1(object): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` |
|
|
|
|
|
|
|