|
|
@@ -118,17 +118,22 @@ class OverfitDataLoader: |
|
|
|
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dataloader, overfit_batches: int): |
|
|
|
def __init__(self, dataloader, overfit_batches: int, batches=None): |
|
|
|
# batches 参数是给重新初始化dataloader使用的 |
|
|
|
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; |
|
|
|
self.batches = [] |
|
|
|
self.overfit_batches = int(overfit_batches) |
|
|
|
|
|
|
|
if self.overfit_batches > len(dataloader): |
|
|
|
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") |
|
|
|
|
|
|
|
for idx, batch in enumerate(dataloader): |
|
|
|
if idx < self.overfit_batches or self.overfit_batches <= -1: |
|
|
|
self.batches.append(batch) |
|
|
|
if batches is None: |
|
|
|
self.batches = [] |
|
|
|
self.overfit_batches = int(overfit_batches) |
|
|
|
|
|
|
|
if self.overfit_batches > len(dataloader): |
|
|
|
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") |
|
|
|
|
|
|
|
for idx, batch in enumerate(dataloader): |
|
|
|
if idx < self.overfit_batches or self.overfit_batches <= -1: |
|
|
|
self.batches.append(batch) |
|
|
|
else: |
|
|
|
assert isinstance(batches, list) |
|
|
|
self.batches = batches |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.batches) |
|
|
|