|
|
@@ -435,7 +435,7 @@ class FitlogCallback(Callback): |
|
|
|
""" |
|
|
|
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` |
|
|
|
|
|
|
|
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 |
|
|
|
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 |
|
|
|
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 |
|
|
|
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 |
|
|
|
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 |
|
|
@@ -444,15 +444,18 @@ class FitlogCallback(Callback): |
|
|
|
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 |
|
|
|
dict的方式传入。如果仅传入DataSet, 则被命名为test |
|
|
|
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` |
|
|
|
:param int verbose: 是否在终端打印内容,0不打印 |
|
|
|
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 |
|
|
|
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 |
|
|
|
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 |
|
|
|
:param bool log_exception: fitlog是否记录发生的exception信息 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, data=None, tester=None, verbose=0, log_exception=False): |
|
|
|
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): |
|
|
|
super().__init__() |
|
|
|
self.datasets = {} |
|
|
|
self.testers = {} |
|
|
|
self._log_exception = log_exception |
|
|
|
assert isinstance(log_loss_every, int) and log_loss_every>=0 |
|
|
|
if tester is not None: |
|
|
|
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." |
|
|
|
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." |
|
|
@@ -472,6 +475,8 @@ class FitlogCallback(Callback): |
|
|
|
raise TypeError("data receives dict[DataSet] or DataSet object.") |
|
|
|
|
|
|
|
self.verbose = verbose |
|
|
|
self._log_loss_every = log_loss_every |
|
|
|
self._avg_loss = 0 |
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
|
if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None: |
|
|
@@ -485,7 +490,11 @@ class FitlogCallback(Callback): |
|
|
|
fitlog.add_progress(total_steps=self.n_steps) |
|
|
|
|
|
|
|
def on_backward_begin(self, loss): |
|
|
|
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) |
|
|
|
if self._log_loss_every>0: |
|
|
|
self._avg_loss += loss.item() |
|
|
|
if self.step%self._log_loss_every==0: |
|
|
|
fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch) |
|
|
|
self._avg_loss = 0 |
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): |
|
|
|
if better_result: |
|
|
@@ -513,7 +522,7 @@ class FitlogCallback(Callback): |
|
|
|
def on_exception(self, exception): |
|
|
|
fitlog.finish(status=1) |
|
|
|
if self._log_exception: |
|
|
|
fitlog.add_other(str(exception), name='except_info') |
|
|
|
fitlog.add_other(repr(exception), name='except_info') |
|
|
|
|
|
|
|
|
|
|
|
class LRScheduler(Callback): |
|
|
|