diff --git a/fastNLP/action/action.py b/fastNLP/action/action.py index ea12a37e..2bc08b75 100644 --- a/fastNLP/action/action.py +++ b/fastNLP/action/action.py @@ -67,5 +67,5 @@ class Batchifier(object): if len(batch) == self.batch_size: yield batch batch = [] - if len(batch) < self.batch_size and self.drop_last is False: + if 0 < len(batch) < self.batch_size and self.drop_last is False: yield batch diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index e9c2b7bc..45049e9d 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -1,11 +1,11 @@ import _pickle +import os +from datetime import timedelta +from time import time import numpy as np import torch import torch.nn as nn -import os -from time import time -from datetime import timedelta from fastNLP.action.action import Action from fastNLP.action.action import RandomSampler, Batchifier @@ -77,16 +77,17 @@ class BaseTrainer(Action): # main training epochs iterations = len(data_train) // self.batch_size + self.define_optimizer() + for epoch in range(1, self.n_epochs + 1): # turn on network training mode; define optimizer; prepare batch iterator self.mode(test=False) - self.define_optimizer() self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) # training iterations in one epoch for step in range(iterations): - batch_x, batch_y = self.batchify(data_train) + batch_x, batch_y = self.batchify(data_train) # pad ? prediction = self.data_forward(network, batch_x) @@ -212,7 +213,7 @@ class BaseTrainer(Action): max_length = max([len(x) for x in batch]) for idx, sample in enumerate(batch): if len(sample) < max_length: - batch[idx] = sample + [fill * (max_length - len(sample))] + batch[idx] = sample + ([fill] * (max_length - len(sample))) return batch def best_eval_result(self, validator):