Browse Source

add dataset support for sampler, update batch

tags/v0.2.0
yunfan 5 years ago
parent
commit
ff6d99bcb2
2 changed files with 9 additions and 2 deletions
  1. +2
    -2
      fastNLP/core/batch.py
  2. +7
    -0
      fastNLP/core/dataset.py

+ 2
- 2
fastNLP/core/batch.py View File

@@ -56,8 +56,8 @@ class Batch(object):
indices = self.idx_list[self.curidx:endidx]

for field_name, field in self.dataset.get_fields():
batch = field.get(indices)
if not field.tensorable: #TODO 修改
batch = torch.from_numpy(field.get(indices))
if not field.need_tensor: #TODO 修改
pass
elif field.is_target:
batch_y[field_name] = batch


+ 7
- 0
fastNLP/core/dataset.py View File

@@ -40,6 +40,13 @@ class DataSet(object):
assert name in self.field_arrays
self.field_arrays[name].append(field)

def get_fields(self):
return self.field_arrays

def __len__(self):
field = self.field_arrays.values()[0]
return len(field)

def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.



Loading…
Cancel
Save