@@ -0,0 +1,29 @@ | |||||
class Inference(object): | |||||
""" | |||||
This is an interface focusing on predicting output based on trained models. | |||||
It does not care about evaluations of the model. | |||||
Possible improvements: | |||||
- use batch to make use of GPU | |||||
""" | |||||
def __init__(self): | |||||
pass | |||||
def predict(self, model, data): | |||||
""" | |||||
this is actually a forward pass. shall be shared by Trainer/Tester | |||||
:param model: | |||||
:param data: | |||||
:return result: the output results | |||||
""" | |||||
raise NotImplementedError | |||||
def prepare_input(self, data_path): | |||||
""" | |||||
This can also be shared. | |||||
:param data_path: | |||||
:return: | |||||
""" | |||||
raise NotImplementedError |
@@ -11,6 +11,7 @@ from fastNLP.action.action import Action | |||||
from fastNLP.action.action import RandomSampler, Batchifier | from fastNLP.action.action import RandomSampler, Batchifier | ||||
from fastNLP.action.tester import POSTester | from fastNLP.action.tester import POSTester | ||||
from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
from fastNLP.saver.model_saver import ModelSaver | |||||
class BaseTrainer(Action): | class BaseTrainer(Action): | ||||
@@ -38,9 +39,13 @@ class BaseTrainer(Action): | |||||
""" | """ | ||||
super(BaseTrainer, self).__init__() | super(BaseTrainer, self).__init__() | ||||
self.n_epochs = train_args["epochs"] | self.n_epochs = train_args["epochs"] | ||||
self.validate = train_args["validate"] | |||||
self.batch_size = train_args["batch_size"] | self.batch_size = train_args["batch_size"] | ||||
self.pickle_path = train_args["pickle_path"] | self.pickle_path = train_args["pickle_path"] | ||||
self.validate = train_args["validate"] | |||||
self.save_best_dev = train_args["save_best_dev"] | |||||
self.model_saved_path = train_args["model_saved_path"] | |||||
self.model = None | self.model = None | ||||
self.iterator = None | self.iterator = None | ||||
self.loss_func = None | self.loss_func = None | ||||
@@ -72,7 +77,7 @@ class BaseTrainer(Action): | |||||
# main training epochs | # main training epochs | ||||
iterations = len(data_train) // self.batch_size | iterations = len(data_train) // self.batch_size | ||||
for epoch in range(self.n_epochs): | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
# turn on network training mode; define optimizer; prepare batch iterator | # turn on network training mode; define optimizer; prepare batch iterator | ||||
self.mode(test=False) | self.mode(test=False) | ||||
@@ -93,6 +98,11 @@ class BaseTrainer(Action): | |||||
if data_dev is None: | if data_dev is None: | ||||
raise RuntimeError("No validation data provided.") | raise RuntimeError("No validation data provided.") | ||||
validator.test(network) | validator.test(network) | ||||
if self.save_best_dev and self.best_eval_result(validator): | |||||
self.save_model(network) | |||||
print("saved better model selected by dev") | |||||
print("[epoch {}]".format(epoch), end=" ") | print("[epoch {}]".format(epoch), end=" ") | ||||
print(validator.show_matrices()) | print(validator.show_matrices()) | ||||
@@ -205,124 +215,49 @@ class BaseTrainer(Action): | |||||
batch[idx] = sample + [fill * (max_length - len(sample))] | batch[idx] = sample + [fill * (max_length - len(sample))] | ||||
return batch | return batch | ||||
def best_eval_result(self, validator): | |||||
""" | |||||
:param validator: a Tester instance | |||||
:return: bool, True means current results on dev set is the best. | |||||
""" | |||||
raise NotImplementedError | |||||
def save_model(self, network): | |||||
""" | |||||
:param network: the PyTorch model | |||||
model_best_dev.pkl may be overwritten by a better model in future epochs. | |||||
""" | |||||
ModelSaver(self.model_saved_path + "model_best_dev.pkl").save_pytorch(network) | |||||
class ToyTrainer(BaseTrainer): | class ToyTrainer(BaseTrainer): | ||||
""" | """ | ||||
deprecated | |||||
An example to show the definition of Trainer. | |||||
""" | """ | ||||
def __init__(self, train_args): | |||||
super(ToyTrainer, self).__init__(train_args) | |||||
self.test_mode = False | |||||
self.weight = np.random.rand(5, 1) | |||||
self.bias = np.random.rand() | |||||
self._loss = 0 | |||||
self._optimizer = None | |||||
def __init__(self, training_args): | |||||
super(ToyTrainer, self).__init__(training_args) | |||||
def prepare_input(self, data): | |||||
return data[:, :-1], data[:, -1] | |||||
def prepare_input(self, data_path): | |||||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
return data_train, data_dev, 0, 1 | |||||
def mode(self, test=False): | def mode(self, test=False): | ||||
self.model.mode(test) | self.model.mode(test) | ||||
def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
return np.matmul(x, self.weight) + self.bias | |||||
return network(x) | |||||
def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
self.model.zero_grad() | |||||
loss.backward() | loss.backward() | ||||
def get_loss(self, pred, truth): | def get_loss(self, pred, truth): | ||||
self._loss = np.mean(np.square(pred - truth)) | |||||
return self._loss | |||||
def define_optimizer(self): | |||||
self._optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) | |||||
def update(self): | |||||
self._optimizer.step() | |||||
class WordSegTrainer(BaseTrainer): | |||||
""" | |||||
deprecated | |||||
""" | |||||
def __init__(self, train_args): | |||||
super(WordSegTrainer, self).__init__(train_args) | |||||
self.id2word = None | |||||
self.word2id = None | |||||
self.id2tag = None | |||||
self.tag2id = None | |||||
self.lstm_batch_size = 8 | |||||
self.lstm_seq_len = 32 # Trainer batch_size == lstm_batch_size * lstm_seq_len | |||||
self.hidden_dim = 100 | |||||
self.lstm_num_layers = 2 | |||||
self.vocab_size = 100 | |||||
self.word_emb_dim = 100 | |||||
self.hidden = (self.to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)), | |||||
self.to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim))) | |||||
self.optimizer = None | |||||
self._loss = None | |||||
self.USE_GPU = False | |||||
def to_var(self, x): | |||||
if torch.cuda.is_available() and self.USE_GPU: | |||||
x = x.cuda() | |||||
return torch.autograd.Variable(x) | |||||
def prepare_input(self, data): | |||||
""" | |||||
perform word indices lookup to convert strings into indices | |||||
:param data: list of string, each string contains word + space + [B, M, E, S] | |||||
:return | |||||
""" | |||||
word_list = [] | |||||
tag_list = [] | |||||
for line in data: | |||||
if len(line) > 2: | |||||
tokens = line.split("#") | |||||
word_list.append(tokens[0]) | |||||
tag_list.append(tokens[2][0]) | |||||
self.id2word = list(set(word_list)) | |||||
self.word2id = {word: idx for idx, word in enumerate(self.id2word)} | |||||
self.id2tag = list(set(tag_list)) | |||||
self.tag2id = {tag: idx for idx, tag in enumerate(self.id2tag)} | |||||
words = np.array([self.word2id[w] for w in word_list]).reshape(-1, 1) | |||||
tags = np.array([self.tag2id[t] for t in tag_list]).reshape(-1, 1) | |||||
return words, tags | |||||
def mode(self, test=False): | |||||
if test: | |||||
self.model.eval() | |||||
else: | |||||
self.model.train() | |||||
def data_forward(self, network, x): | |||||
""" | |||||
:param network: a PyTorch model | |||||
:param x: sequence of length [batch_size], word indices | |||||
:return: | |||||
""" | |||||
x = x.reshape(self.lstm_batch_size, self.lstm_seq_len) | |||||
output, self.hidden = network(x, self.hidden) | |||||
return output | |||||
return np.mean(np.square(pred - truth)) | |||||
def define_optimizer(self): | def define_optimizer(self): | ||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) | |||||
def get_loss(self, predict, truth): | |||||
truth = torch.Tensor(truth) | |||||
self._loss = torch.nn.CrossEntropyLoss(predict, truth) | |||||
return self._loss | |||||
def grad_backward(self, network): | |||||
self.model.zero_grad() | |||||
self._loss.backward() | |||||
torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2) | |||||
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01) | |||||
def update(self): | def update(self): | ||||
self.optimizer.step() | self.optimizer.step() | ||||
@@ -339,6 +274,7 @@ class POSTrainer(BaseTrainer): | |||||
self.num_classes = train_args["num_classes"] | self.num_classes = train_args["num_classes"] | ||||
self.max_len = None | self.max_len = None | ||||
self.mask = None | self.mask = None | ||||
self.best_accuracy = 0.0 | |||||
def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
""" | """ | ||||
@@ -395,6 +331,26 @@ class POSTrainer(BaseTrainer): | |||||
# print("loss={:.2f}".format(loss.data)) | # print("loss={:.2f}".format(loss.data)) | ||||
return loss | return loss | ||||
def best_eval_result(self, validator): | |||||
loss, accuracy = validator.matrices() | |||||
if accuracy > self.best_accuracy: | |||||
self.best_accuracy = accuracy | |||||
return True | |||||
else: | |||||
return False | |||||
class LanguageModelTrainer(BaseTrainer): | |||||
""" | |||||
Trainer for Language Model | |||||
""" | |||||
def __init__(self, train_args): | |||||
super(LanguageModelTrainer, self).__init__(train_args) | |||||
def prepare_input(self, data_path): | |||||
pass | |||||
class ClassTrainer(BaseTrainer): | class ClassTrainer(BaseTrainer): | ||||
"""Trainer for classification.""" | """Trainer for classification.""" | ||||
@@ -70,7 +70,7 @@ class ConfigSection(object): | |||||
""" | """ | ||||
if key in self.__dict__.keys(): | if key in self.__dict__.keys(): | ||||
return getattr(self, key) | return getattr(self, key) | ||||
raise AttributeError('don\'t have attr %s' % (key)) | |||||
raise AttributeError("do NOT have attribute %s" % key) | |||||
def __setitem__(self, key, value): | def __setitem__(self, key, value): | ||||
""" | """ | ||||
@@ -100,3 +100,15 @@ class ConllLoader(DatasetLoader): | |||||
continue | continue | ||||
tokens.append(line.split()) | tokens.append(line.split()) | ||||
return sentences | return sentences | ||||
class LMDatasetLoader(DatasetLoader): | |||||
def __init__(self, data_name, data_path): | |||||
super(LMDatasetLoader, self).__init__(data_name, data_path) | |||||
def load(self): | |||||
if not os.path.exists(self.data_path): | |||||
raise FileNotFoundError("file {} not found.".format(self.data_path)) | |||||
with open(self.data_path, "r", encoding="utf=8") as f: | |||||
text = " ".join(f.readlines()) | |||||
return text.strip().split() |
@@ -20,30 +20,6 @@ class BasePreprocess(object): | |||||
if not self.pickle_path.endswith('/'): | if not self.pickle_path.endswith('/'): | ||||
self.pickle_path = self.pickle_path + '/' | self.pickle_path = self.pickle_path + '/' | ||||
def word2id(self): | |||||
raise NotImplementedError | |||||
def id2word(self): | |||||
raise NotImplementedError | |||||
def class2id(self): | |||||
raise NotImplementedError | |||||
def id2class(self): | |||||
raise NotImplementedError | |||||
def embedding(self): | |||||
raise NotImplementedError | |||||
def data_train(self): | |||||
raise NotImplementedError | |||||
def data_dev(self): | |||||
raise NotImplementedError | |||||
def data_test(self): | |||||
raise NotImplementedError | |||||
class POSPreprocess(BasePreprocess): | class POSPreprocess(BasePreprocess): | ||||
""" | """ | ||||
@@ -65,14 +41,24 @@ class POSPreprocess(BasePreprocess): | |||||
to label5. | to label5. | ||||
""" | """ | ||||
def __init__(self, data, pickle_path): | def __init__(self, data, pickle_path): | ||||
super(POSPreprocess, self).__init__(data, pickle_path) | super(POSPreprocess, self).__init__(data, pickle_path) | ||||
self.word_dict = None | |||||
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
self.label_dict = None | self.label_dict = None | ||||
self.data = data | self.data = data | ||||
self.pickle_path = pickle_path | self.pickle_path = pickle_path | ||||
self.build_dict() | |||||
self.word2id() | |||||
self.build_dict(data) | |||||
if not self.pickle_exist("word2id.pkl"): | |||||
self.word_dict.update(self.word2id(data)) | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.word_dict, f) | |||||
self.vocab_size = self.id2word() | self.vocab_size = self.id2word() | ||||
self.class2id() | self.class2id() | ||||
self.num_classes = self.id2class() | self.num_classes = self.id2class() | ||||
@@ -81,26 +67,26 @@ class POSPreprocess(BasePreprocess): | |||||
self.data_dev() | self.data_dev() | ||||
self.data_test() | self.data_test() | ||||
def build_dict(self): | |||||
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
def build_dict(self, data): | |||||
""" | |||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||||
:param data: list of list [word, label] | |||||
""" | |||||
self.label_dict = {} | self.label_dict = {} | ||||
for w in self.data: | |||||
w = w.strip() | |||||
if len(w) <= 1: | |||||
for line in data: | |||||
line = line.strip() | |||||
if len(line) <= 1: | |||||
continue | continue | ||||
word = w.split('\t') | |||||
tokens = line.split('\t') | |||||
if word[0] not in self.word_dict: | |||||
index = len(self.word_dict) | |||||
self.word_dict[word[0]] = index | |||||
if tokens[0] not in self.word_dict: | |||||
# add (word, index) into the dict | |||||
self.word_dict[tokens[0]] = len(self.word_dict) | |||||
# for label in word[1: ]: | |||||
label = word[1] | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
# for label in tokens[1: ]: | |||||
if tokens[1] not in self.label_dict: | |||||
self.label_dict[tokens[1]] = len(self.label_dict) | |||||
def pickle_exist(self, pickle_name): | def pickle_exist(self, pickle_name): | ||||
""" | """ | ||||
@@ -384,3 +370,8 @@ class ClassPreprocess(BasePreprocess): | |||||
# save data | # save data | ||||
with open(save_path, "wb") as f: | with open(save_path, "wb") as f: | ||||
_pickle.dump(data, f) | _pickle.dump(data, f) | ||||
class LMPreprocess(BasePreprocess): | |||||
def __init__(self, data, pickle_path): | |||||
super(LMPreprocess, self).__init__(data, pickle_path) |
@@ -1,5 +1,4 @@ | |||||
import os | import os | ||||
from collections import namedtuple | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
@@ -23,8 +22,6 @@ class CharLM(BaseModel): | |||||
To do: | To do: | ||||
- where the data goes, call data savers. | - where the data goes, call data savers. | ||||
""" | """ | ||||
DataTuple = namedtuple("DataTuple", ["feature", "label"]) | |||||
def __init__(self, lstm_batch_size, lstm_seq_len): | def __init__(self, lstm_batch_size, lstm_seq_len): | ||||
super(CharLM, self).__init__() | super(CharLM, self).__init__() | ||||
""" | """ | ||||
@@ -1,46 +0,0 @@ | |||||
import torch.nn as nn | |||||
from fastNLP.models.base_model import BaseModel | |||||
class WordSeg(BaseModel): | |||||
""" | |||||
PyTorch Network for word segmentation | |||||
""" | |||||
def __init__(self, hidden_dim, lstm_num_layers, vocab_size, word_emb_dim=100): | |||||
super(WordSeg, self).__init__() | |||||
self.vocab_size = vocab_size | |||||
self.word_emb_dim = word_emb_dim | |||||
self.lstm_num_layers = lstm_num_layers | |||||
self.hidden_dim = hidden_dim | |||||
self.word_emb = nn.Embedding(self.vocab_size, self.word_emb_dim) | |||||
self.lstm = nn.LSTM(input_size=self.word_emb_dim, | |||||
hidden_size=self.word_emb_dim, | |||||
num_layers=self.lstm_num_layers, | |||||
bias=True, | |||||
dropout=0.5, | |||||
batch_first=True) | |||||
self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | |||||
def forward(self, x, hidden): | |||||
""" | |||||
:param x: tensor of shape [batch_size, seq_len], vocabulary index | |||||
:param hidden: | |||||
:return x: probability of vocabulary entries | |||||
hidden: (memory cell, hidden state) from LSTM | |||||
""" | |||||
# [batch_size, seq_len] | |||||
x = self.word_emb(x) | |||||
# [batch_size, seq_len, word_emb_size] | |||||
x, hidden = self.lstm(x, hidden) | |||||
# [batch_size, seq_len, word_emb_size] | |||||
x = x.contiguous().view(x.shape[0] * x.shape[1], -1) | |||||
# [batch_size*seq_len, word_emb_size] | |||||
x = self.linear(x) | |||||
# [batch_size*seq_len, vocab_size] | |||||
return x, hidden |
@@ -58,12 +58,19 @@ epochs = 20 | |||||
batch_size = 1 | batch_size = 1 | ||||
pickle_path = "./data_for_tests/" | pickle_path = "./data_for_tests/" | ||||
validate = true | validate = true | ||||
save_best_dev = true | |||||
model_saved_path = "./" | |||||
rnn_hidden_units = 100 | |||||
rnn_layers = 1 | |||||
rnn_bi_direction = true | |||||
word_emb_dim = 100 | |||||
dropout = 0.5 | |||||
use_crf = true | |||||
[POS_test] | [POS_test] | ||||
save_output = true | save_output = true | ||||
validate_in_training = false | |||||
validate_in_training = true | |||||
save_dev_input = false | save_dev_input = false | ||||
save_loss = true | save_loss = true | ||||
batch_size = 1 | batch_size = 1 | ||||
pickle_path = "./data_for_tests/" | pickle_path = "./data_for_tests/" | ||||
@@ -1,4 +1,5 @@ | |||||
import sys | import sys | ||||
sys.path.append("..") | sys.path.append("..") | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
@@ -9,12 +10,38 @@ from fastNLP.saver.model_saver import ModelSaver | |||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.action.tester import POSTester | from fastNLP.action.tester import POSTester | ||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
from fastNLP.action.inference import Inference | |||||
data_name = "people.txt" | data_name = "people.txt" | ||||
data_path = "data_for_tests/people.txt" | data_path = "data_for_tests/people.txt" | ||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
def test_infer(): | |||||
# Define the same model | |||||
model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], | |||||
num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], | |||||
word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], | |||||
rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) | |||||
# Dump trained parameters into the model | |||||
ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model) | |||||
print("model loaded!") | |||||
# Data Loader | |||||
pos_loader = POSDatasetLoader(data_name, data_path) | |||||
infer_data = pos_loader.load_lines() | |||||
# Preprocessor | |||||
POSPreprocess(infer_data, pickle_path) | |||||
# Inference interface | |||||
infer = Inference() | |||||
results = infer.predict(model, infer_data) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# Config Loader | |||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ||||
@@ -24,37 +51,49 @@ if __name__ == "__main__": | |||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path) | p = POSPreprocess(train_data, pickle_path) | ||||
vocab_size = p.vocab_size | |||||
num_classes = p.num_classes | |||||
train_args["vocab_size"] = vocab_size | |||||
train_args["num_classes"] = num_classes | |||||
train_args["vocab_size"] = p.vocab_size | |||||
train_args["num_classes"] = p.num_classes | |||||
# Trainer | |||||
trainer = POSTrainer(train_args) | trainer = POSTrainer(train_args) | ||||
# Model | # Model | ||||
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) | |||||
model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], | |||||
num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], | |||||
word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], | |||||
rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) | |||||
# Start training | # Start training | ||||
trainer.train(model) | trainer.train(model) | ||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | |||||
saver = ModelSaver("./saved_model.pkl") | saver = ModelSaver("./saved_model.pkl") | ||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
print("Model saved!") | print("Model saved!") | ||||
del model, trainer, pos_loader | del model, trainer, pos_loader | ||||
model = SeqLabeling(100, 1, num_classes, vocab_size, bi_direction=True) | |||||
ModelLoader("xxx", "./saved_model.pkl").load_pytorch(model) | |||||
# Define the same model | |||||
model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], | |||||
num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], | |||||
word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], | |||||
rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) | |||||
# Dump trained parameters into the model | |||||
ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model) | |||||
print("model loaded!") | print("model loaded!") | ||||
# Load test configuration | |||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | ||||
# test_args = {"save_output": True, "validate_in_training": False, "save_dev_input": False, | |||||
# "save_loss": True, "batch_size": 1, "pickle_path": pickle_path} | |||||
# Tester | |||||
tester = POSTester(test_args) | tester = POSTester(test_args) | ||||
# Start testing | |||||
tester.test(model) | tester.test(model) | ||||
# print test results | |||||
print(tester.show_matrices()) | |||||
print("model tested!") | print("model tested!") |
@@ -0,0 +1,28 @@ | |||||
import aggregation | |||||
import decoder | |||||
import encoder | |||||
class Input(object): | |||||
def __init__(self): | |||||
pass | |||||
class Trainer(object): | |||||
def __init__(self, input, target, truth): | |||||
pass | |||||
def train(self): | |||||
pass | |||||
def test_keras_like(): | |||||
data_train, label_train = dataLoader("./data_path") | |||||
x = Input() | |||||
x = encoder.LSTM(input=x) | |||||
x = aggregation.max_pool(input=x) | |||||
y = decoder.CRF(input=x) | |||||
trainer = Trainer(input=data_train, target=y, truth=label_train) | |||||
trainer.train() |
@@ -1,11 +1,3 @@ | |||||
from collections import namedtuple | |||||
import numpy as np | |||||
from model.base_model import ToyModel | |||||
from fastNLP.action.trainer import Trainer | |||||
def test_trainer(): | def test_trainer(): | ||||
Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | ||||
train_config = Config(epochs=5, validate=True, save_when_better=True) | train_config = Config(epochs=5, validate=True, save_when_better=True) | ||||
@@ -1,28 +0,0 @@ | |||||
from fastNLP.action.tester import Tester | |||||
from fastNLP.action.trainer import WordSegTrainer | |||||
from fastNLP.loader.base_loader import BaseLoader | |||||
from fastNLP.models.word_seg_model import WordSeg | |||||
def test_wordseg(): | |||||
train_config = WordSegTrainer.TrainConfig(epochs=5, validate=False, save_when_better=False, | |||||
log_per_step=10, log_validation=False, batch_size=254) | |||||
trainer = WordSegTrainer(train_config) | |||||
model = WordSeg(100, 2, 1000) | |||||
train_data = BaseLoader("load_train", "./data_for_tests/cws_train").load_lines() | |||||
trainer.train(model, train_data) | |||||
test_config = Tester.TestConfig(save_output=False, validate_in_training=False, | |||||
save_dev_input=False, save_loss=False, batch_size=254) | |||||
tester = Tester(test_config) | |||||
test_data = BaseLoader("load_test", "./data_for_tests/cws_test").load_lines() | |||||
tester.test(model, test_data) | |||||
if __name__ == "__main__": | |||||
test_wordseg() |