@@ -164,7 +164,7 @@ class Callback(object): | |||||
@property | @property | ||||
def is_master(self): | def is_master(self): | ||||
return self._trainer.is_master() | |||||
return self._trainer.is_master | |||||
@property | @property | ||||
def disabled(self): | def disabled(self): | ||||
@@ -172,7 +172,7 @@ class Callback(object): | |||||
@property | @property | ||||
def logger(self): | def logger(self): | ||||
return getattr(self._trainer, 'logger', logging) | |||||
return getattr(self._trainer, 'logger', logging.getLogger(__name__)) | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
""" | """ | ||||
@@ -405,11 +405,11 @@ class DistCallbackManager(CallbackManager): | |||||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | def __init__(self, env, callbacks_all=None, callbacks_master=None): | ||||
super(DistCallbackManager, self).__init__(env) | super(DistCallbackManager, self).__init__(env) | ||||
assert 'trainer' in env | assert 'trainer' in env | ||||
is_master = env['trainer'].is_master | |||||
self.patch_callback(callbacks_master, disabled=not is_master) | |||||
self.callbacks_all = self.prepare_callbacks(callbacks_all) | |||||
self.callbacks_master = self.prepare_callbacks(callbacks_master) | |||||
self.callbacks = self.callbacks_all + self.callbacks_master | |||||
self._trainer = env['trainer'] | |||||
self.callbacks_master = [] | |||||
self.callbacks_all = [] | |||||
self.add_callback(callbacks_all, master=False) | |||||
self.add_callback(callbacks_master, master=True) | |||||
def patch_callback(self, callbacks, disabled): | def patch_callback(self, callbacks, disabled): | ||||
if not callbacks: | if not callbacks: | ||||
@@ -419,6 +419,14 @@ class DistCallbackManager(CallbackManager): | |||||
for cb in callbacks: | for cb in callbacks: | ||||
cb._disabled = disabled | cb._disabled = disabled | ||||
def add_callback(self, cb, master=False): | |||||
if master: | |||||
self.patch_callback(cb, not self.is_master) | |||||
self.callbacks_master += self.prepare_callbacks(cb) | |||||
else: | |||||
self.callbacks_all += self.prepare_callbacks(cb) | |||||
self.callbacks = self.callbacks_all + self.callbacks_master | |||||
class GradientClipCallback(Callback): | class GradientClipCallback(Callback): | ||||
""" | """ | ||||
@@ -1048,15 +1056,26 @@ class TesterCallback(Callback): | |||||
self.score = cur_score | self.score = cur_score | ||||
return cur_score, is_better | return cur_score, is_better | ||||
def _get_score(self, metric_dict, key): | |||||
for metric in metric_dict.items(): | |||||
if key in metric: | |||||
return metric[key] | |||||
return None | |||||
def compare_better(self, a): | def compare_better(self, a): | ||||
if self.score is None: | if self.score is None: | ||||
return True | return True | ||||
if self.metric_key is None: | |||||
self.metric_key = list(list(self.score.values())[0].keys())[0] | |||||
k = self.metric_key | k = self.metric_key | ||||
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results | |||||
score = self._get_score(self.score, k) | |||||
new_score = self._get_score(a, k) | |||||
if score is None or new_score is None: | |||||
return False | |||||
if self.increase_better: | if self.increase_better: | ||||
return is_increase | |||||
return score <= new_score | |||||
else: | else: | ||||
return not is_increase | |||||
return score >= new_score | |||||
def on_train_end(self): | def on_train_end(self): | ||||
self.logger.info('Evaluate on training ends.') | self.logger.info('Evaluate on training ends.') | ||||
@@ -22,6 +22,7 @@ from .optimizer import Optimizer | |||||
from .utils import _build_args | from .utils import _build_args | ||||
from .utils import _move_dict_value_to_device | from .utils import _move_dict_value_to_device | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from ..io.logger import initLogger | |||||
from pkg_resources import parse_version | from pkg_resources import parse_version | ||||
__all__ = [ | __all__ = [ | ||||
@@ -40,7 +41,7 @@ def get_local_rank(): | |||||
if 'local_rank' in args and args.local_rank: | if 'local_rank' in args and args.local_rank: | ||||
os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function | os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function | ||||
return args.local_rank | return args.local_rank | ||||
raise RuntimeError('Please use "python -m torch.distributed.launch train_script.py') | |||||
raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py') | |||||
class DistTrainer(): | class DistTrainer(): | ||||
@@ -50,7 +51,7 @@ class DistTrainer(): | |||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
callbacks_all=None, callbacks_master=None, | callbacks_all=None, callbacks_master=None, | ||||
batch_size_per_gpu=8, n_epochs=1, | batch_size_per_gpu=8, n_epochs=1, | ||||
num_data_workers=1, drop_last=False, | |||||
num_workers=1, drop_last=False, | |||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
update_every=1, print_every=10, validate_every=-1, | update_every=1, print_every=10, validate_every=-1, | ||||
log_path=None, | log_path=None, | ||||
@@ -78,7 +79,7 @@ class DistTrainer(): | |||||
self.train_data = train_data | self.train_data = train_data | ||||
self.batch_size_per_gpu = int(batch_size_per_gpu) | self.batch_size_per_gpu = int(batch_size_per_gpu) | ||||
self.n_epochs = int(n_epochs) | self.n_epochs = int(n_epochs) | ||||
self.num_data_workers = int(num_data_workers) | |||||
self.num_data_workers = int(num_workers) | |||||
self.drop_last = drop_last | self.drop_last = drop_last | ||||
self.update_every = int(update_every) | self.update_every = int(update_every) | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
@@ -127,9 +128,8 @@ class DistTrainer(): | |||||
if dev_data and metrics: | if dev_data and metrics: | ||||
cb = TesterCallback( | cb = TesterCallback( | ||||
dev_data, model, metrics, | dev_data, model, metrics, | ||||
batch_size=batch_size_per_gpu, num_workers=num_data_workers) | |||||
self.callback_manager.callbacks_master += \ | |||||
self.callback_manager.prepare_callbacks([cb]) | |||||
batch_size=batch_size_per_gpu, num_workers=num_workers) | |||||
self.callback_manager.add_callback([cb], master=True) | |||||
# Setup logging | # Setup logging | ||||
dist.barrier() | dist.barrier() | ||||
@@ -140,10 +140,7 @@ class DistTrainer(): | |||||
self.cp_save_path = None | self.cp_save_path = None | ||||
# use INFO in the master, WARN for others | # use INFO in the master, WARN for others | ||||
logging.basicConfig(filename=log_path, | |||||
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||||
datefmt='%m/%d/%Y %H:%M:%S', | |||||
level=logging.INFO if self.is_master else logging.WARN) | |||||
initLogger(log_path, level=logging.INFO if self.is_master else logging.WARNING) | |||||
self.logger = logging.getLogger(__name__) | self.logger = logging.getLogger(__name__) | ||||
self.logger.info("Setup Distributed Trainer") | self.logger.info("Setup Distributed Trainer") | ||||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | ||||
@@ -284,18 +281,8 @@ class DistTrainer(): | |||||
self.callback_manager.on_batch_end() | self.callback_manager.on_batch_end() | ||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): | |||||
self.callback_manager.on_valid_begin() | |||||
eval_res = self.callback_manager.on_validation() | |||||
eval_res = list(filter(lambda x: x is not None, eval_res)) | |||||
if len(eval_res): | |||||
eval_res, is_better = list(zip(*eval_res)) | |||||
else: | |||||
eval_res, is_better = None, None | |||||
self.callback_manager.on_valid_end( | |||||
eval_res, self.metric_key, self.optimizer, is_better) | |||||
dist.barrier() | |||||
if (self.validate_every > 0 and self.step % self.validate_every == 0): | |||||
self._do_validation() | |||||
if self.cp_save_path and \ | if self.cp_save_path and \ | ||||
self.save_every > 0 and \ | self.save_every > 0 and \ | ||||
@@ -303,6 +290,9 @@ class DistTrainer(): | |||||
self.save_check_point() | self.save_check_point() | ||||
# ================= mini-batch end ==================== # | # ================= mini-batch end ==================== # | ||||
if self.validate_every < 0: | |||||
self._do_validation() | |||||
if self.save_every < 0 and self.cp_save_path: | if self.save_every < 0 and self.cp_save_path: | ||||
self.save_check_point() | self.save_check_point() | ||||
# lr decay; early stopping | # lr decay; early stopping | ||||
@@ -351,5 +341,17 @@ class DistTrainer(): | |||||
model_to_save = model_to_save.state_dict() | model_to_save = model_to_save.state_dict() | ||||
torch.save(model_to_save, path) | torch.save(model_to_save, path) | ||||
def _do_validation(self): | |||||
self.callback_manager.on_valid_begin() | |||||
eval_res = self.callback_manager.on_validation() | |||||
eval_res = list(filter(lambda x: x is not None, eval_res)) | |||||
if len(eval_res): | |||||
eval_res, is_better = list(zip(*eval_res)) | |||||
else: | |||||
eval_res, is_better = None, None | |||||
self.callback_manager.on_valid_end( | |||||
eval_res, self.metric_key, self.optimizer, is_better) | |||||
dist.barrier() | |||||
def close(self): | def close(self): | ||||
dist.destroy_process_group() | dist.destroy_process_group() |
@@ -353,6 +353,8 @@ from .utils import _get_func_signature | |||||
from .utils import _get_model_device | from .utils import _get_model_device | ||||
from .utils import _move_model_to_device | from .utils import _move_model_to_device | ||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
from ..io.logger import initLogger | |||||
import logging | |||||
class Trainer(object): | class Trainer(object): | ||||
@@ -547,6 +549,12 @@ class Trainer(object): | |||||
else: | else: | ||||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | ||||
log_path = None | |||||
if save_path is not None: | |||||
log_path = os.path.join(os.path.dirname(save_path), 'log') | |||||
initLogger(log_path) | |||||
self.logger = logging.getLogger(__name__) | |||||
self.use_tqdm = use_tqdm | self.use_tqdm = use_tqdm | ||||
self.pbar = None | self.pbar = None | ||||
self.print_every = abs(self.print_every) | self.print_every = abs(self.print_every) | ||||
@@ -588,7 +596,7 @@ class Trainer(object): | |||||
""" | """ | ||||
results = {} | results = {} | ||||
if self.n_epochs <= 0: | if self.n_epochs <= 0: | ||||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
results['seconds'] = 0. | results['seconds'] = 0. | ||||
return results | return results | ||||
try: | try: | ||||
@@ -597,7 +605,7 @@ class Trainer(object): | |||||
self._load_best_model = load_best_model | self._load_best_model = load_best_model | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
start_time = time.time() | start_time = time.time() | ||||
print("training epochs started " + self.start_time, flush=True) | |||||
self.logger.info("training epochs started " + self.start_time) | |||||
try: | try: | ||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
@@ -613,7 +621,7 @@ class Trainer(object): | |||||
raise e | raise e | ||||
if self.dev_data is not None and self.best_dev_perf is not None: | if self.dev_data is not None and self.best_dev_perf is not None: | ||||
print( | |||||
self.logger.info( | |||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | ||||
self.tester._format_eval_results(self.best_dev_perf), ) | self.tester._format_eval_results(self.best_dev_perf), ) | ||||
results['best_eval'] = self.best_dev_perf | results['best_eval'] = self.best_dev_perf | ||||
@@ -623,9 +631,9 @@ class Trainer(object): | |||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | ||||
load_succeed = self._load_model(self.model, model_name) | load_succeed = self._load_model(self.model, model_name) | ||||
if load_succeed: | if load_succeed: | ||||
print("Reloaded the best model.") | |||||
self.logger.info("Reloaded the best model.") | |||||
else: | else: | ||||
print("Fail to reload best model.") | |||||
self.logger.info("Fail to reload best model.") | |||||
finally: | finally: | ||||
pass | pass | ||||
results['seconds'] = round(time.time() - start_time, 2) | results['seconds'] = round(time.time() - start_time, 2) | ||||
@@ -825,12 +833,12 @@ class Trainer(object): | |||||
self.best_metric_indicator = indicator_val | self.best_metric_indicator = indicator_val | ||||
else: | else: | ||||
if self.increase_better is True: | if self.increase_better is True: | ||||
if indicator_val > self.best_metric_indicator: | |||||
if indicator_val >= self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | self.best_metric_indicator = indicator_val | ||||
else: | else: | ||||
is_better = False | is_better = False | ||||
else: | else: | ||||
if indicator_val < self.best_metric_indicator: | |||||
if indicator_val <= self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | self.best_metric_indicator = indicator_val | ||||
else: | else: | ||||
is_better = False | is_better = False | ||||
@@ -17,6 +17,7 @@ import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from typing import List | from typing import List | ||||
import logging | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -659,15 +660,14 @@ class _pseudo_tqdm: | |||||
""" | """ | ||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
pass | |||||
self.logger = logging.getLogger() | |||||
def write(self, info): | def write(self, info): | ||||
print(info) | |||||
self.logger.info(info) | |||||
def set_postfix_str(self, info): | def set_postfix_str(self, info): | ||||
print(info) | |||||
self.logger.info(info) | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
def pass_func(*args, **kwargs): | def pass_func(*args, **kwargs): | ||||
@@ -1,6 +1,5 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.embeddings.utils import get_embeddings | |||||
from fastNLP.core import Const as C | from fastNLP.core import Const as C | ||||
@@ -64,7 +63,8 @@ class RegionEmbedding(nn.Module): | |||||
kernel_sizes = [5, 9] | kernel_sizes = [5, 9] | ||||
assert isinstance( | assert isinstance( | ||||
kernel_sizes, list), 'kernel_sizes should be List(int)' | kernel_sizes, list), 'kernel_sizes should be List(int)' | ||||
self.embed = get_embeddings(init_embed) | |||||
# self.embed = nn.Embedding.from_pretrained(torch.tensor(init_embed).float(), freeze=False) | |||||
self.embed = init_embed | |||||
try: | try: | ||||
embed_dim = self.embed.embedding_dim | embed_dim = self.embed.embedding_dim | ||||
except Exception: | except Exception: | ||||
@@ -13,10 +13,11 @@ from fastNLP.core.sampler import BucketSampler | |||||
from fastNLP.core import LRScheduler | from fastNLP.core import LRScheduler | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
from fastNLP.core.vocabulary import VocabularyOption | from fastNLP.core.vocabulary import VocabularyOption | ||||
from fastNLP.core.dist_trainer import DistTrainer | |||||
from utils.util_init import set_rng_seeds | from utils.util_init import set_rng_seeds | ||||
import os | import os | ||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
@@ -64,27 +65,28 @@ def load_data(): | |||||
ds.apply_field(len, C.INPUT, C.INPUT_LEN) | ds.apply_field(len, C.INPUT, C.INPUT_LEN) | ||||
ds.set_input(C.INPUT, C.INPUT_LEN) | ds.set_input(C.INPUT, C.INPUT_LEN) | ||||
ds.set_target(C.TARGET) | ds.set_target(C.TARGET) | ||||
embedding = StaticEmbedding( | |||||
datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad, | |||||
normalize=False | |||||
) | |||||
return datainfo, embedding | |||||
return datainfo | |||||
datainfo, embedding = load_data() | |||||
datainfo = load_data() | |||||
embedding = StaticEmbedding( | |||||
datainfo.vocabs['words'], model_dir_or_name='en-glove-6b-100d', requires_grad=ops.embedding_grad, | |||||
normalize=False) | |||||
embedding.embedding.weight.data /= embedding.embedding.weight.data.std() | embedding.embedding.weight.data /= embedding.embedding.weight.data.std() | ||||
print(embedding.embedding.weight.mean(), embedding.embedding.weight.std()) | |||||
print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.std()) | |||||
# 2.或直接复用fastNLP的模型 | # 2.或直接复用fastNLP的模型 | ||||
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | ||||
datainfo.datasets['train'] = datainfo.datasets['train'][:1000] | |||||
datainfo.datasets['test'] = datainfo.datasets['test'][:1000] | |||||
print(datainfo) | print(datainfo) | ||||
print(datainfo.datasets['train'][0]) | print(datainfo.datasets['train'][0]) | ||||
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | ||||
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | ||||
print(model) | |||||
# print(model) | |||||
# 3. 声明loss,metric,optimizer | # 3. 声明loss,metric,optimizer | ||||
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET) | ||||
@@ -109,13 +111,17 @@ device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
print(device) | print(device) | ||||
# 4.定义train方法 | # 4.定义train方法 | ||||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
metrics=[metric], | |||||
dev_data=datainfo.datasets['test'], device=device, | |||||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||||
n_epochs=ops.train_epoch, num_workers=4) | |||||
# trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
# sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
# metrics=[metric], | |||||
# dev_data=datainfo.datasets['test'], device=device, | |||||
# check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | |||||
# n_epochs=ops.train_epoch, num_workers=4) | |||||
trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | |||||
metrics=[metric], | |||||
dev_data=datainfo.datasets['test'], device='cuda', | |||||
batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks, | |||||
n_epochs=ops.train_epoch, num_workers=4) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||