|
|
@@ -1,5 +1,6 @@ |
|
|
|
import _pickle |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP.action.action import Action |
|
|
@@ -16,8 +17,7 @@ class BaseTester(Action): |
|
|
|
""" |
|
|
|
super(BaseTester, self).__init__() |
|
|
|
self.validate_in_training = test_args["validate_in_training"] |
|
|
|
self.valid_x = None |
|
|
|
self.valid_y = None |
|
|
|
self.save_dev_data = None |
|
|
|
self.save_output = test_args["save_output"] |
|
|
|
self.output = None |
|
|
|
self.save_loss = test_args["save_loss"] |
|
|
@@ -26,8 +26,14 @@ class BaseTester(Action): |
|
|
|
self.pickle_path = test_args["pickle_path"] |
|
|
|
self.iterator = None |
|
|
|
|
|
|
|
self.model = None |
|
|
|
self.eval_history = [] |
|
|
|
|
|
|
|
def test(self, network): |
|
|
|
# print("--------------testing----------------") |
|
|
|
self.model = network |
|
|
|
|
|
|
|
# turn on the testing mode; clean up the history |
|
|
|
self.mode(network, test=True) |
|
|
|
|
|
|
|
dev_data = self.prepare_input(self.pickle_path) |
|
|
@@ -35,7 +41,6 @@ class BaseTester(Action): |
|
|
|
self.iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) |
|
|
|
|
|
|
|
batch_output = list() |
|
|
|
eval_history = list() |
|
|
|
num_iter = len(dev_data) // self.batch_size |
|
|
|
|
|
|
|
for step in range(num_iter): |
|
|
@@ -47,11 +52,18 @@ class BaseTester(Action): |
|
|
|
if self.save_output: |
|
|
|
batch_output.append(prediction) |
|
|
|
if self.save_loss: |
|
|
|
eval_history.append(eval_results) |
|
|
|
self.eval_history.append(eval_results) |
|
|
|
|
|
|
|
def prepare_input(self, data_path): |
|
|
|
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) |
|
|
|
return data_dev |
|
|
|
""" |
|
|
|
Save the dev data once it is loaded. Can return directly next time. |
|
|
|
:param data_path: str, the path to the pickle data for dev |
|
|
|
:return save_dev_data: list. Each entry is a sample, which is also a list of features and label(s). |
|
|
|
""" |
|
|
|
if self.save_dev_data is None: |
|
|
|
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) |
|
|
|
self.save_dev_data = data_dev |
|
|
|
return self.save_dev_data |
|
|
|
|
|
|
|
def batchify(self, data): |
|
|
|
""" |
|
|
@@ -99,11 +111,12 @@ class BaseTester(Action): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def mode(self, model, test=True): |
|
|
|
"""To do: combine this function with Trainer""" |
|
|
|
"""To do: combine this function with Trainer ?? """ |
|
|
|
if test: |
|
|
|
model.eval() |
|
|
|
else: |
|
|
|
model.train() |
|
|
|
self.eval_history.clear() |
|
|
|
|
|
|
|
|
|
|
|
class POSTester(BaseTester): |
|
|
@@ -115,6 +128,7 @@ class POSTester(BaseTester): |
|
|
|
super(POSTester, self).__init__(test_args) |
|
|
|
self.max_len = None |
|
|
|
self.mask = None |
|
|
|
self.batch_result = None |
|
|
|
|
|
|
|
def data_forward(self, network, x): |
|
|
|
"""To Do: combine with Trainer |
|
|
@@ -132,5 +146,9 @@ class POSTester(BaseTester): |
|
|
|
return y |
|
|
|
|
|
|
|
def evaluate(self, predict, truth): |
|
|
|
"""To Do: """ |
|
|
|
return 0 |
|
|
|
truth = torch.Tensor(truth) |
|
|
|
loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len) |
|
|
|
return loss.data |
|
|
|
|
|
|
|
def matrices(self): |
|
|
|
return np.mean(self.eval_history) |