Browse Source

fix bug in SeqLabelTester

tags/v0.1.0
Yunfan Shao 7 years ago
parent
commit
762a559fab
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      fastNLP/core/tester.py

+ 3
- 2
fastNLP/core/tester.py View File

@@ -59,6 +59,8 @@ class BaseTester(object):
self.batch_output.append(prediction)
if self.save_loss:
self.eval_history.append(eval_results)
if step % n_print == 0:
print('[test step: {:>4}]'.format(step))
step += 1

def prepare_input(self, data_path):
@@ -134,7 +136,7 @@ class SeqLabelTester(BaseTester):
results = torch.Tensor(prediction).view(-1,)
# make sure "results" is in the same device as "truth"
results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0]
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [loss.data, accuracy.data]

def metrics(self):
@@ -153,7 +155,6 @@ class SeqLabelTester(BaseTester):
def make_batch(self, iterator, data):
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True)


class ClassificationTester(BaseTester):
"""Tester for classification."""



Loading…
Cancel
Save