Browse Source

1. merge conflict

tags/v0.4.10
yh 6 years ago
parent
commit
c3d5128ab5
3 changed files with 43 additions and 9 deletions
  1. +26
    -4
      fastNLP/core/callback.py
  2. +15
    -3
      fastNLP/core/dist_trainer.py
  3. +2
    -2
      test/core/test_dist_trainer.py

+ 26
- 4
fastNLP/core/callback.py View File

@@ -956,20 +956,42 @@ class EchoCallback(Callback):




class TesterCallback(Callback): class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\
#TODO add compare & save best
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(TesterCallback, self).__init__() super(TesterCallback, self).__init__()
self.tester = Tester(data, model, self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size, metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0) num_workers=num_workers, verbose=0)
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = True
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None
self.score = None self.score = None


def on_validation(self): def on_validation(self):
cur_socre = self.tester.test()
cur_score = self.tester.test()
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
self.epoch, self.n_epochs, self.step, self.n_steps, self.epoch, self.n_epochs, self.step, self.n_steps,
self.tester._format_eval_results(cur_socre))
self.tester._format_eval_results(cur_score))
self.logger.info(eval_str) self.logger.info(eval_str)
is_better = self.compare_better(cur_score)
if is_better:
self.score = cur_score
return cur_score, is_better

def compare_better(self, a):
if self.score is None:
return True
k = self.metric_key
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results
if self.increase_better:
return is_increase
else:
return not is_increase


def on_train_end(self): def on_train_end(self):
self.logger.info('Evaluate on training ends.') self.logger.info('Evaluate on training ends.')


+ 15
- 3
fastNLP/core/dist_trainer.py View File

@@ -9,6 +9,7 @@ from tqdm import tqdm
import logging import logging
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial


from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback from .callback import DistCallbackManager, CallbackException, TesterCallback
@@ -46,8 +47,9 @@ class DistTrainer():
callbacks_all=None, callbacks_master=None, callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1, batch_size_per_gpu=8, n_epochs=1,
num_data_workers=1, drop_last=False, num_data_workers=1, drop_last=False,
dev_data=None, metrics=None,
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
log_path=None,
save_every=-1, save_path=None, device='auto', save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None): fp16='', backend=None, init_method=None):


@@ -88,6 +90,7 @@ class DistTrainer():
self.callback_manager = DistCallbackManager( self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all, env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master) callbacks_master=callbacks_master)
self.metric_key = metric_key


model.to(self.device) model.to(self.device)
optimizer = self._get_optimizer(optimizer) optimizer = self._get_optimizer(optimizer)
@@ -133,7 +136,8 @@ class DistTrainer():
self.cp_save_path = None self.cp_save_path = None


# use INFO in the master, WARN for others # use INFO in the master, WARN for others
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
logging.basicConfig(filename=log_path,
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if self.is_master else logging.WARN) level=logging.INFO if self.is_master else logging.WARN)
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
@@ -278,7 +282,15 @@ class DistTrainer():


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)): (self.validate_every < 0 and self.step % len(data_iterator) == 0)):
self.callback_manager.on_validation()
self.callback_manager.on_valid_begin()
eval_res = self.callback_manager.on_validation()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
else:
eval_res, is_better = None, None
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
dist.barrier() dist.barrier()


if self.cp_save_path and \ if self.cp_save_path and \


+ 2
- 2
test/core/test_dist_trainer.py View File

@@ -144,12 +144,12 @@ class TestDistTrainer(unittest.TestCase):
cmd = ['python', '-m', 'torch.distributed.launch', cmd = ['python', '-m', 'torch.distributed.launch',
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)] '--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
print(' '.join(cmd)) print(' '.join(cmd))
subprocess.check_call(cmd, timeout=60.0)
subprocess.check_call(cmd)


def test_normal_run(self): def test_normal_run(self):
self.run_dist(1) self.run_dist(1)


def test_fp16(self):
def no_test_fp16(self):
self.run_dist(2) self.run_dist(2)


def test_callback(self): def test_callback(self):


Loading…
Cancel
Save