Browse Source

trainer check_code调整

tags/v0.2.0
yh yunfan 6 years ago
parent
commit
f7275339ff
1 changed files with 11 additions and 9 deletions
  1. +11
    -9
      fastNLP/core/trainer.py

+ 11
- 9
fastNLP/core/trainer.py View File

@@ -11,9 +11,6 @@ from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list

from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _syn_model_data
@@ -23,8 +20,7 @@ class Trainer(object):
"""Main Training Loop

"""

def __init__(self, train_data, model, n_epochs, batch_size, n_print=1,
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(),
@@ -210,13 +206,12 @@ IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2

def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=1):
def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
# check get_loss 方法
model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))

batch_size = min(DEFAULT_CHECK_BATCH_SIZE, batch_size)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
if batch_count == 0:
@@ -236,8 +231,9 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No

refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)

assert isinstance(output, dict), "The return value of {}.forward() should be dict.".format(model_name)
signature_str = get_func_signature(model.forward)
func_signature = '{}.forward(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)

# loss check
if batch_count == 0:
@@ -287,6 +283,12 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
break
_check_loss_evaluate(model=model, model_func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input)
signature_str = get_func_signature(model.evaluate)
func_signature = '{}.evaluate(self, {})'.format(model.__class__.__name__, signature_str[1:-1])
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)
if check_level > IGNORE_CHECK_LEVEL:
print("Finish checking evaluate process.", flush=True)



Loading…
Cancel
Save