@@ -16,6 +16,7 @@ __all__ = [ | |||||
"ResultsMonitor", | "ResultsMonitor", | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
"FitlogCallback", | "FitlogCallback", | ||||
"TimerCallback", | |||||
# collators | # collators | ||||
'Collator', | 'Collator', | ||||
@@ -21,7 +21,9 @@ __all__ = [ | |||||
"ResultsMonitor", | "ResultsMonitor", | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
"FitlogCallback" | |||||
"FitlogCallback", | |||||
"TimerCallback" | |||||
] | ] | ||||
@@ -37,4 +39,4 @@ from .torch_callbacks import * | |||||
from .more_evaluate_callback import MoreEvaluateCallback | from .more_evaluate_callback import MoreEvaluateCallback | ||||
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | ||||
from .fitlog_callback import FitlogCallback | from .fitlog_callback import FitlogCallback | ||||
from .timer_callback import TimerCallback |
@@ -171,7 +171,7 @@ class ResultsMonitor: | |||||
@property | @property | ||||
def log_name(self) -> str: | def log_name(self) -> str: | ||||
""" | """ | ||||
内部用于打印信息使用 | |||||
内部用于打印当前类别信息使用 | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -106,11 +106,11 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||||
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
model_load_fn=self.model_load_fn) | model_load_fn=self.model_load_fn) | ||||
else: | else: | ||||
logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||||
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||||
self.buffer.seek(0) | self.buffer.seek(0) | ||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
if self.delete_after_after: | if self.delete_after_after: | ||||
@@ -1,5 +1,4 @@ | |||||
import json | import json | ||||
import sys | |||||
from typing import Union | from typing import Union | ||||
__all__ = [ | __all__ = [ | ||||
@@ -16,8 +15,21 @@ from fastNLP.core.log import logger | |||||
class ProgressCallback(HasMonitorCallback): | class ProgressCallback(HasMonitorCallback): | ||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||||
super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, | |||||
must_have_monitor=must_have_monitor) | |||||
self.best_monitor_epoch = -1 | |||||
self.best_monitor_step = -1 | |||||
def record_better_monitor(self, trainer): | |||||
self.best_monitor_step = trainer.global_forward_batches | |||||
self.best_monitor_epoch = trainer.cur_epoch_idx | |||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
f_rich_progress.stop() | |||||
if self.best_monitor_epoch != -1: | |||||
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | |||||
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | |||||
logger.info(msg) | |||||
@property | @property | ||||
def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
@@ -97,6 +109,7 @@ class RichCallback(ProgressCallback): | |||||
advance=None, completed=trainer.cur_epoch_idx, refresh=True) | advance=None, completed=trainer.cur_epoch_idx, refresh=True) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
super(RichCallback, self).on_train_end(trainer) | |||||
self.clear_tasks() | self.clear_tasks() | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
@@ -121,8 +134,8 @@ class RichCallback(ProgressCallback): | |||||
text_style = '' | text_style = '' | ||||
characters = '-' | characters = '-' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
self.record_better_monitor(trainer) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
rule_style = 'spring_green3' | rule_style = 'spring_green3' | ||||
text_style = '[bold]' | text_style = '[bold]' | ||||
@@ -201,8 +214,8 @@ class RawTextCallback(ProgressCallback): | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
self.record_better_monitor(trainer) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
if len(text) == 0: | if len(text) == 0: | ||||
@@ -266,6 +279,7 @@ class TqdmCallback(ProgressCallback): | |||||
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
super(TqdmCallback, self).on_train_end(trainer) | |||||
self.clear_tasks() | self.clear_tasks() | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
@@ -287,8 +301,8 @@ class TqdmCallback(ProgressCallback): | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | ||||
text = '' | text = '' | ||||
if self.monitor is not None: | if self.monitor is not None: | ||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
self.record_better_monitor(trainer) | |||||
if abs(self.monitor_value) != float('inf'): | if abs(self.monitor_value) != float('inf'): | ||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | text = '+'*self.num_signs + base_text + '+'*self.num_signs | ||||
if len(text) == 0: | if len(text) == 0: | ||||
@@ -0,0 +1,152 @@ | |||||
import time | |||||
from .callback import Callback | |||||
from ..log import logger | |||||
__all__ = ['TimerCallback'] | |||||
class _Timer: | |||||
"""Timer.""" | |||||
def __init__(self, name): | |||||
self.name_ = name | |||||
self.elapsed_ = 0.0 | |||||
self.started_ = False | |||||
self.start_time = time.time() | |||||
def start(self): | |||||
"""Start the timer.""" | |||||
assert not self.started_, f'{self.name_} timer has already been started' | |||||
self.start_time = time.time() | |||||
self.started_ = True | |||||
def stop(self): | |||||
"""Stop the timer.""" | |||||
assert self.started_, f'{self.name_} timer is not started' | |||||
self.elapsed_ += (time.time() - self.start_time) | |||||
self.started_ = False | |||||
def reset(self): | |||||
"""Reset timer.""" | |||||
self.elapsed_ = 0.0 | |||||
self.started_ = False | |||||
def elapsed(self, reset=True): | |||||
"""Calculate the elapsed time.""" | |||||
started_ = self.started_ | |||||
# If the timing in progress, end it first. | |||||
if self.started_: | |||||
self.stop() | |||||
# Get the elapsed time. | |||||
elapsed_ = self.elapsed_ | |||||
# Reset the elapsed time | |||||
if reset: | |||||
self.reset() | |||||
# If timing was in progress, set it back. | |||||
if started_: | |||||
self.start() | |||||
return elapsed_ | |||||
class Timers: | |||||
"""Group of timers.""" | |||||
def __init__(self): | |||||
self.timers = {} | |||||
def __call__(self, name): | |||||
if name not in self.timers: | |||||
self.timers[name] = _Timer(name) | |||||
return self.timers[name] | |||||
def __contains__(self, item): | |||||
return item in self.timers | |||||
def reset(self): | |||||
for timer in self.timers.values(): | |||||
timer.reset() | |||||
class TimerCallback(Callback): | |||||
""" | |||||
这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长,评测时长,总的时长等 | |||||
""" | |||||
def __init__(self, print_every=-1, time_ndigit=3): | |||||
""" | |||||
:param print_every: 在哪个时候打印时间信息。 | |||||
* *负数*: 表示每隔多少 epoch 结束打印一次; | |||||
* *0*: 表示整个训练结束才打印; | |||||
* *正数*: 每隔多少个 step 打印一次; | |||||
:param time_ndigit: 保留多少位的小数 | |||||
""" | |||||
assert isinstance(print_every, int), "print_every must be an int number." | |||||
self.timers = Timers() | |||||
self.print_every = print_every | |||||
self.time_ndigit = time_ndigit | |||||
def on_train_begin(self, trainer): | |||||
self.timers('total').start() | |||||
self.timers('train').start() | |||||
def on_fetch_data_begin(self, trainer): | |||||
self.timers('fetch-data').start() | |||||
def on_fetch_data_end(self, trainer): | |||||
self.timers('fetch-data').stop() | |||||
def on_train_batch_begin(self, trainer, batch, indices): | |||||
self.timers('forward').start() | |||||
def on_before_backward(self, trainer, outputs): | |||||
self.timers('forward').stop() | |||||
self.timers('backward').start() | |||||
def on_after_backward(self, trainer): | |||||
self.timers('backward').stop() | |||||
def on_before_optimizers_step(self, trainer, optimizers): | |||||
self.timers('optimize').start() | |||||
def on_after_optimizers_step(self, trainer, optimizers): | |||||
self.timers('optimize').stop() | |||||
def on_evaluate_begin(self, trainer): | |||||
self.timers('train').stop() | |||||
self.timers('evaluate').start() | |||||
def on_evaluate_end(self, trainer, results): | |||||
self.timers('evaluate').stop() | |||||
self.timers('train').start() | |||||
def format_timer(self, reset=True): | |||||
line = '' | |||||
timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] | |||||
for timer_name in timers: | |||||
if not timer_name in self.timers: | |||||
continue | |||||
timer = self.timers(timer_name) | |||||
elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) | |||||
if elapsed != 0: | |||||
line = line + f', {timer_name}: {elapsed}s' | |||||
return line | |||||
def on_train_batch_end(self, trainer): | |||||
if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: | |||||
line = self.format_timer() | |||||
logger.info(f"Running {self.print_every} batches{line}") | |||||
def on_train_epoch_end(self, trainer): | |||||
if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: | |||||
line = self.format_timer() | |||||
logger.info(f"Running {abs(self.print_every)} epochs{line}") | |||||
def on_train_end(self, trainer): | |||||
if self.print_every == 0: | |||||
line = self.format_timer() | |||||
logger.info(f"Training finished{line}") | |||||
@@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): | |||||
batch = next(dataloader) | batch = next(dataloader) | ||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
except StopIteration: | except StopIteration: | ||||
trainer.on_fetch_data_end() | |||||
break | break | ||||
trainer.on_fetch_data_end() | |||||
try: | try: | ||||
trainer.on_fetch_data_end() | |||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | batch = match_and_substitute_params(trainer.input_mapping, batch) | ||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
@@ -108,6 +108,9 @@ class TorchDataLoader(DataLoader): | |||||
if not isinstance(dataset, _FDataSet): | if not isinstance(dataset, _FDataSet): | ||||
dataset = _FDataSet(dataset) | dataset = _FDataSet(dataset) | ||||
if num_workers>0 and multiprocessing_context is None: | |||||
multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||||
if batch_sampler is not None: | if batch_sampler is not None: | ||||
batch_size = 1 | batch_size = 1 | ||||
shuffle = False | shuffle = False | ||||