Browse Source

fix bugs: 1. action batchify 2. Trainer optimizer position 3. Trainer.pad

tags/v0.1.0
FengZiYjun 6 years ago
parent
commit
7557e2d065
2 changed files with 8 additions and 7 deletions
  1. +1
    -1
      fastNLP/action/action.py
  2. +7
    -6
      fastNLP/action/trainer.py

+ 1
- 1
fastNLP/action/action.py View File

@@ -67,5 +67,5 @@ class Batchifier(object):
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) < self.batch_size and self.drop_last is False:
if 0 < len(batch) < self.batch_size and self.drop_last is False:
yield batch

+ 7
- 6
fastNLP/action/trainer.py View File

@@ -1,11 +1,11 @@
import _pickle
import os
from datetime import timedelta
from time import time

import numpy as np
import torch
import torch.nn as nn
import os
from time import time
from datetime import timedelta

from fastNLP.action.action import Action
from fastNLP.action.action import RandomSampler, Batchifier
@@ -77,16 +77,17 @@ class BaseTrainer(Action):

# main training epochs
iterations = len(data_train) // self.batch_size
self.define_optimizer()

for epoch in range(1, self.n_epochs + 1):

# turn on network training mode; define optimizer; prepare batch iterator
self.mode(test=False)
self.define_optimizer()
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True))

# training iterations in one epoch
for step in range(iterations):
batch_x, batch_y = self.batchify(data_train)
batch_x, batch_y = self.batchify(data_train) # pad ?

prediction = self.data_forward(network, batch_x)

@@ -212,7 +213,7 @@ class BaseTrainer(Action):
max_length = max([len(x) for x in batch])
for idx, sample in enumerate(batch):
if len(sample) < max_length:
batch[idx] = sample + [fill * (max_length - len(sample))]
batch[idx] = sample + ([fill] * (max_length - len(sample)))
return batch

def best_eval_result(self, validator):


Loading…
Cancel
Save