@@ -100,7 +100,8 @@ class Callback(object): | |||||
def __init__(self): | def __init__(self): | ||||
super(Callback, self).__init__() | super(Callback, self).__init__() | ||||
self._trainer = None # 在Trainer内部被重新赋值 | self._trainer = None # 在Trainer内部被重新赋值 | ||||
self._disabled = False | |||||
@property | @property | ||||
def trainer(self): | def trainer(self): | ||||
""" | """ | ||||
@@ -158,6 +159,14 @@ class Callback(object): | |||||
def batch_per_epoch(self): | def batch_per_epoch(self): | ||||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | """每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | ||||
return self._trainer.batch_per_epoch | return self._trainer.batch_per_epoch | ||||
@property | |||||
def is_master(self): | |||||
return self._trainer.is_master() | |||||
@property | |||||
def disabled(self): | |||||
return self._disabled | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | """ | ||||
@@ -289,6 +298,8 @@ def _transfer(func): | |||||
def wrapper(manager, *arg): | def wrapper(manager, *arg): | ||||
returns = [] | returns = [] | ||||
for callback in manager.callbacks: | for callback in manager.callbacks: | ||||
if callback.disabled: | |||||
continue | |||||
returns.append(getattr(callback, func.__name__)(*arg)) | returns.append(getattr(callback, func.__name__)(*arg)) | ||||
return returns | return returns | ||||
@@ -320,7 +331,7 @@ class CallbackManager(Callback): | |||||
for env_name, env_val in env.items(): | for env_name, env_val in env.items(): | ||||
for callback in self.callbacks: | for callback in self.callbacks: | ||||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | setattr(callback, '_' + env_name, env_val) # Callback.trainer | ||||
@_transfer | @_transfer | ||||
def on_train_begin(self): | def on_train_begin(self): | ||||
pass | pass | ||||
@@ -378,6 +389,24 @@ class CallbackManager(Callback): | |||||
pass | pass | ||||
class DistCallbackManager(CallbackManager): | |||||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||||
assert 'trainer' in env | |||||
is_master = env['trainer'].is_master | |||||
self.patch_callback(callbacks_master, disabled=not is_master) | |||||
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||||
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||||
self.callbacks = self.callbacks_all + self.callbacks_master | |||||
def patch_callback(self, callbacks, disabled): | |||||
if not callbacks: | |||||
return | |||||
if not isinstance(callbacks, (list, tuple)): | |||||
callbacks = [callbacks] | |||||
for cb in callbacks: | |||||
cb._disabled = disabled | |||||
class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
""" | """ | ||||
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | 别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | ||||
@@ -415,6 +444,9 @@ class GradientClipCallback(Callback): | |||||
def on_backward_end(self): | def on_backward_end(self): | ||||
if self.step%self.update_every==0: | if self.step%self.update_every==0: | ||||
if self.parameters is None: | if self.parameters is None: | ||||
if getattr(self.trainer, 'fp16', default=''): | |||||
from apex import amp | |||||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||||
self.clip_fun(self.model.parameters(), self.clip_value) | self.clip_fun(self.model.parameters(), self.clip_value) | ||||
else: | else: | ||||
self.clip_fun(self.parameters, self.clip_value) | self.clip_fun(self.parameters, self.clip_value) | ||||
@@ -896,3 +928,21 @@ class EarlyStopError(CallbackException): | |||||
def __init__(self, msg): | def __init__(self, msg): | ||||
super(EarlyStopError, self).__init__(msg) | super(EarlyStopError, self).__init__(msg) | ||||
class EchoCallback(Callback): | |||||
def __init__(self, name, out=sys.stdout): | |||||
super(EchoCallback, self).__init__() | |||||
self.name = name | |||||
self.out = out | |||||
def __getattribute__(self, item): | |||||
if item.startswith('on_'): | |||||
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||||
file=self.out) | |||||
return super(EchoCallback, self).__getattribute__(item) | |||||
class TesterCallback(Callback): | |||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||||
self.tester = Tester(data, model) |
@@ -11,7 +11,7 @@ import time | |||||
from datetime import datetime, timedelta | from datetime import datetime, timedelta | ||||
from .batch import DataSetIter, BatchIter | from .batch import DataSetIter, BatchIter | ||||
from .callback import CallbackManager, CallbackException | |||||
from .callback import DistCallbackManager, CallbackException | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
@@ -39,18 +39,36 @@ def get_local_rank(): | |||||
class DistTrainer(): | class DistTrainer(): | ||||
def __init__(self, model, train_data, optimizer, loss, callbacks=None, | |||||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||||
callbacks_all=None, callbacks_master=None, | |||||
batch_size_per_gpu=8, n_epochs=1, | batch_size_per_gpu=8, n_epochs=1, | ||||
num_workers=1, drop_last=False, | |||||
num_data_workers=1, drop_last=False, | |||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
save_every=-1, save_path=None, | |||||
logging_level=logging.INFO, | |||||
fp16='', backend='nccl', init_method=None): | |||||
save_every=-1, save_path=None, device='auto', | |||||
fp16='', backend=None, init_method=None): | |||||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||||
if device == 'auto': | |||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
if backend is None: | |||||
backend = 'nccl' if device == 'cuda' else 'gloo' | |||||
# init distributed | |||||
if device == 'cuda': | |||||
torch.cuda.set_device(get_local_rank()) | |||||
self.device = torch.device("cuda", get_local_rank()) | |||||
else: | |||||
self.device = torch.device(device) | |||||
dist.init_process_group(backend=backend, init_method=init_method) | |||||
self.world_size = dist.get_world_size() | |||||
self.rank = dist.get_rank() # unique id for each process | |||||
self.model = model | self.model = model | ||||
self.train_data = train_data | self.train_data = train_data | ||||
self.batch_size_per_gpu = int(batch_size_per_gpu) | self.batch_size_per_gpu = int(batch_size_per_gpu) | ||||
self.n_epochs = int(n_epochs) | self.n_epochs = int(n_epochs) | ||||
self.num_workers = int(num_workers) | |||||
self.num_data_workers = int(num_data_workers) | |||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
@@ -62,16 +80,13 @@ class DistTrainer(): | |||||
self.init_method = init_method | self.init_method = init_method | ||||
self.backend = backend | self.backend = backend | ||||
self.local_rank = get_local_rank() | self.local_rank = get_local_rank() | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||||
self._forward_func = model.forward | self._forward_func = model.forward | ||||
self.callback_manager = DistCallbackManager( | |||||
env={"trainer": self}, callbacks_all=callbacks_all, | |||||
callbacks_master=callbacks_master) | |||||
assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled." | |||||
# init distributed | |||||
torch.cuda.set_device(self.local_rank) | |||||
self.device = torch.device("cuda", self.local_rank) | |||||
dist.init_process_group(backend=self.backend, init_method=self.init_method) | |||||
model.to(self.device) | model.to(self.device) | ||||
optimizer = self.get_optimizer(optimizer) | |||||
optimizer = self._get_optimizer(optimizer) | |||||
# init fp16, must before DataParallel init | # init fp16, must before DataParallel init | ||||
if len(self.fp16): | if len(self.fp16): | ||||
@@ -81,51 +96,48 @@ class DistTrainer(): | |||||
except ImportError: | except ImportError: | ||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | ||||
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | ||||
assert device == 'cuda', "Amp requires cuda device" | |||||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | ||||
# init DataParallel | # init DataParallel | ||||
self.model = DDP(model, device_ids=[self.local_rank], | self.model = DDP(model, device_ids=[self.local_rank], | ||||
output_device=self.local_rank) | output_device=self.local_rank) | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
self.world_size = dist.get_world_size() | |||||
self.rank = dist.get_rank() # unique id for each process | |||||
self.sampler = DistributedSampler(self.train_data) | self.sampler = DistributedSampler(self.train_data) | ||||
self.data_iterator = self.get_data_iter(self.train_data) | |||||
self.n_steps = self.get_n_steps() | |||||
self.data_iterator = self._get_data_iter(self.train_data) | |||||
self.n_steps = self._get_n_steps() | |||||
# Setup logging | # Setup logging | ||||
dist.barrier() | |||||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||||
if self.save_path: | |||||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||||
else: | |||||
self.cp_save_path = None | |||||
# use INFO in the master, WARN for others | |||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||||
datefmt='%m/%d/%Y %H:%M:%S', | datefmt='%m/%d/%Y %H:%M:%S', | ||||
level=logging_level) | |||||
level=logging.INFO if self.is_master else logging.WARN) | |||||
self.logger = logging.getLogger(__name__) | self.logger = logging.getLogger(__name__) | ||||
self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||||
self.logger.info("Setup Distributed Trainer") | |||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | ||||
if self.is_master: | |||||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||||
self.logger.info('Total steps: %d'% self.n_steps) | |||||
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||||
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||||
self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||||
self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks))) | |||||
self.logger.info( | |||||
"Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks])) | |||||
# only master process save model | |||||
if self.save_path: | |||||
self.save_path = os.path.join( | |||||
self.save_path, | |||||
datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid())) | |||||
self.logger.info("Num of processes: {}".format(self.world_size)) | |||||
self.logger.info("Use device: {}".format(device)) | |||||
self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||||
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||||
def get_n_steps(self): | |||||
def _get_n_steps(self): | |||||
batch_size = self.world_size * self.batch_size_per_gpu | batch_size = self.world_size * self.batch_size_per_gpu | ||||
return (len(self.train_data) // batch_size + int( | return (len(self.train_data) // batch_size + int( | ||||
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | ||||
def get_data_iter(self, dataset): | |||||
def _get_data_iter(self, dataset): | |||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
return DataSetIter( | return DataSetIter( | ||||
dataset=dataset, batch_size=self.batch_size_per_gpu, | dataset=dataset, batch_size=self.batch_size_per_gpu, | ||||
num_workers=self.num_workers, sampler=self.sampler, | |||||
num_workers=self.num_data_workers, sampler=self.sampler, | |||||
drop_last=self.drop_last | drop_last=self.drop_last | ||||
) | ) | ||||
elif isinstance(dataset, BatchIter): | elif isinstance(dataset, BatchIter): | ||||
@@ -133,7 +145,7 @@ class DistTrainer(): | |||||
else: | else: | ||||
raise TypeError("train_data type {} not support".format(type(dataset))) | raise TypeError("train_data type {} not support".format(type(dataset))) | ||||
def get_optimizer(self, optimizer): | |||||
def _get_optimizer(self, optimizer): | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
return optimizer | return optimizer | ||||
elif isinstance(optimizer, Optimizer): | elif isinstance(optimizer, Optimizer): | ||||
@@ -148,37 +160,50 @@ class DistTrainer(): | |||||
return self.rank == 0 | return self.rank == 0 | ||||
def train(self, on_exception='auto'): | def train(self, on_exception='auto'): | ||||
start_time = time.time() | |||||
results = {} | |||||
if self.n_epochs <= 0: | |||||
if self.is_master: | |||||
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||||
results['seconds'] = 0. | |||||
return results | |||||
if self.is_master: | |||||
try: | |||||
self.logger.info("###### Training epochs started ######") | self.logger.info("###### Training epochs started ######") | ||||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||||
self.logger.info('Total steps: %d'% self.n_steps) | |||||
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||||
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||||
self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||||
self.logger.info("Num of callbacks for all workers: {}".format( | |||||
len(self.callback_manager.callbacks_all))) | |||||
self.logger.info("Num of callbacks for master workers: {}".format( | |||||
len(self.callback_manager.callbacks_master))) | |||||
self.logger.info("Callbacks for all workers: {}".format( | |||||
[repr(cb) for cb in self.callback_manager.callbacks_all])) | |||||
self.logger.info("Callbacks for master workers: {}".format( | |||||
[repr(cb) for cb in self.callback_manager.callbacks_master])) | |||||
start_time = time.time() | |||||
results = {} | |||||
if self.n_epochs <= 0: | |||||
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||||
results['seconds'] = 0. | |||||
return results | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end() | |||||
except BaseException as e: | |||||
self.callback_manager.on_exception(e) | |||||
if on_exception == 'auto': | |||||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end() | |||||
except BaseException as e: | |||||
self.callback_manager.on_exception(e) | |||||
if on_exception == 'auto': | |||||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||||
raise e | |||||
else: | |||||
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||||
elif on_exception == 'raise': | |||||
raise e | raise e | ||||
else: | |||||
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||||
elif on_exception == 'raise': | |||||
raise e | |||||
results['seconds'] = round(time.time() - start_time, 2) | |||||
if self.is_master: | |||||
results['seconds'] = round(time.time() - start_time, 2) | |||||
self.logger.info("###### Train finished ######") | self.logger.info("###### Train finished ######") | ||||
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | ||||
return results | |||||
return results | |||||
finally: | |||||
self.close() | |||||
def _train(self): | def _train(self): | ||||
if self.fp16: | if self.fp16: | ||||
@@ -187,7 +212,7 @@ class DistTrainer(): | |||||
self.step = 0 | self.step = 0 | ||||
self.epoch = 0 | self.epoch = 0 | ||||
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | ||||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||||
pbar = self.pbar | pbar = self.pbar | ||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = self.data_iterator | data_iterator = self.data_iterator | ||||
@@ -238,18 +263,17 @@ class DistTrainer(): | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | (self.validate_every < 0 and self.step % len(data_iterator) == 0)): | ||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | ||||
self.n_steps) | self.n_steps) | ||||
if self.is_master: | |||||
self.logger.info(eval_str) | |||||
self.logger.info(eval_str) | |||||
self.callback_manager.on_validation() | self.callback_manager.on_validation() | ||||
dist.barrier() | dist.barrier() | ||||
if self.save_path and \ | |||||
if self.cp_save_path and \ | |||||
self.save_every > 0 and \ | self.save_every > 0 and \ | ||||
self.step % self.save_every == 0: | self.step % self.save_every == 0: | ||||
self.save_check_point() | self.save_check_point() | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
if self.save_path and self.save_every < 0: | |||||
if self.save_every < 0 and self.cp_save_path: | |||||
self.save_check_point() | self.save_check_point() | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
@@ -287,16 +311,15 @@ class DistTrainer(): | |||||
return loss.mean() | return loss.mean() | ||||
def save_check_point(self, only_params=False): | def save_check_point(self, only_params=False): | ||||
# only master save models | |||||
if self.is_master: | if self.is_master: | ||||
if not os.path.exists(self.save_path): | |||||
os.makedirs(self.save_path) | |||||
path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step)) | |||||
os.makedirs(self.cp_save_path, exist_ok=True) | |||||
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||||
self.logger.info("Save checkpoint to {}".format(path)) | self.logger.info("Save checkpoint to {}".format(path)) | ||||
model_to_save = self.model.module | model_to_save = self.model.module | ||||
if only_params: | if only_params: | ||||
model_to_save = model_to_save.state_dict() | model_to_save = model_to_save.state_dict() | ||||
torch.save(model_to_save, path) | torch.save(model_to_save, path) | ||||
dist.barrier() | |||||
def close(self): | def close(self): | ||||
dist.destroy_process_group() | dist.destroy_process_group() |
@@ -431,13 +431,13 @@ class Trainer(object): | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
# check metrics and dev_data | # check metrics and dev_data | ||||
if (not metrics) and dev_data is not None: | if (not metrics) and dev_data is not None: | ||||
raise ValueError("No metric for dev_data evaluation.") | raise ValueError("No metric for dev_data evaluation.") | ||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check update every | # check update every | ||||
assert update_every >= 1, "update_every must be no less than 1." | assert update_every >= 1, "update_every must be no less than 1." | ||||
self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
@@ -447,7 +447,7 @@ class Trainer(object): | |||||
raise ValueError("save_path can only be None or `str`.") | raise ValueError("save_path can only be None or `str`.") | ||||
# prepare evaluate | # prepare evaluate | ||||
metrics = _prepare_metrics(metrics) | metrics = _prepare_metrics(metrics) | ||||
# parse metric_key | # parse metric_key | ||||
# increase_better is True. It means the exp result gets better if the indicator increases. | # increase_better is True. It means the exp result gets better if the indicator increases. | ||||
# It is true by default. | # It is true by default. | ||||
@@ -546,7 +546,7 @@ class Trainer(object): | |||||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | ||||
else: | else: | ||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | ||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | self.pbar = None | ||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
@@ -558,10 +558,10 @@ class Trainer(object): | |||||
batch_size=self.batch_size, | batch_size=self.batch_size, | ||||
device=None, # 由上面的部分处理device | device=None, # 由上面的部分处理device | ||||
verbose=0) | verbose=0) | ||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
@@ -597,7 +597,7 @@ class Trainer(object): | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | print("training epochs started " + self.start_time, flush=True) | ||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
@@ -610,7 +610,7 @@ class Trainer(object): | |||||
raise e | raise e | ||||
elif on_exception == 'raise': | elif on_exception == 'raise': | ||||
raise e | raise e | ||||
if self.dev_data is not None and self.best_dev_perf is not None: | if self.dev_data is not None and self.best_dev_perf is not None: | ||||
print( | print( | ||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
@@ -628,9 +628,9 @@ class Trainer(object): | |||||
finally: | finally: | ||||
pass | pass | ||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
return results | return results | ||||
def _train(self): | def _train(self): | ||||
if not self.use_tqdm: | if not self.use_tqdm: | ||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | ||||
@@ -656,21 +656,21 @@ class Trainer(object): | |||||
# negative sampling; replace unknown; re-weight batch_y | # negative sampling; replace unknown; re-weight batch_y | ||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | ||||
prediction = self._data_forward(self.model, batch_x) | prediction = self._data_forward(self.model, batch_x) | ||||
# edit prediction | # edit prediction | ||||
self.callback_manager.on_loss_begin(batch_y, prediction) | self.callback_manager.on_loss_begin(batch_y, prediction) | ||||
loss = self._compute_loss(prediction, batch_y).mean() | loss = self._compute_loss(prediction, batch_y).mean() | ||||
avg_loss += loss.item() | avg_loss += loss.item() | ||||
loss = loss / self.update_every | loss = loss / self.update_every | ||||
# Is loss NaN or inf? requires_grad = False | # Is loss NaN or inf? requires_grad = False | ||||
self.callback_manager.on_backward_begin(loss) | self.callback_manager.on_backward_begin(loss) | ||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
self.callback_manager.on_backward_end() | self.callback_manager.on_backward_end() | ||||
self._update() | self._update() | ||||
self.callback_manager.on_step_end() | self.callback_manager.on_step_end() | ||||
if self.step % self.print_every == 0: | if self.step % self.print_every == 0: | ||||
avg_loss = float(avg_loss) / self.print_every | avg_loss = float(avg_loss) / self.print_every | ||||
if self.use_tqdm: | if self.use_tqdm: | ||||
@@ -684,7 +684,7 @@ class Trainer(object): | |||||
pbar.set_postfix_str(print_output) | pbar.set_postfix_str(print_output) | ||||
avg_loss = 0 | avg_loss = 0 | ||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | ||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | (self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | ||||
and self.dev_data is not None: | and self.dev_data is not None: | ||||
@@ -693,20 +693,20 @@ class Trainer(object): | |||||
self.n_steps) + \ | self.n_steps) + \ | ||||
self.tester._format_eval_results(eval_res) | self.tester._format_eval_results(eval_res) | ||||
pbar.write(eval_str + '\n') | pbar.write(eval_str + '\n') | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
self.callback_manager.on_epoch_end() | self.callback_manager.on_epoch_end() | ||||
# =============== epochs end =================== # | # =============== epochs end =================== # | ||||
pbar.close() | pbar.close() | ||||
self.pbar = None | self.pbar = None | ||||
# ============ tqdm end ============== # | # ============ tqdm end ============== # | ||||
def _do_validation(self, epoch, step): | def _do_validation(self, epoch, step): | ||||
self.callback_manager.on_valid_begin() | self.callback_manager.on_valid_begin() | ||||
res = self.tester.test() | res = self.tester.test() | ||||
is_better_eval = False | is_better_eval = False | ||||
if self._better_eval_result(res): | if self._better_eval_result(res): | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
@@ -721,7 +721,7 @@ class Trainer(object): | |||||
# get validation results; adjust optimizer | # get validation results; adjust optimizer | ||||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | ||||
return res | return res | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
@@ -733,14 +733,14 @@ class Trainer(object): | |||||
model.eval() | model.eval() | ||||
else: | else: | ||||
model.train() | model.train() | ||||
def _update(self): | def _update(self): | ||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
if self.step % self.update_every == 0: | if self.step % self.update_every == 0: | ||||
self.optimizer.step() | self.optimizer.step() | ||||
def _data_forward(self, network, x): | def _data_forward(self, network, x): | ||||
x = _build_args(self._forward_func, **x) | x = _build_args(self._forward_func, **x) | ||||
y = network(**x) | y = network(**x) | ||||
@@ -748,7 +748,7 @@ class Trainer(object): | |||||
raise TypeError( | raise TypeError( | ||||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | ||||
return y | return y | ||||
def _grad_backward(self, loss): | def _grad_backward(self, loss): | ||||
"""Compute gradient with link rules. | """Compute gradient with link rules. | ||||
@@ -759,7 +759,7 @@ class Trainer(object): | |||||
if (self.step-1) % self.update_every == 0: | if (self.step-1) % self.update_every == 0: | ||||
self.model.zero_grad() | self.model.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
def _compute_loss(self, predict, truth): | def _compute_loss(self, predict, truth): | ||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
@@ -768,7 +768,7 @@ class Trainer(object): | |||||
:return: a scalar | :return: a scalar | ||||
""" | """ | ||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
""" 存储不含有显卡信息的state_dict或model | """ 存储不含有显卡信息的state_dict或model | ||||
:param model: | :param model: | ||||
@@ -791,7 +791,7 @@ class Trainer(object): | |||||
model.cpu() | model.cpu() | ||||
torch.save(model, model_path) | torch.save(model, model_path) | ||||
model.to(self._model_device) | model.to(self._model_device) | ||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# 返回bool值指示是否成功reload模型 | # 返回bool值指示是否成功reload模型 | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
@@ -809,7 +809,7 @@ class Trainer(object): | |||||
else: | else: | ||||
return False | return False | ||||
return True | return True | ||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
@@ -835,6 +835,9 @@ class Trainer(object): | |||||
is_better = False | is_better = False | ||||
return is_better | return is_better | ||||
@property | |||||
def is_master(self): | |||||
return True | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
DEFAULT_CHECK_NUM_BATCH = 2 | DEFAULT_CHECK_NUM_BATCH = 2 | ||||
@@ -4,7 +4,7 @@ import numpy as np | |||||
import torch.cuda | import torch.cuda | ||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import CrossEntropyLoss | |||||
from fastNLP import CrossEntropyLoss, BCELoss | |||||
from fastNLP import SGD | from fastNLP import SGD | ||||
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | ||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
@@ -12,6 +12,7 @@ import shutil | |||||
import os | import os | ||||
import subprocess | import subprocess | ||||
from argparse import ArgumentParser | from argparse import ArgumentParser | ||||
from fastNLP.core.callback import EchoCallback | |||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
@@ -36,6 +37,26 @@ def prepare_fake_dataset2(*args, size=100): | |||||
def set_rng_seed(seed): | def set_rng_seed(seed): | ||||
np.random.seed(seed) | np.random.seed(seed) | ||||
def prepare_env(): | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x") | |||||
data_set.set_target("y") | |||||
model = NaiveClassifier(2, 1) | |||||
return data_set, model | |||||
class TestDistTrainer(unittest.TestCase): | class TestDistTrainer(unittest.TestCase): | ||||
save_path = './save_cp' | save_path = './save_cp' | ||||
@@ -84,23 +105,35 @@ class TestDistTrainer(unittest.TestCase): | |||||
if trainer.is_master and os.path.exists(self.save_path): | if trainer.is_master and os.path.exists(self.save_path): | ||||
shutil.rmtree(self.save_path) | shutil.rmtree(self.save_path) | ||||
def run3(self): | |||||
data_set, model = prepare_env() | |||||
trainer = DistTrainer( | |||||
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=3, print_every=50, | |||||
callbacks_all=[EchoCallback('callbacks_all')], | |||||
callbacks_master=[EchoCallback('callbacks_master')] | |||||
) | |||||
trainer.train() | |||||
def run_dist(self, run_id): | def run_dist(self, run_id): | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
ngpu = min(4, torch.cuda.device_count()) | |||||
ngpu = min(2, torch.cuda.device_count()) | |||||
path = __file__ | path = __file__ | ||||
cmd = ['python', '-m', 'torch.distributed.launch', | cmd = ['python', '-m', 'torch.distributed.launch', | ||||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | '--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | ||||
print(' '.join(cmd)) | print(' '.join(cmd)) | ||||
retcode = subprocess.call(cmd) | |||||
if retcode: | |||||
raise RuntimeError('subprocess got non-zero exit status %d' % retcode) | |||||
subprocess.check_call(cmd, timeout=60.0) | |||||
def test1(self): | |||||
def test_normal_run(self): | |||||
self.run_dist(1) | self.run_dist(1) | ||||
def test2(self): | |||||
def test_fp16(self): | |||||
self.run_dist(2) | self.run_dist(2) | ||||
def test_callback(self): | |||||
self.run_dist(3) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
runner = TestDistTrainer() | runner = TestDistTrainer() | ||||
parser = ArgumentParser() | parser = ArgumentParser() | ||||