diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 4ba4b945..24b42b6e 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -569,7 +569,7 @@ class FitlogCallback(Callback): batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), metrics=self.trainer.metrics, verbose=0, - use_tqdm=self.trainer.use_tqdm) + use_tqdm=self.trainer.test_use_tqdm) self.testers[key] = tester fitlog.add_progress(total_steps=self.n_steps) @@ -654,7 +654,7 @@ class EvaluateCallback(Callback): tester = Tester(data=data, model=self.model, batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), metrics=self.trainer.metrics, verbose=0, - use_tqdm=self.trainer.use_tqdm) + use_tqdm=self.trainer.test_use_tqdm) self.testers[key] = tester def on_valid_end(self, eval_result, metric_key, optimizer, better_result): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2c52d104..290a89c1 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -545,6 +545,10 @@ class Trainer(object): self.logger = logger self.use_tqdm = use_tqdm + if 'test_use_tqdm' in kwargs: + self.test_use_tqdm = kwargs.get('test_use_tqdm') + else: + self.test_use_tqdm = self.use_tqdm self.pbar = None self.print_every = abs(self.print_every) self.kwargs = kwargs @@ -555,7 +559,7 @@ class Trainer(object): batch_size=kwargs.get("dev_batch_size", self.batch_size), device=None, # 由上面的部分处理device verbose=0, - use_tqdm=self.use_tqdm) + use_tqdm=self.test_use_tqdm) self.step = 0 self.start_time = None # start timestamp