diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index dd493567..14803e56 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -79,6 +79,7 @@ except: from ..io.model_io import ModelSaver, ModelLoader from .dataset import DataSet from .tester import Tester +import logging try: import fitlog @@ -167,7 +168,11 @@ class Callback(object): @property def disabled(self): return self._disabled - + + @property + def logger(self): + return getattr(self._trainer, 'logger', logging) + def on_train_begin(self): """ 在Train过程开始之前调用。 @@ -316,21 +321,27 @@ class CallbackManager(Callback): """ super(CallbackManager, self).__init__() # set attribute of trainer environment - + self._env = env self.callbacks = [] - if callbacks is not None: - if isinstance(callbacks, list): - if all([isinstance(cb, Callback) for cb in callbacks]) is True: - self.callbacks.extend(callbacks) - else: - obj = [not isinstance(cb, Callback) for cb in callbacks][0] - raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") + if callbacks: + self.callbacks += self.prepare_callbacks(callbacks) + + def prepare_callbacks(self, callbacks): + if not callbacks: + return [] + if isinstance(callbacks, list): + if all([isinstance(cb, Callback) for cb in callbacks]) is True: + self.callbacks.extend(callbacks) else: - raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") - - for env_name, env_val in env.items(): - for callback in self.callbacks: + obj = [not isinstance(cb, Callback) for cb in callbacks][0] + raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") + else: + raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") + + for env_name, env_val in self._env.items(): + for callback in callbacks: setattr(callback, '_' + env_name, env_val) # Callback.trainer + return callbacks @_transfer def on_train_begin(self): @@ -391,11 +402,12 @@ class CallbackManager(Callback): class DistCallbackManager(CallbackManager): def __init__(self, env, callbacks_all=None, callbacks_master=None): + super(DistCallbackManager, self).__init__(env) assert 'trainer' in env is_master = env['trainer'].is_master self.patch_callback(callbacks_master, disabled=not is_master) - self.callbacks_all = CallbackManager(env, callbacks_all).callbacks - self.callbacks_master = CallbackManager(env, callbacks_master).callbacks + self.callbacks_all = self.prepare_callbacks(callbacks_all) + self.callbacks_master = self.prepare_callbacks(callbacks_master) self.callbacks = self.callbacks_all + self.callbacks_master def patch_callback(self, callbacks, disabled): @@ -944,5 +956,21 @@ class EchoCallback(Callback): class TesterCallback(Callback): - def __init__(self, data, model, metrics, batch_size=16, num_workers=None): - self.tester = Tester(data, model) + def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\ + #TODO add compare & save best + super(TesterCallback, self).__init__() + self.tester = Tester(data, model, + metrics=metrics, batch_size=batch_size, + num_workers=num_workers, verbose=0) + self.score = None + + def on_validation(self): + cur_socre = self.tester.test() + eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( + self.epoch, self.n_epochs, self.step, self.n_steps, + self.tester._format_eval_results(cur_socre)) + self.logger.info(eval_str) + + def on_train_end(self): + self.logger.info('Evaluate on training ends.') + self.on_validation() diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 700dcf38..260b93b0 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -11,7 +11,7 @@ import time from datetime import datetime, timedelta from .batch import DataSetIter, BatchIter -from .callback import DistCallbackManager, CallbackException +from .callback import DistCallbackManager, CallbackException, TesterCallback from .dataset import DataSet from .losses import _prepare_losser from .optimizer import Optimizer @@ -39,10 +39,13 @@ def get_local_rank(): class DistTrainer(): + """Distributed Trainer that support distributed and mixed precision training + """ def __init__(self, train_data, model, optimizer=None, loss=None, callbacks_all=None, callbacks_master=None, batch_size_per_gpu=8, n_epochs=1, num_data_workers=1, drop_last=False, + dev_data=None, metrics=None, update_every=1, print_every=10, validate_every=-1, save_every=-1, save_path=None, device='auto', fp16='', backend=None, init_method=None): @@ -107,6 +110,14 @@ class DistTrainer(): self.data_iterator = self._get_data_iter(self.train_data) self.n_steps = self._get_n_steps() + # for evaluation, only run eval on master proc + if dev_data and metrics: + cb = TesterCallback( + dev_data, model, metrics, + batch_size=batch_size_per_gpu, num_workers=num_data_workers) + self.callback_manager.callbacks_master += \ + self.callback_manager.prepare_callbacks([cb]) + # Setup logging dist.barrier() self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') @@ -261,9 +272,6 @@ class DistTrainer(): if ((self.validate_every > 0 and self.step % self.validate_every == 0) or (self.validate_every < 0 and self.step % len(data_iterator) == 0)): - eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, - self.n_steps) - self.logger.info(eval_str) self.callback_manager.on_validation() dist.barrier() diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py index e36615dd..93d87407 100644 --- a/test/core/test_dist_trainer.py +++ b/test/core/test_dist_trainer.py @@ -13,6 +13,7 @@ import os import subprocess from argparse import ArgumentParser from fastNLP.core.callback import EchoCallback +from fastNLP import AccuracyMetric def prepare_fake_dataset(): mean = np.array([-3, -3]) @@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase): shutil.rmtree(self.save_path) def run3(self): + set_rng_seed(100) data_set, model = prepare_env() trainer = DistTrainer( - data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), + data_set, model, optimizer=None, + loss=BCELoss(pred="predict", target="y"), n_epochs=3, print_every=50, callbacks_all=[EchoCallback('callbacks_all')], callbacks_master=[EchoCallback('callbacks_master')] ) trainer.train() + def run4(self): + set_rng_seed(100) + data_set, model = prepare_env() + + train_set, dev_set = data_set.split(0.3) + + model = NaiveClassifier(2, 1) + + trainer = DistTrainer( + train_set, model, optimizer=SGD(lr=0.1), + loss=BCELoss(pred="predict", target="y"), + batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, + metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, + ) + trainer.train() + """ + # 应该正确运行 + """ + def run_dist(self, run_id): if torch.cuda.is_available(): ngpu = min(2, torch.cuda.device_count()) @@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase): def test_callback(self): self.run_dist(3) + def test_dev_data(self): + self.run_dist(4) if __name__ == '__main__': runner = TestDistTrainer()