| @@ -248,3 +248,52 @@ class Tester(object): | |||||
| _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) | _str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()]) | ||||
| _str += '\n' | _str += '\n' | ||||
| return _str[:-1] | return _str[:-1] | ||||
| def flp_topredict(self): | |||||
| r"""开始进行预测,并返回预测结果。 | |||||
| :return 本次的预测结果,为一个字典,其中只有{predict}一个key,而key的值类型为tensor。 | |||||
| """ | |||||
| # turn on the testing mode; clean up the history | |||||
| self._model_device = _get_model_device(self._model) | |||||
| network = self._model | |||||
| self._mode(network, is_test=True) | |||||
| data_iterator = self.data_iterator | |||||
| eval_results = [] | |||||
| try: | |||||
| with torch.no_grad(): | |||||
| if not self.use_tqdm: | |||||
| from .utils import _pseudo_tqdm as inner_tqdm | |||||
| else: | |||||
| inner_tqdm = tqdm | |||||
| with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: | |||||
| pbar.set_description_str(desc="Pred") | |||||
| start_time = time.time() | |||||
| for batch_x, batch_y in data_iterator: | |||||
| _move_dict_value_to_device(batch_x, batch_y, device=self._model_device, | |||||
| non_blocking=self.pin_memory) | |||||
| with self.auto_cast(): | |||||
| pred_dict = self._data_forward(self._predict_func, batch_x) | |||||
| eval_results.extend(pred_dict['predict'].detach().cpu().numpy()) | |||||
| if self.use_tqdm: | |||||
| pbar.update() | |||||
| pbar.close() | |||||
| end_time = time.time() | |||||
| test_str = f'Predict data in {round(end_time - start_time, 2)} seconds!' | |||||
| if self.verbose >= 0: | |||||
| self.logger.info(test_str) | |||||
| except _CheckError as e: | |||||
| prev_func_signature = _get_func_signature(self._predict_func) | |||||
| _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||||
| check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||||
| dataset=self.data, check_level=0) | |||||
| finally: | |||||
| self._mode(network, is_test=False) | |||||
| print(f'预测完成') | |||||
| return eval_results | |||||