Browse Source

[update] distributed trainer

tags/v0.4.10
yunfan 6 years ago
parent
commit
606d63a5a4
4 changed files with 218 additions and 109 deletions
  1. +52
    -2
      fastNLP/core/callback.py
  2. +96
    -73
      fastNLP/core/dist_trainer.py
  3. +30
    -27
      fastNLP/core/trainer.py
  4. +40
    -7
      test/core/test_dist_trainer.py

+ 52
- 2
fastNLP/core/callback.py View File

@@ -100,7 +100,8 @@ class Callback(object):
def __init__(self):
super(Callback, self).__init__()
self._trainer = None # 在Trainer内部被重新赋值
self._disabled = False

@property
def trainer(self):
"""
@@ -158,6 +159,14 @@ class Callback(object):
def batch_per_epoch(self):
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。"""
return self._trainer.batch_per_epoch

@property
def is_master(self):
return self._trainer.is_master()

@property
def disabled(self):
return self._disabled
def on_train_begin(self):
"""
@@ -289,6 +298,8 @@ def _transfer(func):
def wrapper(manager, *arg):
returns = []
for callback in manager.callbacks:
if callback.disabled:
continue
returns.append(getattr(callback, func.__name__)(*arg))
return returns
@@ -320,7 +331,7 @@ class CallbackManager(Callback):
for env_name, env_val in env.items():
for callback in self.callbacks:
setattr(callback, '_' + env_name, env_val) # Callback.trainer
@_transfer
def on_train_begin(self):
pass
@@ -378,6 +389,24 @@ class CallbackManager(Callback):
pass


class DistCallbackManager(CallbackManager):
def __init__(self, env, callbacks_all=None, callbacks_master=None):
assert 'trainer' in env
is_master = env['trainer'].is_master
self.patch_callback(callbacks_master, disabled=not is_master)
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks
self.callbacks = self.callbacks_all + self.callbacks_master

def patch_callback(self, callbacks, disabled):
if not callbacks:
return
if not isinstance(callbacks, (list, tuple)):
callbacks = [callbacks]
for cb in callbacks:
cb._disabled = disabled


class GradientClipCallback(Callback):
"""
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback`
@@ -415,6 +444,9 @@ class GradientClipCallback(Callback):
def on_backward_end(self):
if self.step%self.update_every==0:
if self.parameters is None:
if getattr(self.trainer, 'fp16', default=''):
from apex import amp
self.clip_fun(amp.master_params(self.optimizer), self.clip_value)
self.clip_fun(self.model.parameters(), self.clip_value)
else:
self.clip_fun(self.parameters, self.clip_value)
@@ -896,3 +928,21 @@ class EarlyStopError(CallbackException):
def __init__(self, msg):
super(EarlyStopError, self).__init__(msg)


class EchoCallback(Callback):
def __init__(self, name, out=sys.stdout):
super(EchoCallback, self).__init__()
self.name = name
self.out = out

def __getattribute__(self, item):
if item.startswith('on_'):
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
file=self.out)
return super(EchoCallback, self).__getattribute__(item)


class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):
self.tester = Tester(data, model)

+ 96
- 73
fastNLP/core/dist_trainer.py View File

@@ -11,7 +11,7 @@ import time
from datetime import datetime, timedelta

from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException
from .callback import DistCallbackManager, CallbackException
from .dataset import DataSet
from .losses import _prepare_losser
from .optimizer import Optimizer
@@ -39,18 +39,36 @@ def get_local_rank():


class DistTrainer():
def __init__(self, model, train_data, optimizer, loss, callbacks=None,
def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1,
num_workers=1, drop_last=False,
num_data_workers=1, drop_last=False,
update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None,
logging_level=logging.INFO,
fp16='', backend='nccl', init_method=None):
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):

assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if backend is None:
backend = 'nccl' if device == 'cuda' else 'gloo'

# init distributed
if device == 'cuda':
torch.cuda.set_device(get_local_rank())
self.device = torch.device("cuda", get_local_rank())
else:
self.device = torch.device(device)

dist.init_process_group(backend=backend, init_method=init_method)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process

self.model = model
self.train_data = train_data
self.batch_size_per_gpu = int(batch_size_per_gpu)
self.n_epochs = int(n_epochs)
self.num_workers = int(num_workers)
self.num_data_workers = int(num_data_workers)
self.drop_last = drop_last
self.update_every = int(update_every)
self.print_every = int(print_every)
@@ -62,16 +80,13 @@ class DistTrainer():
self.init_method = init_method
self.backend = backend
self.local_rank = get_local_rank()
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
self._forward_func = model.forward
self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master)

assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled."
# init distributed
torch.cuda.set_device(self.local_rank)
self.device = torch.device("cuda", self.local_rank)
dist.init_process_group(backend=self.backend, init_method=self.init_method)
model.to(self.device)
optimizer = self.get_optimizer(optimizer)
optimizer = self._get_optimizer(optimizer)

# init fp16, must before DataParallel init
if len(self.fp16):
@@ -81,51 +96,48 @@ class DistTrainer():
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
assert device == 'cuda', "Amp requires cuda device"
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)

# init DataParallel
self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank)
self.optimizer = optimizer
self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process
self.sampler = DistributedSampler(self.train_data)
self.data_iterator = self.get_data_iter(self.train_data)
self.n_steps = self.get_n_steps()
self.data_iterator = self._get_data_iter(self.train_data)
self.n_steps = self._get_n_steps()

# Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
if self.save_path:
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time)
else:
self.cp_save_path = None

# use INFO in the master, WARN for others
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging_level)
level=logging.INFO if self.is_master else logging.WARN)
self.logger = logging.getLogger(__name__)
self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))
if self.is_master:
self.logger.info('Total epochs: %d'% self.n_epochs)
self.logger.info('Total steps: %d'% self.n_steps)
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
self.logger.info('Total num of samples: %d'% len(self.train_data))
self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks)))
self.logger.info(
"Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks]))

# only master process save model
if self.save_path:
self.save_path = os.path.join(
self.save_path,
datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid()))
self.logger.info("Num of processes: {}".format(self.world_size))
self.logger.info("Use device: {}".format(device))
self.logger.info("Training with fp16: {}, optimization level: {}".format(
len(self.fp16) > 0, self.fp16 if self.fp16 else None))

def get_n_steps(self):
def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int(
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs

def get_data_iter(self, dataset):
def _get_data_iter(self, dataset):
if isinstance(dataset, DataSet):
return DataSetIter(
dataset=dataset, batch_size=self.batch_size_per_gpu,
num_workers=self.num_workers, sampler=self.sampler,
num_workers=self.num_data_workers, sampler=self.sampler,
drop_last=self.drop_last
)
elif isinstance(dataset, BatchIter):
@@ -133,7 +145,7 @@ class DistTrainer():
else:
raise TypeError("train_data type {} not support".format(type(dataset)))

def get_optimizer(self, optimizer):
def _get_optimizer(self, optimizer):
if isinstance(optimizer, torch.optim.Optimizer):
return optimizer
elif isinstance(optimizer, Optimizer):
@@ -148,37 +160,50 @@ class DistTrainer():
return self.rank == 0

def train(self, on_exception='auto'):
start_time = time.time()
results = {}
if self.n_epochs <= 0:
if self.is_master:
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
results['seconds'] = 0.
return results

if self.is_master:
try:
self.logger.info("###### Training epochs started ######")
self.logger.info('Total epochs: %d'% self.n_epochs)
self.logger.info('Total steps: %d'% self.n_steps)
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
self.logger.info('Total num of samples: %d'% len(self.train_data))
self.logger.info("Num of callbacks for all workers: {}".format(
len(self.callback_manager.callbacks_all)))
self.logger.info("Num of callbacks for master workers: {}".format(
len(self.callback_manager.callbacks_master)))
self.logger.info("Callbacks for all workers: {}".format(
[repr(cb) for cb in self.callback_manager.callbacks_all]))
self.logger.info("Callbacks for master workers: {}".format(
[repr(cb) for cb in self.callback_manager.callbacks_master]))

start_time = time.time()
results = {}
if self.n_epochs <= 0:
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
results['seconds'] = 0.
return results

try:
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end()

except BaseException as e:
self.callback_manager.on_exception(e)
if on_exception == 'auto':
if not isinstance(e, (CallbackException, KeyboardInterrupt)):
try:
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end()

except BaseException as e:
self.callback_manager.on_exception(e)
if on_exception == 'auto':
if not isinstance(e, (CallbackException, KeyboardInterrupt)):
raise e
else:
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
elif on_exception == 'raise':
raise e
else:
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
elif on_exception == 'raise':
raise e

results['seconds'] = round(time.time() - start_time, 2)
if self.is_master:
results['seconds'] = round(time.time() - start_time, 2)
self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
return results
return results
finally:
self.close()

def _train(self):
if self.fp16:
@@ -187,7 +212,7 @@ class DistTrainer():
self.step = 0
self.epoch = 0
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
leave=False, dynamic_ncols=True, disable=not self.is_master)
leave=False, dynamic_ncols=True, disable=not self.is_master)
pbar = self.pbar
avg_loss = 0
data_iterator = self.data_iterator
@@ -238,18 +263,17 @@ class DistTrainer():
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps)
if self.is_master:
self.logger.info(eval_str)
self.logger.info(eval_str)
self.callback_manager.on_validation()
dist.barrier()

if self.save_path and \
if self.cp_save_path and \
self.save_every > 0 and \
self.step % self.save_every == 0:
self.save_check_point()

# ================= mini-batch end ==================== #
if self.save_path and self.save_every < 0:
if self.save_every < 0 and self.cp_save_path:
self.save_check_point()
# lr decay; early stopping
self.callback_manager.on_epoch_end()
@@ -287,16 +311,15 @@ class DistTrainer():
return loss.mean()

def save_check_point(self, only_params=False):
# only master save models
if self.is_master:
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step))
os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step))
self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.model.module
if only_params:
model_to_save = model_to_save.state_dict()
torch.save(model_to_save, path)
dist.barrier()

def close(self):
dist.destroy_process_group()

+ 30
- 27
fastNLP/core/trainer.py View File

@@ -431,13 +431,13 @@ class Trainer(object):
super(Trainer, self).__init__()
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")
# check update every
assert update_every >= 1, "update_every must be no less than 1."
self.update_every = int(update_every)
@@ -447,7 +447,7 @@ class Trainer(object):
raise ValueError("save_path can only be None or `str`.")
# prepare evaluate
metrics = _prepare_metrics(metrics)
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
@@ -546,7 +546,7 @@ class Trainer(object):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
self.use_tqdm = use_tqdm
self.pbar = None
self.print_every = abs(self.print_every)
@@ -558,10 +558,10 @@ class Trainer(object):
batch_size=self.batch_size,
device=None, # 由上面的部分处理device
verbose=0)
self.step = 0
self.start_time = None # start timestamp
self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)

@@ -597,7 +597,7 @@ class Trainer(object):
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time()
print("training epochs started " + self.start_time, flush=True)
try:
self.callback_manager.on_train_begin()
self._train()
@@ -610,7 +610,7 @@ class Trainer(object):
raise e
elif on_exception == 'raise':
raise e
if self.dev_data is not None and self.best_dev_perf is not None:
print(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
@@ -628,9 +628,9 @@ class Trainer(object):
finally:
pass
results['seconds'] = round(time.time() - start_time, 2)
return results
def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
@@ -656,21 +656,21 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x)
# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y).mean()
avg_loss += loss.item()
loss = loss / self.update_every
# Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss)
self.callback_manager.on_backward_end()
self._update()
self.callback_manager.on_step_end()
if self.step % self.print_every == 0:
avg_loss = float(avg_loss) / self.print_every
if self.use_tqdm:
@@ -684,7 +684,7 @@ class Trainer(object):
pbar.set_postfix_str(print_output)
avg_loss = 0
self.callback_manager.on_batch_end()
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None:
@@ -693,20 +693,20 @@ class Trainer(object):
self.n_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str + '\n')
# ================= mini-batch end ==================== #
# lr decay; early stopping
self.callback_manager.on_epoch_end()
# =============== epochs end =================== #
pbar.close()
self.pbar = None
# ============ tqdm end ============== #
def _do_validation(self, epoch, step):
self.callback_manager.on_valid_begin()
res = self.tester.test()
is_better_eval = False
if self._better_eval_result(res):
if self.save_path is not None:
@@ -721,7 +721,7 @@ class Trainer(object):
# get validation results; adjust optimizer
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval)
return res
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.

@@ -733,14 +733,14 @@ class Trainer(object):
model.eval()
else:
model.train()
def _update(self):
"""Perform weight update on a model.

"""
if self.step % self.update_every == 0:
self.optimizer.step()
def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x)
y = network(**x)
@@ -748,7 +748,7 @@ class Trainer(object):
raise TypeError(
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
return y
def _grad_backward(self, loss):
"""Compute gradient with link rules.

@@ -759,7 +759,7 @@ class Trainer(object):
if (self.step-1) % self.update_every == 0:
self.model.zero_grad()
loss.backward()
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.

@@ -768,7 +768,7 @@ class Trainer(object):
:return: a scalar
"""
return self.losser(predict, truth)
def _save_model(self, model, model_name, only_param=False):
""" 存储不含有显卡信息的state_dict或model
:param model:
@@ -791,7 +791,7 @@ class Trainer(object):
model.cpu()
torch.save(model, model_path)
model.to(self._model_device)
def _load_model(self, model, model_name, only_param=False):
# 返回bool值指示是否成功reload模型
if self.save_path is not None:
@@ -809,7 +809,7 @@ class Trainer(object):
else:
return False
return True
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.

@@ -835,6 +835,9 @@ class Trainer(object):
is_better = False
return is_better

@property
def is_master(self):
return True

DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2


+ 40
- 7
test/core/test_dist_trainer.py View File

@@ -4,7 +4,7 @@ import numpy as np
import torch.cuda
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import CrossEntropyLoss
from fastNLP import CrossEntropyLoss, BCELoss
from fastNLP import SGD
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank
from fastNLP.models.base_model import NaiveClassifier
@@ -12,6 +12,7 @@ import shutil
import os
import subprocess
from argparse import ArgumentParser
from fastNLP.core.callback import EchoCallback

def prepare_fake_dataset():
mean = np.array([-3, -3])
@@ -36,6 +37,26 @@ def prepare_fake_dataset2(*args, size=100):
def set_rng_seed(seed):
np.random.seed(seed)

def prepare_env():
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set

data_set = prepare_fake_dataset()
data_set.set_input("x")
data_set.set_target("y")
model = NaiveClassifier(2, 1)
return data_set, model

class TestDistTrainer(unittest.TestCase):
save_path = './save_cp'

@@ -84,23 +105,35 @@ class TestDistTrainer(unittest.TestCase):
if trainer.is_master and os.path.exists(self.save_path):
shutil.rmtree(self.save_path)

def run3(self):
data_set, model = prepare_env()
trainer = DistTrainer(
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"),
n_epochs=3, print_every=50,
callbacks_all=[EchoCallback('callbacks_all')],
callbacks_master=[EchoCallback('callbacks_master')]
)
trainer.train()

def run_dist(self, run_id):
if torch.cuda.is_available():
ngpu = min(4, torch.cuda.device_count())
ngpu = min(2, torch.cuda.device_count())
path = __file__
cmd = ['python', '-m', 'torch.distributed.launch',
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
print(' '.join(cmd))
retcode = subprocess.call(cmd)
if retcode:
raise RuntimeError('subprocess got non-zero exit status %d' % retcode)
subprocess.check_call(cmd, timeout=60.0)

def test1(self):
def test_normal_run(self):
self.run_dist(1)

def test2(self):
def test_fp16(self):
self.run_dist(2)

def test_callback(self):
self.run_dist(3)


if __name__ == '__main__':
runner = TestDistTrainer()
parser = ArgumentParser()


Loading…
Cancel
Save