@@ -79,6 +79,7 @@ except: | |||
from ..io.model_io import ModelSaver, ModelLoader | |||
from .dataset import DataSet | |||
from .tester import Tester | |||
import logging | |||
try: | |||
import fitlog | |||
@@ -167,7 +168,11 @@ class Callback(object): | |||
@property | |||
def disabled(self): | |||
return self._disabled | |||
@property | |||
def logger(self): | |||
return getattr(self._trainer, 'logger', logging) | |||
def on_train_begin(self): | |||
""" | |||
在Train过程开始之前调用。 | |||
@@ -316,21 +321,27 @@ class CallbackManager(Callback): | |||
""" | |||
super(CallbackManager, self).__init__() | |||
# set attribute of trainer environment | |||
self._env = env | |||
self.callbacks = [] | |||
if callbacks is not None: | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
self.callbacks.extend(callbacks) | |||
else: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
if callbacks: | |||
self.callbacks += self.prepare_callbacks(callbacks) | |||
def prepare_callbacks(self, callbacks): | |||
if not callbacks: | |||
return [] | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
self.callbacks.extend(callbacks) | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in self._env.items(): | |||
for callback in callbacks: | |||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
return callbacks | |||
@_transfer | |||
def on_train_begin(self): | |||
@@ -391,11 +402,12 @@ class CallbackManager(Callback): | |||
class DistCallbackManager(CallbackManager): | |||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||
super(DistCallbackManager, self).__init__(env) | |||
assert 'trainer' in env | |||
is_master = env['trainer'].is_master | |||
self.patch_callback(callbacks_master, disabled=not is_master) | |||
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||
self.callbacks_all = self.prepare_callbacks(callbacks_all) | |||
self.callbacks_master = self.prepare_callbacks(callbacks_master) | |||
self.callbacks = self.callbacks_all + self.callbacks_master | |||
def patch_callback(self, callbacks, disabled): | |||
@@ -944,5 +956,21 @@ class EchoCallback(Callback): | |||
class TesterCallback(Callback): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||
self.tester = Tester(data, model) | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\ | |||
#TODO add compare & save best | |||
super(TesterCallback, self).__init__() | |||
self.tester = Tester(data, model, | |||
metrics=metrics, batch_size=batch_size, | |||
num_workers=num_workers, verbose=0) | |||
self.score = None | |||
def on_validation(self): | |||
cur_socre = self.tester.test() | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||
self.epoch, self.n_epochs, self.step, self.n_steps, | |||
self.tester._format_eval_results(cur_socre)) | |||
self.logger.info(eval_str) | |||
def on_train_end(self): | |||
self.logger.info('Evaluate on training ends.') | |||
self.on_validation() |
@@ -11,7 +11,7 @@ import time | |||
from datetime import datetime, timedelta | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import DistCallbackManager, CallbackException | |||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .optimizer import Optimizer | |||
@@ -39,10 +39,13 @@ def get_local_rank(): | |||
class DistTrainer(): | |||
"""Distributed Trainer that support distributed and mixed precision training | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
callbacks_all=None, callbacks_master=None, | |||
batch_size_per_gpu=8, n_epochs=1, | |||
num_data_workers=1, drop_last=False, | |||
dev_data=None, metrics=None, | |||
update_every=1, print_every=10, validate_every=-1, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
@@ -107,6 +110,14 @@ class DistTrainer(): | |||
self.data_iterator = self._get_data_iter(self.train_data) | |||
self.n_steps = self._get_n_steps() | |||
# for evaluation, only run eval on master proc | |||
if dev_data and metrics: | |||
cb = TesterCallback( | |||
dev_data, model, metrics, | |||
batch_size=batch_size_per_gpu, num_workers=num_data_workers) | |||
self.callback_manager.callbacks_master += \ | |||
self.callback_manager.prepare_callbacks([cb]) | |||
# Setup logging | |||
dist.barrier() | |||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||
@@ -261,9 +272,6 @@ class DistTrainer(): | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) | |||
self.logger.info(eval_str) | |||
self.callback_manager.on_validation() | |||
dist.barrier() | |||
@@ -13,6 +13,7 @@ import os | |||
import subprocess | |||
from argparse import ArgumentParser | |||
from fastNLP.core.callback import EchoCallback | |||
from fastNLP import AccuracyMetric | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
@@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase): | |||
shutil.rmtree(self.save_path) | |||
def run3(self): | |||
set_rng_seed(100) | |||
data_set, model = prepare_env() | |||
trainer = DistTrainer( | |||
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||
data_set, model, optimizer=None, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=3, print_every=50, | |||
callbacks_all=[EchoCallback('callbacks_all')], | |||
callbacks_master=[EchoCallback('callbacks_master')] | |||
) | |||
trainer.train() | |||
def run4(self): | |||
set_rng_seed(100) | |||
data_set, model = prepare_env() | |||
train_set, dev_set = data_set.split(0.3) | |||
model = NaiveClassifier(2, 1) | |||
trainer = DistTrainer( | |||
train_set, model, optimizer=SGD(lr=0.1), | |||
loss=BCELoss(pred="predict", target="y"), | |||
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
def run_dist(self, run_id): | |||
if torch.cuda.is_available(): | |||
ngpu = min(2, torch.cuda.device_count()) | |||
@@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase): | |||
def test_callback(self): | |||
self.run_dist(3) | |||
def test_dev_data(self): | |||
self.run_dist(4) | |||
if __name__ == '__main__': | |||
runner = TestDistTrainer() | |||