|
@@ -1,5 +1,9 @@ |
|
|
import time |
|
|
import time |
|
|
from datetime import timedelta, datetime |
|
|
|
|
|
|
|
|
from datetime import timedelta |
|
|
|
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
|
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from tensorboardX import SummaryWriter |
|
|
from tensorboardX import SummaryWriter |
|
@@ -12,13 +16,17 @@ from fastNLP.core.sampler import RandomSampler |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
from fastNLP.core.tester import Tester |
|
|
from fastNLP.core.tester import Tester |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
|
|
from fastNLP.core.utils import _syn_model_data |
|
|
|
|
|
from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
class Trainer(object): |
|
|
"""Main Training Loop |
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, train_data, model, n_epochs, batch_size, n_print, |
|
|
|
|
|
|
|
|
def __init__(self, train_data, model, n_epochs=1, batch_size=32, print_every=-1, |
|
|
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", |
|
|
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", |
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), |
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), |
|
|
evaluator=Evaluator(), |
|
|
evaluator=Evaluator(), |
|
@@ -32,7 +40,7 @@ class Trainer(object): |
|
|
self.batch_size = int(batch_size) |
|
|
self.batch_size = int(batch_size) |
|
|
self.use_cuda = bool(use_cuda) |
|
|
self.use_cuda = bool(use_cuda) |
|
|
self.save_path = str(save_path) |
|
|
self.save_path = str(save_path) |
|
|
self.n_print = int(n_print) |
|
|
|
|
|
|
|
|
self.print_every = int(print_every) |
|
|
|
|
|
|
|
|
self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() |
|
|
self.loss_func = self.model.loss if hasattr(self.model, "loss") else loss.get() |
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
@@ -51,7 +59,7 @@ class Trainer(object): |
|
|
self.step = 0 |
|
|
self.step = 0 |
|
|
self.start_time = None # start timestamp |
|
|
self.start_time = None # start timestamp |
|
|
|
|
|
|
|
|
print(self.__dict__) |
|
|
|
|
|
|
|
|
# print(self.__dict__) |
|
|
|
|
|
|
|
|
def train(self): |
|
|
def train(self): |
|
|
"""Start Training. |
|
|
"""Start Training. |
|
@@ -70,17 +78,16 @@ class Trainer(object): |
|
|
epoch = 1 |
|
|
epoch = 1 |
|
|
while epoch <= self.n_epochs: |
|
|
while epoch <= self.n_epochs: |
|
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), |
|
|
|
|
|
use_cuda=self.use_cuda) |
|
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler()) |
|
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start, self.n_print) |
|
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) |
|
|
|
|
|
|
|
|
if self.dev_data: |
|
|
if self.dev_data: |
|
|
self.do_validation() |
|
|
self.do_validation() |
|
|
self.save_model(self.model, 'training_model_' + self.start_time) |
|
|
self.save_model(self.model, 'training_model_' + self.start_time) |
|
|
epoch += 1 |
|
|
epoch += 1 |
|
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, n_print, **kwargs): |
|
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): |
|
|
"""Training process in one epoch. |
|
|
"""Training process in one epoch. |
|
|
|
|
|
|
|
|
kwargs should contain: |
|
|
kwargs should contain: |
|
@@ -103,7 +110,7 @@ class Trainer(object): |
|
|
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) |
|
|
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) |
|
|
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) |
|
|
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0: |
|
|
|
|
|
|
|
|
if self.print_every > 0 and self.step % self.print_every == 0: |
|
|
end = time.time() |
|
|
end = time.time() |
|
|
diff = timedelta(seconds=round(end - kwargs["start"])) |
|
|
diff = timedelta(seconds=round(end - kwargs["start"])) |
|
|
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( |
|
|
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( |
|
@@ -197,9 +204,6 @@ def best_eval_result(self, metrics): |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
|
|
|
|
|
@@ -207,64 +211,209 @@ IGNORE_CHECK_LEVEL=0 |
|
|
WARNING_CHECK_LEVEL=1 |
|
|
WARNING_CHECK_LEVEL=1 |
|
|
STRICT_CHECK_LEVEL=2 |
|
|
STRICT_CHECK_LEVEL=2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): |
|
|
|
|
|
# check loss 方法 |
|
|
|
|
|
|
|
|
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): |
|
|
|
|
|
# check get_loss 方法 |
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
if not hasattr(model, 'get_loss'): |
|
|
if not hasattr(model, 'get_loss'): |
|
|
raise AttributeError("{} has to have a 'get_loss' function.".format(type(model))) |
|
|
|
|
|
|
|
|
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) |
|
|
|
|
|
|
|
|
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) |
|
|
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) |
|
|
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
|
|
|
_syn_model_data(model, batch_x, batch_y) |
|
|
|
|
|
# forward check |
|
|
if batch_count==0: |
|
|
if batch_count==0: |
|
|
check_res = _check_arg_dict_list(model.forward, batch_x) |
|
|
|
|
|
_info_str = '' |
|
|
|
|
|
if len(check_res.missing)>0: |
|
|
|
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
|
|
|
for field_name in check_res.missing: |
|
|
|
|
|
if hasattr(dataset, field_name): |
|
|
|
|
|
_info_str += "{} " |
|
|
|
|
|
_info_str += "Missing argument: [{}] needed by '{}.forward' is not presented in the input.\n" |
|
|
|
|
|
_info_str += "" |
|
|
|
|
|
print("") |
|
|
|
|
|
if len(check_res.unused)>0: |
|
|
|
|
|
if check_level == WARNING_CHECK_LEVEL: |
|
|
|
|
|
_info_str += "" |
|
|
|
|
|
|
|
|
_check_forward_error(model=model, model_func=model.forward, check_level=check_level, |
|
|
|
|
|
batch_x=batch_x) |
|
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
output = model(**refined_batch_x) |
|
|
output = model(**refined_batch_x) |
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(output, dict), "The return value of {}.forward() should be dict.".format(model_name) |
|
|
|
|
|
|
|
|
|
|
|
# loss check |
|
|
if batch_count == 0: |
|
|
if batch_count == 0: |
|
|
_dict = _check_arg_dict_list(model.loss, [output, batch_y]) |
|
|
|
|
|
if len(_dict)!=0: |
|
|
|
|
|
pass |
|
|
|
|
|
loss_input = _build_args(model.loss, **output, **batch_y) |
|
|
|
|
|
loss = model.loss(**loss_input) |
|
|
|
|
|
if batch_count == 0: |
|
|
|
|
|
if isinstance(loss, torch.Tensor): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
_check_loss_evaluate(model=model, model_func=model.get_loss, check_level=check_level, |
|
|
|
|
|
output=output, batch_y=batch_y) |
|
|
|
|
|
loss_input = _build_args(model.get_loss, **output, **batch_y) |
|
|
|
|
|
loss = model.get_loss(**loss_input) |
|
|
|
|
|
|
|
|
|
|
|
# check loss output |
|
|
|
|
|
if batch_count == 0: |
|
|
|
|
|
if not isinstance(loss, torch.Tensor): |
|
|
|
|
|
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". |
|
|
|
|
|
format(model_name, type(loss))) |
|
|
|
|
|
if len(loss.size())!=0: |
|
|
|
|
|
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( |
|
|
|
|
|
model_name, loss.size() |
|
|
|
|
|
)) |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if batch_count+1>=DEFAULT_CHECK_BATCH_SIZE: |
|
|
|
|
|
|
|
|
model.zero_grad() |
|
|
|
|
|
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: |
|
|
break |
|
|
break |
|
|
|
|
|
if check_level > IGNORE_CHECK_LEVEL: |
|
|
|
|
|
print('Finish checking training process.', flush=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
|
|
|
if dev_data is not None: |
|
|
if dev_data is not None: |
|
|
if not hasattr(model, 'evaluate'): |
|
|
if not hasattr(model, 'evaluate'): |
|
|
raise AttributeError("If {} wants to do evaluation, {} has to have a 'evaluate' function. Or you can set" |
|
|
|
|
|
|
|
|
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" |
|
|
"dev_data to 'None'." |
|
|
"dev_data to 'None'." |
|
|
.format(type(model), type(model))) |
|
|
|
|
|
|
|
|
.format(model_name)) |
|
|
|
|
|
outputs, truths = defaultdict(list), defaultdict(list) |
|
|
|
|
|
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): |
|
|
|
|
|
_syn_model_data(model, batch_x, batch_y) |
|
|
|
|
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
|
|
output = model(**refined_batch_x) |
|
|
|
|
|
for k, v in output.items(): |
|
|
|
|
|
outputs[k].append(v) |
|
|
|
|
|
for k, v in batch_y.items(): |
|
|
|
|
|
truths[k].append(v) |
|
|
|
|
|
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: |
|
|
|
|
|
break |
|
|
|
|
|
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level, |
|
|
|
|
|
output=outputs, batch_y=truths) |
|
|
|
|
|
print("Finish checking evaluate process.", flush=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_forward_error(model, model_func, check_level, batch_x): |
|
|
|
|
|
check_res = _check_arg_dict_list(model_func, batch_x) |
|
|
|
|
|
_missing = '' |
|
|
|
|
|
_unused = '' |
|
|
|
|
|
signature_str = get_func_signature(model_func) |
|
|
|
|
|
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
|
|
if len(check_res.missing)!=0: |
|
|
|
|
|
_missing = "Function {} misses {}, only provided with {}, " \ |
|
|
|
|
|
".\n".format(func_signature, check_res.missing, |
|
|
|
|
|
list(batch_x.keys())) |
|
|
|
|
|
if len(check_res.unused)!=0: |
|
|
|
|
|
if len(check_res.unused) > 1: |
|
|
|
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
|
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
|
|
|
if _missing: |
|
|
|
|
|
if not _unused and STRICT_CHECK_LEVEL: |
|
|
|
|
|
_error_str = "(1).{} (2).{}".format(_missing, _unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_error_str = _missing |
|
|
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
|
|
|
raise TypeError(_error_str) |
|
|
|
|
|
if _unused: |
|
|
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
|
|
|
raise ValueError(_unused) |
|
|
|
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
|
|
|
warnings.warn(message=_unused, ) |
|
|
|
|
|
|
|
|
|
|
|
def _check_loss_evaluate(model, model_func, check_level, output, batch_y): |
|
|
|
|
|
check_res = _check_arg_dict_list(model_func, [output, batch_y]) |
|
|
|
|
|
_missing = '' |
|
|
|
|
|
_unused = '' |
|
|
|
|
|
_duplicated = '' |
|
|
|
|
|
signature_str = get_func_signature(model_func) |
|
|
|
|
|
func_signature = "{}.{}(self, {})".format(model.__class__.__name__, model_func.__name__, signature_str[1:-1]) |
|
|
|
|
|
forward_func_signature = "{}.forward(self, {})".format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
|
|
|
if len(check_res.missing)>0: |
|
|
|
|
|
_missing = "Function {} misses argument {}, only provided with {}(from {}) and " \ |
|
|
|
|
|
"{}." \ |
|
|
|
|
|
.format(func_signature, check_res.missing, |
|
|
|
|
|
list(output.keys()), model_name, |
|
|
|
|
|
list(batch_y.keys())) |
|
|
|
|
|
if len(check_res.unused)>0: |
|
|
|
|
|
if len(check_res.unused) > 1: |
|
|
|
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
|
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
|
|
|
if len(check_res.duplicated)>0: |
|
|
|
|
|
if len(check_res.duplicated) > 1: |
|
|
|
|
|
_duplicated = "Duplicated keys: {} are detected in function {}. Don't set {} as target and output " \ |
|
|
|
|
|
"them in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
|
|
func_signature, |
|
|
|
|
|
check_res.duplicated, |
|
|
|
|
|
forward_func_signature) |
|
|
|
|
|
else: |
|
|
|
|
|
_duplicated = "Duplicated key: {} is detected in function {}. Don't set {} as target and output " \ |
|
|
|
|
|
"it in {} at the same time.\n".format(check_res.duplicated, |
|
|
|
|
|
func_signature, |
|
|
|
|
|
check_res.duplicated, |
|
|
|
|
|
forward_func_signature) |
|
|
|
|
|
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) |
|
|
|
|
|
if _number_errs > 0: |
|
|
|
|
|
_error_str = '' |
|
|
|
|
|
if _number_errs > 1: |
|
|
|
|
|
count = 1 |
|
|
|
|
|
if _missing: |
|
|
|
|
|
_error_str += '({}).{}'.format(count, _missing) |
|
|
|
|
|
count += 1 |
|
|
|
|
|
if _duplicated: |
|
|
|
|
|
_error_str += '({}).{}'.format(count, _duplicated) |
|
|
|
|
|
count += 1 |
|
|
|
|
|
if _unused and check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
_error_str += '({}).{}'.format(count, _unused) |
|
|
|
|
|
else: |
|
|
|
|
|
if _unused: |
|
|
|
|
|
if check_level == STRICT_CHECK_LEVEL: |
|
|
|
|
|
# TODO 这里可能需要自定义一些Error类型 |
|
|
|
|
|
_error_str = _unused |
|
|
|
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
|
|
|
_unused = _unused.strip() |
|
|
|
|
|
warnings.warn(_unused) |
|
|
|
|
|
else: |
|
|
|
|
|
_error_str = _missing + _duplicated |
|
|
|
|
|
if _error_str: |
|
|
|
|
|
raise ValueError(_error_str) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
import torch |
|
|
|
|
|
from torch import nn |
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self. fc1 = nn.Linear(10, 2) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, words, chars): |
|
|
|
|
|
output = {} |
|
|
|
|
|
output['prediction'] = torch.randn(3, 4) |
|
|
|
|
|
output['words'] = words |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
def get_loss(self, prediction, labels, words): |
|
|
|
|
|
return torch.mean(self.fc1.weight) |
|
|
|
|
|
|
|
|
|
|
|
def evaluate(self, prediction, labels, demo=2): |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
|
|
|
model = Model() |
|
|
|
|
|
|
|
|
|
|
|
num_samples = 4 |
|
|
|
|
|
fake_data_dict = {'words': np.random.randint(num_samples, size=(4, 3)), 'chars': np.random.randn(num_samples, 6), |
|
|
|
|
|
'labels': np.random.randint(2, size=(num_samples,))} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = DataSet(fake_data_dict) |
|
|
|
|
|
dataset.set_input(words=True, chars=True) |
|
|
|
|
|
dataset.set_target(labels=True) |
|
|
|
|
|
|
|
|
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): |
|
|
|
|
|
if batch_count == 0: |
|
|
|
|
|
_dict = _check_arg_dict_list(model.evaluate, [output, batch_y]) |
|
|
|
|
|
|
|
|
# trainer = Trainer(dataset, model) |
|
|
|
|
|
|
|
|
if len(_dict)!=0: |
|
|
|
|
|
pass |
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
|
|
output = model(**refined_batch_x) |
|
|
|
|
|
|
|
|
_check_code(dataset=dataset, model=model, dev_data=dataset, check_level=2) |
|
|
|
|
|
|
|
|
|
|
|
# _check_forward_error(model=model, model_func=model.forward, check_level=1, |
|
|
|
|
|
# batch_x=fake_data_dict) |
|
|
|
|
|
|
|
|
|
|
|
# import inspect |
|
|
|
|
|
# print(inspect.getfullargspec(model.forward)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|