diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 874d0ad9..cf3b158c 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -250,6 +250,14 @@ class Callback(object): :return: """ pass + + def on_validation(self): + """ + 如果Trainer中设置了验证,则会在每次需要验证时调用该函数 + + :return: + """ + pass def on_epoch_end(self): """ @@ -352,6 +360,10 @@ class CallbackManager(Callback): @_transfer def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): pass + + @_transfer + def on_validation(self): + pass @_transfer def on_epoch_end(self): diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py new file mode 100644 index 00000000..1d782733 --- /dev/null +++ b/fastNLP/core/dist_trainer.py @@ -0,0 +1,302 @@ +import torch +import torch.cuda +import torch.optim +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import os +from tqdm import tqdm +import logging +import time +from datetime import datetime, timedelta + +from .batch import DataSetIter, BatchIter +from .callback import CallbackManager, CallbackException +from .dataset import DataSet +from .losses import _prepare_losser +from .optimizer import Optimizer +from .utils import _build_args +from .utils import _move_dict_value_to_device +from .utils import _get_func_signature + +__all__ = [ + 'get_local_rank', + 'DistTrainer', +] + + +def get_local_rank(): + if 'LOCAL_RANK' in os.environ: + return int(os.environ['LOCAL_RANK']) + from argparse import ArgumentParser + parser = ArgumentParser() + parser.add_argument('--local_rank', type=int) + args, _ = parser.parse_known_args() + if 'local_rank' in args and args.local_rank: + os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function + return args.local_rank + raise RuntimeError('Please use "python -m torch.distributed.launch train_script.py') + + +class DistTrainer(): + def __init__(self, model, train_data, optimizer, loss, callbacks=None, + batch_size_per_gpu=8, n_epochs=1, + num_workers=1, drop_last=False, + update_every=1, print_every=10, validate_every=-1, + save_every=-1, save_path=None, + logging_level=logging.INFO, + fp16='', backend='nccl', init_method=None): + self.model = model + self.train_data = train_data + self.batch_size_per_gpu = int(batch_size_per_gpu) + self.n_epochs = int(n_epochs) + self.num_workers = int(num_workers) + self.drop_last = drop_last + self.update_every = int(update_every) + self.print_every = int(print_every) + self.validate_every = int(validate_every) + self.save_every = int(save_every) + self.save_path = save_path + self.losser = _prepare_losser(loss) + self.fp16 = fp16 + self.init_method = init_method + self.backend = backend + self.local_rank = get_local_rank() + self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) + self._forward_func = model.forward + + assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled." + # init distributed + torch.cuda.set_device(self.local_rank) + self.device = torch.device("cuda", self.local_rank) + dist.init_process_group(backend=self.backend, init_method=self.init_method) + model.to(self.device) + optimizer = self.get_optimizer(optimizer) + + # init fp16, must before DataParallel init + if len(self.fp16): + assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." + model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) + + # init DataParallel + self.model = DDP(model, device_ids=[self.local_rank], + output_device=self.local_rank) + self.optimizer = optimizer + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() # unique id for each process + self.sampler = DistributedSampler(self.train_data) + self.data_iterator = self.get_data_iter(self.train_data) + self.n_steps = self.get_n_steps() + + # Setup logging + logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S', + level=logging_level) + self.logger = logging.getLogger(__name__) + self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( + os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) + if self.is_master: + self.logger.info('Total epochs: %d'% self.n_epochs) + self.logger.info('Total steps: %d'% self.n_steps) + self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) + self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) + self.logger.info('Total num of samples: %d'% len(self.train_data)) + self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks))) + self.logger.info( + "Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks])) + + # only master process save model + if self.save_path: + self.save_path = os.path.join( + self.save_path, + datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid())) + + def get_n_steps(self): + batch_size = self.world_size * self.batch_size_per_gpu + return (len(self.train_data) // batch_size + int( + len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs + + def get_data_iter(self, dataset): + if isinstance(dataset, DataSet): + return DataSetIter( + dataset=dataset, batch_size=self.batch_size_per_gpu, + num_workers=self.num_workers, sampler=self.sampler, + drop_last=self.drop_last + ) + elif isinstance(dataset, BatchIter): + return dataset + else: + raise TypeError("train_data type {} not support".format(type(dataset))) + + def get_optimizer(self, optimizer): + if isinstance(optimizer, torch.optim.Optimizer): + return optimizer + elif isinstance(optimizer, Optimizer): + return optimizer.construct_from_pytorch(self.model.parameters()) + elif optimizer is None: + return torch.optim.Adam(self.model.parameters(), lr=4e-3) + else: + raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) + + @property + def is_master(self): + return self.rank == 0 + + def train(self, on_exception='auto'): + start_time = time.time() + results = {} + if self.n_epochs <= 0: + if self.is_master: + self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) + results['seconds'] = 0. + return results + + if self.is_master: + self.logger.info("###### Training epochs started ######") + + try: + self.callback_manager.on_train_begin() + self._train() + self.callback_manager.on_train_end() + + except BaseException as e: + self.callback_manager.on_exception(e) + if on_exception == 'auto': + if not isinstance(e, (CallbackException, KeyboardInterrupt)): + raise e + else: + self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) + elif on_exception == 'raise': + raise e + + results['seconds'] = round(time.time() - start_time, 2) + if self.is_master: + self.logger.info("###### Train finished ######") + self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) + return results + + def _train(self): + if self.fp16: + # skip check, done in __init__() + from apex import amp + self.step = 0 + self.epoch = 0 + self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', + leave=False, dynamic_ncols=True, disable=not self.is_master) + pbar = self.pbar + avg_loss = 0 + data_iterator = self.data_iterator + self.model.zero_grad() + for epoch in range(1, self.n_epochs + 1): + self.epoch = epoch + pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) + # early stopping + self.callback_manager.on_epoch_begin() + for batch_x, batch_y in data_iterator: + self.model.train() + self.step += 1 + _move_dict_value_to_device(batch_x, batch_y, device=self.device) + indices = data_iterator.get_batch_indices() + # negative sampling; replace unknown; re-weight batch_y + self.callback_manager.on_batch_begin(batch_x, batch_y, indices) + prediction = self._data_forward(self.model, batch_x) + + # edit prediction + self.callback_manager.on_loss_begin(batch_y, prediction) + loss = self._compute_loss(prediction, batch_y) + avg_loss += loss.item() + + # Is loss NaN or inf? requires_grad = False + self.callback_manager.on_backward_begin(loss) + + if self.fp16: + with amp.scale_loss(loss, self.optimizer) as scale_loss: + scale_loss.backward() + else: + loss.backward() + + self.callback_manager.on_backward_end() + + self._update() + self.callback_manager.on_step_end() + + if self.step % self.print_every == 0: + avg_loss = float(avg_loss) / self.print_every + print_output = "loss:{:<6.5f}".format(avg_loss) + pbar.update(self.print_every) + pbar.set_postfix_str(print_output) + avg_loss = 0 + + self.callback_manager.on_batch_end() + + if ((self.validate_every > 0 and self.step % self.validate_every == 0) or + (self.validate_every < 0 and self.step % len(data_iterator) == 0)): + eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, + self.n_steps) + if self.is_master: + self.logger.info(eval_str) + self.callback_manager.on_validation() + dist.barrier() + + if self.save_path and \ + self.save_every > 0 and \ + self.step % self.save_every == 0: + self.save_check_point() + + # ================= mini-batch end ==================== # + if self.save_path and self.save_every < 0: + self.save_check_point() + # lr decay; early stopping + self.callback_manager.on_epoch_end() + # =============== epochs end =================== # + pbar.close() + self.pbar = None + # ============ tqdm end ============== # + + def _update(self): + """Perform weight update on a model. + + """ + if self.step % self.update_every == 0: + self.optimizer.step() + self.model.zero_grad() + + def _data_forward(self, network, x): + x = _build_args(self._forward_func, **x) + y = network(**x) + if not isinstance(y, dict): + raise TypeError( + f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") + return y + + def _compute_loss(self, predict, truth): + """Compute loss given prediction and ground truth. + + :param predict: prediction dict, produced by model.forward + :param truth: ground truth dict, produced by batch_y + :return: a scalar + """ + loss = self.losser(predict, truth) + if self.update_every > 1: + loss = loss / self.update_every + return loss.mean() + + def save_check_point(self, only_params=False): + if self.is_master: + if not os.path.exists(self.save_path): + os.makedirs(self.save_path) + path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step)) + self.logger.info("Save checkpoint to {}".format(path)) + model_to_save = self.model.module + if only_params: + model_to_save = model_to_save.state_dict() + torch.save(model_to_save, path) + dist.barrier() + + def close(self): + dist.destroy_process_group() diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py new file mode 100644 index 00000000..59be35c6 --- /dev/null +++ b/test/core/test_dist_trainer.py @@ -0,0 +1,110 @@ +import unittest + +import numpy as np +import torch.cuda +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import CrossEntropyLoss +from fastNLP import SGD +from fastNLP.core.dist_trainer import DistTrainer, get_local_rank +from fastNLP.models.base_model import NaiveClassifier +import shutil +import os +import subprocess +from argparse import ArgumentParser + +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B]) + return data_set + +def prepare_fake_dataset2(*args, size=100): + ys = np.random.randint(4, size=100, dtype=np.int64) + data = {'y': ys} + for arg in args: + data[arg] = np.random.randn(size, 5) + return DataSet(data=data) + +def set_rng_seed(seed): + np.random.seed(seed) + +class TestDistTrainer(unittest.TestCase): + save_path = './save_cp' + + def run1(self): + # test distributed training + print('local rank', get_local_rank()) + set_rng_seed(100) + data_set = prepare_fake_dataset() + data_set.set_input("x", flag=True) + data_set.set_target("y", flag=True) + + model = NaiveClassifier(2, 2) + + trainer = DistTrainer( + model=model, train_data=data_set, optimizer=SGD(lr=0.1), + loss=CrossEntropyLoss(pred="predict", target="y"), + batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, + ) + trainer.train() + """ + # 应该正确运行 + """ + if trainer.is_master and os.path.exists(self.save_path): + shutil.rmtree(self.save_path) + + def run2(self): + # test fp16 with distributed training + print('local rank', get_local_rank()) + set_rng_seed(100) + data_set = prepare_fake_dataset() + data_set.set_input("x", flag=True) + data_set.set_target("y", flag=True) + + model = NaiveClassifier(2, 2) + + trainer = DistTrainer( + model=model, train_data=data_set, optimizer=SGD(lr=0.1), + loss=CrossEntropyLoss(pred="predict", target="y"), + batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path, + fp16='O1' + ) + trainer.train() + """ + # 应该正确运行 + """ + if trainer.is_master and os.path.exists(self.save_path): + shutil.rmtree(self.save_path) + + def run_dist(self, run_id): + if torch.cuda.is_available(): + ngpu = min(4, torch.cuda.device_count()) + path = __file__ + cmd = ['python', '-m', 'torch.distributed.launch', + '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] + print(' '.join(cmd)) + retcode = subprocess.call(cmd) + if retcode: + raise RuntimeError('subprocess got non-zero exit status %d' % retcode) + + def test1(self): + self.run_dist(1) + + def test2(self): + self.run_dist(2) + +if __name__ == '__main__': + runner = TestDistTrainer() + parser = ArgumentParser() + parser.add_argument('--test', type=int) + args, _ = parser.parse_known_args() + if args.test and hasattr(runner, 'run%s'%args.test): + getattr(runner, 'run%s'%args.test)()