@@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
def save_pickle(obj, pickle_path, file_name): | def save_pickle(obj, pickle_path, file_name): | ||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | with open(os.path.join(pickle_path, file_name), "wb") as f: | ||||
_pickle.dump(obj, f) | _pickle.dump(obj, f) | ||||
print("{} saved. ".format(file_name)) | |||||
print("{} saved in {}".format(file_name, pickle_path)) | |||||
def load_pickle(pickle_path, file_name): | def load_pickle(pickle_path, file_name): | ||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | with open(os.path.join(pickle_path, file_name), "rb") as f: | ||||
obj = _pickle.load(f) | obj = _pickle.load(f) | ||||
print("{} loaded. ".format(file_name)) | |||||
print("{} loaded from {}".format(file_name, pickle_path)) | |||||
return obj | return obj | ||||
@@ -98,7 +98,7 @@ class BaseTester(object): | |||||
print_output = "[test step {}] {}".format(step, eval_results) | print_output = "[test step {}] {}".format(step, eval_results) | ||||
logger.info(print_output) | logger.info(print_output) | ||||
if step % self.print_every_step == 0: | |||||
if self.print_every_step > 0 and step % self.print_every_step == 0: | |||||
print(print_output) | print(print_output) | ||||
step += 1 | step += 1 | ||||
@@ -187,7 +187,7 @@ class SeqLabelTester(BaseTester): | |||||
# make sure "results" is in the same device as "truth" | # make sure "results" is in the same device as "truth" | ||||
results = results.to(truth) | results = results.to(truth) | ||||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | ||||
return [loss.data, accuracy.data] | |||||
return [float(loss), float(accuracy)] | |||||
def metrics(self): | def metrics(self): | ||||
batch_loss = np.mean([x[0] for x in self.eval_history]) | batch_loss = np.mean([x[0] for x in self.eval_history]) | ||||
@@ -4,7 +4,6 @@ import os | |||||
import time | import time | ||||
from datetime import timedelta | from datetime import timedelta | ||||
import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.action import Action | from fastNLP.core.action import Action | ||||
@@ -47,7 +46,7 @@ class BaseTrainer(object): | |||||
Otherwise, error will raise. | Otherwise, error will raise. | ||||
""" | """ | ||||
default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | ||||
"save_best_dev": True, "model_name": "default_model_name.pkl", | |||||
"save_best_dev": True, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||||
"loss": Loss(None), | "loss": Loss(None), | ||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | ||||
} | } | ||||
@@ -86,6 +85,7 @@ class BaseTrainer(object): | |||||
self.save_best_dev = default_args["save_best_dev"] | self.save_best_dev = default_args["save_best_dev"] | ||||
self.use_cuda = default_args["use_cuda"] | self.use_cuda = default_args["use_cuda"] | ||||
self.model_name = default_args["model_name"] | self.model_name = default_args["model_name"] | ||||
self.print_every_step = default_args["print_every_step"] | |||||
self._model = None | self._model = None | ||||
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | ||||
@@ -93,48 +93,35 @@ class BaseTrainer(object): | |||||
self._optimizer_proto = default_args["optimizer"] | self._optimizer_proto = default_args["optimizer"] | ||||
def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
"""General Training Steps | |||||
"""General Training Procedure | |||||
:param network: a model | :param network: a model | ||||
:param train_data: three-level list, the training set. | :param train_data: three-level list, the training set. | ||||
:param dev_data: three-level list, the validation data (optional) | :param dev_data: three-level list, the validation data (optional) | ||||
The method is framework independent. | |||||
Work by calling the following methods: | |||||
- prepare_input | |||||
- mode | |||||
- define_optimizer | |||||
- data_forward | |||||
- get_loss | |||||
- grad_backward | |||||
- update | |||||
Subclasses must implement these methods with a specific framework. | |||||
""" | """ | ||||
# prepare model and data, transfer model to gpu if available | |||||
# transfer model to gpu if available | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self._model = network.cuda() | self._model = network.cuda() | ||||
# self._model is used to access model-specific loss | |||||
else: | else: | ||||
self._model = network | self._model = network | ||||
# define tester over dev data | |||||
# define Tester over dev data | |||||
if self.validate: | if self.validate: | ||||
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | ||||
"use_cuda": self.use_cuda} | |||||
"use_cuda": self.use_cuda, "print_every_step": 0} | |||||
validator = self._create_validator(default_valid_args) | validator = self._create_validator(default_valid_args) | ||||
logger.info("validator defined as {}".format(str(validator))) | logger.info("validator defined as {}".format(str(validator))) | ||||
# optimizer and loss | |||||
self.define_optimizer() | self.define_optimizer() | ||||
logger.info("optimizer defined as {}".format(str(self._optimizer))) | logger.info("optimizer defined as {}".format(str(self._optimizer))) | ||||
self.define_loss() | self.define_loss() | ||||
logger.info("loss function defined as {}".format(str(self._loss_func))) | logger.info("loss function defined as {}".format(str(self._loss_func))) | ||||
# main training epochs | |||||
n_samples = len(train_data) | |||||
n_batches = n_samples // self.batch_size | |||||
n_print = 1 | |||||
# main training procedure | |||||
start = time.time() | start = time.time() | ||||
logger.info("training epochs started") | logger.info("training epochs started") | ||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
logger.info("training epoch {}".format(epoch)) | logger.info("training epoch {}".format(epoch)) | ||||
@@ -144,23 +131,30 @@ class BaseTrainer(object): | |||||
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) | data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) | ||||
logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) | |||||
# one forward and backward pass | |||||
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||||
# validation | |||||
if self.validate: | if self.validate: | ||||
logger.info("validation started") | logger.info("validation started") | ||||
validator.test(network, dev_data) | validator.test(network, dev_data) | ||||
if self.save_best_dev and self.best_eval_result(validator): | if self.save_best_dev and self.best_eval_result(validator): | ||||
self.save_model(network, self.model_name) | self.save_model(network, self.model_name) | ||||
print("saved better model selected by dev") | |||||
logger.info("saved better model selected by dev") | |||||
print("Saved better model selected by validation.") | |||||
logger.info("Saved better model selected by validation.") | |||||
valid_results = validator.show_matrices() | valid_results = validator.show_matrices() | ||||
print("[epoch {}] {}".format(epoch, valid_results)) | print("[epoch {}] {}".format(epoch, valid_results)) | ||||
logger.info("[epoch {}] {}".format(epoch, valid_results)) | logger.info("[epoch {}] {}".format(epoch, valid_results)) | ||||
def _train_step(self, data_iterator, network, **kwargs): | def _train_step(self, data_iterator, network, **kwargs): | ||||
"""Training process in one epoch.""" | |||||
"""Training process in one epoch. | |||||
kwargs should contain: | |||||
- n_print: int, print training information every n steps. | |||||
- start: time.time(), the starting time of this step. | |||||
- epoch: int, | |||||
""" | |||||
step = 0 | step = 0 | ||||
for batch_x, batch_y in self.make_batch(data_iterator): | for batch_x, batch_y in self.make_batch(data_iterator): | ||||
@@ -170,7 +164,7 @@ class BaseTrainer(object): | |||||
self.grad_backward(loss) | self.grad_backward(loss) | ||||
self.update() | self.update() | ||||
if step % kwargs["n_print"] == 0: | |||||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||||
end = time.time() | end = time.time() | ||||
diff = timedelta(seconds=round(end - kwargs["start"])) | diff = timedelta(seconds=round(end - kwargs["start"])) | ||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | ||||
@@ -287,10 +281,11 @@ class BaseTrainer(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def save_model(self, network, model_name): | def save_model(self, network, model_name): | ||||
""" | |||||
"""Save this model with such a name. | |||||
This method may be called multiple times by Trainer to overwritten a better model. | |||||
:param network: the PyTorch model | :param network: the PyTorch model | ||||
:param model_name: str | :param model_name: str | ||||
model_best_dev.pkl may be overwritten by a better model in future epochs. | |||||
""" | """ | ||||
if model_name[-4:] != ".pkl": | if model_name[-4:] != ".pkl": | ||||
model_name += ".pkl" | model_name += ".pkl" | ||||
@@ -300,33 +295,9 @@ class BaseTrainer(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class ToyTrainer(BaseTrainer): | |||||
""" | |||||
An example to show the definition of Trainer. | |||||
""" | |||||
def __init__(self, training_args): | |||||
super(ToyTrainer, self).__init__(training_args) | |||||
def load_train_data(self, data_path): | |||||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||||
return data_train, data_dev, 0, 1 | |||||
def data_forward(self, network, x): | |||||
return network(x) | |||||
def grad_backward(self, loss): | |||||
self._model.zero_grad() | |||||
loss.backward() | |||||
def get_loss(self, pred, truth): | |||||
return np.mean(np.square(pred - truth)) | |||||
class SeqLabelTrainer(BaseTrainer): | class SeqLabelTrainer(BaseTrainer): | ||||
""" | """ | ||||
Trainer for Sequence Modeling | |||||
Trainer for Sequence Labeling | |||||
""" | """ | ||||
@@ -384,7 +355,7 @@ class SeqLabelTrainer(BaseTrainer): | |||||
class ClassificationTrainer(BaseTrainer): | class ClassificationTrainer(BaseTrainer): | ||||
"""Trainer for classification.""" | |||||
"""Trainer for text classification.""" | |||||
def __init__(self, **train_args): | def __init__(self, **train_args): | ||||
super(ClassificationTrainer, self).__init__(**train_args) | super(ClassificationTrainer, self).__init__(**train_args) | ||||
@@ -1,4 +1,5 @@ | |||||
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | ||||
from fastNLP.core.preprocess import load_pickle | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
@@ -7,14 +8,13 @@ mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||||
Notice that the class of the model should be in "models" directory. | Notice that the class of the model should be in "models" directory. | ||||
Example: | Example: | ||||
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] | |||||
""" | |||||
FastNLP_MODEL_COLLECTION = { | |||||
"seq_label_model": { | "seq_label_model": { | ||||
"url": "www.fudan.edu.cn", | "url": "www.fudan.edu.cn", | ||||
"class": "sequence_modeling.SeqLabeling", | |||||
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ | |||||
"pickle": "seq_label_model.pkl", | "pickle": "seq_label_model.pkl", | ||||
"type": "seq_label" | |||||
"type": "seq_label", | |||||
"config_file_name": "config", # the name of the config file which stores model initialization parameters | |||||
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params | |||||
}, | }, | ||||
"text_class_model": { | "text_class_model": { | ||||
"url": "www.fudan.edu.cn", | "url": "www.fudan.edu.cn", | ||||
@@ -22,11 +22,18 @@ FastNLP_MODEL_COLLECTION = { | |||||
"pickle": "text_class_model.pkl", | "pickle": "text_class_model.pkl", | ||||
"type": "text_class" | "type": "text_class" | ||||
} | } | ||||
""" | |||||
FastNLP_MODEL_COLLECTION = { | |||||
"cws_basic_model": { | |||||
"url": "", | |||||
"class": "sequence_modeling.AdvSeqLabel", | |||||
"pickle": "cws_basic_model_v_0.pkl", | |||||
"type": "seq_label", | |||||
"config_file_name": "config", | |||||
"config_section_name": "text_class_model" | |||||
} | |||||
} | } | ||||
CONFIG_FILE_NAME = "config" | |||||
SECTION_NAME = "text_class_model" | |||||
class FastNLP(object): | class FastNLP(object): | ||||
""" | """ | ||||
@@ -51,10 +58,13 @@ class FastNLP(object): | |||||
self.model = None | self.model = None | ||||
self.infer_type = None # "seq_label"/"text_class" | self.infer_type = None # "seq_label"/"text_class" | ||||
def load(self, model_name): | |||||
def load(self, model_name, config_file="config", section_name="model"): | |||||
""" | """ | ||||
Load a pre-trained FastNLP model together with additional data. | Load a pre-trained FastNLP model together with additional data. | ||||
:param model_name: str, the name of a FastNLP model. | :param model_name: str, the name of a FastNLP model. | ||||
:param config_file: str, the name of the config file which stores the initialization information of the model. | |||||
(default: "config") | |||||
:param section_name: str, the name of the corresponding section in the config file. (default: model) | |||||
""" | """ | ||||
assert type(model_name) is str | assert type(model_name) is str | ||||
if model_name not in FastNLP_MODEL_COLLECTION: | if model_name not in FastNLP_MODEL_COLLECTION: | ||||
@@ -64,37 +74,47 @@ class FastNLP(object): | |||||
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) | self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) | ||||
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) | model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) | ||||
print("Restore model class {}".format(str(model_class))) | |||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) | |||||
ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args}) | |||||
print("Restore model hyper-parameters {}".format(str(model_args.data))) | |||||
# fetch dictionary size and number of labels from pickle files | |||||
word2index = load_pickle(self.model_dir, "word2id.pkl") | |||||
model_args["vocab_size"] = len(word2index) | |||||
index2label = load_pickle(self.model_dir, "id2class.pkl") | |||||
model_args["num_classes"] = len(index2label) | |||||
# Construct the model | # Construct the model | ||||
model = model_class(model_args) | model = model_class(model_args) | ||||
print("Model constructed.") | |||||
# To do: framework independent | # To do: framework independent | ||||
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) | ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) | ||||
print("Model weights loaded.") | |||||
self.model = model | self.model = model | ||||
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] | self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] | ||||
print("Model loaded. ") | |||||
print("Inference ready.") | |||||
def run(self, raw_input): | def run(self, raw_input): | ||||
""" | """ | ||||
Perform inference over given input using the loaded model. | Perform inference over given input using the loaded model. | ||||
:param raw_input: str, raw text | |||||
:param raw_input: list of string. Each list is an input query. | |||||
:return results: | :return results: | ||||
""" | """ | ||||
infer = self._create_inference(self.model_dir) | infer = self._create_inference(self.model_dir) | ||||
# string ---> 2-D list of string | |||||
infer_input = self.string_to_list(raw_input) | |||||
# tokenize: list of string ---> 2-D list of string | |||||
infer_input = self.tokenize(raw_input, language="zh") | |||||
# 2-D list of string ---> list of strings | |||||
# 2-D list of string ---> 2-D list of tags | |||||
results = infer.predict(self.model, infer_input) | results = infer.predict(self.model, infer_input) | ||||
# list of strings ---> final answers | |||||
# 2-D list of tags ---> list of final answers | |||||
outputs = self._make_output(results, infer_input) | outputs = self._make_output(results, infer_input) | ||||
return outputs | return outputs | ||||
@@ -142,81 +162,100 @@ class FastNLP(object): | |||||
""" | """ | ||||
return True | return True | ||||
def string_to_list(self, text, delimiter="\n"): | |||||
""" | |||||
This function is used to transform raw input to lists, which is done by DatasetLoader in training. | |||||
Split text string into three-level lists. | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
:param text: string | |||||
:param delimiter: str, character used to split text into sentences. | |||||
:return data: two-level lists | |||||
def tokenize(self, text, language): | |||||
"""Extract tokens from strings. | |||||
For English, extract words separated by space. | |||||
For Chinese, extract characters. | |||||
TODO: more complex tokenization methods | |||||
:param text: list of string | |||||
:param language: str, one of ('zh', 'en'), Chinese or English. | |||||
:return data: list of list of string, each string is a token. | |||||
""" | """ | ||||
assert language in ("zh", "en") | |||||
data = [] | data = [] | ||||
sents = text.strip().split(delimiter) | |||||
for sent in sents: | |||||
characters = [] | |||||
for ch in sent: | |||||
characters.append(ch) | |||||
data.append(characters) | |||||
for sent in text: | |||||
if language == "en": | |||||
tokens = sent.strip().split() | |||||
elif language == "zh": | |||||
tokens = [char for char in sent] | |||||
else: | |||||
raise RuntimeError("Unknown language {}".format(language)) | |||||
data.append(tokens) | |||||
return data | return data | ||||
def _make_output(self, results, infer_input): | def _make_output(self, results, infer_input): | ||||
"""Transform the infer output into user-friendly output. | |||||
:param results: 1 or 2-D list of strings. | |||||
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length] | |||||
If self.infer_type == "text_class", it is of shape [num_examples] | |||||
:param infer_input: 2-D list of string, the input query before inference. | |||||
:return outputs: list. Each entry is a prediction. | |||||
""" | |||||
if self.infer_type == "seq_label": | if self.infer_type == "seq_label": | ||||
outputs = make_seq_label_output(results, infer_input) | outputs = make_seq_label_output(results, infer_input) | ||||
elif self.infer_type == "text_class": | elif self.infer_type == "text_class": | ||||
outputs = make_class_output(results, infer_input) | outputs = make_class_output(results, infer_input) | ||||
else: | else: | ||||
raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
return outputs | return outputs | ||||
def make_seq_label_output(result, infer_input): | def make_seq_label_output(result, infer_input): | ||||
""" | |||||
Transform model output into user-friendly contents. | |||||
:param result: 1-D list of strings. (model output) | |||||
"""Transform model output into user-friendly contents. | |||||
:param result: 2-D list of strings. (model output) | |||||
:param infer_input: 2-D list of string (model input) | :param infer_input: 2-D list of string (model input) | ||||
:return outputs: | |||||
:return ret: list of list of tuples | |||||
[ | |||||
[(word_11, label_11), (word_12, label_12), ...], | |||||
[(word_21, label_21), (word_22, label_22), ...], | |||||
... | |||||
] | |||||
""" | """ | ||||
return result | |||||
ret = [] | |||||
for example_x, example_y in zip(infer_input, result): | |||||
ret.append([(x, y) for x, y in zip(example_x, example_y)]) | |||||
return ret | |||||
def make_class_output(result, infer_input): | def make_class_output(result, infer_input): | ||||
"""Transform model output into user-friendly contents. | |||||
:param result: 2-D list of strings. (model output) | |||||
:param infer_input: 1-D list of string (model input) | |||||
:return ret: the same as result, [label_1, label_2, ...] | |||||
""" | |||||
return result | return result | ||||
def interpret_word_seg_results(infer_input, results): | |||||
""" | |||||
Transform model output into user-friendly contents. | |||||
def interpret_word_seg_results(char_seq, label_seq): | |||||
"""Transform model output into user-friendly contents. | |||||
Example: In CWS, convert <BMES> labeling into segmented text. | Example: In CWS, convert <BMES> labeling into segmented text. | ||||
:param results: list of strings. (model output) | |||||
:param infer_input: 2-D list of string (model input) | |||||
:return output: list of strings | |||||
:param char_seq: list of string, | |||||
:param label_seq: list of string, the same length as char_seq | |||||
Each entry is one of ('B', 'M', 'E', 'S'). | |||||
:return output: list of words | |||||
""" | """ | ||||
outputs = [] | |||||
for sent_char, sent_label in zip(infer_input, results): | |||||
words = [] | |||||
word = "" | |||||
for char, label in zip(sent_char, sent_label): | |||||
if label[0] == "B": | |||||
if word != "": | |||||
words.append(word) | |||||
word = char | |||||
elif label[0] == "M": | |||||
word += char | |||||
elif label[0] == "E": | |||||
word += char | |||||
words = [] | |||||
word = "" | |||||
for char, label in zip(char_seq, label_seq): | |||||
if label[0] == "B": | |||||
if word != "": | |||||
words.append(word) | words.append(word) | ||||
word = "" | |||||
elif label[0] == "S": | |||||
if word != "": | |||||
words.append(word) | |||||
word = "" | |||||
words.append(char) | |||||
else: | |||||
raise ValueError("invalid label") | |||||
outputs.append(" ".join(words)) | |||||
return outputs | |||||
word = char | |||||
elif label[0] == "M": | |||||
word += char | |||||
elif label[0] == "E": | |||||
word += char | |||||
words.append(word) | |||||
word = "" | |||||
elif label[0] == "S": | |||||
if word != "": | |||||
words.append(word) | |||||
word = "" | |||||
words.append(char) | |||||
else: | |||||
raise ValueError("invalid label {}".format(label[0])) | |||||
return words |
@@ -1,26 +1,26 @@ | |||||
import sys, os | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | ||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
from fastNLP.core.inference import SeqLabelInfer | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
# not in the file's dir | # not in the file's dir | ||||
if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
datadir = 'icwb2-data' | |||||
cfgfile = 'cws.cfg' | |||||
datadir = "/home/zyfeng/data/" | |||||
cfgfile = './cws.cfg' | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
cws_data_path = os.path.join(datadir, "training/pku_training.utf8") | |||||
cws_data_path = os.path.join(datadir, "pku_training.utf8") | |||||
pickle_path = "save" | pickle_path = "save" | ||||
data_infer_path = os.path.join(datadir, "infer.utf8") | data_infer_path = os.path.join(datadir, "infer.utf8") | ||||
@@ -70,12 +70,13 @@ def train(): | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = p.vocab_size | |||||
train_args["num_classes"] = p.num_classes | |||||
preprocessor = SeqLabelPreprocess() | |||||
data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = preprocessor.vocab_size | |||||
train_args["num_classes"] = preprocessor.num_classes | |||||
# Trainer | # Trainer | ||||
trainer = SeqLabelTrainer(train_args) | |||||
trainer = SeqLabelTrainer(**train_args.data) | |||||
# Model | # Model | ||||
model = AdvSeqLabel(train_args) | model = AdvSeqLabel(train_args) | ||||
@@ -83,10 +84,11 @@ def train(): | |||||
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | ||||
print('model parameter loaded!') | print('model parameter loaded!') | ||||
except Exception as e: | except Exception as e: | ||||
print("No saved model. Continue.") | |||||
pass | pass | ||||
# Start training | # Start training | ||||
trainer.train(model) | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
@@ -106,6 +108,9 @@ def test(): | |||||
index2label = load_pickle(pickle_path, "id2class.pkl") | index2label = load_pickle(pickle_path, "id2class.pkl") | ||||
test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
# load dev data | |||||
dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||||
# Define the same model | # Define the same model | ||||
model = AdvSeqLabel(test_args) | model = AdvSeqLabel(test_args) | ||||
@@ -114,10 +119,10 @@ def test(): | |||||
print("model loaded!") | print("model loaded!") | ||||
# Tester | # Tester | ||||
tester = SeqLabelTester(test_args) | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | # Start testing | ||||
tester.test(model) | |||||
tester.test(model, dev_data) | |||||
# print test results | # print test results | ||||
print(tester.show_matrices()) | print(tester.show_matrices()) | ||||
@@ -123,7 +123,7 @@ def train_and_test(): | |||||
tester = SeqLabelTester(save_output=False, | tester = SeqLabelTester(save_output=False, | ||||
save_loss=False, | save_loss=False, | ||||
save_best_dev=False, | save_best_dev=False, | ||||
batch_size=8, | |||||
batch_size=4, | |||||
use_cuda=False, | use_cuda=False, | ||||
pickle_path=pickle_path, | pickle_path=pickle_path, | ||||
model_name="seq_label_in_test.pkl", | model_name="seq_label_in_test.pkl", | ||||
@@ -140,4 +140,4 @@ def train_and_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | train_and_test() | ||||
infer() | |||||
# infer() |
@@ -1,13 +1,24 @@ | |||||
import sys | |||||
sys.path.append("..") | |||||
from fastNLP.fastnlp import FastNLP | from fastNLP.fastnlp import FastNLP | ||||
from fastNLP.fastnlp import interpret_word_seg_results | |||||
PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/" | |||||
def word_seg(): | def word_seg(): | ||||
nlp = FastNLP("./data_for_tests/") | |||||
nlp.load("seq_label_model") | |||||
text = "这是最好的基于深度学习的中文分词系统。" | |||||
result = nlp.run(text) | |||||
print(result) | |||||
print("FastNLP finished!") | |||||
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES) | |||||
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test") | |||||
text = ["这是最好的基于深度学习的中文分词系统。", | |||||
"大王叫我来巡山。", | |||||
"我党多年来致力于改善人民生活水平。"] | |||||
results = nlp.run(text) | |||||
print(results) | |||||
for example in results: | |||||
words, labels = [], [] | |||||
for res in example: | |||||
words.append(res[0]) | |||||
labels.append(res[1]) | |||||
print(interpret_word_seg_results(words, labels)) | |||||
def text_class(): | def text_class(): | ||||
@@ -19,5 +30,14 @@ def text_class(): | |||||
print("FastNLP finished!") | print("FastNLP finished!") | ||||
def test_word_seg_interpret(): | |||||
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), | |||||
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), | |||||
('。', 'S')]] | |||||
chars = [x[0] for x in foo[0]] | |||||
labels = [x[1] for x in foo[0]] | |||||
print(interpret_word_seg_results(chars, labels)) | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
text_class() | |||||
word_seg() |