diff --git a/fastNLP/core/dataloaders/utils.py b/fastNLP/core/dataloaders/utils.py index 06f09da3..4f8fa743 100644 --- a/fastNLP/core/dataloaders/utils.py +++ b/fastNLP/core/dataloaders/utils.py @@ -118,17 +118,22 @@ class OverfitDataLoader: 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; """ - def __init__(self, dataloader, overfit_batches: int): + def __init__(self, dataloader, overfit_batches: int, batches=None): + # batches 参数是给重新初始化dataloader使用的 self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; - self.batches = [] - self.overfit_batches = int(overfit_batches) - - if self.overfit_batches > len(dataloader): - logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") - - for idx, batch in enumerate(dataloader): - if idx < self.overfit_batches or self.overfit_batches <= -1: - self.batches.append(batch) + if batches is None: + self.batches = [] + self.overfit_batches = int(overfit_batches) + + if self.overfit_batches > len(dataloader): + logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") + + for idx, batch in enumerate(dataloader): + if idx < self.overfit_batches or self.overfit_batches <= -1: + self.batches.append(batch) + else: + assert isinstance(batches, list) + self.batches = batches def __len__(self): return len(self.batches) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index fff8b5c2..0238a65d 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -445,7 +445,7 @@ class DataSet: "DataSet object has {} fields, but attempt to append an Instance object with {} fields." .format(len(self.field_arrays), len(instance.fields))) for name, field in instance.items(): - assert name in self.field_arrays + assert name in self.field_arrays, f'Field:`{name}` is not found in {self.field_arrays.keys()}' try: self.field_arrays[name].append(field) except Exception as e: