Browse Source

修复OverfitBatches被替换的问题

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
eb0e563fec
2 changed files with 16 additions and 11 deletions
  1. +15
    -10
      fastNLP/core/dataloaders/utils.py
  2. +1
    -1
      fastNLP/core/dataset/dataset.py

+ 15
- 10
fastNLP/core/dataloaders/utils.py View File

@@ -118,17 +118,22 @@ class OverfitDataLoader:
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能; 实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能;
""" """


def __init__(self, dataloader, overfit_batches: int):
def __init__(self, dataloader, overfit_batches: int, batches=None):
# batches 参数是给重新初始化dataloader使用的
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作;
self.batches = []
self.overfit_batches = int(overfit_batches)

if self.overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")

for idx, batch in enumerate(dataloader):
if idx < self.overfit_batches or self.overfit_batches <= -1:
self.batches.append(batch)
if batches is None:
self.batches = []
self.overfit_batches = int(overfit_batches)

if self.overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")

for idx, batch in enumerate(dataloader):
if idx < self.overfit_batches or self.overfit_batches <= -1:
self.batches.append(batch)
else:
assert isinstance(batches, list)
self.batches = batches


def __len__(self): def __len__(self):
return len(self.batches) return len(self.batches)


+ 1
- 1
fastNLP/core/dataset/dataset.py View File

@@ -445,7 +445,7 @@ class DataSet:
"DataSet object has {} fields, but attempt to append an Instance object with {} fields." "DataSet object has {} fields, but attempt to append an Instance object with {} fields."
.format(len(self.field_arrays), len(instance.fields))) .format(len(self.field_arrays), len(instance.fields)))
for name, field in instance.items(): for name, field in instance.items():
assert name in self.field_arrays
assert name in self.field_arrays, f'Field:`{name}` is not found in {self.field_arrays.keys()}'
try: try:
self.field_arrays[name].append(field) self.field_arrays[name].append(field)
except Exception as e: except Exception as e:


Loading…
Cancel
Save