Browse Source

1.Trainer训练train增加一个on_exception判断,是否抛出异常; 2.Fitlog默认不记录loss,因为loss实在比较占硬盘

tags/v0.4.10
yh_cc 5 years ago
parent
commit
1a3dcd3dde
2 changed files with 20 additions and 7 deletions
  1. +14
    -5
      fastNLP/core/callback.py
  2. +6
    -2
      fastNLP/core/trainer.py

+ 14
- 5
fastNLP/core/callback.py View File

@@ -435,7 +435,7 @@ class FitlogCallback(Callback):
"""
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`

该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
该callback将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
@@ -444,15 +444,18 @@ class FitlogCallback(Callback):
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过
dict的方式传入。如果仅传入DataSet, 则被命名为test
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test`
:param int verbose: 是否在终端打印内容,0不打印
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。
:param int verbose: 是否在终端打印evaluation的结果,0不打印。
:param bool log_exception: fitlog是否记录发生的exception信息
"""

def __init__(self, data=None, tester=None, verbose=0, log_exception=False):
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False):
super().__init__()
self.datasets = {}
self.testers = {}
self._log_exception = log_exception
assert isinstance(log_loss_every, int) and log_loss_every>=0
if tester is not None:
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed."
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data."
@@ -472,6 +475,8 @@ class FitlogCallback(Callback):
raise TypeError("data receives dict[DataSet] or DataSet object.")

self.verbose = verbose
self._log_loss_every = log_loss_every
self._avg_loss = 0

def on_train_begin(self):
if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None:
@@ -485,7 +490,11 @@ class FitlogCallback(Callback):
fitlog.add_progress(total_steps=self.n_steps)

def on_backward_begin(self, loss):
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch)
if self._log_loss_every>0:
self._avg_loss += loss.item()
if self.step%self._log_loss_every==0:
fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch)
self._avg_loss = 0

def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
if better_result:
@@ -513,7 +522,7 @@ class FitlogCallback(Callback):
def on_exception(self, exception):
fitlog.finish(status=1)
if self._log_exception:
fitlog.add_other(str(exception), name='except_info')
fitlog.add_other(repr(exception), name='except_info')


class LRScheduler(Callback):


+ 6
- 2
fastNLP/core/trainer.py View File

@@ -494,12 +494,14 @@ class Trainer(object):
self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)
def train(self, load_best_model=True):
def train(self, load_best_model=True, on_exception='ignore'):
"""
使用该函数使Trainer开始训练。

:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现
最好的模型参数。
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。
支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。
:return dict: 返回一个字典类型的数据,
内含以下内容::

@@ -527,8 +529,10 @@ class Trainer(object):
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end()
except (CallbackException, KeyboardInterrupt) as e:
except (CallbackException, KeyboardInterrupt, Exception) as e:
self.callback_manager.on_exception(e)
if on_exception=='raise':
raise e
if self.dev_data is not None and hasattr(self, 'best_dev_perf'):
print(


Loading…
Cancel
Save