@@ -592,9 +592,10 @@ class FitlogCallback(Callback): | |||||
fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) | fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) | ||||
if better_result: | if better_result: | ||||
fitlog.add_best_metric(eval_result, name=key) | fitlog.add_best_metric(eval_result, name=key) | ||||
except Exception: | |||||
except Exception as e: | |||||
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | ||||
raise e | |||||
def on_train_end(self): | def on_train_end(self): | ||||
fitlog.finish() | fitlog.finish() | ||||
@@ -660,9 +661,9 @@ class EvaluateCallback(Callback): | |||||
eval_result = tester.test() | eval_result = tester.test() | ||||
self.logger.info("EvaluateCallback evaluation on {}:".format(key)) | self.logger.info("EvaluateCallback evaluation on {}:".format(key)) | ||||
self.logger.info(tester._format_eval_results(eval_result)) | self.logger.info(tester._format_eval_results(eval_result)) | ||||
except Exception: | |||||
except Exception as e: | |||||
self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) | self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) | ||||
raise e | |||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
""" | """ | ||||
@@ -228,17 +228,18 @@ class CrossEntropyLoss(LossBase): | |||||
self.class_in_dim = class_in_dim | self.class_in_dim = class_in_dim | ||||
def get_loss(self, pred, target, seq_len=None): | def get_loss(self, pred, target, seq_len=None): | ||||
if seq_len is not None and target.dim()>1: | |||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) | |||||
target = target.masked_fill(mask, self.padding_idx) | |||||
if pred.dim() > 2: | if pred.dim() > 2: | ||||
if self.class_in_dim == -1: | if self.class_in_dim == -1: | ||||
if pred.size(1) != target.size(1): # 有可能顺序替换了 | if pred.size(1) != target.size(1): # 有可能顺序替换了 | ||||
pred = pred.transpose(1, 2) | pred = pred.transpose(1, 2) | ||||
else: | else: | ||||
pred = pred.tranpose(-1, pred) | |||||
pred = pred.transpose(-1, self.class_in_dim) | |||||
pred = pred.reshape(-1, pred.size(-1)) | pred = pred.reshape(-1, pred.size(-1)) | ||||
target = target.reshape(-1) | target = target.reshape(-1) | ||||
if seq_len is not None and target.dim()>1: | |||||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) | |||||
target = target.masked_fill(mask, self.padding_idx) | |||||
return F.cross_entropy(input=pred, target=target, | return F.cross_entropy(input=pred, target=target, | ||||
ignore_index=self.padding_idx, reduction=self.reduction) | ignore_index=self.padding_idx, reduction=self.reduction) | ||||
@@ -189,10 +189,10 @@ class Tester(object): | |||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | ||||
dataset=self.data, check_level=0) | dataset=self.data, check_level=0) | ||||
finally: | |||||
self._mode(network, is_test=False) | |||||
if self.verbose >= 1: | if self.verbose >= 1: | ||||
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) | logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) | ||||
self._mode(network, is_test=False) | |||||
return eval_results | return eval_results | ||||
def _mode(self, model, is_test=False): | def _mode(self, model, is_test=False): | ||||
@@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase): | |||||
b = torch.empty(3, dtype=torch.long).random_(5) | b = torch.empty(3, dtype=torch.long).random_(5) | ||||
ans = ce({"my_predict": a}, {"my_truth": b}) | ans = ce({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | ||||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1) | |||||
a = torch.randn(3, 4, 3) | |||||
b = torch.randint(3, (3, 3)) | |||||
ans = ce({"my_predict": a}, {"my_truth": b}) | |||||
self.assertAlmostEqual(ans.item(), torch.nn.functional.cross_entropy(a, b).item(), places=4) | |||||
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) | |||||
a = torch.randn(3, 4, 3) | |||||
b = torch.randint(3, (3, 4)) | |||||
ans = ce({"my_predict": a}, {"my_truth": b}) | |||||
self.assertAlmostEqual(ans.item(), torch.nn.functional.cross_entropy(a.transpose(1, 2), b).item(), places=4) | |||||
def test_BCELoss(self): | def test_BCELoss(self): | ||||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | bce = loss.BCELoss(pred="my_predict", target="my_truth") | ||||