From abd10e24a8457752d2e441fdbaa4934b5762f3e7 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 10:47:16 +0800 Subject: [PATCH 1/5] =?UTF-8?q?1.=20=E4=BF=AE=E5=A4=8DTester=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E8=BF=87=E7=A8=8B=E4=B8=AD=E5=87=BA=E7=8E=B0=E5=BC=82?= =?UTF-8?q?=E5=B8=B8model=E4=B8=8D=E8=83=BD=E9=87=8D=E7=BD=AE=E4=B8=BATrai?= =?UTF-8?q?ning=E7=8A=B6=E6=80=81=E7=9A=84bug;=202.=20FitlogCallback,=20Ev?= =?UTF-8?q?aluateCallback=E5=9C=A8=E9=81=AD=E9=81=87exception=E6=97=B6?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5raise=EF=BC=8C=E9=98=B2=E6=AD=A2=E5=87=BA?= =?UTF-8?q?=E7=8E=B0metric=E6=B2=A1=E6=9C=89=E9=87=8D=E6=96=B0=E5=BD=92?= =?UTF-8?q?=E9=9B=B6=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callback.py | 9 +++++---- fastNLP/core/tester.py | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index ad417340..095ebc3d 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -592,9 +592,10 @@ class FitlogCallback(Callback): fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) if better_result: fitlog.add_best_metric(eval_result, name=key) - except Exception: + except Exception as e: self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) - + raise e + def on_train_end(self): fitlog.finish() @@ -660,9 +661,9 @@ class EvaluateCallback(Callback): eval_result = tester.test() self.logger.info("EvaluateCallback evaluation on {}:".format(key)) self.logger.info(tester._format_eval_results(eval_result)) - except Exception: + except Exception as e: self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) - + raise e class LRScheduler(Callback): """ diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index e92eb422..8e2ac8a7 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -189,10 +189,10 @@ class Tester(object): _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=self.data, check_level=0) - + finally: + self._mode(network, is_test=False) if self.verbose >= 1: logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) - self._mode(network, is_test=False) return eval_results def _mode(self, model, is_test=False): From 1eec5b234b3bf938dd9affde22ee5460493f7421 Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 19:14:50 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8DCrossEntropyLoss=E5=9C=A8?= =?UTF-8?q?seq=5Flen=E4=B8=8D=E4=B8=BANone=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=E4=BC=9A=E5=87=BA=E7=8E=B0=E8=AE=A1=E7=AE=97=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 2166734d..909e90a9 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -228,17 +228,18 @@ class CrossEntropyLoss(LossBase): self.class_in_dim = class_in_dim def get_loss(self, pred, target, seq_len=None): + if seq_len is not None and target.dim()>1: + mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(0) + target = target.masked_fill(mask, self.padding_idx) + if pred.dim() > 2: if self.class_in_dim == -1: if pred.size(1) != target.size(1): # 有可能顺序替换了 pred = pred.transpose(1, 2) else: - pred = pred.tranpose(-1, pred) + pred = pred.transpose(1, 2) pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) - if seq_len is not None and target.dim()>1: - mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) - target = target.masked_fill(mask, self.padding_idx) return F.cross_entropy(input=pred, target=target, ignore_index=self.padding_idx, reduction=self.reduction) From bf800dc30406c0ed4b171f75a02aca56d1f7645d Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 19:18:03 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E4=BF=AE=E5=A4=8DCrossEntropy=E6=9C=AA?= =?UTF-8?q?=E4=BD=BF=E7=94=A8class=5Fin=5Fdim=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 909e90a9..19fb5724 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -237,7 +237,7 @@ class CrossEntropyLoss(LossBase): if pred.size(1) != target.size(1): # 有可能顺序替换了 pred = pred.transpose(1, 2) else: - pred = pred.transpose(1, 2) + pred = pred.transpose(-1, self.class_in_dim) pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) From 9c0190fbd82f4c50956d41e32a6370eb93317add Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 19:28:57 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=AF=B9CrossEntropyLoss?= =?UTF-8?q?=E4=B8=ADclass=5Fin=5Fdim=E7=9A=84=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/core/test_loss.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/core/test_loss.py b/test/core/test_loss.py index 9ba8159f..a57e6542 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -13,6 +13,18 @@ class TestLoss(unittest.TestCase): b = torch.empty(3, dtype=torch.long).random_(5) ans = ce({"my_predict": a}, {"my_truth": b}) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=1) + a = torch.randn(3, 4, 3) + b = torch.randint(3, (3, 3)) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + + ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) + a = torch.randn(3, 4, 3) + b = torch.randint(3, (3, 4)) + ans = ce({"my_predict": a}, {"my_truth": b}) + self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b)) def test_BCELoss(self): bce = loss.BCELoss(pred="my_predict", target="my_truth") From f887da12a15d70acf5570b27b252ba2ba3e0b14f Mon Sep 17 00:00:00 2001 From: yh Date: Tue, 15 Oct 2019 22:32:33 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E8=A7=A3=E5=86=B3CrossEntropyLoss=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E5=9B=A0=E4=B8=BA=E6=95=B0=E5=80=BC=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E6=97=A0=E6=B3=95=E9=80=9A=E8=BF=87=E6=B5=8B=E8=AF=95=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/core/test_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/core/test_loss.py b/test/core/test_loss.py index a57e6542..976285a9 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -18,13 +18,13 @@ class TestLoss(unittest.TestCase): a = torch.randn(3, 4, 3) b = torch.randint(3, (3, 3)) ans = ce({"my_predict": a}, {"my_truth": b}) - self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) + self.assertAlmostEqual(ans.item(), torch.nn.functional.cross_entropy(a, b).item(), places=4) ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth", class_in_dim=2) a = torch.randn(3, 4, 3) b = torch.randint(3, (3, 4)) ans = ce({"my_predict": a}, {"my_truth": b}) - self.assertEqual(ans, torch.nn.functional.cross_entropy(a.transpose(1, 2), b)) + self.assertAlmostEqual(ans.item(), torch.nn.functional.cross_entropy(a.transpose(1, 2), b).item(), places=4) def test_BCELoss(self): bce = loss.BCELoss(pred="my_predict", target="my_truth")