@@ -100,7 +100,8 @@ class Callback(object): | |||
def __init__(self): | |||
super(Callback, self).__init__() | |||
self._trainer = None # 在Trainer内部被重新赋值 | |||
self._disabled = False | |||
@property | |||
def trainer(self): | |||
""" | |||
@@ -158,6 +159,14 @@ class Callback(object): | |||
def batch_per_epoch(self): | |||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
return self._trainer.batch_per_epoch | |||
@property | |||
def is_master(self): | |||
return self._trainer.is_master() | |||
@property | |||
def disabled(self): | |||
return self._disabled | |||
def on_train_begin(self): | |||
""" | |||
@@ -289,6 +298,8 @@ def _transfer(func): | |||
def wrapper(manager, *arg): | |||
returns = [] | |||
for callback in manager.callbacks: | |||
if callback.disabled: | |||
continue | |||
returns.append(getattr(callback, func.__name__)(*arg)) | |||
return returns | |||
@@ -320,7 +331,7 @@ class CallbackManager(Callback): | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
@_transfer | |||
def on_train_begin(self): | |||
pass | |||
@@ -378,6 +389,24 @@ class CallbackManager(Callback): | |||
pass | |||
class DistCallbackManager(CallbackManager): | |||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||
assert 'trainer' in env | |||
is_master = env['trainer'].is_master | |||
self.patch_callback(callbacks_master, disabled=not is_master) | |||
self.callbacks_all = CallbackManager(env, callbacks_all).callbacks | |||
self.callbacks_master = CallbackManager(env, callbacks_master).callbacks | |||
self.callbacks = self.callbacks_all + self.callbacks_master | |||
def patch_callback(self, callbacks, disabled): | |||
if not callbacks: | |||
return | |||
if not isinstance(callbacks, (list, tuple)): | |||
callbacks = [callbacks] | |||
for cb in callbacks: | |||
cb._disabled = disabled | |||
class GradientClipCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | |||
@@ -415,6 +444,9 @@ class GradientClipCallback(Callback): | |||
def on_backward_end(self): | |||
if self.step%self.update_every==0: | |||
if self.parameters is None: | |||
if getattr(self.trainer, 'fp16', default=''): | |||
from apex import amp | |||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
else: | |||
self.clip_fun(self.parameters, self.clip_value) | |||
@@ -896,3 +928,21 @@ class EarlyStopError(CallbackException): | |||
def __init__(self, msg): | |||
super(EarlyStopError, self).__init__(msg) | |||
class EchoCallback(Callback): | |||
def __init__(self, name, out=sys.stdout): | |||
super(EchoCallback, self).__init__() | |||
self.name = name | |||
self.out = out | |||
def __getattribute__(self, item): | |||
if item.startswith('on_'): | |||
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()), | |||
file=self.out) | |||
return super(EchoCallback, self).__getattribute__(item) | |||
class TesterCallback(Callback): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=None): | |||
self.tester = Tester(data, model) |
@@ -11,7 +11,7 @@ import time | |||
from datetime import datetime, timedelta | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import CallbackManager, CallbackException | |||
from .callback import DistCallbackManager, CallbackException | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .optimizer import Optimizer | |||
@@ -39,18 +39,36 @@ def get_local_rank(): | |||
class DistTrainer(): | |||
def __init__(self, model, train_data, optimizer, loss, callbacks=None, | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
callbacks_all=None, callbacks_master=None, | |||
batch_size_per_gpu=8, n_epochs=1, | |||
num_workers=1, drop_last=False, | |||
num_data_workers=1, drop_last=False, | |||
update_every=1, print_every=10, validate_every=-1, | |||
save_every=-1, save_path=None, | |||
logging_level=logging.INFO, | |||
fp16='', backend='nccl', init_method=None): | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||
if device == 'auto': | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
if backend is None: | |||
backend = 'nccl' if device == 'cuda' else 'gloo' | |||
# init distributed | |||
if device == 'cuda': | |||
torch.cuda.set_device(get_local_rank()) | |||
self.device = torch.device("cuda", get_local_rank()) | |||
else: | |||
self.device = torch.device(device) | |||
dist.init_process_group(backend=backend, init_method=init_method) | |||
self.world_size = dist.get_world_size() | |||
self.rank = dist.get_rank() # unique id for each process | |||
self.model = model | |||
self.train_data = train_data | |||
self.batch_size_per_gpu = int(batch_size_per_gpu) | |||
self.n_epochs = int(n_epochs) | |||
self.num_workers = int(num_workers) | |||
self.num_data_workers = int(num_data_workers) | |||
self.drop_last = drop_last | |||
self.update_every = int(update_every) | |||
self.print_every = int(print_every) | |||
@@ -62,16 +80,13 @@ class DistTrainer(): | |||
self.init_method = init_method | |||
self.backend = backend | |||
self.local_rank = get_local_rank() | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
self._forward_func = model.forward | |||
self.callback_manager = DistCallbackManager( | |||
env={"trainer": self}, callbacks_all=callbacks_all, | |||
callbacks_master=callbacks_master) | |||
assert torch.cuda.is_available(), "Distributed Trainer requires cuda to be enabled." | |||
# init distributed | |||
torch.cuda.set_device(self.local_rank) | |||
self.device = torch.device("cuda", self.local_rank) | |||
dist.init_process_group(backend=self.backend, init_method=self.init_method) | |||
model.to(self.device) | |||
optimizer = self.get_optimizer(optimizer) | |||
optimizer = self._get_optimizer(optimizer) | |||
# init fp16, must before DataParallel init | |||
if len(self.fp16): | |||
@@ -81,51 +96,48 @@ class DistTrainer(): | |||
except ImportError: | |||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |||
assert device == 'cuda', "Amp requires cuda device" | |||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | |||
# init DataParallel | |||
self.model = DDP(model, device_ids=[self.local_rank], | |||
output_device=self.local_rank) | |||
self.optimizer = optimizer | |||
self.world_size = dist.get_world_size() | |||
self.rank = dist.get_rank() # unique id for each process | |||
self.sampler = DistributedSampler(self.train_data) | |||
self.data_iterator = self.get_data_iter(self.train_data) | |||
self.n_steps = self.get_n_steps() | |||
self.data_iterator = self._get_data_iter(self.train_data) | |||
self.n_steps = self._get_n_steps() | |||
# Setup logging | |||
dist.barrier() | |||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||
if self.save_path: | |||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||
else: | |||
self.cp_save_path = None | |||
# use INFO in the master, WARN for others | |||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
datefmt='%m/%d/%Y %H:%M:%S', | |||
level=logging_level) | |||
level=logging.INFO if self.is_master else logging.WARN) | |||
self.logger = logging.getLogger(__name__) | |||
self.logger.info("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
self.logger.info("Setup Distributed Trainer") | |||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | |||
if self.is_master: | |||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||
self.logger.info('Total steps: %d'% self.n_steps) | |||
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||
self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||
self.logger.info("Num of callbacks: {}".format(len(self.callback_manager.callbacks))) | |||
self.logger.info( | |||
"Use callbacks: {}".format([repr(cb) for cb in self.callback_manager.callbacks])) | |||
# only master process save model | |||
if self.save_path: | |||
self.save_path = os.path.join( | |||
self.save_path, | |||
datetime.now().strftime('%m_%d_%y-%H_%M_%S')+'-'+str(os.getpid())) | |||
self.logger.info("Num of processes: {}".format(self.world_size)) | |||
self.logger.info("Use device: {}".format(device)) | |||
self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||
def get_n_steps(self): | |||
def _get_n_steps(self): | |||
batch_size = self.world_size * self.batch_size_per_gpu | |||
return (len(self.train_data) // batch_size + int( | |||
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | |||
def get_data_iter(self, dataset): | |||
def _get_data_iter(self, dataset): | |||
if isinstance(dataset, DataSet): | |||
return DataSetIter( | |||
dataset=dataset, batch_size=self.batch_size_per_gpu, | |||
num_workers=self.num_workers, sampler=self.sampler, | |||
num_workers=self.num_data_workers, sampler=self.sampler, | |||
drop_last=self.drop_last | |||
) | |||
elif isinstance(dataset, BatchIter): | |||
@@ -133,7 +145,7 @@ class DistTrainer(): | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(dataset))) | |||
def get_optimizer(self, optimizer): | |||
def _get_optimizer(self, optimizer): | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
return optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
@@ -148,37 +160,50 @@ class DistTrainer(): | |||
return self.rank == 0 | |||
def train(self, on_exception='auto'): | |||
start_time = time.time() | |||
results = {} | |||
if self.n_epochs <= 0: | |||
if self.is_master: | |||
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||
results['seconds'] = 0. | |||
return results | |||
if self.is_master: | |||
try: | |||
self.logger.info("###### Training epochs started ######") | |||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||
self.logger.info('Total steps: %d'% self.n_steps) | |||
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||
self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||
self.logger.info("Num of callbacks for all workers: {}".format( | |||
len(self.callback_manager.callbacks_all))) | |||
self.logger.info("Num of callbacks for master workers: {}".format( | |||
len(self.callback_manager.callbacks_master))) | |||
self.logger.info("Callbacks for all workers: {}".format( | |||
[repr(cb) for cb in self.callback_manager.callbacks_all])) | |||
self.logger.info("Callbacks for master workers: {}".format( | |||
[repr(cb) for cb in self.callback_manager.callbacks_master])) | |||
start_time = time.time() | |||
results = {} | |||
if self.n_epochs <= 0: | |||
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||
results['seconds'] = 0. | |||
return results | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except BaseException as e: | |||
self.callback_manager.on_exception(e) | |||
if on_exception == 'auto': | |||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except BaseException as e: | |||
self.callback_manager.on_exception(e) | |||
if on_exception == 'auto': | |||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
raise e | |||
else: | |||
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||
elif on_exception == 'raise': | |||
raise e | |||
else: | |||
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||
elif on_exception == 'raise': | |||
raise e | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
if self.is_master: | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
self.logger.info("###### Train finished ######") | |||
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | |||
return results | |||
return results | |||
finally: | |||
self.close() | |||
def _train(self): | |||
if self.fp16: | |||
@@ -187,7 +212,7 @@ class DistTrainer(): | |||
self.step = 0 | |||
self.epoch = 0 | |||
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
pbar = self.pbar | |||
avg_loss = 0 | |||
data_iterator = self.data_iterator | |||
@@ -238,18 +263,17 @@ class DistTrainer(): | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) | |||
if self.is_master: | |||
self.logger.info(eval_str) | |||
self.logger.info(eval_str) | |||
self.callback_manager.on_validation() | |||
dist.barrier() | |||
if self.save_path and \ | |||
if self.cp_save_path and \ | |||
self.save_every > 0 and \ | |||
self.step % self.save_every == 0: | |||
self.save_check_point() | |||
# ================= mini-batch end ==================== # | |||
if self.save_path and self.save_every < 0: | |||
if self.save_every < 0 and self.cp_save_path: | |||
self.save_check_point() | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
@@ -287,16 +311,15 @@ class DistTrainer(): | |||
return loss.mean() | |||
def save_check_point(self, only_params=False): | |||
# only master save models | |||
if self.is_master: | |||
if not os.path.exists(self.save_path): | |||
os.makedirs(self.save_path) | |||
path = os.path.join(self.save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
os.makedirs(self.cp_save_path, exist_ok=True) | |||
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
self.logger.info("Save checkpoint to {}".format(path)) | |||
model_to_save = self.model.module | |||
if only_params: | |||
model_to_save = model_to_save.state_dict() | |||
torch.save(model_to_save, path) | |||
dist.barrier() | |||
def close(self): | |||
dist.destroy_process_group() |
@@ -431,13 +431,13 @@ class Trainer(object): | |||
super(Trainer, self).__init__() | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
# check metrics and dev_data | |||
if (not metrics) and dev_data is not None: | |||
raise ValueError("No metric for dev_data evaluation.") | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check update every | |||
assert update_every >= 1, "update_every must be no less than 1." | |||
self.update_every = int(update_every) | |||
@@ -447,7 +447,7 @@ class Trainer(object): | |||
raise ValueError("save_path can only be None or `str`.") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
@@ -546,7 +546,7 @@ class Trainer(object): | |||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
@@ -558,10 +558,10 @@ class Trainer(object): | |||
batch_size=self.batch_size, | |||
device=None, # 由上面的部分处理device | |||
verbose=0) | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
@@ -597,7 +597,7 @@ class Trainer(object): | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
start_time = time.time() | |||
print("training epochs started " + self.start_time, flush=True) | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
@@ -610,7 +610,7 @@ class Trainer(object): | |||
raise e | |||
elif on_exception == 'raise': | |||
raise e | |||
if self.dev_data is not None and self.best_dev_perf is not None: | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
@@ -628,9 +628,9 @@ class Trainer(object): | |||
finally: | |||
pass | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
return results | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
@@ -656,21 +656,21 @@ class Trainer(object): | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y).mean() | |||
avg_loss += loss.item() | |||
loss = loss / self.update_every | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
self._grad_backward(loss) | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
if self.use_tqdm: | |||
@@ -684,7 +684,7 @@ class Trainer(object): | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
@@ -693,20 +693,20 @@ class Trainer(object): | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
self.callback_manager.on_valid_begin() | |||
res = self.tester.test() | |||
is_better_eval = False | |||
if self._better_eval_result(res): | |||
if self.save_path is not None: | |||
@@ -721,7 +721,7 @@ class Trainer(object): | |||
# get validation results; adjust optimizer | |||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||
return res | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -733,14 +733,14 @@ class Trainer(object): | |||
model.eval() | |||
else: | |||
model.train() | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
x = _build_args(self._forward_func, **x) | |||
y = network(**x) | |||
@@ -748,7 +748,7 @@ class Trainer(object): | |||
raise TypeError( | |||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
@@ -759,7 +759,7 @@ class Trainer(object): | |||
if (self.step-1) % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
@@ -768,7 +768,7 @@ class Trainer(object): | |||
:return: a scalar | |||
""" | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
@@ -791,7 +791,7 @@ class Trainer(object): | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.to(self._model_device) | |||
def _load_model(self, model, model_name, only_param=False): | |||
# 返回bool值指示是否成功reload模型 | |||
if self.save_path is not None: | |||
@@ -809,7 +809,7 @@ class Trainer(object): | |||
else: | |||
return False | |||
return True | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
@@ -835,6 +835,9 @@ class Trainer(object): | |||
is_better = False | |||
return is_better | |||
@property | |||
def is_master(self): | |||
return True | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
@@ -4,7 +4,7 @@ import numpy as np | |||
import torch.cuda | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import CrossEntropyLoss | |||
from fastNLP import CrossEntropyLoss, BCELoss | |||
from fastNLP import SGD | |||
from fastNLP.core.dist_trainer import DistTrainer, get_local_rank | |||
from fastNLP.models.base_model import NaiveClassifier | |||
@@ -12,6 +12,7 @@ import shutil | |||
import os | |||
import subprocess | |||
from argparse import ArgumentParser | |||
from fastNLP.core.callback import EchoCallback | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
@@ -36,6 +37,26 @@ def prepare_fake_dataset2(*args, size=100): | |||
def set_rng_seed(seed): | |||
np.random.seed(seed) | |||
def prepare_env(): | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
model = NaiveClassifier(2, 1) | |||
return data_set, model | |||
class TestDistTrainer(unittest.TestCase): | |||
save_path = './save_cp' | |||
@@ -84,23 +105,35 @@ class TestDistTrainer(unittest.TestCase): | |||
if trainer.is_master and os.path.exists(self.save_path): | |||
shutil.rmtree(self.save_path) | |||
def run3(self): | |||
data_set, model = prepare_env() | |||
trainer = DistTrainer( | |||
data_set, model, optimizer=None, loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=3, print_every=50, | |||
callbacks_all=[EchoCallback('callbacks_all')], | |||
callbacks_master=[EchoCallback('callbacks_master')] | |||
) | |||
trainer.train() | |||
def run_dist(self, run_id): | |||
if torch.cuda.is_available(): | |||
ngpu = min(4, torch.cuda.device_count()) | |||
ngpu = min(2, torch.cuda.device_count()) | |||
path = __file__ | |||
cmd = ['python', '-m', 'torch.distributed.launch', | |||
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] | |||
print(' '.join(cmd)) | |||
retcode = subprocess.call(cmd) | |||
if retcode: | |||
raise RuntimeError('subprocess got non-zero exit status %d' % retcode) | |||
subprocess.check_call(cmd, timeout=60.0) | |||
def test1(self): | |||
def test_normal_run(self): | |||
self.run_dist(1) | |||
def test2(self): | |||
def test_fp16(self): | |||
self.run_dist(2) | |||
def test_callback(self): | |||
self.run_dist(3) | |||
if __name__ == '__main__': | |||
runner = TestDistTrainer() | |||
parser = ArgumentParser() | |||