Browse Source

[update] distributed trainer, add evaluation part

tags/v0.4.10
yunfan 6 years ago
parent
commit
329a18976f
3 changed files with 82 additions and 22 deletions
  1. +45
    -17
      fastNLP/core/callback.py
  2. +12
    -4
      fastNLP/core/dist_trainer.py
  3. +25
    -1
      test/core/test_dist_trainer.py

+ 45
- 17
fastNLP/core/callback.py View File

@@ -79,6 +79,7 @@ except:
from ..io.model_io import ModelSaver, ModelLoader from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet from .dataset import DataSet
from .tester import Tester from .tester import Tester
import logging


try: try:
import fitlog import fitlog
@@ -167,7 +168,11 @@ class Callback(object):
@property @property
def disabled(self): def disabled(self):
return self._disabled return self._disabled

@property
def logger(self):
return getattr(self._trainer, 'logger', logging)

def on_train_begin(self): def on_train_begin(self):
""" """
在Train过程开始之前调用。 在Train过程开始之前调用。
@@ -316,21 +321,27 @@ class CallbackManager(Callback):
""" """
super(CallbackManager, self).__init__() super(CallbackManager, self).__init__()
# set attribute of trainer environment # set attribute of trainer environment
self._env = env
self.callbacks = [] self.callbacks = []
if callbacks is not None:
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
else:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
if callbacks:
self.callbacks += self.prepare_callbacks(callbacks)

def prepare_callbacks(self, callbacks):
if not callbacks:
return []
if isinstance(callbacks, list):
if all([isinstance(cb, Callback) for cb in callbacks]) is True:
self.callbacks.extend(callbacks)
else: else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
for env_name, env_val in env.items():
for callback in self.callbacks:
obj = [not isinstance(cb, Callback) for cb in callbacks][0]
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
else:
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")

for env_name, env_val in self._env.items():
for callback in callbacks:
setattr(callback, '_' + env_name, env_val) # Callback.trainer setattr(callback, '_' + env_name, env_val) # Callback.trainer
return callbacks


@_transfer @_transfer
def on_train_begin(self): def on_train_begin(self):
@@ -391,11 +402,12 @@ class CallbackManager(Callback):


class DistCallbackManager(CallbackManager): class DistCallbackManager(CallbackManager):
def __init__(self, env, callbacks_all=None, callbacks_master=None): def __init__(self, env, callbacks_all=None, callbacks_master=None):
super(DistCallbackManager, self).__init__(env)
assert 'trainer' in env assert 'trainer' in env
is_master = env['trainer'].is_master is_master = env['trainer'].is_master
self.patch_callback(callbacks_master, disabled=not is_master) self.patch_callback(callbacks_master, disabled=not is_master)
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks
self.callbacks_all = self.prepare_callbacks(callbacks_all)
self.callbacks_master = self.prepare_callbacks(callbacks_master)
self.callbacks = self.callbacks_all + self.callbacks_master self.callbacks = self.callbacks_all + self.callbacks_master


def patch_callback(self, callbacks, disabled): def patch_callback(self, callbacks, disabled):
@@ -944,5 +956,21 @@ class EchoCallback(Callback):




class TesterCallback(Callback): class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):
self.tester = Tester(data, model)
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\
#TODO add compare & save best
super(TesterCallback, self).__init__()
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)
self.score = None

def on_validation(self):
cur_socre = self.tester.test()
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
self.epoch, self.n_epochs, self.step, self.n_steps,
self.tester._format_eval_results(cur_socre))
self.logger.info(eval_str)

def on_train_end(self):
self.logger.info('Evaluate on training ends.')
self.on_validation()

+ 12
- 4
fastNLP/core/dist_trainer.py View File

@@ -11,7 +11,7 @@ import time
from datetime import datetime, timedelta from datetime import datetime, timedelta


from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException
from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
@@ -39,10 +39,13 @@ def get_local_rank():




class DistTrainer(): class DistTrainer():
"""Distributed Trainer that support distributed and mixed precision training
"""
def __init__(self, train_data, model, optimizer=None, loss=None, def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None, callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1, batch_size_per_gpu=8, n_epochs=1,
num_data_workers=1, drop_last=False, num_data_workers=1, drop_last=False,
dev_data=None, metrics=None,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None, device='auto', save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None): fp16='', backend=None, init_method=None):
@@ -107,6 +110,14 @@ class DistTrainer():
self.data_iterator = self._get_data_iter(self.train_data) self.data_iterator = self._get_data_iter(self.train_data)
self.n_steps = self._get_n_steps() self.n_steps = self._get_n_steps()


# for evaluation, only run eval on master proc
if dev_data and metrics:
cb = TesterCallback(
dev_data, model, metrics,
batch_size=batch_size_per_gpu, num_workers=num_data_workers)
self.callback_manager.callbacks_master += \
self.callback_manager.prepare_callbacks([cb])

# Setup logging # Setup logging
dist.barrier() dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
@@ -261,9 +272,6 @@ class DistTrainer():


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): (self.validate_every < 0 and self.step % len(data_iterator) == 0)):
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps)
self.logger.info(eval_str)
self.callback_manager.on_validation() self.callback_manager.on_validation()
dist.barrier() dist.barrier()




+ 25
- 1
test/core/test_dist_trainer.py View File

@@ -13,6 +13,7 @@ import os
import subprocess import subprocess
from argparse import ArgumentParser from argparse import ArgumentParser
from fastNLP.core.callback import EchoCallback from fastNLP.core.callback import EchoCallback
from fastNLP import AccuracyMetric


def prepare_fake_dataset(): def prepare_fake_dataset():
mean = np.array([-3, -3]) mean = np.array([-3, -3])
@@ -106,15 +107,36 @@ class TestDistTrainer(unittest.TestCase):
shutil.rmtree(self.save_path) shutil.rmtree(self.save_path)


def run3(self): def run3(self):
set_rng_seed(100)
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = DistTrainer( trainer = DistTrainer(
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"),
data_set, model, optimizer=None,
loss=BCELoss(pred="predict", target="y"),
n_epochs=3, print_every=50, n_epochs=3, print_every=50,
callbacks_all=[EchoCallback('callbacks_all')], callbacks_all=[EchoCallback('callbacks_all')],
callbacks_master=[EchoCallback('callbacks_master')] callbacks_master=[EchoCallback('callbacks_master')]
) )
trainer.train() trainer.train()


def run4(self):
set_rng_seed(100)
data_set, model = prepare_env()

train_set, dev_set = data_set.split(0.3)

model = NaiveClassifier(2, 1)

trainer = DistTrainer(
train_set, model, optimizer=SGD(lr=0.1),
loss=BCELoss(pred="predict", target="y"),
batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
)
trainer.train()
"""
# 应该正确运行
"""

def run_dist(self, run_id): def run_dist(self, run_id):
if torch.cuda.is_available(): if torch.cuda.is_available():
ngpu = min(2, torch.cuda.device_count()) ngpu = min(2, torch.cuda.device_count())
@@ -133,6 +155,8 @@ class TestDistTrainer(unittest.TestCase):
def test_callback(self): def test_callback(self):
self.run_dist(3) self.run_dist(3)


def test_dev_data(self):
self.run_dist(4)


if __name__ == '__main__': if __name__ == '__main__':
runner = TestDistTrainer() runner = TestDistTrainer()


Loading…
Cancel
Save