diff --git a/fastNLP/core/inference.py b/fastNLP/core/inference.py index c1085554..3937e3f4 100644 --- a/fastNLP/core/inference.py +++ b/fastNLP/core/inference.py @@ -149,7 +149,7 @@ class SeqLabelInfer(Inference): """ Transform list of batch outputs into strings. :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. - :return: + :return results: 2-D list of strings """ results = [] for batch in batch_outputs: @@ -178,7 +178,7 @@ class ClassificationInfer(Inference): """ Transform list of batch outputs into strings. :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. - :return: + :return results: list of strings """ results = [] for batch_out in batch_outputs: diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d7515a40..8fcdc692 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -315,14 +315,8 @@ class ClassificationTrainer(BaseTrainer): def __init__(self, train_args): super(ClassificationTrainer, self).__init__(train_args) - if "learn_rate" in train_args: - self.learn_rate = train_args["learn_rate"] - else: - self.learn_rate = 1e-3 - if "momentum" in train_args: - self.momentum = train_args["momentum"] - else: - self.momentum = 0.9 + self.learn_rate = train_args["learn_rate"] + self.momentum = train_args["momentum"] self.iterator = None self.loss_func = None diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 066123da..e67fc63b 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -1,4 +1,4 @@ -from fastNLP.core.inference import Inference +from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader @@ -10,14 +10,28 @@ Example: "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] """ FastNLP_MODEL_COLLECTION = { - "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] + "seq_label_model": { + "url": "www.fudan.edu.cn", + "class": "sequence_modeling.SeqLabeling", + "pickle": "seq_label_model.pkl", + "type": "seq_label" + }, + "text_class_model": { + "url": "www.fudan.edu.cn", + "class": "cnn_text_classification.CNNText", + "pickle": "text_class_model.pkl", + "type": "text_class" + } } +CONFIG_FILE_NAME = "config" +SECTION_NAME = "text_class_model" + class FastNLP(object): """ High-level interface for direct model inference. - Usage: + Example Usage: fastnlp = FastNLP() fastnlp.load("zh_pos_tag_model") text = "这是最好的基于深度学习的中文分词系统。" @@ -35,6 +49,7 @@ class FastNLP(object): """ self.model_dir = model_dir self.model = None + self.infer_type = None # "seq_label"/"text_class" def load(self, model_name): """ @@ -46,21 +61,21 @@ class FastNLP(object): raise ValueError("No FastNLP model named {}.".format(model_name)) if not self.model_exist(model_dir=self.model_dir): - self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0]) + self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) - model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1]) + model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) model_args = ConfigSection() - # To do: customized config file for model init parameters - ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args}) + ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) # Construct the model model = model_class(model_args) # To do: framework independent - ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2]) + ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) self.model = model + self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] print("Model loaded. ") @@ -71,12 +86,16 @@ class FastNLP(object): :return results: """ - infer = Inference(self.model_dir) + infer = self._create_inference(self.model_dir) + + # string ---> 2-D list of string infer_input = self.string_to_list(raw_input) + # 2-D list of string ---> list of strings results = infer.predict(self.model, infer_input) - outputs = self.make_output(results) + # list of strings ---> final answers + outputs = self._make_output(results, infer_input) return outputs @staticmethod @@ -95,6 +114,14 @@ class FastNLP(object): module = getattr(module, sub) return module + def _create_inference(self, model_dir): + if self.infer_type == "seq_label": + return SeqLabelInfer(model_dir) + elif self.infer_type == "text_class": + return ClassificationInfer(model_dir) + else: + raise ValueError("fail to create inference instance") + def _load(self, model_dir, model_name): # To do return 0 @@ -117,7 +144,6 @@ class FastNLP(object): def string_to_list(self, text, delimiter="\n"): """ - For word seg only, currently. This function is used to transform raw input to lists, which is done by DatasetLoader in training. Split text string into three-level lists. [ @@ -127,7 +153,7 @@ class FastNLP(object): ] :param text: string :param delimiter: str, character used to split text into sentences. - :return data: three-level lists + :return data: two-level lists """ data = [] sents = text.strip().split(delimiter) @@ -136,38 +162,61 @@ class FastNLP(object): for ch in sent: characters.append(ch) data.append(characters) - # To refactor: this is used in make_output - self.data = data return data - def make_output(self, results): - """ - Transform model output into user-friendly contents. - Example: In CWS, convert labeling into segmented text. - :param results: - :return: - """ - outputs = [] - for sent_char, sent_label in zip(self.data, results): - words = [] - word = "" - for char, label in zip(sent_char, sent_label): - if label[0] == "B": - if word != "": - words.append(word) - word = char - elif label[0] == "M": - word += char - elif label[0] == "E": - word += char - words.append(word) - word = "" - elif label[0] == "S": - if word != "": - words.append(word) - word = "" - words.append(char) - else: - raise ValueError("invalid label") - outputs.append(" ".join(words)) + def _make_output(self, results, infer_input): + if self.infer_type == "seq_label": + outputs = make_seq_label_output(results, infer_input) + elif self.infer_type == "text_class": + outputs = make_class_output(results, infer_input) + else: + raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) return outputs + + +def make_seq_label_output(result, infer_input): + """ + Transform model output into user-friendly contents. + :param result: 1-D list of strings. (model output) + :param infer_input: 2-D list of string (model input) + :return outputs: + """ + return result + + +def make_class_output(result, infer_input): + return result + + +def interpret_word_seg_results(infer_input, results): + """ + Transform model output into user-friendly contents. + Example: In CWS, convert labeling into segmented text. + :param results: list of strings. (model output) + :param infer_input: 2-D list of string (model input) + :return output: list of strings + """ + outputs = [] + for sent_char, sent_label in zip(infer_input, results): + words = [] + word = "" + for char, label in zip(sent_char, sent_label): + if label[0] == "B": + if word != "": + words.append(word) + word = char + elif label[0] == "M": + word += char + elif label[0] == "E": + word += char + words.append(word) + word = "" + elif label[0] == "S": + if word != "": + words.append(word) + word = "" + words.append(char) + else: + raise ValueError("invalid label") + outputs.append(" ".join(words)) + return outputs diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 66bb5ecc..b6dcafb3 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -15,12 +15,17 @@ class CNNText(torch.nn.Module): Classification.' """ - def __init__(self, class_num=9, - kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5], - embed_num=1000, embed_dim=300, pretrained_embed=None, - drop_prob=0.5): + def __init__(self, args): super(CNNText, self).__init__() + class_num = args["num_classes"] + kernel_nums = [100, 100, 100] + kernel_sizes = [3, 4, 5] + embed_num = args["vocab_size"] + embed_dim = 300 + pretrained_embed = None + drop_prob = 0.5 + # no support for pre-trained embedding currently self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) self.conv_pool = ConvMaxpool( diff --git a/test/data_for_tests/config b/test/data_for_tests/config index a88ecd5d..60a7c9a5 100644 --- a/test/data_for_tests/config +++ b/test/data_for_tests/config @@ -89,5 +89,20 @@ rnn_hidden_units = 100 rnn_layers = 1 rnn_bi_direction = true word_emb_dim = 100 -vocab_size = 52 -num_classes = 22 \ No newline at end of file +vocab_size = 53 +num_classes = 27 + +[text_class] +epochs = 1 +batch_size = 10 +pickle_path = "./data_for_tests/" +validate = false +save_best_dev = false +model_saved_path = "./data_for_tests/" +use_cuda = true +learn_rate = 1e-3 +momentum = 0.9 + +[text_class_model] +vocab_size = 867 +num_classes = 18 \ No newline at end of file diff --git a/test/seq_labeling.py b/test/seq_labeling.py index db171215..adc686df 100644 --- a/test/seq_labeling.py +++ b/test/seq_labeling.py @@ -112,5 +112,5 @@ def train_and_test(): if __name__ == "__main__": - # train_and_test() - infer() + train_and_test() + # infer() diff --git a/test/test_fastNLP.py b/test/test_fastNLP.py index 35bac153..0b64b261 100644 --- a/test/test_fastNLP.py +++ b/test/test_fastNLP.py @@ -1,9 +1,18 @@ from fastNLP.fastnlp import FastNLP -def foo(): +def word_seg(): nlp = FastNLP("./data_for_tests/") - nlp.load("zh_pos_tag_model") + nlp.load("seq_label_model") + text = "这是最好的基于深度学习的中文分词系统。" + result = nlp.run(text) + print(result) + print("FastNLP finished!") + + +def text_class(): + nlp = FastNLP("./data_for_tests/") + nlp.load("text_class_model") text = "这是最好的基于深度学习的中文分词系统。" result = nlp.run(text) print(result) @@ -11,4 +20,4 @@ def foo(): if __name__ == "__main__": - foo() + text_class() diff --git a/test/text_classify.py b/test/text_classify.py index f18d4a38..7400b1da 100644 --- a/test/text_classify.py +++ b/test/text_classify.py @@ -5,6 +5,7 @@ import os from fastNLP.core.inference import ClassificationInfer from fastNLP.core.trainer import ClassificationTrainer +from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.dataset_loader import ClassDatasetLoader from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.preprocess import ClassPreprocess @@ -29,9 +30,13 @@ def infer(): print("vocabulary size:", vocab_size) print("number of classes:", n_classes) + model_args = ConfigSection() + ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args}) + # construct model print("Building model...") - cnn = CNNText(class_num=n_classes, embed_num=vocab_size) + cnn = CNNText(model_args) + # Dump trained parameters into the model ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") print("model loaded!") @@ -42,6 +47,9 @@ def infer(): def train(): + train_args, model_args = ConfigSection(), ConfigSection() + ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args}) + # load dataset print("Loading data...") ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) @@ -56,19 +64,11 @@ def train(): # construct model print("Building model...") - cnn = CNNText(class_num=n_classes, embed_num=vocab_size) + cnn = CNNText(model_args) # train print("Training...") - train_args = { - "epochs": 1, - "batch_size": 10, - "pickle_path": data_dir, - "validate": False, - "save_best_dev": False, - "model_saved_path": "./data_for_tests/", - "use_cuda": True - } + trainer = ClassificationTrainer(train_args) trainer.train(cnn)