|
@@ -48,7 +48,7 @@ class Trainer(object): |
|
|
:param use_cuda: |
|
|
:param use_cuda: |
|
|
:param str save_path: file path to save models |
|
|
:param str save_path: file path to save models |
|
|
:param Optimizer optimizer: an optimizer object |
|
|
:param Optimizer optimizer: an optimizer object |
|
|
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict. |
|
|
|
|
|
|
|
|
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. |
|
|
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one |
|
|
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one |
|
|
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets |
|
|
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets |
|
|
smaller, add a `-` character in front of the string. For example |
|
|
smaller, add a `-` character in front of the string. For example |
|
@@ -91,7 +91,7 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
if check_code_level > -1: |
|
|
if check_code_level > -1: |
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, |
|
|
check_level=check_code_level) |
|
|
|
|
|
|
|
|
metric_key=metric_key, check_level=check_code_level) |
|
|
|
|
|
|
|
|
self.train_data = train_data |
|
|
self.train_data = train_data |
|
|
self.dev_data = dev_data # If None, No validation. |
|
|
self.dev_data = dev_data # If None, No validation. |
|
@@ -294,7 +294,7 @@ DEFAULT_CHECK_NUM_BATCH = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
dev_data=None, |
|
|
|
|
|
|
|
|
dev_data=None, metric_key=None, |
|
|
check_level=0): |
|
|
check_level=0): |
|
|
# check get_loss 方法 |
|
|
# check get_loss 方法 |
|
|
model_devcie = model.parameters().__next__().device |
|
|
model_devcie = model.parameters().__next__().device |
|
@@ -340,7 +340,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ |
|
|
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, |
|
|
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, |
|
|
batch_size=batch_size, verbose=-1) |
|
|
batch_size=batch_size, verbose=-1) |
|
|
evaluate_results = tester.test() |
|
|
evaluate_results = tester.test() |
|
|
# TODO 这里需要检查是否返回来的值是否是合理的 |
|
|
|
|
|
|
|
|
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_eval_results(metrics, metric_key, metric_list): |
|
|
def _check_eval_results(metrics, metric_key, metric_list): |
|
|