From 6d0b1ea716ec70b0f17096ac77ed23a87a64cfdf Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 20 May 2019 23:10:52 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E5=AF=B9callback=E4=B8=ADindices=E6=BD=9C?= =?UTF-8?q?=E5=9C=A8None=E4=BD=9C=E5=87=BA=E6=8F=90=E7=A4=BA;2.DataSet?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=80=9A=E8=BF=87List=E8=BF=9B=E8=A1=8Cindex?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 2 +- fastNLP/core/dataset.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 9dce426b..7f5624d8 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -168,7 +168,7 @@ class Callback(object): :param dict batch_x: DataSet中被设置为input的field的batch。 :param dict batch_y: DataSet中被设置为target的field的batch。 :param list(int) indices: 这次采样使用到的indices,可以通过DataSet[indices]获取出这个batch采出的Instance,在一些 - 情况下可以帮助定位是哪个Sample导致了错误。 + 情况下可以帮助定位是哪个Sample导致了错误。仅在Trainer的prefetch为False时可用。 :return: """ pass diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index b506dfae..38623caa 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -365,6 +365,15 @@ class DataSet(object): if idx not in self: raise KeyError("No such field called {} in DataSet.".format(idx)) return self.field_arrays[idx] + elif isinstance(idx, list): + dataset = DataSet() + for i in idx: + assert isinstance(i, int), "Only int index allowed." + instance = self[i] + dataset.append(instance) + for field_name, field in self.field_arrays.items(): + dataset.field_arrays[field_name].to(field) + return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))