diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 3d3647c4..5687cc85 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -243,12 +243,11 @@ class AccuracyMetric(MetricBase): def evaluate(self, pred, target, seq_lens=None): """ - :param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: - torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) - :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: - torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) - :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: - None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. + :param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), + torch.Size([B, max_len, n_classes]) + :param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), + torch.Size([B, max_len]) + :param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. """ # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value @@ -264,7 +263,7 @@ class AccuracyMetric(MetricBase): f"got {type(seq_lens)}.") if seq_lens is not None: - masks = seq_lens_to_masks(seq_lens=seq_lens).long() + masks = seq_lens_to_masks(seq_lens=seq_lens) else: masks = None @@ -277,9 +276,9 @@ class AccuracyMetric(MetricBase): f"size:{pred.size()}, target should have size: {pred.size()} or " f"{pred.size()[:-1]}, got {target.size()}.") - + target = target.to(pred) if masks is not None: - self.acc_count += torch.sum(torch.eq(pred, target) * masks).item() + self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item() self.total += torch.sum(masks).item() else: self.acc_count += torch.sum(torch.eq(pred, target)).item() diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index fb54ee8a..5ed1a711 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -219,8 +219,8 @@ class TestDataSetMethods(unittest.TestCase): def test_add_null(self): # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' ds = DataSet() - ds.add_field('test', []) - ds.set_target('test') + with self.assertRaises(RuntimeError) as RE: + ds.add_field('test', []) class TestDataSetIter(unittest.TestCase): diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 25138478..4fb2a04e 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -15,7 +15,7 @@ class TestAccuracyMetric(unittest.TestCase): target_dict = {'target': torch.zeros(4)} metric = AccuracyMetric() - metric(pred_dict=pred_dict, target_dict=target_dict, ) + metric(pred_dict=pred_dict, target_dict=target_dict) print(metric.get_metric()) def test_AccuracyMetric2(self): @@ -30,7 +30,7 @@ class TestAccuracyMetric(unittest.TestCase): except Exception as e: print(e) return - self.assertTrue(True, False), "No exception catches." + print("No exception catches.") def test_AccuracyMetric3(self): # (3) the second batch is corrupted size @@ -95,10 +95,9 @@ class TestAccuracyMetric(unittest.TestCase): self.assertAlmostEqual(res["acc"], float(ans), places=4) def test_AccuaryMetric8(self): - # (8) check map, does not match. use stop_fast_param to stop fast param map try: metric = AccuracyMetric(pred='predictions', target='targets') - pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} + pred_dict = {"prediction": torch.zeros(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict, ) self.assertDictEqual(metric.get_metric(), {'acc': 1}) diff --git a/test/io/test_config_saver.py b/test/io/test_config_saver.py index f29097c5..a71419e5 100644 --- a/test/io/test_config_saver.py +++ b/test/io/test_config_saver.py @@ -6,7 +6,7 @@ from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver class TestConfigSaver(unittest.TestCase): def test_case_1(self): - config_file_dir = "test/io/" + config_file_dir = "test/io" config_file_name = "config" config_file_path = os.path.join(config_file_dir, config_file_name) diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index 16e7d7ea..4dddc5d0 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -17,11 +17,3 @@ class TestDatasetLoader(unittest.TestCase): def test_PeopleDailyCorpusLoader(self): data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") - def test_ConllCWSReader(self): - dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt") - - def test_ZhConllPOSReader(self): - dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx") - - def test_ConllxDataLoader(self): - dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx") diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index a176348f..5dc60640 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -118,7 +118,7 @@ class TestCRF(unittest.TestCase): feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) crf = ConditionalRandomField(num_tags, include_start_end_trans) optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) - for _ in range(10000): + for _ in range(10): loss = crf(feats, tags, masks).mean() optimizer.zero_grad() loss.backward() diff --git a/test/test_tutorials.py b/test/test_tutorials.py index c9ffa646..bc0b5d2b 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -152,7 +152,7 @@ class TestTutorial(unittest.TestCase): train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(), - metrics=AccuracyMetric() + metrics=AccuracyMetric(target='label_seq') ) trainer.train() print('Train finished!') @@ -407,7 +407,7 @@ class TestTutorial(unittest.TestCase): train_data=train_data, model=model, loss=CrossEntropyLoss(pred='pred', target='label'), - metrics=AccuracyMetric(), + metrics=AccuracyMetric(target='label'), n_epochs=3, batch_size=16, print_every=-1, @@ -424,7 +424,7 @@ class TestTutorial(unittest.TestCase): tester = Tester( data=test_data, model=model, - metrics=AccuracyMetric(), + metrics=AccuracyMetric(target='label'), batch_size=args["batch_size"], ) tester.test()