diff --git a/fastNLP/api/__init__.py b/fastNLP/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py new file mode 100644 index 00000000..b5c4cc7a --- /dev/null +++ b/fastNLP/api/pipeline.py @@ -0,0 +1,23 @@ +from fastNLP.api.processor import Processor + + + +class Pipeline: + def __init__(self): + self.pipeline = [] + + def add_processor(self, processor): + assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) + processor_name = type(processor) + self.pipeline.append(processor) + + def process(self, dataset): + assert len(self.pipeline)!=0, "You need to add some processor first." + + for proc_name, proc in self.pipeline: + dataset = proc(dataset) + + return dataset + + def __call__(self, *args, **kwargs): + return self.process(*args, **kwargs) \ No newline at end of file diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py new file mode 100644 index 00000000..793cfe10 --- /dev/null +++ b/fastNLP/api/processor.py @@ -0,0 +1,15 @@ + + +class Processor: + def __init__(self, field_name, new_added_field_name): + self.field_name = field_name + if new_added_field_name is None: + self.new_added_field_name = field_name + else: + self.new_added_field_name = new_added_field_name + + def process(self): + pass + + def __call__(self, *args, **kwargs): + return self.process(*args, **kwargs) \ No newline at end of file diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index b55ae3dd..0381d267 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -51,34 +51,20 @@ class Batch(object): raise StopIteration else: endidx = min(self.curidx + self.batch_size, len(self.idx_list)) - batch_idxes = self.idx_list[self.curidx: endidx] - padding_length = {field_name: max([field_length[idx] for idx in batch_idxes]) - for field_name, field_length in self.lengths.items()} - batch_x, batch_y = defaultdict(list), defaultdict(list) - - # transform index to tensor and do padding for sequences - batch = [] - for idx in batch_idxes: - x, y = self.dataset.to_tensor(idx, padding_length) - batch.append((self.lengths[self.sort_key][idx] if self.sort_in_batch else None, x, y)) - - if self.sort_in_batch: - batch = sorted(batch, key=lambda x: x[0], reverse=True) - - for _, x, y in batch: - for name, tensor in x.items(): - batch_x[name].append(tensor) - for name, tensor in y.items(): - batch_y[name].append(tensor) - - # combine instances to form a batch - for batch in (batch_x, batch_y): - for name, tensor_list in batch.items(): - if self.use_cuda: - batch[name] = torch.stack(tensor_list, dim=0).cuda() - else: - batch[name] = torch.stack(tensor_list, dim=0) + batch_x, batch_y = {}, {} + + indices = self.idx_list[self.curidx:endidx] + + for field_name, field in self.dataset.get_fields(): + batch = field.get(indices) + if not field.tensorable: #TODO 修改 + pass + elif field.is_target: + batch_y[field_name] = batch + else: + batch_x[field_name] = batch self.curidx = endidx + return batch_x, batch_y