Browse Source

修改batch, 新增pipeline和processor的接口

tags/v0.2.0
yh 5 years ago
parent
commit
fcf5af93d8
4 changed files with 51 additions and 27 deletions
  1. +0
    -0
      fastNLP/api/__init__.py
  2. +23
    -0
      fastNLP/api/pipeline.py
  3. +15
    -0
      fastNLP/api/processor.py
  4. +13
    -27
      fastNLP/core/batch.py

+ 0
- 0
fastNLP/api/__init__.py View File


+ 23
- 0
fastNLP/api/pipeline.py View File

@@ -0,0 +1,23 @@
from fastNLP.api.processor import Processor



class Pipeline:
def __init__(self):
self.pipeline = []

def add_processor(self, processor):
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
processor_name = type(processor)
self.pipeline.append(processor)

def process(self, dataset):
assert len(self.pipeline)!=0, "You need to add some processor first."

for proc_name, proc in self.pipeline:
dataset = proc(dataset)

return dataset

def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)

+ 15
- 0
fastNLP/api/processor.py View File

@@ -0,0 +1,15 @@


class Processor:
def __init__(self, field_name, new_added_field_name):
self.field_name = field_name
if new_added_field_name is None:
self.new_added_field_name = field_name
else:
self.new_added_field_name = new_added_field_name

def process(self):
pass

def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)

+ 13
- 27
fastNLP/core/batch.py View File

@@ -51,34 +51,20 @@ class Batch(object):
raise StopIteration
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
batch_idxes = self.idx_list[self.curidx: endidx]
padding_length = {field_name: max([field_length[idx] for idx in batch_idxes])
for field_name, field_length in self.lengths.items()}
batch_x, batch_y = defaultdict(list), defaultdict(list)

# transform index to tensor and do padding for sequences
batch = []
for idx in batch_idxes:
x, y = self.dataset.to_tensor(idx, padding_length)
batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y))

if self.sort_in_batch:
batch = sorted(batch, key=lambda x: x[0], reverse=True)

for _, x, y in batch:
for name, tensor in x.items():
batch_x[name].append(tensor)
for name, tensor in y.items():
batch_y[name].append(tensor)

# combine instances to form a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
batch[name] = torch.stack(tensor_list, dim=0).cuda()
else:
batch[name] = torch.stack(tensor_list, dim=0)
batch_x, batch_y = {}, {}

indices = self.idx_list[self.curidx:endidx]

for field_name, field in self.dataset.get_fields():
batch = field.get(indices)
if not field.tensorable: #TODO 修改
pass
elif field.is_target:
batch_y[field_name] = batch
else:
batch_x[field_name] = batch

self.curidx = endidx

return batch_x, batch_y


Loading…
Cancel
Save