|
@@ -9,6 +9,7 @@ import shutil |
|
|
|
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
from tensorboardX import SummaryWriter |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.loss import Loss |
|
|
from fastNLP.core.loss import Loss |
|
@@ -21,12 +22,13 @@ from fastNLP.core.utils import _check_arg_dict_list |
|
|
from fastNLP.core.utils import _build_args |
|
|
from fastNLP.core.utils import _build_args |
|
|
from fastNLP.core.utils import _syn_model_data |
|
|
from fastNLP.core.utils import _syn_model_data |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
class Trainer(object): |
|
|
"""Main Training Loop |
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, |
|
|
|
|
|
|
|
|
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, |
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, |
|
|
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, |
|
|
**kwargs): |
|
|
**kwargs): |
|
@@ -35,6 +37,8 @@ class Trainer(object): |
|
|
self.train_data = train_data |
|
|
self.train_data = train_data |
|
|
self.dev_data = dev_data # If None, No validation. |
|
|
self.dev_data = dev_data # If None, No validation. |
|
|
self.model = model |
|
|
self.model = model |
|
|
|
|
|
self.losser = losser |
|
|
|
|
|
self.metrics = metrics |
|
|
self.n_epochs = int(n_epochs) |
|
|
self.n_epochs = int(n_epochs) |
|
|
self.batch_size = int(batch_size) |
|
|
self.batch_size = int(batch_size) |
|
|
self.use_cuda = bool(use_cuda) |
|
|
self.use_cuda = bool(use_cuda) |
|
@@ -43,23 +47,22 @@ class Trainer(object): |
|
|
self.validate_every = int(validate_every) |
|
|
self.validate_every = int(validate_every) |
|
|
self._best_accuracy = 0 |
|
|
self._best_accuracy = 0 |
|
|
|
|
|
|
|
|
if need_check_code: |
|
|
|
|
|
_check_code(dataset=train_data, model=model, dev_data=dev_data) |
|
|
|
|
|
|
|
|
|
|
|
model_name = model.__class__.__name__ |
|
|
|
|
|
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) |
|
|
|
|
|
self.loss_func = self.model.get_loss |
|
|
|
|
|
|
|
|
# TODO check loss与metrics的类型 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO self._best_accuracy不能表现出当前的metric多种的情况 |
|
|
|
|
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|
else: |
|
|
else: |
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) |
|
|
|
|
|
|
|
|
assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name) |
|
|
|
|
|
self.evaluator = self.model.evaluate |
|
|
|
|
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
if self.dev_data is not None: |
|
|
self.tester = Tester(model=self.model, |
|
|
self.tester = Tester(model=self.model, |
|
|
data=self.dev_data, |
|
|
data=self.dev_data, |
|
|
|
|
|
metrics=self.metrics, |
|
|
batch_size=self.batch_size, |
|
|
batch_size=self.batch_size, |
|
|
use_cuda=self.use_cuda) |
|
|
use_cuda=self.use_cuda) |
|
|
|
|
|
|
|
@@ -71,6 +74,38 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
# print(self.__dict__) |
|
|
# print(self.__dict__) |
|
|
|
|
|
|
|
|
|
|
|
def _check_params(self, train_data, model, losser, metrics=[], n_epochs=3, batch_size=32, print_every=-1, |
|
|
|
|
|
validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", |
|
|
|
|
|
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, |
|
|
|
|
|
**kwargs): |
|
|
|
|
|
if not isinstance(train_data, DataSet): |
|
|
|
|
|
raise TypeError("The type of train_data must be fastNLP.DataSet, got {}.".\ |
|
|
|
|
|
format(type(train_data))) |
|
|
|
|
|
if not isinstance(model, nn.Module): |
|
|
|
|
|
raise TypeError("The type of model must be torch.nn.Module, got {}.".\ |
|
|
|
|
|
format(type(model))) |
|
|
|
|
|
if losser is not None: |
|
|
|
|
|
# TODO change |
|
|
|
|
|
if not isinstance(losser, None): |
|
|
|
|
|
raise TypeError("The type of losser must be xxx, got {}.".\ |
|
|
|
|
|
format(type(losser))) |
|
|
|
|
|
|
|
|
|
|
|
# check metrics and dev_data |
|
|
|
|
|
if (not metrics) and dev_data is not None: |
|
|
|
|
|
raise ValueError("No metric for dev_data evaluation.") |
|
|
|
|
|
if metrics and (dev_data is None): |
|
|
|
|
|
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") |
|
|
|
|
|
|
|
|
|
|
|
# check loss |
|
|
|
|
|
if isinstance(losser, type): |
|
|
|
|
|
self.losser = losser() |
|
|
|
|
|
if not isinstance(self.losser, None): |
|
|
|
|
|
raise TypeError(f'The type of losser must be `{}`, got {type(self.losser)}.') |
|
|
|
|
|
|
|
|
|
|
|
if need_check_code: |
|
|
|
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(self): |
|
|
def train(self): |
|
|
"""Start Training. |
|
|
"""Start Training. |
|
|
|
|
|
|
|
@@ -171,6 +206,9 @@ class Trainer(object): |
|
|
def data_forward(self, network, x): |
|
|
def data_forward(self, network, x): |
|
|
x = _build_args(network.forward, **x) |
|
|
x = _build_args(network.forward, **x) |
|
|
y = network(**x) |
|
|
y = network(**x) |
|
|
|
|
|
if not isinstance(y, dict): |
|
|
|
|
|
|
|
|
|
|
|
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") |
|
|
return y |
|
|
return y |
|
|
|
|
|
|
|
|
def grad_backward(self, loss): |
|
|
def grad_backward(self, loss): |
|
@@ -231,11 +269,11 @@ IGNORE_CHECK_LEVEL = 0 |
|
|
WARNING_CHECK_LEVEL = 1 |
|
|
WARNING_CHECK_LEVEL = 1 |
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
|
|
|
|
|
|
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): |
|
|
|
|
|
|
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
|
|
|
dev_data=None, |
|
|
|
|
|
check_level=WARNING_CHECK_LEVEL): |
|
|
# check get_loss 方法 |
|
|
# check get_loss 方法 |
|
|
model_name = model.__class__.__name__ |
|
|
model_name = model.__class__.__name__ |
|
|
if not hasattr(model, 'get_loss'): |
|
|
|
|
|
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) |
|
|
|
|
|
|
|
|
|
|
|
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
@@ -248,23 +286,26 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
output = model(**refined_batch_x) |
|
|
output = model(**refined_batch_x) |
|
|
func_signature = get_func_signature(model.forward) |
|
|
func_signature = get_func_signature(model.forward) |
|
|
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) |
|
|
|
|
|
|
|
|
if not isinstance(output, dict): |
|
|
|
|
|
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`.") |
|
|
|
|
|
|
|
|
# loss check |
|
|
# loss check |
|
|
if batch_count == 0: |
|
|
|
|
|
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level, |
|
|
|
|
|
output=output, batch_y=batch_y) |
|
|
|
|
|
loss_input = _build_args(model.get_loss, **output, **batch_y) |
|
|
|
|
|
loss = model.get_loss(**loss_input) |
|
|
|
|
|
|
|
|
if isinstance(losser, type): # 这种情况,用户传的是losser.CE这种未初始化的loss |
|
|
|
|
|
# 需要保证output与batch_y是无歧义的? |
|
|
|
|
|
# (1) output和batch_y长度为1 |
|
|
|
|
|
# (2) output和batch_y的key是和losser接受的完全一致 |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
loss = losser(output, batch_y) |
|
|
|
|
|
|
|
|
# check loss output |
|
|
# check loss output |
|
|
if batch_count == 0: |
|
|
if batch_count == 0: |
|
|
if not isinstance(loss, torch.Tensor): |
|
|
if not isinstance(loss, torch.Tensor): |
|
|
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.". |
|
|
|
|
|
format(model_name, type(loss))) |
|
|
|
|
|
|
|
|
raise ValueError("The return value of {} should be torch.Tensor, but got {}.". |
|
|
|
|
|
format(type(losser), type(loss))) |
|
|
if len(loss.size())!=0: |
|
|
if len(loss.size())!=0: |
|
|
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format( |
|
|
|
|
|
model_name, loss.size() |
|
|
|
|
|
|
|
|
raise ValueError("The size of return value of {} is {}, should be torch.size([])".format( |
|
|
|
|
|
type(losser), loss.size() |
|
|
)) |
|
|
)) |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
model.zero_grad() |
|
|
model.zero_grad() |
|
@@ -272,26 +313,29 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
if dev_data is not None: |
|
|
if dev_data is not None: |
|
|
if not hasattr(model, 'evaluate'): |
|
|
|
|
|
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set" |
|
|
|
|
|
"dev_data to 'None'." |
|
|
|
|
|
.format(model_name)) |
|
|
|
|
|
outputs, truths = defaultdict(list), defaultdict(list) |
|
|
outputs, truths = defaultdict(list), defaultdict(list) |
|
|
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
|
|
|
# TODO 这里修改为使用tester |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
with torch.no_grad(): |
|
|
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): |
|
|
for batch_count, (batch_x, batch_y) in enumerate(dev_batch): |
|
|
_syn_model_data(model, batch_x, batch_y) |
|
|
_syn_model_data(model, batch_x, batch_y) |
|
|
|
|
|
|
|
|
if hasattr(model, 'predict'): |
|
|
if hasattr(model, 'predict'): |
|
|
|
|
|
if not callable(model.predict): |
|
|
|
|
|
raise TypeError(f"{get_func_signature(model.predict)} must be callable to be used " |
|
|
|
|
|
f"for evaluation.") |
|
|
refined_batch_x = _build_args(model.predict, **batch_x) |
|
|
refined_batch_x = _build_args(model.predict, **batch_x) |
|
|
prev_func = model.predict |
|
|
prev_func = model.predict |
|
|
output = prev_func(**refined_batch_x) |
|
|
output = prev_func(**refined_batch_x) |
|
|
func_signature = get_func_signature(model.predict) |
|
|
|
|
|
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) |
|
|
|
|
|
else: |
|
|
else: |
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
prev_func = model.forward |
|
|
prev_func = model.forward |
|
|
output = prev_func(**refined_batch_x) |
|
|
output = prev_func(**refined_batch_x) |
|
|
|
|
|
func_signature = get_func_signature(prev_func) |
|
|
|
|
|
if not isinstance(output, dict): |
|
|
|
|
|
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(output)}`") |
|
|
for k, v in output.items(): |
|
|
for k, v in output.items(): |
|
|
outputs[k].append(v) |
|
|
outputs[k].append(v) |
|
|
for k, v in batch_y.items(): |
|
|
for k, v in batch_y.items(): |
|
@@ -299,16 +343,15 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: |
|
|
if batch_count+1>DEFAULT_CHECK_NUM_BATCH: |
|
|
break |
|
|
break |
|
|
for k, v in outputs.items(): |
|
|
for k, v in outputs.items(): |
|
|
outputs[k] = itertools.chain(*v) |
|
|
|
|
|
|
|
|
outputs[k] = tuple(itertools.chain(*v)) |
|
|
for k, v in truths.items(): |
|
|
for k, v in truths.items(): |
|
|
truths[k] = itertools.chain(*v) |
|
|
|
|
|
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level, |
|
|
|
|
|
output=outputs, batch_y=truths) |
|
|
|
|
|
refined_input = _build_args(model.evaluate, **outputs, **truths) |
|
|
|
|
|
metrics = model.evaluate(**refined_input) |
|
|
|
|
|
func_signature = get_func_signature(model.evaluate) |
|
|
|
|
|
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ |
|
|
|
|
|
format(func_signature) |
|
|
|
|
|
|
|
|
truths[k] = tuple(itertools.chain(*v)) |
|
|
|
|
|
#TODO 这里需要根据新版的metrics做修改,另外这里需要捕获来自metric的报错,因为需要指导用户debug |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_forward_error(model_func, check_level, batch_x): |
|
|
def _check_forward_error(model_func, check_level, batch_x): |
|
|