@@ -243,12 +243,11 @@ class AccuracyMetric(MetricBase): | |||||
def evaluate(self, pred, target, seq_lens=None): | def evaluate(self, pred, target, seq_lens=None): | ||||
""" | """ | ||||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | |||||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), | |||||
torch.Size([B, max_len, n_classes]) | |||||
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), | |||||
torch.Size([B, max_len]) | |||||
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||||
""" | """ | ||||
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | ||||
@@ -264,7 +263,7 @@ class AccuracyMetric(MetricBase): | |||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
if seq_lens is not None: | if seq_lens is not None: | ||||
masks = seq_lens_to_masks(seq_lens=seq_lens).long() | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens) | |||||
else: | else: | ||||
masks = None | masks = None | ||||
@@ -277,9 +276,9 @@ class AccuracyMetric(MetricBase): | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
target = target.to(pred) | |||||
if masks is not None: | if masks is not None: | ||||
self.acc_count += torch.sum(torch.eq(pred, target) * masks).item() | |||||
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item() | |||||
self.total += torch.sum(masks).item() | self.total += torch.sum(masks).item() | ||||
else: | else: | ||||
self.acc_count += torch.sum(torch.eq(pred, target)).item() | self.acc_count += torch.sum(torch.eq(pred, target)).item() | ||||
@@ -219,8 +219,8 @@ class TestDataSetMethods(unittest.TestCase): | |||||
def test_add_null(self): | def test_add_null(self): | ||||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | ||||
ds = DataSet() | ds = DataSet() | ||||
ds.add_field('test', []) | |||||
ds.set_target('test') | |||||
with self.assertRaises(RuntimeError) as RE: | |||||
ds.add_field('test', []) | |||||
class TestDataSetIter(unittest.TestCase): | class TestDataSetIter(unittest.TestCase): | ||||
@@ -15,7 +15,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_AccuracyMetric2(self): | def test_AccuracyMetric2(self): | ||||
@@ -30,7 +30,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | |||||
print("No exception catches.") | |||||
def test_AccuracyMetric3(self): | def test_AccuracyMetric3(self): | ||||
# (3) the second batch is corrupted size | # (3) the second batch is corrupted size | ||||
@@ -95,10 +95,9 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | self.assertAlmostEqual(res["acc"], float(ans), places=4) | ||||
def test_AccuaryMetric8(self): | def test_AccuaryMetric8(self): | ||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||||
try: | try: | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1} | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | metric(pred_dict=pred_dict, target_dict=target_dict, ) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
@@ -6,7 +6,7 @@ from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver | |||||
class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
config_file_dir = "test/io/" | |||||
config_file_dir = "test/io" | |||||
config_file_name = "config" | config_file_name = "config" | ||||
config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
@@ -17,11 +17,3 @@ class TestDatasetLoader(unittest.TestCase): | |||||
def test_PeopleDailyCorpusLoader(self): | def test_PeopleDailyCorpusLoader(self): | ||||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | ||||
def test_ConllCWSReader(self): | |||||
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt") | |||||
def test_ZhConllPOSReader(self): | |||||
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx") | |||||
def test_ConllxDataLoader(self): | |||||
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx") |
@@ -118,7 +118,7 @@ class TestCRF(unittest.TestCase): | |||||
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) | ||||
crf = ConditionalRandomField(num_tags, include_start_end_trans) | crf = ConditionalRandomField(num_tags, include_start_end_trans) | ||||
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) | ||||
for _ in range(10000): | |||||
for _ in range(10): | |||||
loss = crf(feats, tags, masks).mean() | loss = crf(feats, tags, masks).mean() | ||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
@@ -152,7 +152,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data=train_data, | train_data=train_data, | ||||
dev_data=dev_data, | dev_data=dev_data, | ||||
loss=CrossEntropyLoss(), | loss=CrossEntropyLoss(), | ||||
metrics=AccuracyMetric() | |||||
metrics=AccuracyMetric(target='label_seq') | |||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||
@@ -407,7 +407,7 @@ class TestTutorial(unittest.TestCase): | |||||
train_data=train_data, | train_data=train_data, | ||||
model=model, | model=model, | ||||
loss=CrossEntropyLoss(pred='pred', target='label'), | loss=CrossEntropyLoss(pred='pred', target='label'), | ||||
metrics=AccuracyMetric(), | |||||
metrics=AccuracyMetric(target='label'), | |||||
n_epochs=3, | n_epochs=3, | ||||
batch_size=16, | batch_size=16, | ||||
print_every=-1, | print_every=-1, | ||||
@@ -424,7 +424,7 @@ class TestTutorial(unittest.TestCase): | |||||
tester = Tester( | tester = Tester( | ||||
data=test_data, | data=test_data, | ||||
model=model, | model=model, | ||||
metrics=AccuracyMetric(), | |||||
metrics=AccuracyMetric(target='label'), | |||||
batch_size=args["batch_size"], | batch_size=args["batch_size"], | ||||
) | ) | ||||
tester.test() | tester.test() | ||||