|
|
|
@@ -59,6 +59,8 @@ class BaseTester(object): |
|
|
|
self.batch_output.append(prediction) |
|
|
|
if self.save_loss: |
|
|
|
self.eval_history.append(eval_results) |
|
|
|
if step % n_print == 0: |
|
|
|
print('[test step: {:>4}]'.format(step)) |
|
|
|
step += 1 |
|
|
|
|
|
|
|
def prepare_input(self, data_path): |
|
|
|
@@ -134,7 +136,7 @@ class SeqLabelTester(BaseTester): |
|
|
|
results = torch.Tensor(prediction).view(-1,) |
|
|
|
# make sure "results" is in the same device as "truth" |
|
|
|
results = results.to(truth) |
|
|
|
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0] |
|
|
|
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] |
|
|
|
return [loss.data, accuracy.data] |
|
|
|
|
|
|
|
def metrics(self): |
|
|
|
@@ -153,7 +155,6 @@ class SeqLabelTester(BaseTester): |
|
|
|
def make_batch(self, iterator, data): |
|
|
|
return Action.make_batch(iterator, data, use_cuda=self.use_cuda, output_length=True) |
|
|
|
|
|
|
|
|
|
|
|
class ClassificationTester(BaseTester): |
|
|
|
"""Tester for classification.""" |
|
|
|
|
|
|
|
|