Browse Source

1. 对callback中indices潜在None作出提示;2.DataSet支持通过List进行index

tags/v0.4.10
yh_cc 6 years ago
parent
commit
6d0b1ea716
2 changed files with 10 additions and 1 deletions
  1. +1
    -1
      fastNLP/core/callback.py
  2. +9
    -0
      fastNLP/core/dataset.py

+ 1
- 1
fastNLP/core/callback.py View File

@@ -168,7 +168,7 @@ class Callback(object):
:param dict batch_x: DataSet中被设置为input的field的batch。 :param dict batch_x: DataSet中被设置为input的field的batch。
:param dict batch_y: DataSet中被设置为target的field的batch。 :param dict batch_y: DataSet中被设置为target的field的batch。
:param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些
情况下可以帮助定位是哪个Sample导致了错误。
情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。
:return: :return:
""" """
pass pass


+ 9
- 0
fastNLP/core/dataset.py View File

@@ -365,6 +365,15 @@ class DataSet(object):
if idx not in self: if idx not in self:
raise KeyError("No such field called {} in DataSet.".format(idx)) raise KeyError("No such field called {} in DataSet.".format(idx))
return self.field_arrays[idx] return self.field_arrays[idx]
elif isinstance(idx, list):
dataset = DataSet()
for i in idx:
assert isinstance(i, int), "Only int index allowed."
instance = self[i]
dataset.append(instance)
for field_name, field in self.field_arrays.items():
dataset.field_arrays[field_name].to(field)
return dataset
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))


Loading…
Cancel
Save