Browse Source

1.删除Trainer中的prefetch参数; 2.增加中文分词的下载; 3.增加DataBundle的delete_dataset, delete_vocab

tags/v0.4.10
yh 6 years ago
parent
commit
ce083de26b
2 changed files with 8 additions and 11 deletions
  1. +1
    -2
      fastNLP/core/tester.py
  2. +7
    -9
      fastNLP/core/trainer.py

+ 1
- 2
fastNLP/core/tester.py View File

@@ -180,12 +180,11 @@ class Tester(object):
f"`dict`, got {type(eval_result)}") f"`dict`, got {type(eval_result)}")
metric_name = metric.get_metric_name() metric_name = metric.get_metric_name()
eval_results[metric_name] = eval_result eval_results[metric_name] = eval_result
pbar.close()
end_time = time.time() end_time = time.time()
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
# pbar.write(test_str) # pbar.write(test_str)
self.logger.info(test_str) self.logger.info(test_str)
pbar.close()
except _CheckError as e: except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func) prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,


+ 7
- 9
fastNLP/core/trainer.py View File

@@ -336,7 +336,7 @@ except:
import warnings import warnings


from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException
from .callback import CallbackManager, CallbackException, Callback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .metrics import _prepare_metrics from .metrics import _prepare_metrics
@@ -422,13 +422,8 @@ class Trainer(object):
batch_size=32, sampler=None, drop_last=False, update_every=1, batch_size=32, sampler=None, drop_last=False, update_every=1,
num_workers=0, n_epochs=10, print_every=5, num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
validate_every=-1, save_path=None, use_tqdm=True, device=None,
callbacks=None, check_code_level=0, **kwargs): callbacks=None, check_code_level=0, **kwargs):
if prefetch and num_workers==0:
num_workers = 1
if prefetch:
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.")

super(Trainer, self).__init__() super(Trainer, self).__init__()
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
@@ -566,6 +561,9 @@ class Trainer(object):
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp


if isinstance(callbacks, Callback):
callbacks = [callbacks]

self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks) callbacks=callbacks)


@@ -617,8 +615,8 @@ class Trainer(object):


if self.dev_data is not None and self.best_dev_perf is not None: if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info( self.logger.info(
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
self.tester._format_eval_results(self.best_dev_perf), )
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
self.logger.info(self.tester._format_eval_results(self.best_dev_perf))
results['best_eval'] = self.best_dev_perf results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step results['best_step'] = self.best_dev_step


Loading…
Cancel
Save