diff --git a/fastNLP/core/inference.py b/fastNLP/core/inference.py index 11a3ba48..3937e3f4 100644 --- a/fastNLP/core/inference.py +++ b/fastNLP/core/inference.py @@ -63,7 +63,7 @@ class Inference(object): """ Perform inference. :param network: - :param data: multi-level lists of strings + :param data: two-level lists of strings :return result: the model outputs """ # transform strings into indices @@ -97,7 +97,7 @@ class Inference(object): def prepare_input(self, data): """ - Transform three-level list of strings into that of index. + Transform two-level list of strings into that of index. :param data: [ [word_11, word_12, ...], @@ -140,7 +140,7 @@ class SeqLabelInfer(Inference): mask = mask.byte().view(batch_size, max_len) y = network(x) prediction = network.prediction(y, mask) - return torch.Tensor(prediction, required_grad=False) + return torch.Tensor(prediction) def make_batch(self, iterator, data, use_cuda): return make_batch(iterator, data, use_cuda, output_length=True) @@ -149,7 +149,7 @@ class SeqLabelInfer(Inference): """ Transform list of batch outputs into strings. :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length]. - :return: + :return results: 2-D list of strings """ results = [] for batch in batch_outputs: @@ -178,7 +178,7 @@ class ClassificationInfer(Inference): """ Transform list of batch outputs into strings. :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. - :return: + :return results: list of strings """ results = [] for batch_out in batch_outputs: diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 425f2029..3799eed1 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -37,10 +37,6 @@ class BaseTester(object): else: self.model = network - # no backward setting for model - for param in network.parameters(): - param.requires_grad = False - # turn on the testing mode; clean up the history self.mode(network, test=True) self.eval_history.clear() @@ -112,6 +108,7 @@ class SeqLabelTester(BaseTester): super(SeqLabelTester, self).__init__(test_args) self.max_len = None self.mask = None + self.seq_len = None self.batch_result = None def data_forward(self, network, inputs): @@ -125,7 +122,7 @@ class SeqLabelTester(BaseTester): if torch.cuda.is_available() and self.use_cuda: mask = mask.cuda() self.mask = mask - + self.seq_len = seq_len y = network(x) return y diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d7515a40..8fcdc692 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -315,14 +315,8 @@ class ClassificationTrainer(BaseTrainer): def __init__(self, train_args): super(ClassificationTrainer, self).__init__(train_args) - if "learn_rate" in train_args: - self.learn_rate = train_args["learn_rate"] - else: - self.learn_rate = 1e-3 - if "momentum" in train_args: - self.momentum = train_args["momentum"] - else: - self.momentum = 0.9 + self.learn_rate = train_args["learn_rate"] + self.momentum = train_args["momentum"] self.iterator = None self.loss_func = None diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py index 066123da..e67fc63b 100644 --- a/fastNLP/fastnlp.py +++ b/fastNLP/fastnlp.py @@ -1,4 +1,4 @@ -from fastNLP.core.inference import Inference +from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.model_loader import ModelLoader @@ -10,14 +10,28 @@ Example: "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] """ FastNLP_MODEL_COLLECTION = { - "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] + "seq_label_model": { + "url": "www.fudan.edu.cn", + "class": "sequence_modeling.SeqLabeling", + "pickle": "seq_label_model.pkl", + "type": "seq_label" + }, + "text_class_model": { + "url": "www.fudan.edu.cn", + "class": "cnn_text_classification.CNNText", + "pickle": "text_class_model.pkl", + "type": "text_class" + } } +CONFIG_FILE_NAME = "config" +SECTION_NAME = "text_class_model" + class FastNLP(object): """ High-level interface for direct model inference. - Usage: + Example Usage: fastnlp = FastNLP() fastnlp.load("zh_pos_tag_model") text = "这是最好的基于深度学习的中文分词系统。" @@ -35,6 +49,7 @@ class FastNLP(object): """ self.model_dir = model_dir self.model = None + self.infer_type = None # "seq_label"/"text_class" def load(self, model_name): """ @@ -46,21 +61,21 @@ class FastNLP(object): raise ValueError("No FastNLP model named {}.".format(model_name)) if not self.model_exist(model_dir=self.model_dir): - self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0]) + self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) - model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1]) + model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) model_args = ConfigSection() - # To do: customized config file for model init parameters - ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args}) + ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) # Construct the model model = model_class(model_args) # To do: framework independent - ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2]) + ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) self.model = model + self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] print("Model loaded. ") @@ -71,12 +86,16 @@ class FastNLP(object): :return results: """ - infer = Inference(self.model_dir) + infer = self._create_inference(self.model_dir) + + # string ---> 2-D list of string infer_input = self.string_to_list(raw_input) + # 2-D list of string ---> list of strings results = infer.predict(self.model, infer_input) - outputs = self.make_output(results) + # list of strings ---> final answers + outputs = self._make_output(results, infer_input) return outputs @staticmethod @@ -95,6 +114,14 @@ class FastNLP(object): module = getattr(module, sub) return module + def _create_inference(self, model_dir): + if self.infer_type == "seq_label": + return SeqLabelInfer(model_dir) + elif self.infer_type == "text_class": + return ClassificationInfer(model_dir) + else: + raise ValueError("fail to create inference instance") + def _load(self, model_dir, model_name): # To do return 0 @@ -117,7 +144,6 @@ class FastNLP(object): def string_to_list(self, text, delimiter="\n"): """ - For word seg only, currently. This function is used to transform raw input to lists, which is done by DatasetLoader in training. Split text string into three-level lists. [ @@ -127,7 +153,7 @@ class FastNLP(object): ] :param text: string :param delimiter: str, character used to split text into sentences. - :return data: three-level lists + :return data: two-level lists """ data = [] sents = text.strip().split(delimiter) @@ -136,38 +162,61 @@ class FastNLP(object): for ch in sent: characters.append(ch) data.append(characters) - # To refactor: this is used in make_output - self.data = data return data - def make_output(self, results): - """ - Transform model output into user-friendly contents. - Example: In CWS, convert labeling into segmented text. - :param results: - :return: - """ - outputs = [] - for sent_char, sent_label in zip(self.data, results): - words = [] - word = "" - for char, label in zip(sent_char, sent_label): - if label[0] == "B": - if word != "": - words.append(word) - word = char - elif label[0] == "M": - word += char - elif label[0] == "E": - word += char - words.append(word) - word = "" - elif label[0] == "S": - if word != "": - words.append(word) - word = "" - words.append(char) - else: - raise ValueError("invalid label") - outputs.append(" ".join(words)) + def _make_output(self, results, infer_input): + if self.infer_type == "seq_label": + outputs = make_seq_label_output(results, infer_input) + elif self.infer_type == "text_class": + outputs = make_class_output(results, infer_input) + else: + raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) return outputs + + +def make_seq_label_output(result, infer_input): + """ + Transform model output into user-friendly contents. + :param result: 1-D list of strings. (model output) + :param infer_input: 2-D list of string (model input) + :return outputs: + """ + return result + + +def make_class_output(result, infer_input): + return result + + +def interpret_word_seg_results(infer_input, results): + """ + Transform model output into user-friendly contents. + Example: In CWS, convert labeling into segmented text. + :param results: list of strings. (model output) + :param infer_input: 2-D list of string (model input) + :return output: list of strings + """ + outputs = [] + for sent_char, sent_label in zip(infer_input, results): + words = [] + word = "" + for char, label in zip(sent_char, sent_label): + if label[0] == "B": + if word != "": + words.append(word) + word = char + elif label[0] == "M": + word += char + elif label[0] == "E": + word += char + words.append(word) + word = "" + elif label[0] == "S": + if word != "": + words.append(word) + word = "" + words.append(char) + else: + raise ValueError("invalid label") + outputs.append(" ".join(words)) + return outputs diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py index 66bb5ecc..b6dcafb3 100644 --- a/fastNLP/models/cnn_text_classification.py +++ b/fastNLP/models/cnn_text_classification.py @@ -15,12 +15,17 @@ class CNNText(torch.nn.Module): Classification.' """ - def __init__(self, class_num=9, - kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5], - embed_num=1000, embed_dim=300, pretrained_embed=None, - drop_prob=0.5): + def __init__(self, args): super(CNNText, self).__init__() + class_num = args["num_classes"] + kernel_nums = [100, 100, 100] + kernel_sizes = [3, 4, 5] + embed_num = args["vocab_size"] + embed_dim = 300 + pretrained_embed = None + drop_prob = 0.5 + # no support for pre-trained embedding currently self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) self.conv_pool = ConvMaxpool( diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index b28ef604..5addc73e 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -56,3 +56,49 @@ class SeqLabeling(BaseModel): """ tag_seq = self.Crf.viterbi_decode(x, mask) return tag_seq + + +class AdvSeqLabel(SeqLabeling): + """ + Advanced Sequence Labeling Model + """ + + def __init__(self, args, emb=None): + super(AdvSeqLabel, self).__init__(args) + + vocab_size = args["vocab_size"] + word_emb_dim = args["word_emb_dim"] + hidden_dim = args["rnn_hidden_units"] + num_classes = args["num_classes"] + + self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) + self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) + self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) + self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) + self.relu = torch.nn.ReLU() + self.drop = torch.nn.Dropout(0.3) + self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) + + self.Crf = decoder.CRF.ConditionalRandomField(num_classes) + + def forward(self, x): + """ + :param x: LongTensor, [batch_size, mex_len] + :return y: [batch_size, mex_len, tag_size] + """ + batch_size = x.size(0) + max_len = x.size(1) + x = self.Embedding(x) + # [batch_size, max_len, word_emb_dim] + x = self.Rnn(x) + # [batch_size, max_len, hidden_size * direction] + x = x.contiguous() + x = x.view(batch_size * max_len, -1) + x = self.Linear1(x) + x = self.batch_norm(x) + x = self.relu(x) + x = self.drop(x) + x = self.Linear2(x) + x = x.view(batch_size, max_len, -1) + # [batch_size, max_len, num_classes] + return x diff --git a/test/data_for_tests/config b/test/data_for_tests/config index a88ecd5d..60a7c9a5 100644 --- a/test/data_for_tests/config +++ b/test/data_for_tests/config @@ -89,5 +89,20 @@ rnn_hidden_units = 100 rnn_layers = 1 rnn_bi_direction = true word_emb_dim = 100 -vocab_size = 52 -num_classes = 22 \ No newline at end of file +vocab_size = 53 +num_classes = 27 + +[text_class] +epochs = 1 +batch_size = 10 +pickle_path = "./data_for_tests/" +validate = false +save_best_dev = false +model_saved_path = "./data_for_tests/" +use_cuda = true +learn_rate = 1e-3 +momentum = 0.9 + +[text_class_model] +vocab_size = 867 +num_classes = 18 \ No newline at end of file diff --git a/test/data_for_tests/people.txt b/test/data_for_tests/people.txt index e4909679..9ef0de6d 100644 --- a/test/data_for_tests/people.txt +++ b/test/data_for_tests/people.txt @@ -123,6 +123,160 @@ 张 S-q ) S-w +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + +中 B-nt +共 M-nt +中 M-nt +央 E-nt +总 B-n +书 M-n +记 E-n +、 S-w +国 B-n +家 E-n +主 B-n +席 E-n +江 B-nr +泽 M-nr +民 E-nr + +( S-w +一 B-t +九 M-t +九 M-t +七 M-t +年 E-t +十 B-t +二 M-t +月 E-t +三 B-t +十 M-t +一 M-t +日 E-t +) S-w + +1 B-t +2 M-t +月 E-t +3 B-t +1 M-t +日 E-t +, S-w +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + 迈 B-v 向 E-v 充 B-v diff --git a/test/ner.py b/test/ner.py new file mode 100644 index 00000000..beaac1d6 --- /dev/null +++ b/test/ner.py @@ -0,0 +1,137 @@ +import _pickle +import os + +import numpy as np +import torch + +from fastNLP.core.tester import SeqLabelTester +from fastNLP.core.trainer import SeqLabelTrainer +from fastNLP.loader.preprocess import POSPreprocess +from fastNLP.models.sequence_modeling import AdvSeqLabel + + +class MyNERTrainer(SeqLabelTrainer): + def __init__(self, train_args): + super(MyNERTrainer, self).__init__(train_args) + self.scheduler = None + + def define_optimizer(self): + """ + override + :return: + """ + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) + self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) + + def update(self): + """ + override + :return: + """ + self.optimizer.step() + self.scheduler.step() + + def _create_validator(self, valid_args): + return MyNERTester(valid_args) + + def best_eval_result(self, validator): + accuracy = validator.metrics() + if accuracy > self.best_accuracy: + self.best_accuracy = accuracy + return True + else: + return False + + +class MyNERTester(SeqLabelTester): + def __init__(self, test_args): + super(MyNERTester, self).__init__(test_args) + + def _evaluate(self, prediction, batch_y, seq_len): + """ + :param prediction: [batch_size, seq_len, num_classes] + :param batch_y: [batch_size, seq_len] + :param seq_len: [batch_size] + :return: + """ + summ = 0 + correct = 0 + _, indices = torch.max(prediction, 2) + for p, y, l in zip(indices, batch_y, seq_len): + summ += l + correct += np.sum(p[:l].cpu().numpy() == y[:l].cpu().numpy()) + return float(correct / summ) + + def evaluate(self, predict, truth): + return self._evaluate(predict, truth, self.seq_len) + + def metrics(self): + return np.mean(self.eval_history) + + def show_matrices(self): + return "dev accuracy={:.2f}".format(float(self.metrics())) + + +def embedding_process(emb_file, word_dict, emb_dim, emb_pkl): + if os.path.exists(emb_pkl): + with open(emb_pkl, "rb") as f: + embedding_np = _pickle.load(f) + return embedding_np + with open(emb_file, "r", encoding="utf-8") as f: + embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim)) + for line in f: + line = line.strip().split() + if len(line) != emb_dim + 1: + continue + if line[0] in word_dict: + embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]] + with open(emb_pkl, "wb") as f: + _pickle.dump(embedding_np, f) + return embedding_np + + +def data_load(data_file): + with open(data_file, "r", encoding="utf-8") as f: + all_data = [] + sent = [] + label = [] + for line in f: + line = line.strip().split() + + if not len(line) <= 1: + sent.append(line[0]) + label.append(line[1]) + else: + all_data.append([sent, label]) + sent = [] + label = [] + return all_data + + +data_path = "data_for_tests/people.txt" +pick_path = "data_for_tests/" +emb_path = "data_for_tests/emb50.txt" +save_path = "data_for_tests/" +if __name__ == "__main__": + data = data_load(data_path) + p = POSPreprocess(data, pickle_path=pick_path, train_dev_split=0.3) + # emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl")) + emb = None + args = {"epochs": 20, + "batch_size": 1, + "pickle_path": pick_path, + "validate": True, + "save_best_dev": True, + "model_saved_path": save_path, + "use_cuda": True, + + "vocab_size": p.vocab_size, + "num_classes": p.num_classes, + "word_emb_dim": 50, + "rnn_hidden_units": 100 + } + # emb = torch.Tensor(emb).float().cuda() + networks = AdvSeqLabel(args, emb) + trainer = MyNERTrainer(args) + trainer.train(network=networks) + print("Training finished!") diff --git a/test/ner_decode.py b/test/ner_decode.py new file mode 100644 index 00000000..a319a20e --- /dev/null +++ b/test/ner_decode.py @@ -0,0 +1,129 @@ +import _pickle +import os + +import torch + +from fastNLP.core.inference import SeqLabelInfer +from fastNLP.core.trainer import SeqLabelTrainer +from fastNLP.loader.model_loader import ModelLoader +from fastNLP.models.sequence_modeling import AdvSeqLabel + + +class Decode(SeqLabelTrainer): + def __init__(self, args): + super(Decode, self).__init__(args) + + def decoder(self, network, sents, model_path): + self.model = network + self.model.load_state_dict(torch.load(model_path)) + out_put = [] + self.mode(network, test=True) + for batch_x in sents: + prediction = self.data_forward(self.model, batch_x) + + seq_tag = self.model.prediction(prediction, batch_x[1]) + + out_put.append(list(seq_tag)[0]) + return out_put + + +def process_sent(sents, word2id): + sents_num = [] + for s in sents: + sent_num = [] + for c in s: + if c in word2id: + sent_num.append(word2id[c]) + else: + sent_num.append(word2id[""]) + sents_num.append(([sent_num], [len(sent_num)])) # batch_size is 1 + + return sents_num + + +def process_tag(sents, tags, id2class): + Tags = [] + for ttt in tags: + Tags.append([id2class[t] for t in ttt]) + + Segs = [] + PosNers = [] + for sent, tag in zip(sents, tags): + word__ = [] + lll__ = [] + for c, t in zip(sent, tag): + + t = id2class[t] + l = t.split("-") + split_ = l[0] + pn = l[1] + + if split_ == "S": + word__.append(c) + lll__.append(pn) + word_1 = "" + elif split_ == "E": + word_1 += c + word__.append(word_1) + lll__.append(pn) + word_1 = "" + elif split_ == "B": + word_1 = "" + word_1 += c + else: + word_1 += c + Segs.append(word__) + PosNers.append(lll__) + return Segs, PosNers + + +pickle_path = "data_for_tests/" +model_path = "data_for_tests/model_best_dev.pkl" +if __name__ == "__main__": + + with open(os.path.join(pickle_path, "id2word.pkl"), "rb") as f: + id2word = _pickle.load(f) + with open(os.path.join(pickle_path, "word2id.pkl"), "rb") as f: + word2id = _pickle.load(f) + with open(os.path.join(pickle_path, "id2class.pkl"), "rb") as f: + id2class = _pickle.load(f) + + sent = ["中共中央总书记、国家主席江泽民", + "逆向处理输入序列并返回逆序后的序列"] # here is input + + args = {"epochs": 1, + "batch_size": 1, + "pickle_path": "data_for_tests/", + "validate": True, + "save_best_dev": True, + "model_saved_path": "data_for_tests/", + "use_cuda": False, + + "vocab_size": len(word2id), + "num_classes": len(id2class), + "word_emb_dim": 50, + "rnn_hidden_units": 100, + } + """ + network = AdvSeqLabel(args, None) + decoder_ = Decode(args) + tags_num = decoder_.decoder(network, process_sent(sent, word2id), model_path=model_path) + output_seg, output_pn = process_tag(sent, tags_num, id2class) # here is output + print(output_seg) + print(output_pn) + """ + # Define the same model + model = AdvSeqLabel(args, None) + + # Dump trained parameters into the model + ModelLoader.load_pytorch(model, "./data_for_tests/model_best_dev.pkl") + print("model loaded!") + + # Inference interface + infer = SeqLabelInfer(pickle_path) + sent = [[ch for ch in s] for s in sent] + results = infer.predict(model, sent) + + for res in results: + print(res) + print("Inference finished!") diff --git a/test/seq_labeling.py b/test/seq_labeling.py index db171215..adc686df 100644 --- a/test/seq_labeling.py +++ b/test/seq_labeling.py @@ -112,5 +112,5 @@ def train_and_test(): if __name__ == "__main__": - # train_and_test() - infer() + train_and_test() + # infer() diff --git a/test/test_fastNLP.py b/test/test_fastNLP.py index 35bac153..0b64b261 100644 --- a/test/test_fastNLP.py +++ b/test/test_fastNLP.py @@ -1,9 +1,18 @@ from fastNLP.fastnlp import FastNLP -def foo(): +def word_seg(): nlp = FastNLP("./data_for_tests/") - nlp.load("zh_pos_tag_model") + nlp.load("seq_label_model") + text = "这是最好的基于深度学习的中文分词系统。" + result = nlp.run(text) + print(result) + print("FastNLP finished!") + + +def text_class(): + nlp = FastNLP("./data_for_tests/") + nlp.load("text_class_model") text = "这是最好的基于深度学习的中文分词系统。" result = nlp.run(text) print(result) @@ -11,4 +20,4 @@ def foo(): if __name__ == "__main__": - foo() + text_class() diff --git a/test/text_classify.py b/test/text_classify.py index f18d4a38..7400b1da 100644 --- a/test/text_classify.py +++ b/test/text_classify.py @@ -5,6 +5,7 @@ import os from fastNLP.core.inference import ClassificationInfer from fastNLP.core.trainer import ClassificationTrainer +from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.dataset_loader import ClassDatasetLoader from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.preprocess import ClassPreprocess @@ -29,9 +30,13 @@ def infer(): print("vocabulary size:", vocab_size) print("number of classes:", n_classes) + model_args = ConfigSection() + ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args}) + # construct model print("Building model...") - cnn = CNNText(class_num=n_classes, embed_num=vocab_size) + cnn = CNNText(model_args) + # Dump trained parameters into the model ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl") print("model loaded!") @@ -42,6 +47,9 @@ def infer(): def train(): + train_args, model_args = ConfigSection(), ConfigSection() + ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args}) + # load dataset print("Loading data...") ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) @@ -56,19 +64,11 @@ def train(): # construct model print("Building model...") - cnn = CNNText(class_num=n_classes, embed_num=vocab_size) + cnn = CNNText(model_args) # train print("Training...") - train_args = { - "epochs": 1, - "batch_size": 10, - "pickle_path": data_dir, - "validate": False, - "save_best_dev": False, - "model_saved_path": "./data_for_tests/", - "use_cuda": True - } + trainer = ClassificationTrainer(train_args) trainer.train(cnn)