- change parameter <seq_length-->mask> in loss function defined in seq model - Trainer & Tester have Action as default parameter, shared static methods like make_batch - add seq_len in make_batch of Inference - add SeqLabelInfer, a subclass of Inference - seq_labeling.py workstags/v0.1.0
@@ -4,20 +4,16 @@ | |||||
""" | """ | ||||
from collections import Counter | from collections import Counter | ||||
import torch | |||||
import numpy as np | import numpy as np | ||||
import _pickle | |||||
class Action(object): | class Action(object): | ||||
""" | """ | ||||
Operations shared by Trainer, Tester, and Inference. | Operations shared by Trainer, Tester, and Inference. | ||||
This is designed for reducing replicate codes. | This is designed for reducing replicate codes. | ||||
- prepare_input: data preparation before a forward pass. | |||||
- make_batch: produce a min-batch of data. @staticmethod | - make_batch: produce a min-batch of data. @staticmethod | ||||
- pad: padding method used in sequence modeling. @staticmethod | - pad: padding method used in sequence modeling. @staticmethod | ||||
- mode: change network mode for either train or test. (for PyTorch) @staticmethod | - mode: change network mode for either train or test. (for PyTorch) @staticmethod | ||||
- data_forward: a forward pass of the network. | |||||
The base Action shall define operations shared by as much task-specific Actions as possible. | The base Action shall define operations shared by as much task-specific Actions as possible. | ||||
""" | """ | ||||
@@ -83,47 +79,6 @@ class Action(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def data_forward(self, network, x): | |||||
""" | |||||
Forward pass of the data. | |||||
:param network: a model | |||||
:param x: input feature matrix and label vector | |||||
:return: output by the models | |||||
For PyTorch, just do "network(*x)" | |||||
""" | |||||
raise NotImplementedError | |||||
class SeqLabelAction(Action): | |||||
def __init__(self, action_args): | |||||
""" | |||||
Define task-specific member variables. | |||||
:param action_args: | |||||
""" | |||||
super(SeqLabelAction, self).__init__() | |||||
self.max_len = None | |||||
self.mask = None | |||||
self.best_accuracy = 0.0 | |||||
self.use_cuda = action_args["use_cuda"] | |||||
self.seq_len = None | |||||
self.batch_size = None | |||||
def data_forward(self, network, inputs): | |||||
# unpack the returned value from make_batch | |||||
if isinstance(inputs, tuple): | |||||
x = inputs[0] | |||||
self.seq_len = inputs[1] | |||||
else: | |||||
x = inputs | |||||
x = torch.Tensor(x).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
self.batch_size = x.size(0) | |||||
self.max_len = x.size(1) | |||||
y = network(x) | |||||
return y | |||||
def k_means_1d(x, k, max_iter=100): | def k_means_1d(x, k, max_iter=100): | ||||
""" | """ | ||||
@@ -1,7 +1,9 @@ | |||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.action import Batchifier, SequentialSampler | from fastNLP.core.action import Batchifier, SequentialSampler | ||||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | ||||
from fastNLP.modules import utils | |||||
class Inference(object): | class Inference(object): | ||||
@@ -32,13 +34,14 @@ class Inference(object): | |||||
# turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
self.mode(network, test=True) | self.mode(network, test=True) | ||||
self.batch_output.clear() | |||||
self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||||
iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||||
num_iter = len(data) // self.batch_size | num_iter = len(data) // self.batch_size | ||||
for step in range(num_iter): | for step in range(num_iter): | ||||
batch_x = self.make_batch(data) | |||||
batch_x = self.make_batch(iterator, data) | |||||
prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
@@ -54,26 +57,18 @@ class Inference(object): | |||||
self.batch_output.clear() | self.batch_output.clear() | ||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
""" | |||||
This is only for sequence labeling with CRF decoder. TODO: more general ? | |||||
:param network: | |||||
:param x: | |||||
:return: | |||||
""" | |||||
seq_len = [len(seq) for seq in x] | |||||
x = torch.Tensor(x).long() | |||||
y = network(x) | |||||
prediction = network.prediction(y, seq_len) | |||||
# To do: hide framework | |||||
results = torch.Tensor(prediction).view(-1, ) | |||||
return list(results.data) | |||||
raise NotImplementedError | |||||
def make_batch(self, data): | |||||
indices = next(self.iterator) | |||||
@staticmethod | |||||
def make_batch(iterator, data, output_length=True): | |||||
indices = next(iterator) | |||||
batch_x = [data[idx] for idx in indices] | batch_x = [data[idx] for idx in indices] | ||||
if self.batch_size > 1: | |||||
batch_x = self.pad(batch_x) | |||||
return batch_x | |||||
batch_x_pad = Inference.pad(batch_x) | |||||
if output_length: | |||||
seq_len = [len(x) for x in batch_x] | |||||
return [batch_x_pad, seq_len] | |||||
else: | |||||
return batch_x_pad | |||||
@staticmethod | @staticmethod | ||||
def pad(batch, fill=0): | def pad(batch, fill=0): | ||||
@@ -86,7 +81,7 @@ class Inference(object): | |||||
max_length = max([len(x) for x in batch]) | max_length = max([len(x) for x in batch]) | ||||
for idx, sample in enumerate(batch): | for idx, sample in enumerate(batch): | ||||
if len(sample) < max_length: | if len(sample) < max_length: | ||||
batch[idx] = sample + [fill * (max_length - len(sample))] | |||||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
return batch | return batch | ||||
def prepare_input(self, data): | def prepare_input(self, data): | ||||
@@ -109,10 +104,39 @@ class Inference(object): | |||||
def prepare_output(self, batch_outputs): | def prepare_output(self, batch_outputs): | ||||
""" | """ | ||||
Transform list of batch outputs into strings. | Transform list of batch outputs into strings. | ||||
:param batch_outputs: list of list, of shape [num_batch, tag_seq_length]. Element type is Tensor. | |||||
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. | |||||
:return: | :return: | ||||
""" | """ | ||||
results = [] | results = [] | ||||
for batch in batch_outputs: | for batch in batch_outputs: | ||||
results.append([self.index2label[int(x.data)] for x in batch]) | |||||
for example in np.array(batch): | |||||
results.append([self.index2label[int(x)] for x in example]) | |||||
return results | return results | ||||
class SeqLabelInfer(Inference): | |||||
""" | |||||
Inference on sequence labeling models. | |||||
""" | |||||
def __init__(self, pickle_path): | |||||
super(SeqLabelInfer, self).__init__(pickle_path) | |||||
def data_forward(self, network, inputs): | |||||
""" | |||||
This is only for sequence labeling with CRF decoder. | |||||
:param network: | |||||
:param inputs: | |||||
:return: Tensor | |||||
""" | |||||
if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | |||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||||
# unpack the returned value from make_batch | |||||
x, seq_len = inputs[0], inputs[1] | |||||
x = torch.Tensor(x).long() | |||||
batch_size, max_len = x.size(0), x.size(1) | |||||
mask = utils.seq_mask(seq_len, max_len) | |||||
mask = mask.byte().view(batch_size, max_len) | |||||
y = network(x) | |||||
prediction = network.prediction(y, mask) | |||||
return torch.Tensor(prediction) |
@@ -6,17 +6,18 @@ import torch | |||||
from fastNLP.core.action import Action | from fastNLP.core.action import Action | ||||
from fastNLP.core.action import RandomSampler, Batchifier | from fastNLP.core.action import RandomSampler, Batchifier | ||||
from fastNLP.modules import utils | |||||
class BaseTester(Action): | class BaseTester(Action): | ||||
"""docstring for Tester""" | """docstring for Tester""" | ||||
def __init__(self, test_args, action): | |||||
def __init__(self, test_args, action=None): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(BaseTester, self).__init__() | super(BaseTester, self).__init__() | ||||
self.action = action | |||||
self.action = action if action is not None else Action() | |||||
self.validate_in_training = test_args["validate_in_training"] | self.validate_in_training = test_args["validate_in_training"] | ||||
self.save_dev_data = None | self.save_dev_data = None | ||||
self.save_output = test_args["save_output"] | self.save_output = test_args["save_output"] | ||||
@@ -52,7 +53,7 @@ class BaseTester(Action): | |||||
for step in range(num_iter): | for step in range(num_iter): | ||||
batch_x, batch_y = self.action.make_batch(iterator, dev_data) | batch_x, batch_y = self.action.make_batch(iterator, dev_data) | ||||
prediction = self.action.data_forward(network, batch_x) | |||||
prediction = self.data_forward(network, batch_x) | |||||
eval_results = self.evaluate(prediction, batch_y) | eval_results = self.evaluate(prediction, batch_y) | ||||
@@ -72,6 +73,9 @@ class BaseTester(Action): | |||||
self.save_dev_data = data_dev | self.save_dev_data = data_dev | ||||
return self.save_dev_data | return self.save_dev_data | ||||
def data_forward(self, network, x): | |||||
raise NotImplementedError | |||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -92,7 +96,7 @@ class POSTester(BaseTester): | |||||
Tester for sequence labeling. | Tester for sequence labeling. | ||||
""" | """ | ||||
def __init__(self, test_args, action): | |||||
def __init__(self, test_args, action=None): | |||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
@@ -101,17 +105,37 @@ class POSTester(BaseTester): | |||||
self.mask = None | self.mask = None | ||||
self.batch_result = None | self.batch_result = None | ||||
def data_forward(self, network, inputs): | |||||
if not isinstance(inputs, tuple): | |||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||||
# unpack the returned value from make_batch | |||||
x, seq_len = inputs[0], inputs[1] | |||||
x = torch.Tensor(x).long() | |||||
batch_size, max_len = x.size(0), x.size(1) | |||||
mask = utils.seq_mask(seq_len, max_len) | |||||
mask = mask.byte().view(batch_size, max_len) | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
mask = mask.cuda() | |||||
self.mask = mask | |||||
y = network(x) | |||||
return y | |||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
truth = truth.cuda() | truth = truth.cuda() | ||||
loss = self.model.loss(predict, truth, self.action.seq_len) / self.batch_size | |||||
prediction = self.model.prediction(predict, self.action.seq_len) | |||||
batch_size, max_len = predict.size(0), predict.size(1) | |||||
loss = self.model.loss(predict, truth, self.mask) / batch_size | |||||
prediction = self.model.prediction(predict, self.mask) | |||||
results = torch.Tensor(prediction).view(-1,) | results = torch.Tensor(prediction).view(-1,) | ||||
if torch.cuda.is_available() and self.use_cuda: | |||||
results = results.cuda() | |||||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | |||||
return [loss.data, accuracy] | |||||
# make sure "results" is in the same device as "truth" | |||||
results = results.to(truth) | |||||
accuracy = torch.sum(results == truth.view((-1,))) / results.shape[0] | |||||
return [loss.data, accuracy.data] | |||||
def metrics(self): | def metrics(self): | ||||
batch_loss = np.mean([x[0] for x in self.eval_history]) | batch_loss = np.mean([x[0] for x in self.eval_history]) | ||||
@@ -8,8 +8,9 @@ import torch | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core.action import Action | from fastNLP.core.action import Action | ||||
from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler | |||||
from fastNLP.core.action import RandomSampler, Batchifier | |||||
from fastNLP.core.tester import POSTester | from fastNLP.core.tester import POSTester | ||||
from fastNLP.modules import utils | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
@@ -23,10 +24,10 @@ class BaseTrainer(Action): | |||||
- get_loss | - get_loss | ||||
""" | """ | ||||
def __init__(self, train_args, action): | |||||
def __init__(self, train_args, action=None): | |||||
""" | """ | ||||
:param train_args: dict of (key, value), or dict-like object. key is str. | :param train_args: dict of (key, value), or dict-like object. key is str. | ||||
:param action: an Action object that wrap most operations shared by Trainer, Tester, and Inference. | |||||
:param action: (optional) an Action object that wrap most operations shared by Trainer, Tester, and Inference. | |||||
The base trainer requires the following keys: | The base trainer requires the following keys: | ||||
- epochs: int, the number of epochs in training | - epochs: int, the number of epochs in training | ||||
@@ -35,7 +36,7 @@ class BaseTrainer(Action): | |||||
- pickle_path: str, the path to pickle files for pre-processing | - pickle_path: str, the path to pickle files for pre-processing | ||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.action = action | |||||
self.action = action if action is not None else Action() | |||||
self.n_epochs = train_args["epochs"] | self.n_epochs = train_args["epochs"] | ||||
self.batch_size = train_args["batch_size"] | self.batch_size = train_args["batch_size"] | ||||
self.pickle_path = train_args["pickle_path"] | self.pickle_path = train_args["pickle_path"] | ||||
@@ -94,7 +95,7 @@ class BaseTrainer(Action): | |||||
for step in range(iterations): | for step in range(iterations): | ||||
batch_x, batch_y = self.action.make_batch(iterator, data_train) | batch_x, batch_y = self.action.make_batch(iterator, data_train) | ||||
prediction = self.action.data_forward(network, batch_x) | |||||
prediction = self.data_forward(network, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
@@ -137,6 +138,9 @@ class BaseTrainer(Action): | |||||
""" | """ | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def data_forward(self, network, x): | |||||
raise NotImplementedError | |||||
def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
""" | """ | ||||
Compute gradient with link rules. | Compute gradient with link rules. | ||||
@@ -223,7 +227,8 @@ class POSTrainer(BaseTrainer): | |||||
Trainer for Sequence Modeling | Trainer for Sequence Modeling | ||||
""" | """ | ||||
def __init__(self, train_args, action): | |||||
def __init__(self, train_args, action=None): | |||||
super(POSTrainer, self).__init__(train_args, action) | super(POSTrainer, self).__init__(train_args, action) | ||||
self.vocab_size = train_args["vocab_size"] | self.vocab_size = train_args["vocab_size"] | ||||
self.num_classes = train_args["num_classes"] | self.num_classes = train_args["num_classes"] | ||||
@@ -241,6 +246,24 @@ class POSTrainer(BaseTrainer): | |||||
def update(self): | def update(self): | ||||
self.optimizer.step() | self.optimizer.step() | ||||
def data_forward(self, network, inputs): | |||||
if not isinstance(inputs, tuple): | |||||
raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") | |||||
# unpack the returned value from make_batch | |||||
x, seq_len = inputs[0], inputs[1] | |||||
batch_size, max_len = x.size(0), x.size(1) | |||||
mask = utils.seq_mask(seq_len, max_len) | |||||
mask = mask.byte().view(batch_size, max_len) | |||||
x = torch.Tensor(x).long() | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
x = x.cuda() | |||||
mask = mask.cuda() | |||||
self.mask = mask | |||||
y = network(x) | |||||
return y | |||||
def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
""" | """ | ||||
Compute loss given prediction and ground truth. | Compute loss given prediction and ground truth. | ||||
@@ -251,13 +274,10 @@ class POSTrainer(BaseTrainer): | |||||
truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
truth = truth.cuda() | truth = truth.cuda() | ||||
assert truth.shape == (self.batch_size, self.action.max_len) | |||||
if self.loss_func is None: | |||||
if hasattr(self.model, "loss"): | |||||
self.loss_func = self.model.loss | |||||
else: | |||||
self.define_loss() | |||||
loss = self.loss_func(predict, truth, self.action.seq_len) | |||||
batch_size, max_len = predict.size(0), predict.size(1) | |||||
assert truth.shape == (batch_size, max_len) | |||||
loss = self.model.loss(predict, truth, self.mask) | |||||
return loss | return loss | ||||
def best_eval_result(self, validator): | def best_eval_result(self, validator): | ||||
@@ -1,7 +1,7 @@ | |||||
import torch | import torch | ||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder, utils | |||||
from fastNLP.modules import decoder, encoder | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
@@ -34,46 +34,25 @@ class SeqLabeling(BaseModel): | |||||
# [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
return x | return x | ||||
def loss(self, x, y, seq_length): | |||||
def loss(self, x, y, mask): | |||||
""" | """ | ||||
Negative log likelihood loss. | Negative log likelihood loss. | ||||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||||
:param y: LongTensor, [batch_size, max_len] | |||||
:param seq_length: list of int. [batch_size] | |||||
:param x: Tensor, [batch_size, max_len, tag_size] | |||||
:param y: Tensor, [batch_size, max_len] | |||||
:param mask: ByteTensor, [batch_size, ,max_len] | |||||
:return loss: a scalar Tensor | :return loss: a scalar Tensor | ||||
""" | """ | ||||
x = x.float() | x = x.float() | ||||
y = y.long() | y = y.long() | ||||
batch_size = x.size(0) | |||||
max_len = x.size(1) | |||||
mask = utils.seq_mask(seq_length, max_len) | |||||
mask = mask.byte().view(batch_size, max_len) | |||||
# TODO: remove | |||||
if torch.cuda.is_available(): | |||||
mask = mask.cuda() | |||||
# mask = x.new(batch_size, max_len) | |||||
total_loss = self.Crf(x, y, mask) | total_loss = self.Crf(x, y, mask) | ||||
return torch.mean(total_loss) | return torch.mean(total_loss) | ||||
def prediction(self, x, seq_length): | |||||
def prediction(self, x, mask): | |||||
""" | """ | ||||
:param x: FloatTensor, [batch_size, max_len, tag_size] | :param x: FloatTensor, [batch_size, max_len, tag_size] | ||||
:param seq_length: int | |||||
:return prediction: list of tuple of (decode path(list), best score) | |||||
:param mask: ByteTensor, [batch_size, max_len] | |||||
:return prediction: list of [decode path(list)] | |||||
""" | """ | ||||
x = x.float() | |||||
max_len = x.size(1) | |||||
mask = utils.seq_mask(seq_length, max_len) | |||||
# hack: make sure mask has the same device as x | |||||
mask = mask.to(x).byte() | |||||
tag_seq = self.Crf.viterbi_decode(x, mask) | tag_seq = self.Crf.viterbi_decode(x, mask) | ||||
return tag_seq | return tag_seq |
@@ -132,6 +132,7 @@ class ConditionalRandomField(nn.Module): | |||||
Given a feats matrix, return best decode path and best score. | Given a feats matrix, return best decode path and best score. | ||||
:param feats: | :param feats: | ||||
:param masks: | :param masks: | ||||
:param get_score: bool, whether to output the decode score. | |||||
:return:List[Tuple(List, float)], | :return:List[Tuple(List, float)], | ||||
""" | """ | ||||
batch_size, max_len, tag_size = feats.size() | batch_size, max_len, tag_size = feats.size() | ||||
@@ -2,7 +2,6 @@ import sys | |||||
sys.path.append("..") | sys.path.append("..") | ||||
from fastNLP.core.action import SeqLabelAction | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import POSTrainer | from fastNLP.core.trainer import POSTrainer | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | ||||
@@ -11,7 +10,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import POSTester | from fastNLP.core.tester import POSTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.core.inference import Inference | |||||
from fastNLP.core.inference import SeqLabelInfer | |||||
data_name = "people.txt" | data_name = "people.txt" | ||||
data_path = "data_for_tests/people.txt" | data_path = "data_for_tests/people.txt" | ||||
@@ -51,10 +50,11 @@ def infer(): | |||||
""" | """ | ||||
# Inference interface | # Inference interface | ||||
infer = Inference(pickle_path) | |||||
infer = SeqLabelInfer(pickle_path) | |||||
results = infer.predict(model, infer_data) | results = infer.predict(model, infer_data) | ||||
print(results) | |||||
for res in results: | |||||
print(res) | |||||
print("Inference finished!") | print("Inference finished!") | ||||
@@ -72,10 +72,8 @@ def train_and_test(): | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
action = SeqLabelAction(train_args) | |||||
# Trainer | # Trainer | ||||
trainer = POSTrainer(train_args, action) | |||||
trainer = POSTrainer(train_args) | |||||
# Model | # Model | ||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
@@ -103,7 +101,7 @@ def train_and_test(): | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
# Tester | # Tester | ||||
tester = POSTester(test_args, action) | |||||
tester = POSTester(test_args) | |||||
# Start testing | # Start testing | ||||
tester.test(model) | tester.test(model) | ||||
@@ -114,5 +112,5 @@ def train_and_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | |||||
# train_and_test() | |||||
infer() |