|
|
@@ -11,9 +11,6 @@ from fastNLP.core.optimizer import Optimizer |
|
|
|
from fastNLP.core.sampler import RandomSampler |
|
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
|
from fastNLP.core.tester import Tester |
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
|
|
|
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
|
from fastNLP.core.utils import _build_args |
|
|
|
from fastNLP.core.utils import _syn_model_data |
|
|
@@ -23,8 +20,7 @@ class Trainer(object): |
|
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, train_data, model, n_epochs, batch_size, n_print=1, |
|
|
|
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, |
|
|
|
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", |
|
|
|
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), |
|
|
|
evaluator=Evaluator(), |
|
|
@@ -210,13 +206,12 @@ IGNORE_CHECK_LEVEL = 0 |
|
|
|
WARNING_CHECK_LEVEL = 1 |
|
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
|
|
|
|
|
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1): |
|
|
|
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): |
|
|
|
# check get_loss 方法 |
|
|
|
model_name = model.__class__.__name__ |
|
|
|
if not hasattr(model, 'get_loss'): |
|
|
|
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name)) |
|
|
|
|
|
|
|
batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size) |
|
|
|
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) |
|
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
|
if batch_count == 0: |
|
|
@@ -236,8 +231,9 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
|
|
|
|
refined_batch_x = _build_args(model.forward, **batch_x) |
|
|
|
output = model(**refined_batch_x) |
|
|
|
|
|
|
|
assert isinstance(output, dict), "The return value of {}.forward() should be dict.".format(model_name) |
|
|
|
signature_str = get_func_signature(model.forward) |
|
|
|
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature) |
|
|
|
|
|
|
|
# loss check |
|
|
|
if batch_count == 0: |
|
|
@@ -287,6 +283,12 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No |
|
|
|
break |
|
|
|
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level, |
|
|
|
output=outputs, batch_y=truths) |
|
|
|
refined_input = _build_args(model.evaluate, **outputs, **truths) |
|
|
|
metrics = model.evaluate(**refined_input) |
|
|
|
signature_str = get_func_signature(model.evaluate) |
|
|
|
func_signature = '{}.evaluate(self, {})'.format(model.__class__.__name__, signature_str[1:-1]) |
|
|
|
assert isinstance(metrics, dict), "The return value of {} should be dict.". \ |
|
|
|
format(func_signature) |
|
|
|
if check_level > IGNORE_CHECK_LEVEL: |
|
|
|
print("Finish checking evaluate process.", flush=True) |
|
|
|
|
|
|
|