Browse Source

准备发布0.4.0版本“

tags/v0.4.10
yh_cc 6 years ago
parent
commit
29f81e79ad
7 changed files with 18 additions and 28 deletions
  1. +8
    -9
      fastNLP/core/metrics.py
  2. +2
    -2
      test/core/test_dataset.py
  3. +3
    -4
      test/core/test_metrics.py
  4. +1
    -1
      test/io/test_config_saver.py
  5. +0
    -8
      test/io/test_dataset_loader.py
  6. +1
    -1
      test/modules/decoder/test_CRF.py
  7. +3
    -3
      test/test_tutorials.py

+ 8
- 9
fastNLP/core/metrics.py View File

@@ -243,12 +243,11 @@ class AccuracyMetric(MetricBase):
def evaluate(self, pred, target, seq_lens=None):
"""

:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]),
torch.Size([B, max_len, n_classes])
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]),
torch.Size([B, max_len])
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.

"""
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
@@ -264,7 +263,7 @@ class AccuracyMetric(MetricBase):
f"got {type(seq_lens)}.")

if seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens).long()
masks = seq_lens_to_masks(seq_lens=seq_lens)
else:
masks = None

@@ -277,9 +276,9 @@ class AccuracyMetric(MetricBase):
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

target = target.to(pred)
if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target) * masks).item()
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item()
self.total += torch.sum(masks).item()
else:
self.acc_count += torch.sum(torch.eq(pred, target)).item()


+ 2
- 2
test/core/test_dataset.py View File

@@ -219,8 +219,8 @@ class TestDataSetMethods(unittest.TestCase):
def test_add_null(self):
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError'
ds = DataSet()
ds.add_field('test', [])
ds.set_target('test')
with self.assertRaises(RuntimeError) as RE:
ds.add_field('test', [])


class TestDataSetIter(unittest.TestCase):


+ 3
- 4
test/core/test_metrics.py View File

@@ -15,7 +15,7 @@ class TestAccuracyMetric(unittest.TestCase):
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())

def test_AccuracyMetric2(self):
@@ -30,7 +30,7 @@ class TestAccuracyMetric(unittest.TestCase):
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
print("No exception catches.")

def test_AccuracyMetric3(self):
# (3) the second batch is corrupted size
@@ -95,10 +95,9 @@ class TestAccuracyMetric(unittest.TestCase):
self.assertAlmostEqual(res["acc"], float(ans), places=4)

def test_AccuaryMetric8(self):
# (8) check map, does not match. use stop_fast_param to stop fast param map
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1}
pred_dict = {"prediction": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1})


+ 1
- 1
test/io/test_config_saver.py View File

@@ -6,7 +6,7 @@ from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver

class TestConfigSaver(unittest.TestCase):
def test_case_1(self):
config_file_dir = "test/io/"
config_file_dir = "test/io"
config_file_name = "config"
config_file_path = os.path.join(config_file_dir, config_file_name)



+ 0
- 8
test/io/test_dataset_loader.py View File

@@ -17,11 +17,3 @@ class TestDatasetLoader(unittest.TestCase):
def test_PeopleDailyCorpusLoader(self):
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt")

def test_ConllCWSReader(self):
dataset = ConllCWSReader().load("test/data_for_tests/conll_example.txt")

def test_ZhConllPOSReader(self):
dataset = ZhConllPOSReader().load("test/data_for_tests/zh_sample.conllx")

def test_ConllxDataLoader(self):
dataset = ConllxDataLoader().load("test/data_for_tests/zh_sample.conllx")

+ 1
- 1
test/modules/decoder/test_CRF.py View File

@@ -118,7 +118,7 @@ class TestCRF(unittest.TestCase):
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
crf = ConditionalRandomField(num_tags, include_start_end_trans)
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
for _ in range(10000):
for _ in range(10):
loss = crf(feats, tags, masks).mean()
optimizer.zero_grad()
loss.backward()


+ 3
- 3
test/test_tutorials.py View File

@@ -152,7 +152,7 @@ class TestTutorial(unittest.TestCase):
train_data=train_data,
dev_data=dev_data,
loss=CrossEntropyLoss(),
metrics=AccuracyMetric()
metrics=AccuracyMetric(target='label_seq')
)
trainer.train()
print('Train finished!')
@@ -407,7 +407,7 @@ class TestTutorial(unittest.TestCase):
train_data=train_data,
model=model,
loss=CrossEntropyLoss(pred='pred', target='label'),
metrics=AccuracyMetric(),
metrics=AccuracyMetric(target='label'),
n_epochs=3,
batch_size=16,
print_every=-1,
@@ -424,7 +424,7 @@ class TestTutorial(unittest.TestCase):
tester = Tester(
data=test_data,
model=model,
metrics=AccuracyMetric(),
metrics=AccuracyMetric(target='label'),
batch_size=args["batch_size"],
)
tester.test()


Loading…
Cancel
Save