diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 9830ff1e..e02e04a0 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -206,7 +206,7 @@ class Callback(object): def on_batch_begin(self, batch_x, batch_y, indices): r""" 每次采集到一个batch的数据则调用一次。这里对batch_x或batch_y删除添加内容是可以影响到Trainer中内容的。所以在这一步 - 可以进行一些负采样之类的操作 + 可以进行一些负采样之类的操作。batch_x和batch_y中的tensor已经被放置到了模型所在的设备上。 :param dict batch_x: DataSet中被设置为input的field的batch。 :param dict batch_y: DataSet中被设置为target的field的batch。 @@ -1169,11 +1169,12 @@ class EchoCallback(Callback): class _TesterCallback(Callback): - def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): + def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None, sampler=None, + use_tqdm=True): super(_TesterCallback, self).__init__() self.tester = Tester(data, model, metrics=metrics, batch_size=batch_size, - num_workers=num_workers, verbose=0) + num_workers=num_workers, verbose=0, sampler=sampler, use_tqdm=use_tqdm) if metric_key is not None: self.metric_key, self.increase_better = self._parse_metric_key(metric_key) else: diff --git a/fastNLP/core/dist_trainer.py b/fastNLP/core/dist_trainer.py index 726a5e60..a76d0a05 100644 --- a/fastNLP/core/dist_trainer.py +++ b/fastNLP/core/dist_trainer.py @@ -73,7 +73,7 @@ class DistTrainer(): dev_data=None, metrics=None, metric_key=None, update_every=1, print_every=10, validate_every=-1, save_path=None, device='auto', - fp16='', use_tqdm=True): + fp16='', use_tqdm=True, **kwargs): r""" :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 @@ -106,6 +106,9 @@ class DistTrainer(): :param str device: 指定 device,可以是 gpu,cpu 或 auto :param str fp16: 指定半精度训练的优化等级,可为 O1,O2 或 O3,若为空字符串则不使用半精度。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 + :param kwargs: 支持配置可选参数 + bool test_use_tqdm: 在dev上验证的时候是否开启tqdm + Sampler test_sampler: 在evaluate的时候使用的sampler """ assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" if device == 'auto': @@ -163,16 +166,23 @@ class DistTrainer(): self.model = self.ddp_model.module self.optimizer = optimizer - self.sampler = DistributedSampler(self.train_data) + if isinstance(self.train_data, DataSet): + self.sampler = DistributedSampler(self.train_data) self.data_iterator = self._get_data_iter(self.train_data) self.batch_size = self.world_size * self.batch_size_per_gpu self.n_steps = self._get_n_steps() + if 'test_use_tqdm' in kwargs: + test_use_tqdm = kwargs.get('test_use_tqdm') + else: + test_use_tqdm = self.use_tqdm + # for evaluation, only run eval on master proc if dev_data and metrics: cb = _TesterCallback( dev_data, model, metrics, - batch_size=batch_size_per_gpu, num_workers=num_workers) + batch_size=batch_size_per_gpu, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), + use_tqdm=test_use_tqdm) self.test_manager.add_callback([cb], master=True) # Setup logging @@ -232,8 +242,10 @@ class DistTrainer(): elif optimizer is None: return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) else: - raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) - + if not (hasattr(optimizer, 'step') and callable(optimizer.step)): + raise TypeError("optimizer must have a callable step() function.") + else: + self.optimizer = optimizer @property def is_master(self): r"""是否是主进程""" diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a562c58a..72aba38a 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -545,7 +545,10 @@ class Trainer(object): elif optimizer is None: self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) else: - raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) + if not (hasattr(optimizer, 'step') and callable(optimizer.step)): + raise TypeError("optimizer must have a callable step() function.") + else: + self.optimizer = optimizer self.logger = logger diff --git a/fastNLP/modules/generator/seq2seq_generator.py b/fastNLP/modules/generator/seq2seq_generator.py index faa9a93a..7f15091c 100644 --- a/fastNLP/modules/generator/seq2seq_generator.py +++ b/fastNLP/modules/generator/seq2seq_generator.py @@ -273,6 +273,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) # 得到(batch_size, num_beams), (batch_size, num_beams) next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) + # TODO 这里需要考虑如果在第一个位置就结束的情况 # 根据index来做顺序的调转 indices = torch.arange(batch_size, dtype=torch.long).to(device)