|
@@ -1,11 +1,11 @@ |
|
|
import _pickle |
|
|
import _pickle |
|
|
|
|
|
import os |
|
|
|
|
|
from datetime import timedelta |
|
|
|
|
|
from time import time |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
import os |
|
|
|
|
|
from time import time |
|
|
|
|
|
from datetime import timedelta |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.action.action import Action |
|
|
from fastNLP.action.action import Action |
|
|
from fastNLP.action.action import RandomSampler, Batchifier |
|
|
from fastNLP.action.action import RandomSampler, Batchifier |
|
@@ -77,16 +77,17 @@ class BaseTrainer(Action): |
|
|
|
|
|
|
|
|
# main training epochs |
|
|
# main training epochs |
|
|
iterations = len(data_train) // self.batch_size |
|
|
iterations = len(data_train) // self.batch_size |
|
|
|
|
|
self.define_optimizer() |
|
|
|
|
|
|
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
for epoch in range(1, self.n_epochs + 1): |
|
|
|
|
|
|
|
|
# turn on network training mode; define optimizer; prepare batch iterator |
|
|
# turn on network training mode; define optimizer; prepare batch iterator |
|
|
self.mode(test=False) |
|
|
self.mode(test=False) |
|
|
self.define_optimizer() |
|
|
|
|
|
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) |
|
|
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) |
|
|
|
|
|
|
|
|
# training iterations in one epoch |
|
|
# training iterations in one epoch |
|
|
for step in range(iterations): |
|
|
for step in range(iterations): |
|
|
batch_x, batch_y = self.batchify(data_train) |
|
|
|
|
|
|
|
|
batch_x, batch_y = self.batchify(data_train) # pad ? |
|
|
|
|
|
|
|
|
prediction = self.data_forward(network, batch_x) |
|
|
prediction = self.data_forward(network, batch_x) |
|
|
|
|
|
|
|
@@ -212,7 +213,7 @@ class BaseTrainer(Action): |
|
|
max_length = max([len(x) for x in batch]) |
|
|
max_length = max([len(x) for x in batch]) |
|
|
for idx, sample in enumerate(batch): |
|
|
for idx, sample in enumerate(batch): |
|
|
if len(sample) < max_length: |
|
|
if len(sample) < max_length: |
|
|
batch[idx] = sample + [fill * (max_length - len(sample))] |
|
|
|
|
|
|
|
|
batch[idx] = sample + ([fill] * (max_length - len(sample))) |
|
|
return batch |
|
|
return batch |
|
|
|
|
|
|
|
|
def best_eval_result(self, validator): |
|
|
def best_eval_result(self, validator): |
|
|