Browse Source

change Batch to torch.DataLoader

tags/v0.4.10
yunfan 6 years ago
parent
commit
9f5c00b17d
1 changed files with 73 additions and 1 deletions
  1. +73
    -1
      fastNLP/core/batch.py

+ 73
- 1
fastNLP/core/batch.py View File

@@ -12,8 +12,10 @@ from queue import Empty, Full
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.utils.data

from .sampler import RandomSampler
from .dataset import DataSet

_python_is_exit = False

@@ -25,8 +27,78 @@ def _set_python_is_exit():

atexit.register(_set_python_is_exit)

class DataSetGetter:
def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input}
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target}
self.as_numpy = as_numpy

def __getitem__(self, idx: int):
inputs = {n:f.get(idx) for n, f in self.inputs.items()}
targets = {n:f.get(idx) for n, f in self.targets.items()}
return idx, inputs, targets

def __len__(self):
return len(self.dataset)

def collate_fn(self, batch: list):
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
indices = []
for idx, x, y in batch:
indices.append(idx)
for n, v in x.items():
batch_x[n].append(v)
for n, v in y.items():
batch_y[n].append(v)

def pad_batch(batch_dict, field_array):
for n, vlist in batch_dict.items():
f = field_array[n]
if f.padder is None:
batch_dict[n] = np.array(vlist)
else:
data = f.padder(vlist, field_name=n, field_ele_dtype=f.dtype)
if not self.as_numpy:
data = _to_tensor(data, f.dtype)
batch_dict[n] = data
return batch_dict

return (indices,
pad_batch(batch_x, self.inputs),
pad_batch(batch_y, self.targets))


class Batch:
def __init__(self, dataset, batch_size, sampler=None, buffer_size=0, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, **kwargs):

dataset_getter = DataSetGetter(dataset, as_numpy)
self.buffer_size = buffer_size
self.cur_batch_indices = None
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
shuffle = isinstance(sampler, RandomSampler)
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset_getter, batch_size=batch_size, shuffle=shuffle,
collate_fn=dataset_getter.collate_fn,
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)

def __iter__(self):
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y

def get_batch_indices(self):
return self.cur_batch_indices

def __len__(self):
return self.num_batches


class Batch(object):
class Batch1(object):
"""
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch`



Loading…
Cancel
Save