From ff6d99bcb2699170e5fbec1db8ab52911b0e58be Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 9 Nov 2018 20:12:06 +0800 Subject: [PATCH] add dataset support for sampler, update batch --- fastNLP/core/batch.py | 4 ++-- fastNLP/core/dataset.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 0381d267..397a3ddb 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -56,8 +56,8 @@ class Batch(object): indices = self.idx_list[self.curidx:endidx] for field_name, field in self.dataset.get_fields(): - batch = field.get(indices) - if not field.tensorable: #TODO 修改 + batch = torch.from_numpy(field.get(indices)) + if not field.need_tensor: #TODO 修改 pass elif field.is_target: batch_y[field_name] = batch diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index a08a429c..e626ff26 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -40,6 +40,13 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) + def get_fields(self): + return self.field_arrays + + def __len__(self): + field = self.field_arrays.values()[0] + return len(field) + def get_length(self): """Fetch lengths of all fields in all instances in a dataset.