Browse Source

1.删除Trainer中的prefetch参数; 2.增加中文分词的下载; 3.增加DataBundle的delete_dataset, delete_vocab

tags/v0.4.10
yh 6 years ago
parent
commit
ce083de26b
2 changed files with 8 additions and 11 deletions
  1. +1
    -2
      fastNLP/core/tester.py
  2. +7
    -9
      fastNLP/core/trainer.py

+ 1
- 2
fastNLP/core/tester.py View File

@@ -180,12 +180,11 @@ class Tester(object):
f"`dict`, got {type(eval_result)}")
metric_name = metric.get_metric_name()
eval_results[metric_name] = eval_result
pbar.close()
end_time = time.time()
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
# pbar.write(test_str)
self.logger.info(test_str)
pbar.close()
except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,


+ 7
- 9
fastNLP/core/trainer.py View File

@@ -336,7 +336,7 @@ except:
import warnings

from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException
from .callback import CallbackManager, CallbackException, Callback
from .dataset import DataSet
from .losses import _prepare_losser
from .metrics import _prepare_metrics
@@ -422,13 +422,8 @@ class Trainer(object):
batch_size=32, sampler=None, drop_last=False, update_every=1,
num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
validate_every=-1, save_path=None, use_tqdm=True, device=None,
callbacks=None, check_code_level=0, **kwargs):
if prefetch and num_workers==0:
num_workers = 1
if prefetch:
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.")

super(Trainer, self).__init__()
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
@@ -566,6 +561,9 @@ class Trainer(object):
self.step = 0
self.start_time = None # start timestamp

if isinstance(callbacks, Callback):
callbacks = [callbacks]

self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)

@@ -617,8 +615,8 @@ class Trainer(object):

if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf), )
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
self.logger.info(self.tester._format_eval_results(self.best_dev_perf))
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step


Loading…
Cancel
Save