Browse Source

在linux桌面系统上Trainer中使用Tester的tqdm存在bug; 增加一个可选项使得用户可以关闭Tester的tqdm

tags/v0.4.10
yh_cc 5 years ago
parent
commit
c38e8986cc
2 changed files with 7 additions and 3 deletions
  1. +2
    -2
      fastNLP/core/callback.py
  2. +5
    -1
      fastNLP/core/trainer.py

+ 2
- 2
fastNLP/core/callback.py View File

@@ -569,7 +569,7 @@ class FitlogCallback(Callback):
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics,
verbose=0,
use_tqdm=self.trainer.use_tqdm)
use_tqdm=self.trainer.test_use_tqdm)
self.testers[key] = tester
fitlog.add_progress(total_steps=self.n_steps)
@@ -654,7 +654,7 @@ class EvaluateCallback(Callback):
tester = Tester(data=data, model=self.model,
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics, verbose=0,
use_tqdm=self.trainer.use_tqdm)
use_tqdm=self.trainer.test_use_tqdm)
self.testers[key] = tester

def on_valid_end(self, eval_result, metric_key, optimizer, better_result):


+ 5
- 1
fastNLP/core/trainer.py View File

@@ -545,6 +545,10 @@ class Trainer(object):
self.logger = logger

self.use_tqdm = use_tqdm
if 'test_use_tqdm' in kwargs:
self.test_use_tqdm = kwargs.get('test_use_tqdm')
else:
self.test_use_tqdm = self.use_tqdm
self.pbar = None
self.print_every = abs(self.print_every)
self.kwargs = kwargs
@@ -555,7 +559,7 @@ class Trainer(object):
batch_size=kwargs.get("dev_batch_size", self.batch_size),
device=None, # 由上面的部分处理device
verbose=0,
use_tqdm=self.use_tqdm)
use_tqdm=self.test_use_tqdm)

self.step = 0
self.start_time = None # start timestamp


Loading…
Cancel
Save