Browse Source

[update] distributed trainer

tags/v0.4.10
yunfan 6 years ago
parent
commit
606d63a5a4
4 changed files with 218 additions and 109 deletions
  1. +52
    -2
      fastNLP/core/callback.py
  2. +96
    -73
      fastNLP/core/dist_trainer.py
  3. +30
    -27
      fastNLP/core/trainer.py
  4. +40
    -7
      test/core/test_dist_trainer.py

+ 52
- 2
fastNLP/core/callback.py View File

@@ -100,7 +100,8 @@ class Callback(object):
def __init__(self): def __init__(self):
super(Callback, self).__init__() super(Callback, self).__init__()
self._trainer = None # 在Trainer内部被重新赋值 self._trainer = None # 在Trainer内部被重新赋值
self._disabled = False

@property @property
def trainer(self): def trainer(self):
""" """
@@ -158,6 +159,14 @@ class Callback(object):
def batch_per_epoch(self): def batch_per_epoch(self):
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。"""
return self._trainer.batch_per_epoch return self._trainer.batch_per_epoch

@property
def is_master(self):
return self._trainer.is_master()

@property
def disabled(self):
return self._disabled
def on_train_begin(self): def on_train_begin(self):
""" """
@@ -289,6 +298,8 @@ def _transfer(func):
def wrapper(manager, *arg): def wrapper(manager, *arg):
returns = [] returns = []
for callback in manager.callbacks: for callback in manager.callbacks:
if callback.disabled:
continue
returns.append(getattr(callback, func.__name__)(*arg)) returns.append(getattr(callback, func.__name__)(*arg))
return returns return returns
@@ -320,7 +331,7 @@ class CallbackManager(Callback):
for env_name, env_val in env.items(): for env_name, env_val in env.items():
for callback in self.callbacks: for callback in self.callbacks:
setattr(callback, '_' + env_name, env_val) # Callback.trainer setattr(callback, '_' + env_name, env_val) # Callback.trainer
@_transfer @_transfer
def on_train_begin(self): def on_train_begin(self):
pass pass
@@ -378,6 +389,24 @@ class CallbackManager(Callback):
pass pass




class DistCallbackManager(CallbackManager):
def __init__(self, env, callbacks_all=None, callbacks_master=None):
assert 'trainer' in env
is_master = env['trainer'].is_master
self.patch_callback(callbacks_master, disabled=not is_master)
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks
self.callbacks = self.callbacks_all + self.callbacks_master

def patch_callback(self, callbacks, disabled):
if not callbacks:
return
if not isinstance(callbacks, (list, tuple)):
callbacks = [callbacks]
for cb in callbacks:
cb._disabled = disabled


class GradientClipCallback(Callback): class GradientClipCallback(Callback):
""" """
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` 别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback`
@@ -415,6 +444,9 @@ class GradientClipCallback(Callback):
def on_backward_end(self): def on_backward_end(self):
if self.step%self.update_every==0: if self.step%self.update_every==0:
if self.parameters is None: if self.parameters is None:
if getattr(self.trainer, 'fp16', default=''):
from apex import amp
self.clip_fun(amp.master_params(self.optimizer), self.clip_value)
self.clip_fun(self.model.parameters(), self.clip_value) self.clip_fun(self.model.parameters(), self.clip_value)
else: else:
self.clip_fun(self.parameters, self.clip_value) self.clip_fun(self.parameters, self.clip_value)
@@ -896,3 +928,21 @@ class EarlyStopError(CallbackException):
def __init__(self, msg): def __init__(self, msg):
super(EarlyStopError, self).__init__(msg) super(EarlyStopError, self).__init__(msg)


class EchoCallback(Callback):
def __init__(self, name, out=sys.stdout):
super(EchoCallback, self).__init__()
self.name = name
self.out = out

def __getattribute__(self, item):
if item.startswith('on_'):
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
file=self.out)
return super(EchoCallback, self).__getattribute__(item)


class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):
self.tester = Tester(data, model)

+ 96
- 73
fastNLP/core/dist_trainer.py View File

@@ -11,7 +11,7 @@ import time
from datetime import datetime, timedelta from datetime import datetime, timedelta


from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException
from .callback import DistCallbackManager, CallbackException
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
@@ -39,18 +39,36 @@ def get_local_rank():




class DistTrainer(): class DistTrainer():
def __init__(self, model, train_data, optimizer, loss, callbacks=None,
def __init__(self, train_data, model, optimizer=None, loss=None,
callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1, batch_size_per_gpu=8, n_epochs=1,
num_workers=1, drop_last=False,
num_data_workers=1, drop_last=False,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
save_every=-1, save_path=None,
logging_level=logging.INFO,
fp16='', backend='nccl', init_method=None):
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):

assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']"
if device == 'auto':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if backend is None:
backend = 'nccl' if device == 'cuda' else 'gloo'

# init distributed
if device == 'cuda':
torch.cuda.set_device(get_local_rank())
self.device = torch.device("cuda", get_local_rank())
else:
self.device = torch.device(device)

dist.init_process_group(backend=backend, init_method=init_method)
self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process

self.model = model self.model = model
self.train_data = train_data self.train_data = train_data
self.batch_size_per_gpu = int(batch_size_per_gpu) self.batch_size_per_gpu = int(batch_size_per_gpu)
self.n_epochs = int(n_epochs) self.n_epochs = int(n_epochs)
self.num_workers = int(num_workers)
self.num_data_workers = int(num_data_workers)
self.drop_last = drop_last self.drop_last = drop_last
self.update_every = int(update_every) self.update_every = int(update_every)
self.print_every = int(print_every) self.print_every = int(print_every)
@@ -62,16 +80,13 @@ class DistTrainer():
self.init_method = init_method self.init_method = init_method
self.backend = backend self.backend = backend
self.local_rank = get_local_rank() self.local_rank = get_local_rank()
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
self._forward_func = model.forward self._forward_func = model.forward
self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master)


assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled."
# init distributed
torch.cuda.set_device(self.local_rank)
self.device = torch.device("cuda", self.local_rank)
dist.init_process_group(backend=self.backend, init_method=self.init_method)
model.to(self.device) model.to(self.device)
optimizer = self.get_optimizer(optimizer)
optimizer = self._get_optimizer(optimizer)


# init fp16, must before DataParallel init # init fp16, must before DataParallel init
if len(self.fp16): if len(self.fp16):
@@ -81,51 +96,48 @@ class DistTrainer():
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
assert device == 'cuda', "Amp requires cuda device"
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16)


# init DataParallel # init DataParallel
self.model = DDP(model, device_ids=[self.local_rank], self.model = DDP(model, device_ids=[self.local_rank],
output_device=self.local_rank) output_device=self.local_rank)
self.optimizer = optimizer self.optimizer = optimizer
self.world_size = dist.get_world_size()
self.rank = dist.get_rank() # unique id for each process
self.sampler = DistributedSampler(self.train_data) self.sampler = DistributedSampler(self.train_data)
self.data_iterator = self.get_data_iter(self.train_data)
self.n_steps = self.get_n_steps()
self.data_iterator = self._get_data_iter(self.train_data)
self.n_steps = self._get_n_steps()


# Setup logging # Setup logging
dist.barrier()
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
if self.save_path:
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time)
else:
self.cp_save_path = None

# use INFO in the master, WARN for others
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
level=logging_level)
level=logging.INFO if self.is_master else logging.WARN)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
self.logger.info("Setup Distributed Trainer")
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))
if self.is_master:
self.logger.info('Total epochs: %d'% self.n_epochs)
self.logger.info('Total steps: %d'% self.n_steps)
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
self.logger.info('Total num of samples: %d'% len(self.train_data))
self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks)))
self.logger.info(
"Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks]))

# only master process save model
if self.save_path:
self.save_path = os.path.join(
self.save_path,
datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid()))
self.logger.info("Num of processes: {}".format(self.world_size))
self.logger.info("Use device: {}".format(device))
self.logger.info("Training with fp16: {}, optimization level: {}".format(
len(self.fp16) > 0, self.fp16 if self.fp16 else None))


def get_n_steps(self):
def _get_n_steps(self):
batch_size = self.world_size * self.batch_size_per_gpu batch_size = self.world_size * self.batch_size_per_gpu
return (len(self.train_data) // batch_size + int( return (len(self.train_data) // batch_size + int(
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs


def get_data_iter(self, dataset):
def _get_data_iter(self, dataset):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
return DataSetIter( return DataSetIter(
dataset=dataset, batch_size=self.batch_size_per_gpu, dataset=dataset, batch_size=self.batch_size_per_gpu,
num_workers=self.num_workers, sampler=self.sampler,
num_workers=self.num_data_workers, sampler=self.sampler,
drop_last=self.drop_last drop_last=self.drop_last
) )
elif isinstance(dataset, BatchIter): elif isinstance(dataset, BatchIter):
@@ -133,7 +145,7 @@ class DistTrainer():
else: else:
raise TypeError("train_data type {} not support".format(type(dataset))) raise TypeError("train_data type {} not support".format(type(dataset)))


def get_optimizer(self, optimizer):
def _get_optimizer(self, optimizer):
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
return optimizer return optimizer
elif isinstance(optimizer, Optimizer): elif isinstance(optimizer, Optimizer):
@@ -148,37 +160,50 @@ class DistTrainer():
return self.rank == 0 return self.rank == 0


def train(self, on_exception='auto'): def train(self, on_exception='auto'):
start_time = time.time()
results = {}
if self.n_epochs <= 0:
if self.is_master:
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
results['seconds'] = 0.
return results

if self.is_master:
try:
self.logger.info("###### Training epochs started ######") self.logger.info("###### Training epochs started ######")
self.logger.info('Total epochs: %d'% self.n_epochs)
self.logger.info('Total steps: %d'% self.n_steps)
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
self.logger.info('Total num of samples: %d'% len(self.train_data))
self.logger.info("Num of callbacks for all workers: {}".format(
len(self.callback_manager.callbacks_all)))
self.logger.info("Num of callbacks for master workers: {}".format(
len(self.callback_manager.callbacks_master)))
self.logger.info("Callbacks for all workers: {}".format(
[repr(cb) for cb in self.callback_manager.callbacks_all]))
self.logger.info("Callbacks for master workers: {}".format(
[repr(cb) for cb in self.callback_manager.callbacks_master]))

start_time = time.time()
results = {}
if self.n_epochs <= 0:
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
results['seconds'] = 0.
return results


try:
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end()

except BaseException as e:
self.callback_manager.on_exception(e)
if on_exception == 'auto':
if not isinstance(e, (CallbackException, KeyboardInterrupt)):
try:
self.callback_manager.on_train_begin()
self._train()
self.callback_manager.on_train_end()

except BaseException as e:
self.callback_manager.on_exception(e)
if on_exception == 'auto':
if not isinstance(e, (CallbackException, KeyboardInterrupt)):
raise e
else:
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
elif on_exception == 'raise':
raise e raise e
else:
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
elif on_exception == 'raise':
raise e


results['seconds'] = round(time.time() - start_time, 2)
if self.is_master:
results['seconds'] = round(time.time() - start_time, 2)
self.logger.info("###### Train finished ######") self.logger.info("###### Train finished ######")
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
return results
return results
finally:
self.close()


def _train(self): def _train(self):
if self.fp16: if self.fp16:
@@ -187,7 +212,7 @@ class DistTrainer():
self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
leave=False, dynamic_ncols=True, disable=not self.is_master)
leave=False, dynamic_ncols=True, disable=not self.is_master)
pbar = self.pbar pbar = self.pbar
avg_loss = 0 avg_loss = 0
data_iterator = self.data_iterator data_iterator = self.data_iterator
@@ -238,18 +263,17 @@ class DistTrainer():
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): (self.validate_every < 0 and self.step % len(data_iterator) == 0)):
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps) self.n_steps)
if self.is_master:
self.logger.info(eval_str)
self.logger.info(eval_str)
self.callback_manager.on_validation() self.callback_manager.on_validation()
dist.barrier() dist.barrier()


if self.save_path and \
if self.cp_save_path and \
self.save_every > 0 and \ self.save_every > 0 and \
self.step % self.save_every == 0: self.step % self.save_every == 0:
self.save_check_point() self.save_check_point()


# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #
if self.save_path and self.save_every < 0:
if self.save_every < 0 and self.cp_save_path:
self.save_check_point() self.save_check_point()
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
@@ -287,16 +311,15 @@ class DistTrainer():
return loss.mean() return loss.mean()


def save_check_point(self, only_params=False): def save_check_point(self, only_params=False):
# only master save models
if self.is_master: if self.is_master:
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step))
os.makedirs(self.cp_save_path, exist_ok=True)
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step))
self.logger.info("Save checkpoint to {}".format(path)) self.logger.info("Save checkpoint to {}".format(path))
model_to_save = self.model.module model_to_save = self.model.module
if only_params: if only_params:
model_to_save = model_to_save.state_dict() model_to_save = model_to_save.state_dict()
torch.save(model_to_save, path) torch.save(model_to_save, path)
dist.barrier()


def close(self): def close(self):
dist.destroy_process_group() dist.destroy_process_group()

+ 30
- 27
fastNLP/core/trainer.py View File

@@ -431,13 +431,13 @@ class Trainer(object):
super(Trainer, self).__init__() super(Trainer, self).__init__()
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
# check metrics and dev_data # check metrics and dev_data
if (not metrics) and dev_data is not None: if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.") raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None): if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")
# check update every # check update every
assert update_every >= 1, "update_every must be no less than 1." assert update_every >= 1, "update_every must be no less than 1."
self.update_every = int(update_every) self.update_every = int(update_every)
@@ -447,7 +447,7 @@ class Trainer(object):
raise ValueError("save_path can only be None or `str`.") raise ValueError("save_path can only be None or `str`.")
# prepare evaluate # prepare evaluate
metrics = _prepare_metrics(metrics) metrics = _prepare_metrics(metrics)
# parse metric_key # parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases. # increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default. # It is true by default.
@@ -546,7 +546,7 @@ class Trainer(object):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
else: else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
self.pbar = None self.pbar = None
self.print_every = abs(self.print_every) self.print_every = abs(self.print_every)
@@ -558,10 +558,10 @@ class Trainer(object):
batch_size=self.batch_size, batch_size=self.batch_size,
device=None, # 由上面的部分处理device device=None, # 由上面的部分处理device
verbose=0) verbose=0)
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks) callbacks=callbacks)


@@ -597,7 +597,7 @@ class Trainer(object):
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time() start_time = time.time()
print("training epochs started " + self.start_time, flush=True) print("training epochs started " + self.start_time, flush=True)
try: try:
self.callback_manager.on_train_begin() self.callback_manager.on_train_begin()
self._train() self._train()
@@ -610,7 +610,7 @@ class Trainer(object):
raise e raise e
elif on_exception == 'raise': elif on_exception == 'raise':
raise e raise e
if self.dev_data is not None and self.best_dev_perf is not None: if self.dev_data is not None and self.best_dev_perf is not None:
print( print(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
@@ -628,9 +628,9 @@ class Trainer(object):
finally: finally:
pass pass
results['seconds'] = round(time.time() - start_time, 2) results['seconds'] = round(time.time() - start_time, 2)
return results return results
def _train(self): def _train(self):
if not self.use_tqdm: if not self.use_tqdm:
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
@@ -656,21 +656,21 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y # negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x) prediction = self._data_forward(self.model, batch_x)
# edit prediction # edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction) self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y).mean() loss = self._compute_loss(prediction, batch_y).mean()
avg_loss += loss.item() avg_loss += loss.item()
loss = loss / self.update_every loss = loss / self.update_every
# Is loss NaN or inf? requires_grad = False # Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss) self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss) self._grad_backward(loss)
self.callback_manager.on_backward_end() self.callback_manager.on_backward_end()
self._update() self._update()
self.callback_manager.on_step_end() self.callback_manager.on_step_end()
if self.step % self.print_every == 0: if self.step % self.print_every == 0:
avg_loss = float(avg_loss) / self.print_every avg_loss = float(avg_loss) / self.print_every
if self.use_tqdm: if self.use_tqdm:
@@ -684,7 +684,7 @@ class Trainer(object):
pbar.set_postfix_str(print_output) pbar.set_postfix_str(print_output)
avg_loss = 0 avg_loss = 0
self.callback_manager.on_batch_end() self.callback_manager.on_batch_end()
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None: and self.dev_data is not None:
@@ -693,20 +693,20 @@ class Trainer(object):
self.n_steps) + \ self.n_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar.write(eval_str + '\n') pbar.write(eval_str + '\n')
# ================= mini-batch end ==================== # # ================= mini-batch end ==================== #
# lr decay; early stopping # lr decay; early stopping
self.callback_manager.on_epoch_end() self.callback_manager.on_epoch_end()
# =============== epochs end =================== # # =============== epochs end =================== #
pbar.close() pbar.close()
self.pbar = None self.pbar = None
# ============ tqdm end ============== # # ============ tqdm end ============== #
def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
self.callback_manager.on_valid_begin() self.callback_manager.on_valid_begin()
res = self.tester.test() res = self.tester.test()
is_better_eval = False is_better_eval = False
if self._better_eval_result(res): if self._better_eval_result(res):
if self.save_path is not None: if self.save_path is not None:
@@ -721,7 +721,7 @@ class Trainer(object):
# get validation results; adjust optimizer # get validation results; adjust optimizer
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval)
return res return res
def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


@@ -733,14 +733,14 @@ class Trainer(object):
model.eval() model.eval()
else: else:
model.train() model.train()
def _update(self): def _update(self):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
if self.step % self.update_every == 0: if self.step % self.update_every == 0:
self.optimizer.step() self.optimizer.step()
def _data_forward(self, network, x): def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x) x = _build_args(self._forward_func, **x)
y = network(**x) y = network(**x)
@@ -748,7 +748,7 @@ class Trainer(object):
raise TypeError( raise TypeError(
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
return y return y
def _grad_backward(self, loss): def _grad_backward(self, loss):
"""Compute gradient with link rules. """Compute gradient with link rules.


@@ -759,7 +759,7 @@ class Trainer(object):
if (self.step-1) % self.update_every == 0: if (self.step-1) % self.update_every == 0:
self.model.zero_grad() self.model.zero_grad()
loss.backward() loss.backward()
def _compute_loss(self, predict, truth): def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth. """Compute loss given prediction and ground truth.


@@ -768,7 +768,7 @@ class Trainer(object):
:return: a scalar :return: a scalar
""" """
return self.losser(predict, truth) return self.losser(predict, truth)
def _save_model(self, model, model_name, only_param=False): def _save_model(self, model, model_name, only_param=False):
""" 存储不含有显卡信息的state_dict或model """ 存储不含有显卡信息的state_dict或model
:param model: :param model:
@@ -791,7 +791,7 @@ class Trainer(object):
model.cpu() model.cpu()
torch.save(model, model_path) torch.save(model, model_path)
model.to(self._model_device) model.to(self._model_device)
def _load_model(self, model, model_name, only_param=False): def _load_model(self, model, model_name, only_param=False):
# 返回bool值指示是否成功reload模型 # 返回bool值指示是否成功reload模型
if self.save_path is not None: if self.save_path is not None:
@@ -809,7 +809,7 @@ class Trainer(object):
else: else:
return False return False
return True return True
def _better_eval_result(self, metrics): def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results. """Check if the current epoch yields better validation results.


@@ -835,6 +835,9 @@ class Trainer(object):
is_better = False is_better = False
return is_better return is_better


@property
def is_master(self):
return True


DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2


+ 40
- 7
test/core/test_dist_trainer.py View File

@@ -4,7 +4,7 @@ import numpy as np
import torch.cuda import torch.cuda
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import CrossEntropyLoss
from fastNLP import CrossEntropyLoss, BCELoss
from fastNLP import SGD from fastNLP import SGD
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank from fastNLP.core.dist_trainer import DistTrainer, get_local_rank
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier
@@ -12,6 +12,7 @@ import shutil
import os import os
import subprocess import subprocess
from argparse import ArgumentParser from argparse import ArgumentParser
from fastNLP.core.callback import EchoCallback


def prepare_fake_dataset(): def prepare_fake_dataset():
mean = np.array([-3, -3]) mean = np.array([-3, -3])
@@ -36,6 +37,26 @@ def prepare_fake_dataset2(*args, size=100):
def set_rng_seed(seed): def set_rng_seed(seed):
np.random.seed(seed) np.random.seed(seed)


def prepare_env():
def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set

data_set = prepare_fake_dataset()
data_set.set_input("x")
data_set.set_target("y")
model = NaiveClassifier(2, 1)
return data_set, model

class TestDistTrainer(unittest.TestCase): class TestDistTrainer(unittest.TestCase):
save_path = './save_cp' save_path = './save_cp'


@@ -84,23 +105,35 @@ class TestDistTrainer(unittest.TestCase):
if trainer.is_master and os.path.exists(self.save_path): if trainer.is_master and os.path.exists(self.save_path):
shutil.rmtree(self.save_path) shutil.rmtree(self.save_path)


def run3(self):
data_set, model = prepare_env()
trainer = DistTrainer(
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"),
n_epochs=3, print_every=50,
callbacks_all=[EchoCallback('callbacks_all')],
callbacks_master=[EchoCallback('callbacks_master')]
)
trainer.train()

def run_dist(self, run_id): def run_dist(self, run_id):
if torch.cuda.is_available(): if torch.cuda.is_available():
ngpu = min(4, torch.cuda.device_count())
ngpu = min(2, torch.cuda.device_count())
path = __file__ path = __file__
cmd = ['python', '-m', 'torch.distributed.launch', cmd = ['python', '-m', 'torch.distributed.launch',
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] '--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
print(' '.join(cmd)) print(' '.join(cmd))
retcode = subprocess.call(cmd)
if retcode:
raise RuntimeError('subprocess got non-zero exit status %d' % retcode)
subprocess.check_call(cmd, timeout=60.0)


def test1(self):
def test_normal_run(self):
self.run_dist(1) self.run_dist(1)


def test2(self):
def test_fp16(self):
self.run_dist(2) self.run_dist(2)


def test_callback(self):
self.run_dist(3)


if __name__ == '__main__': if __name__ == '__main__':
runner = TestDistTrainer() runner = TestDistTrainer()
parser = ArgumentParser() parser = ArgumentParser()


Loading…
Cancel
Save