diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 0381d267..397a3ddb 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -56,8 +56,8 @@ class Batch(object): indices = self.idx_list[self.curidx:endidx] for field_name, field in self.dataset.get_fields(): - batch = field.get(indices) - if not field.tensorable: #TODO 修改 + batch = torch.from_numpy(field.get(indices)) + if not field.need_tensor: #TODO 修改 pass elif field.is_target: batch_y[field_name] = batch diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index a08a429c..e626ff26 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -40,6 +40,13 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) + def get_fields(self): + return self.field_arrays + + def __len__(self): + field = self.field_arrays.values()[0] + return len(field) + def get_length(self): """Fetch lengths of all fields in all instances in a dataset.