Browse Source

fix bugs: 1. action batchify 2. Trainer optimizer position 3. Trainer.pad

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
7557e2d065
2 changed files with 8 additions and 7 deletions
  1. +1
    -1
      fastNLP/action/action.py
  2. +7
    -6
      fastNLP/action/trainer.py

+ 1
- 1
fastNLP/action/action.py View File

@@ -67,5 +67,5 @@ class Batchifier(object):
if len(batch) == self.batch_size: if len(batch) == self.batch_size:
yield batch yield batch
batch = [] batch = []
if len(batch) < self.batch_size and self.drop_last is False:
if 0 < len(batch) < self.batch_size and self.drop_last is False:
yield batch yield batch

+ 7
- 6
fastNLP/action/trainer.py View File

@@ -1,11 +1,11 @@
import _pickle import _pickle
import os
from datetime import timedelta
from time import time


import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import os
from time import time
from datetime import timedelta


from fastNLP.action.action import Action from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier from fastNLP.action.action import RandomSampler, Batchifier
@@ -77,16 +77,17 @@ class BaseTrainer(Action):


# main training epochs # main training epochs
iterations = len(data_train) // self.batch_size iterations = len(data_train) // self.batch_size
self.define_optimizer()

for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):


# turn on network training mode; define optimizer; prepare batch iterator # turn on network training mode; define optimizer; prepare batch iterator
self.mode(test=False) self.mode(test=False)
self.define_optimizer()
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True))


# training iterations in one epoch # training iterations in one epoch
for step in range(iterations): for step in range(iterations):
batch_x, batch_y = self.batchify(data_train)
batch_x, batch_y = self.batchify(data_train) # pad ?


prediction = self.data_forward(network, batch_x) prediction = self.data_forward(network, batch_x)


@@ -212,7 +213,7 @@ class BaseTrainer(Action):
max_length = max([len(x) for x in batch]) max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch): for idx, sample in enumerate(batch):
if len(sample) < max_length: if len(sample) < max_length:
batch[idx] = sample + [fill * (max_length - len(sample))]
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch return batch


def best_eval_result(self, validator): def best_eval_result(self, validator):


Loading…
Cancel
Save