diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 9dce426b..7f5624d8 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -168,7 +168,7 @@ class Callback(object): :param dict batch_x: DataSet中被设置为input的field的batch。 :param dict batch_y: DataSet中被设置为target的field的batch。 :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 - 情况下可以帮助定位是哪个Sample导致了错误。 + 情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。 :return: """ pass diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index b506dfae..38623caa 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -365,6 +365,15 @@ class DataSet(object): if idx not in self: raise KeyError("No such field called {} in DataSet.".format(idx)) return self.field_arrays[idx] + elif isinstance(idx, list): + dataset = DataSet() + for i in idx: + assert isinstance(i, int), "Only int index allowed." + instance = self[i] + dataset.append(instance) + for field_name, field in self.field_arrays.items(): + dataset.field_arrays[field_name].to(field) + return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))