From eb0e563fecf8c29949d58a4c4374269ea2e3be2a Mon Sep 17 00:00:00 2001 From: yhcc Date: Wed, 22 Jun 2022 16:34:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8DOverfitBatches=E8=A2=AB?= =?UTF-8?q?=E6=9B=BF=E6=8D=A2=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataloaders/utils.py | 25 +++++++++++++++---------- fastNLP/core/dataset/dataset.py | 2 +- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 06f09da3..4f8fa743 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -118,17 +118,22 @@ class OverfitDataLoader: 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; """ - def __init__(self, dataloader, overfit_batches: int): + def __init__(self, dataloader, overfit_batches: int, batches=None): + # batches 参数是给重新初始化dataloader使用的 self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; - self.batches = [] - self.overfit_batches = int(overfit_batches) - - if self.overfit_batches > len(dataloader): - logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") - - for idx, batch in enumerate(dataloader): - if idx < self.overfit_batches or self.overfit_batches <= -1: - self.batches.append(batch) + if batches is None: + self.batches = [] + self.overfit_batches = int(overfit_batches) + + if self.overfit_batches > len(dataloader): + logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") + + for idx, batch in enumerate(dataloader): + if idx < self.overfit_batches or self.overfit_batches <= -1: + self.batches.append(batch) + else: + assert isinstance(batches, list) + self.batches = batches def __len__(self): return len(self.batches) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index fff8b5c2..0238a65d 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -445,7 +445,7 @@ class DataSet: "DataSet object has {} fields, but attempt to append an Instance object with {} fields." .format(len(self.field_arrays), len(instance.fields))) for name, field in instance.items(): - assert name in self.field_arrays + assert name in self.field_arrays, f'Field:`{name}` is not found in {self.field_arrays.keys()}' try: self.field_arrays[name].append(field) except Exception as e: