@@ -79,6 +79,7 @@ except: | |||||
from ..io.model_io import ModelSaver, ModelLoader | from ..io.model_io import ModelSaver, ModelLoader | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .tester import Tester | from .tester import Tester | ||||
import logging | |||||
try: | try: | ||||
import fitlog | import fitlog | ||||
@@ -167,7 +168,11 @@ class Callback(object): | |||||
@property | @property | ||||
def disabled(self): | def disabled(self): | ||||
return self._disabled | return self._disabled | ||||
@property | |||||
def logger(self): | |||||
return getattr(self._trainer, 'logger', logging) | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | """ | ||||
在Train过程开始之前调用。 | 在Train过程开始之前调用。 | ||||
@@ -316,21 +321,27 @@ class CallbackManager(Callback): | |||||
""" | """ | ||||
super(CallbackManager, self).__init__() | super(CallbackManager, self).__init__() | ||||
# set attribute of trainer environment | # set attribute of trainer environment | ||||
self._env = env | |||||
self.callbacks = [] | self.callbacks = [] | ||||
if callbacks is not None: | |||||
if isinstance(callbacks, list): | |||||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||||
self.callbacks.extend(callbacks) | |||||
else: | |||||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||||
if callbacks: | |||||
self.callbacks += self.prepare_callbacks(callbacks) | |||||
def prepare_callbacks(self, callbacks): | |||||
if not callbacks: | |||||
return [] | |||||
if isinstance(callbacks, list): | |||||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||||
self.callbacks.extend(callbacks) | |||||
else: | else: | ||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||||
for env_name, env_val in env.items(): | |||||
for callback in self.callbacks: | |||||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||||
else: | |||||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||||
for env_name, env_val in self._env.items(): | |||||
for callback in callbacks: | |||||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | setattr(callback, '_' + env_name, env_val) # Callback.trainer | ||||
return callbacks | |||||
@_transfer | @_transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
@@ -391,11 +402,12 @@ class CallbackManager(Callback): | |||||
class DistCallbackManager(CallbackManager): | class DistCallbackManager(CallbackManager): | ||||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | def __init__(self, env, callbacks_all=None, callbacks_master=None): | ||||
super(DistCallbackManager, self).__init__(env) | |||||
assert 'trainer' in env | assert 'trainer' in env | ||||
is_master = env['trainer'].is_master | is_master = env['trainer'].is_master | ||||
self.patch_callback(callbacks_master, disabled=not is_master) | self.patch_callback(callbacks_master, disabled=not is_master) | ||||
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||||
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||||
self.callbacks_all = self.prepare_callbacks(callbacks_all) | |||||
self.callbacks_master = self.prepare_callbacks(callbacks_master) | |||||
self.callbacks = self.callbacks_all + self.callbacks_master | self.callbacks = self.callbacks_all + self.callbacks_master | ||||
def patch_callback(self, callbacks, disabled): | def patch_callback(self, callbacks, disabled): | ||||
@@ -944,5 +956,21 @@ class EchoCallback(Callback): | |||||
class TesterCallback(Callback): | class TesterCallback(Callback): | ||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||||
self.tester = Tester(data, model) | |||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\ | |||||
#TODO add compare & save best | |||||
super(TesterCallback, self).__init__() | |||||
self.tester = Tester(data, model, | |||||
metrics=metrics, batch_size=batch_size, | |||||
num_workers=num_workers, verbose=0) | |||||
self.score = None | |||||
def on_validation(self): | |||||
cur_socre = self.tester.test() | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||||
self.epoch, self.n_epochs, self.step, self.n_steps, | |||||
self.tester._format_eval_results(cur_socre)) | |||||
self.logger.info(eval_str) | |||||
def on_train_end(self): | |||||
self.logger.info('Evaluate on training ends.') | |||||
self.on_validation() |
@@ -11,7 +11,7 @@ import time | |||||
from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
from .batch import DataSetIter, BatchIter | from .batch import DataSetIter, BatchIter | ||||
from .callback import DistCallbackManager, CallbackException | |||||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
@@ -39,10 +39,13 @@ def get_local_rank(): | |||||
class DistTrainer(): | class DistTrainer(): | ||||
"""Distributed Trainer that support distributed and mixed precision training | |||||
""" | |||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
batch_size_per_gpu=8, n_epochs=1, | batch_size_per_gpu=8, n_epochs=1, | ||||
num_data_workers=1, drop_last=False, | num_data_workers=1, drop_last=False, | ||||
dev_data=None, metrics=None, | |||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
save_every=-1, save_path=None, device='auto', | save_every=-1, save_path=None, device='auto', | ||||
fp16='', backend=None, init_method=None): | fp16='', backend=None, init_method=None): | ||||
@@ -107,6 +110,14 @@ class DistTrainer(): | |||||
self.data_iterator = self._get_data_iter(self.train_data) | self.data_iterator = self._get_data_iter(self.train_data) | ||||
self.n_steps = self._get_n_steps() | self.n_steps = self._get_n_steps() | ||||
# for evaluation, only run eval on master proc | |||||
if dev_data and metrics: | |||||
cb = TesterCallback( | |||||
dev_data, model, metrics, | |||||
batch_size=batch_size_per_gpu, num_workers=num_data_workers) | |||||
self.callback_manager.callbacks_master += \ | |||||
self.callback_manager.prepare_callbacks([cb]) | |||||
# Setup logging | # Setup logging | ||||
dist.barrier() | dist.barrier() | ||||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | ||||
@@ -261,9 +272,6 @@ class DistTrainer(): | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
self.n_steps) | |||||
self.logger.info(eval_str) | |||||
self.callback_manager.on_validation() | self.callback_manager.on_validation() | ||||
dist.barrier() | dist.barrier() | ||||
@@ -13,6 +13,7 @@ import os | |||||
import subprocess | import subprocess | ||||
from argparse import ArgumentParser | from argparse import ArgumentParser | ||||
from fastNLP.core.callback import EchoCallback | from fastNLP.core.callback import EchoCallback | ||||
from fastNLP import AccuracyMetric | |||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
@@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase): | |||||
shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
def run3(self): | def run3(self): | ||||
set_rng_seed(100) | |||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = DistTrainer( | trainer = DistTrainer( | ||||
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||||
data_set, model, optimizer=None, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=3, print_every=50, | n_epochs=3, print_every=50, | ||||
callbacks_all=[EchoCallback('callbacks_all')], | callbacks_all=[EchoCallback('callbacks_all')], | ||||
callbacks_master=[EchoCallback('callbacks_master')] | callbacks_master=[EchoCallback('callbacks_master')] | ||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
def run4(self): | |||||
set_rng_seed(100) | |||||
data_set, model = prepare_env() | |||||
train_set, dev_set = data_set.split(0.3) | |||||
model = NaiveClassifier(2, 1) | |||||
trainer = DistTrainer( | |||||
train_set, model, optimizer=SGD(lr=0.1), | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
ngpu = min(2, torch.cuda.device_count()) | ngpu = min(2, torch.cuda.device_count()) | ||||
@@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase): | |||||
def test_callback(self): | def test_callback(self): | ||||
self.run_dist(3) | self.run_dist(3) | ||||
def test_dev_data(self): | |||||
self.run_dist(4) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
runner = TestDistTrainer() | runner = TestDistTrainer() | ||||