Browse Source

1. merge conflict

tags/v0.4.10
yh 6 years ago
parent
commit
c3d5128ab5
3 changed files with 43 additions and 9 deletions
  1. +26
    -4
      fastNLP/core/callback.py
  2. +15
    -3
      fastNLP/core/dist_trainer.py
  3. +2
    -2
      test/core/test_dist_trainer.py

+ 26
- 4
fastNLP/core/callback.py View File

@@ -956,20 +956,42 @@ class EchoCallback(Callback):


class TesterCallback(Callback):
def __init__(self, data, model, metrics, batch_size=16, num_workers=None):\
#TODO add compare & save best
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
super(TesterCallback, self).__init__()
self.tester = Tester(data, model,
metrics=metrics, batch_size=batch_size,
num_workers=num_workers, verbose=0)
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = True
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
else:
self.metric_key = None
self.score = None

def on_validation(self):
cur_socre = self.tester.test()
cur_score = self.tester.test()
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
self.epoch, self.n_epochs, self.step, self.n_steps,
self.tester._format_eval_results(cur_socre))
self.tester._format_eval_results(cur_score))
self.logger.info(eval_str)
is_better = self.compare_better(cur_score)
if is_better:
self.score = cur_score
return cur_score, is_better

def compare_better(self, a):
if self.score is None:
return True
k = self.metric_key
is_increase = self.score[k] <= a[k] # if equal, prefer more recent results
if self.increase_better:
return is_increase
else:
return not is_increase

def on_train_end(self):
self.logger.info('Evaluate on training ends.')


+ 15
- 3
fastNLP/core/dist_trainer.py View File

@@ -9,6 +9,7 @@ from tqdm import tqdm
import logging
import time
from datetime import datetime, timedelta
from functools import partial

from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback
@@ -46,8 +47,9 @@ class DistTrainer():
callbacks_all=None, callbacks_master=None,
batch_size_per_gpu=8, n_epochs=1,
num_data_workers=1, drop_last=False,
dev_data=None, metrics=None,
dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1,
log_path=None,
save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None):

@@ -88,6 +90,7 @@ class DistTrainer():
self.callback_manager = DistCallbackManager(
env={"trainer": self}, callbacks_all=callbacks_all,
callbacks_master=callbacks_master)
self.metric_key = metric_key

model.to(self.device)
optimizer = self._get_optimizer(optimizer)
@@ -133,7 +136,8 @@ class DistTrainer():
self.cp_save_path = None

# use INFO in the master, WARN for others
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
logging.basicConfig(filename=log_path,
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO if self.is_master else logging.WARN)
self.logger = logging.getLogger(__name__)
@@ -278,7 +282,15 @@ class DistTrainer():

if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)):
self.callback_manager.on_validation()
self.callback_manager.on_valid_begin()
eval_res = self.callback_manager.on_validation()
eval_res = list(filter(lambda x: x is not None, eval_res))
if len(eval_res):
eval_res, is_better = list(zip(*eval_res))
else:
eval_res, is_better = None, None
self.callback_manager.on_valid_end(
eval_res, self.metric_key, self.optimizer, is_better)
dist.barrier()

if self.cp_save_path and \


+ 2
- 2
test/core/test_dist_trainer.py View File

@@ -144,12 +144,12 @@ class TestDistTrainer(unittest.TestCase):
cmd = ['python', '-m', 'torch.distributed.launch',
'--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
print(' '.join(cmd))
subprocess.check_call(cmd, timeout=60.0)
subprocess.check_call(cmd)

def test_normal_run(self):
self.run_dist(1)

def test_fp16(self):
def no_test_fp16(self):
self.run_dist(2)

def test_callback(self):


Loading…
Cancel
Save