Browse Source

change Batch to torch.DataLoader

tags/v0.4.10
yunfan 6 years ago
parent
commit
9f5c00b17d
1 changed files with 73 additions and 1 deletions
  1. +73
    -1
      fastNLP/core/batch.py

+ 73
- 1
fastNLP/core/batch.py View File

@@ -12,8 +12,10 @@ from queue import Empty, Full
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.utils.data


from .sampler import RandomSampler from .sampler import RandomSampler
from .dataset import DataSet


_python_is_exit = False _python_is_exit = False


@@ -25,8 +27,78 @@ def _set_python_is_exit():


atexit.register(_set_python_is_exit) atexit.register(_set_python_is_exit)


class DataSetGetter:
def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input}
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target}
self.as_numpy = as_numpy

def __getitem__(self, idx: int):
inputs = {n:f.get(idx) for n, f in self.inputs.items()}
targets = {n:f.get(idx) for n, f in self.targets.items()}
return idx, inputs, targets

def __len__(self):
return len(self.dataset)

def collate_fn(self, batch: list):
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
indices = []
for idx, x, y in batch:
indices.append(idx)
for n, v in x.items():
batch_x[n].append(v)
for n, v in y.items():
batch_y[n].append(v)

def pad_batch(batch_dict, field_array):
for n, vlist in batch_dict.items():
f = field_array[n]
if f.padder is None:
batch_dict[n] = np.array(vlist)
else:
data = f.padder(vlist, field_name=n, field_ele_dtype=f.dtype)
if not self.as_numpy:
data = _to_tensor(data, f.dtype)
batch_dict[n] = data
return batch_dict

return (indices,
pad_batch(batch_x, self.inputs),
pad_batch(batch_y, self.targets))


class Batch:
def __init__(self, dataset, batch_size, sampler=None, buffer_size=0, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, **kwargs):

dataset_getter = DataSetGetter(dataset, as_numpy)
self.buffer_size = buffer_size
self.cur_batch_indices = None
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
shuffle = isinstance(sampler, RandomSampler)
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset_getter, batch_size=batch_size, shuffle=shuffle,
collate_fn=dataset_getter.collate_fn,
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)

def __iter__(self):
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y

def get_batch_indices(self):
return self.cur_batch_indices

def __len__(self):
return self.num_batches



class Batch(object):
class Batch1(object):
""" """
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` 别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch`




Loading…
Cancel
Save