|
@@ -54,6 +54,7 @@ __all__ = [ |
|
|
"GradientClipCallback", |
|
|
"GradientClipCallback", |
|
|
"EarlyStopCallback", |
|
|
"EarlyStopCallback", |
|
|
"TensorboardCallback", |
|
|
"TensorboardCallback", |
|
|
|
|
|
"FitlogCallback", |
|
|
"LRScheduler", |
|
|
"LRScheduler", |
|
|
"ControlC", |
|
|
"ControlC", |
|
|
|
|
|
|
|
@@ -65,6 +66,7 @@ import os |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from copy import deepcopy |
|
|
from copy import deepcopy |
|
|
|
|
|
|
|
|
try: |
|
|
try: |
|
|
from tensorboardX import SummaryWriter |
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
|
|
@@ -81,6 +83,7 @@ try: |
|
|
except: |
|
|
except: |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Callback(object): |
|
|
class Callback(object): |
|
|
""" |
|
|
""" |
|
|
别名::class:`fastNLP.Callback` :class:`fastNLP.core.callback.Callback` |
|
|
别名::class:`fastNLP.Callback` :class:`fastNLP.core.callback.Callback` |
|
@@ -431,14 +434,13 @@ class EarlyStopCallback(Callback): |
|
|
else: |
|
|
else: |
|
|
raise exception # 抛出陌生Error |
|
|
raise exception # 抛出陌生Error |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FitlogCallback(Callback): |
|
|
class FitlogCallback(Callback): |
|
|
""" |
|
|
""" |
|
|
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` |
|
|
|
|
|
|
|
|
|
|
|
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 |
|
|
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 |
|
|
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 |
|
|
|
|
|
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 |
|
|
|
|
|
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 |
|
|
|
|
|
|
|
|
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 |
|
|
|
|
|
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 |
|
|
|
|
|
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 |
|
|
|
|
|
|
|
|
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 |
|
|
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 |
|
|
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 |
|
|
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 |
|
@@ -447,7 +449,9 @@ class FitlogCallback(Callback): |
|
|
:param int verbose: 是否在终端打印内容,0不打印 |
|
|
:param int verbose: 是否在终端打印内容,0不打印 |
|
|
:param bool log_exception: fitlog是否记录发生的exception信息 |
|
|
:param bool log_exception: fitlog是否记录发生的exception信息 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
# 还没有被导出到 fastNLP 层 |
|
|
|
|
|
# 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` |
|
|
|
|
|
|
|
|
def __init__(self, data=None, tester=None, verbose=0, log_exception=False): |
|
|
def __init__(self, data=None, tester=None, verbose=0, log_exception=False): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.datasets = {} |
|
|
self.datasets = {} |
|
@@ -460,7 +464,7 @@ class FitlogCallback(Callback): |
|
|
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." |
|
|
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." |
|
|
setattr(tester, 'verbose', 0) |
|
|
setattr(tester, 'verbose', 0) |
|
|
self.testers['test'] = tester |
|
|
self.testers['test'] = tester |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(data, dict): |
|
|
if isinstance(data, dict): |
|
|
for key, value in data.items(): |
|
|
for key, value in data.items(): |
|
|
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." |
|
|
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." |
|
@@ -470,23 +474,23 @@ class FitlogCallback(Callback): |
|
|
self.datasets['test'] = data |
|
|
self.datasets['test'] = data |
|
|
else: |
|
|
else: |
|
|
raise TypeError("data receives dict[DataSet] or DataSet object.") |
|
|
raise TypeError("data receives dict[DataSet] or DataSet object.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.verbose = verbose |
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
def on_train_begin(self): |
|
|
if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None: |
|
|
|
|
|
|
|
|
if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None: |
|
|
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") |
|
|
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") |
|
|
|
|
|
|
|
|
if len(self.datasets)>0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if len(self.datasets) > 0: |
|
|
for key, data in self.datasets.items(): |
|
|
for key, data in self.datasets.items(): |
|
|
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, |
|
|
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, |
|
|
verbose=0) |
|
|
verbose=0) |
|
|
self.testers[key] = tester |
|
|
self.testers[key] = tester |
|
|
fitlog.add_progress(total_steps=self.n_steps) |
|
|
fitlog.add_progress(total_steps=self.n_steps) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_backward_begin(self, loss): |
|
|
def on_backward_begin(self, loss): |
|
|
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) |
|
|
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): |
|
|
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): |
|
|
if better_result: |
|
|
if better_result: |
|
|
eval_result = deepcopy(eval_result) |
|
|
eval_result = deepcopy(eval_result) |
|
@@ -494,11 +498,11 @@ class FitlogCallback(Callback): |
|
|
eval_result['epoch'] = self.epoch |
|
|
eval_result['epoch'] = self.epoch |
|
|
fitlog.add_best_metric(eval_result) |
|
|
fitlog.add_best_metric(eval_result) |
|
|
fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch) |
|
|
fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch) |
|
|
if len(self.testers)>0: |
|
|
|
|
|
|
|
|
if len(self.testers) > 0: |
|
|
for key, tester in self.testers.items(): |
|
|
for key, tester in self.testers.items(): |
|
|
try: |
|
|
try: |
|
|
eval_result = tester.test() |
|
|
eval_result = tester.test() |
|
|
if self.verbose!=0: |
|
|
|
|
|
|
|
|
if self.verbose != 0: |
|
|
self.pbar.write("Evaluation on DataSet {}:".format(key)) |
|
|
self.pbar.write("Evaluation on DataSet {}:".format(key)) |
|
|
self.pbar.write(tester._format_eval_results(eval_result)) |
|
|
self.pbar.write(tester._format_eval_results(eval_result)) |
|
|
fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) |
|
|
fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) |
|
@@ -506,10 +510,10 @@ class FitlogCallback(Callback): |
|
|
fitlog.add_best_metric(eval_result, name=key) |
|
|
fitlog.add_best_metric(eval_result, name=key) |
|
|
except Exception: |
|
|
except Exception: |
|
|
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) |
|
|
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_train_end(self): |
|
|
def on_train_end(self): |
|
|
fitlog.finish() |
|
|
fitlog.finish() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_exception(self, exception): |
|
|
def on_exception(self, exception): |
|
|
fitlog.finish(status=1) |
|
|
fitlog.finish(status=1) |
|
|
if self._log_exception: |
|
|
if self._log_exception: |
|
|