Chinese word segmentation interfacetags/v0.1.0
| @@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||
| def save_pickle(obj, pickle_path, file_name): | |||
| with open(os.path.join(pickle_path, file_name), "wb") as f: | |||
| _pickle.dump(obj, f) | |||
| print("{} saved. ".format(file_name)) | |||
| print("{} saved in {}".format(file_name, pickle_path)) | |||
| def load_pickle(pickle_path, file_name): | |||
| with open(os.path.join(pickle_path, file_name), "rb") as f: | |||
| obj = _pickle.load(f) | |||
| print("{} loaded. ".format(file_name)) | |||
| print("{} loaded from {}".format(file_name, pickle_path)) | |||
| return obj | |||
| @@ -98,7 +98,7 @@ class BaseTester(object): | |||
| print_output = "[test step {}] {}".format(step, eval_results) | |||
| logger.info(print_output) | |||
| if step % self.print_every_step == 0: | |||
| if self.print_every_step > 0 and step % self.print_every_step == 0: | |||
| print(print_output) | |||
| step += 1 | |||
| @@ -187,7 +187,7 @@ class SeqLabelTester(BaseTester): | |||
| # make sure "results" is in the same device as "truth" | |||
| results = results.to(truth) | |||
| accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||
| return [loss.data, accuracy.data] | |||
| return [float(loss), float(accuracy)] | |||
| def metrics(self): | |||
| batch_loss = np.mean([x[0] for x in self.eval_history]) | |||
| @@ -4,7 +4,6 @@ import os | |||
| import time | |||
| from datetime import timedelta | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.action import Action | |||
| @@ -47,7 +46,7 @@ class BaseTrainer(object): | |||
| Otherwise, error will raise. | |||
| """ | |||
| default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | |||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | |||
| "save_best_dev": True, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||
| "loss": Loss(None), | |||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | |||
| } | |||
| @@ -86,6 +85,7 @@ class BaseTrainer(object): | |||
| self.save_best_dev = default_args["save_best_dev"] | |||
| self.use_cuda = default_args["use_cuda"] | |||
| self.model_name = default_args["model_name"] | |||
| self.print_every_step = default_args["print_every_step"] | |||
| self._model = None | |||
| self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | |||
| @@ -93,48 +93,35 @@ class BaseTrainer(object): | |||
| self._optimizer_proto = default_args["optimizer"] | |||
| def train(self, network, train_data, dev_data=None): | |||
| """General Training Steps | |||
| """General Training Procedure | |||
| :param network: a model | |||
| :param train_data: three-level list, the training set. | |||
| :param dev_data: three-level list, the validation data (optional) | |||
| The method is framework independent. | |||
| Work by calling the following methods: | |||
| - prepare_input | |||
| - mode | |||
| - define_optimizer | |||
| - data_forward | |||
| - get_loss | |||
| - grad_backward | |||
| - update | |||
| Subclasses must implement these methods with a specific framework. | |||
| """ | |||
| # prepare model and data, transfer model to gpu if available | |||
| # transfer model to gpu if available | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| self._model = network.cuda() | |||
| # self._model is used to access model-specific loss | |||
| else: | |||
| self._model = network | |||
| # define tester over dev data | |||
| # define Tester over dev data | |||
| if self.validate: | |||
| default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
| "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
| "use_cuda": self.use_cuda} | |||
| "use_cuda": self.use_cuda, "print_every_step": 0} | |||
| validator = self._create_validator(default_valid_args) | |||
| logger.info("validator defined as {}".format(str(validator))) | |||
| # optimizer and loss | |||
| self.define_optimizer() | |||
| logger.info("optimizer defined as {}".format(str(self._optimizer))) | |||
| self.define_loss() | |||
| logger.info("loss function defined as {}".format(str(self._loss_func))) | |||
| # main training epochs | |||
| n_samples = len(train_data) | |||
| n_batches = n_samples // self.batch_size | |||
| n_print = 1 | |||
| # main training procedure | |||
| start = time.time() | |||
| logger.info("training epochs started") | |||
| for epoch in range(1, self.n_epochs + 1): | |||
| logger.info("training epoch {}".format(epoch)) | |||
| @@ -144,23 +131,30 @@ class BaseTrainer(object): | |||
| data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) | |||
| logger.info("prepared data iterator") | |||
| self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) | |||
| # one forward and backward pass | |||
| self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch) | |||
| # validation | |||
| if self.validate: | |||
| logger.info("validation started") | |||
| validator.test(network, dev_data) | |||
| if self.save_best_dev and self.best_eval_result(validator): | |||
| self.save_model(network, self.model_name) | |||
| print("saved better model selected by dev") | |||
| logger.info("saved better model selected by dev") | |||
| print("Saved better model selected by validation.") | |||
| logger.info("Saved better model selected by validation.") | |||
| valid_results = validator.show_matrices() | |||
| print("[epoch {}] {}".format(epoch, valid_results)) | |||
| logger.info("[epoch {}] {}".format(epoch, valid_results)) | |||
| def _train_step(self, data_iterator, network, **kwargs): | |||
| """Training process in one epoch.""" | |||
| """Training process in one epoch. | |||
| kwargs should contain: | |||
| - n_print: int, print training information every n steps. | |||
| - start: time.time(), the starting time of this step. | |||
| - epoch: int, | |||
| """ | |||
| step = 0 | |||
| for batch_x, batch_y in self.make_batch(data_iterator): | |||
| @@ -170,7 +164,7 @@ class BaseTrainer(object): | |||
| self.grad_backward(loss) | |||
| self.update() | |||
| if step % kwargs["n_print"] == 0: | |||
| if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||
| end = time.time() | |||
| diff = timedelta(seconds=round(end - kwargs["start"])) | |||
| print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( | |||
| @@ -287,10 +281,11 @@ class BaseTrainer(object): | |||
| raise NotImplementedError | |||
| def save_model(self, network, model_name): | |||
| """ | |||
| """Save this model with such a name. | |||
| This method may be called multiple times by Trainer to overwritten a better model. | |||
| :param network: the PyTorch model | |||
| :param model_name: str | |||
| model_best_dev.pkl may be overwritten by a better model in future epochs. | |||
| """ | |||
| if model_name[-4:] != ".pkl": | |||
| model_name += ".pkl" | |||
| @@ -300,33 +295,9 @@ class BaseTrainer(object): | |||
| raise NotImplementedError | |||
| class ToyTrainer(BaseTrainer): | |||
| """ | |||
| An example to show the definition of Trainer. | |||
| """ | |||
| def __init__(self, training_args): | |||
| super(ToyTrainer, self).__init__(training_args) | |||
| def load_train_data(self, data_path): | |||
| data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
| data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
| return data_train, data_dev, 0, 1 | |||
| def data_forward(self, network, x): | |||
| return network(x) | |||
| def grad_backward(self, loss): | |||
| self._model.zero_grad() | |||
| loss.backward() | |||
| def get_loss(self, pred, truth): | |||
| return np.mean(np.square(pred - truth)) | |||
| class SeqLabelTrainer(BaseTrainer): | |||
| """ | |||
| Trainer for Sequence Modeling | |||
| Trainer for Sequence Labeling | |||
| """ | |||
| @@ -384,7 +355,7 @@ class SeqLabelTrainer(BaseTrainer): | |||
| class ClassificationTrainer(BaseTrainer): | |||
| """Trainer for classification.""" | |||
| """Trainer for text classification.""" | |||
| def __init__(self, **train_args): | |||
| super(ClassificationTrainer, self).__init__(**train_args) | |||
| @@ -1,4 +1,5 @@ | |||
| from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | |||
| from fastNLP.core.preprocess import load_pickle | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| @@ -7,14 +8,13 @@ mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||
| Notice that the class of the model should be in "models" directory. | |||
| Example: | |||
| "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] | |||
| """ | |||
| FastNLP_MODEL_COLLECTION = { | |||
| "seq_label_model": { | |||
| "url": "www.fudan.edu.cn", | |||
| "class": "sequence_modeling.SeqLabeling", | |||
| "class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/ | |||
| "pickle": "seq_label_model.pkl", | |||
| "type": "seq_label" | |||
| "type": "seq_label", | |||
| "config_file_name": "config", # the name of the config file which stores model initialization parameters | |||
| "config_section_name": "text_class_model" # the name of the section in the config file which stores model init params | |||
| }, | |||
| "text_class_model": { | |||
| "url": "www.fudan.edu.cn", | |||
| @@ -22,11 +22,18 @@ FastNLP_MODEL_COLLECTION = { | |||
| "pickle": "text_class_model.pkl", | |||
| "type": "text_class" | |||
| } | |||
| """ | |||
| FastNLP_MODEL_COLLECTION = { | |||
| "cws_basic_model": { | |||
| "url": "", | |||
| "class": "sequence_modeling.AdvSeqLabel", | |||
| "pickle": "cws_basic_model_v_0.pkl", | |||
| "type": "seq_label", | |||
| "config_file_name": "config", | |||
| "config_section_name": "text_class_model" | |||
| } | |||
| } | |||
| CONFIG_FILE_NAME = "config" | |||
| SECTION_NAME = "text_class_model" | |||
| class FastNLP(object): | |||
| """ | |||
| @@ -51,10 +58,13 @@ class FastNLP(object): | |||
| self.model = None | |||
| self.infer_type = None # "seq_label"/"text_class" | |||
| def load(self, model_name): | |||
| def load(self, model_name, config_file="config", section_name="model"): | |||
| """ | |||
| Load a pre-trained FastNLP model together with additional data. | |||
| :param model_name: str, the name of a FastNLP model. | |||
| :param config_file: str, the name of the config file which stores the initialization information of the model. | |||
| (default: "config") | |||
| :param section_name: str, the name of the corresponding section in the config file. (default: model) | |||
| """ | |||
| assert type(model_name) is str | |||
| if model_name not in FastNLP_MODEL_COLLECTION: | |||
| @@ -64,37 +74,47 @@ class FastNLP(object): | |||
| self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) | |||
| model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) | |||
| print("Restore model class {}".format(str(model_class))) | |||
| model_args = ConfigSection() | |||
| ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args}) | |||
| ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args}) | |||
| print("Restore model hyper-parameters {}".format(str(model_args.data))) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(self.model_dir, "word2id.pkl") | |||
| model_args["vocab_size"] = len(word2index) | |||
| index2label = load_pickle(self.model_dir, "id2class.pkl") | |||
| model_args["num_classes"] = len(index2label) | |||
| # Construct the model | |||
| model = model_class(model_args) | |||
| print("Model constructed.") | |||
| # To do: framework independent | |||
| ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) | |||
| print("Model weights loaded.") | |||
| self.model = model | |||
| self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] | |||
| print("Model loaded. ") | |||
| print("Inference ready.") | |||
| def run(self, raw_input): | |||
| """ | |||
| Perform inference over given input using the loaded model. | |||
| :param raw_input: str, raw text | |||
| :param raw_input: list of string. Each list is an input query. | |||
| :return results: | |||
| """ | |||
| infer = self._create_inference(self.model_dir) | |||
| # string ---> 2-D list of string | |||
| infer_input = self.string_to_list(raw_input) | |||
| # tokenize: list of string ---> 2-D list of string | |||
| infer_input = self.tokenize(raw_input, language="zh") | |||
| # 2-D list of string ---> list of strings | |||
| # 2-D list of string ---> 2-D list of tags | |||
| results = infer.predict(self.model, infer_input) | |||
| # list of strings ---> final answers | |||
| # 2-D list of tags ---> list of final answers | |||
| outputs = self._make_output(results, infer_input) | |||
| return outputs | |||
| @@ -142,81 +162,100 @@ class FastNLP(object): | |||
| """ | |||
| return True | |||
| def string_to_list(self, text, delimiter="\n"): | |||
| """ | |||
| This function is used to transform raw input to lists, which is done by DatasetLoader in training. | |||
| Split text string into three-level lists. | |||
| [ | |||
| [word_11, word_12, ...], | |||
| [word_21, word_22, ...], | |||
| ... | |||
| ] | |||
| :param text: string | |||
| :param delimiter: str, character used to split text into sentences. | |||
| :return data: two-level lists | |||
| def tokenize(self, text, language): | |||
| """Extract tokens from strings. | |||
| For English, extract words separated by space. | |||
| For Chinese, extract characters. | |||
| TODO: more complex tokenization methods | |||
| :param text: list of string | |||
| :param language: str, one of ('zh', 'en'), Chinese or English. | |||
| :return data: list of list of string, each string is a token. | |||
| """ | |||
| assert language in ("zh", "en") | |||
| data = [] | |||
| sents = text.strip().split(delimiter) | |||
| for sent in sents: | |||
| characters = [] | |||
| for ch in sent: | |||
| characters.append(ch) | |||
| data.append(characters) | |||
| for sent in text: | |||
| if language == "en": | |||
| tokens = sent.strip().split() | |||
| elif language == "zh": | |||
| tokens = [char for char in sent] | |||
| else: | |||
| raise RuntimeError("Unknown language {}".format(language)) | |||
| data.append(tokens) | |||
| return data | |||
| def _make_output(self, results, infer_input): | |||
| """Transform the infer output into user-friendly output. | |||
| :param results: 1 or 2-D list of strings. | |||
| If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length] | |||
| If self.infer_type == "text_class", it is of shape [num_examples] | |||
| :param infer_input: 2-D list of string, the input query before inference. | |||
| :return outputs: list. Each entry is a prediction. | |||
| """ | |||
| if self.infer_type == "seq_label": | |||
| outputs = make_seq_label_output(results, infer_input) | |||
| elif self.infer_type == "text_class": | |||
| outputs = make_class_output(results, infer_input) | |||
| else: | |||
| raise ValueError("fail to make outputs with infer type {}".format(self.infer_type)) | |||
| raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||
| return outputs | |||
| def make_seq_label_output(result, infer_input): | |||
| """ | |||
| Transform model output into user-friendly contents. | |||
| :param result: 1-D list of strings. (model output) | |||
| """Transform model output into user-friendly contents. | |||
| :param result: 2-D list of strings. (model output) | |||
| :param infer_input: 2-D list of string (model input) | |||
| :return outputs: | |||
| :return ret: list of list of tuples | |||
| [ | |||
| [(word_11, label_11), (word_12, label_12), ...], | |||
| [(word_21, label_21), (word_22, label_22), ...], | |||
| ... | |||
| ] | |||
| """ | |||
| return result | |||
| ret = [] | |||
| for example_x, example_y in zip(infer_input, result): | |||
| ret.append([(x, y) for x, y in zip(example_x, example_y)]) | |||
| return ret | |||
| def make_class_output(result, infer_input): | |||
| """Transform model output into user-friendly contents. | |||
| :param result: 2-D list of strings. (model output) | |||
| :param infer_input: 1-D list of string (model input) | |||
| :return ret: the same as result, [label_1, label_2, ...] | |||
| """ | |||
| return result | |||
| def interpret_word_seg_results(infer_input, results): | |||
| """ | |||
| Transform model output into user-friendly contents. | |||
| def interpret_word_seg_results(char_seq, label_seq): | |||
| """Transform model output into user-friendly contents. | |||
| Example: In CWS, convert <BMES> labeling into segmented text. | |||
| :param results: list of strings. (model output) | |||
| :param infer_input: 2-D list of string (model input) | |||
| :return output: list of strings | |||
| :param char_seq: list of string, | |||
| :param label_seq: list of string, the same length as char_seq | |||
| Each entry is one of ('B', 'M', 'E', 'S'). | |||
| :return output: list of words | |||
| """ | |||
| outputs = [] | |||
| for sent_char, sent_label in zip(infer_input, results): | |||
| words = [] | |||
| word = "" | |||
| for char, label in zip(sent_char, sent_label): | |||
| if label[0] == "B": | |||
| if word != "": | |||
| words.append(word) | |||
| word = char | |||
| elif label[0] == "M": | |||
| word += char | |||
| elif label[0] == "E": | |||
| word += char | |||
| words = [] | |||
| word = "" | |||
| for char, label in zip(char_seq, label_seq): | |||
| if label[0] == "B": | |||
| if word != "": | |||
| words.append(word) | |||
| word = "" | |||
| elif label[0] == "S": | |||
| if word != "": | |||
| words.append(word) | |||
| word = "" | |||
| words.append(char) | |||
| else: | |||
| raise ValueError("invalid label") | |||
| outputs.append(" ".join(words)) | |||
| return outputs | |||
| word = char | |||
| elif label[0] == "M": | |||
| word += char | |||
| elif label[0] == "E": | |||
| word += char | |||
| words.append(word) | |||
| word = "" | |||
| elif label[0] == "S": | |||
| if word != "": | |||
| words.append(word) | |||
| word = "" | |||
| words.append(char) | |||
| else: | |||
| raise ValueError("invalid label {}".format(label[0])) | |||
| return words | |||
| @@ -1,26 +1,26 @@ | |||
| import sys, os | |||
| import os | |||
| import sys | |||
| sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||
| from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| from fastNLP.core.inference import SeqLabelInfer | |||
| from fastNLP.core.optimizer import SGD | |||
| from fastNLP.core.predictor import SeqLabelInfer | |||
| # not in the file's dir | |||
| if len(os.path.dirname(__file__)) != 0: | |||
| os.chdir(os.path.dirname(__file__)) | |||
| datadir = 'icwb2-data' | |||
| cfgfile = 'cws.cfg' | |||
| datadir = "/home/zyfeng/data/" | |||
| cfgfile = './cws.cfg' | |||
| data_name = "pku_training.utf8" | |||
| cws_data_path = os.path.join(datadir, "training/pku_training.utf8") | |||
| cws_data_path = os.path.join(datadir, "pku_training.utf8") | |||
| pickle_path = "save" | |||
| data_infer_path = os.path.join(datadir, "infer.utf8") | |||
| @@ -70,12 +70,13 @@ def train(): | |||
| train_data = loader.load_pku() | |||
| # Preprocessor | |||
| p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3) | |||
| train_args["vocab_size"] = p.vocab_size | |||
| train_args["num_classes"] = p.num_classes | |||
| preprocessor = SeqLabelPreprocess() | |||
| data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||
| train_args["vocab_size"] = preprocessor.vocab_size | |||
| train_args["num_classes"] = preprocessor.num_classes | |||
| # Trainer | |||
| trainer = SeqLabelTrainer(train_args) | |||
| trainer = SeqLabelTrainer(**train_args.data) | |||
| # Model | |||
| model = AdvSeqLabel(train_args) | |||
| @@ -83,10 +84,11 @@ def train(): | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| print('model parameter loaded!') | |||
| except Exception as e: | |||
| print("No saved model. Continue.") | |||
| pass | |||
| # Start training | |||
| trainer.train(model) | |||
| trainer.train(model, data_train, data_dev) | |||
| print("Training finished!") | |||
| # Saver | |||
| @@ -106,6 +108,9 @@ def test(): | |||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||
| test_args["num_classes"] = len(index2label) | |||
| # load dev data | |||
| dev_data = load_pickle(pickle_path, "data_dev.pkl") | |||
| # Define the same model | |||
| model = AdvSeqLabel(test_args) | |||
| @@ -114,10 +119,10 @@ def test(): | |||
| print("model loaded!") | |||
| # Tester | |||
| tester = SeqLabelTester(test_args) | |||
| tester = SeqLabelTester(**test_args.data) | |||
| # Start testing | |||
| tester.test(model) | |||
| tester.test(model, dev_data) | |||
| # print test results | |||
| print(tester.show_matrices()) | |||
| @@ -123,7 +123,7 @@ def train_and_test(): | |||
| tester = SeqLabelTester(save_output=False, | |||
| save_loss=False, | |||
| save_best_dev=False, | |||
| batch_size=8, | |||
| batch_size=4, | |||
| use_cuda=False, | |||
| pickle_path=pickle_path, | |||
| model_name="seq_label_in_test.pkl", | |||
| @@ -140,4 +140,4 @@ def train_and_test(): | |||
| if __name__ == "__main__": | |||
| train_and_test() | |||
| infer() | |||
| # infer() | |||
| @@ -1,13 +1,24 @@ | |||
| import sys | |||
| sys.path.append("..") | |||
| from fastNLP.fastnlp import FastNLP | |||
| from fastNLP.fastnlp import interpret_word_seg_results | |||
| PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/" | |||
| def word_seg(): | |||
| nlp = FastNLP("./data_for_tests/") | |||
| nlp.load("seq_label_model") | |||
| text = "这是最好的基于深度学习的中文分词系统。" | |||
| result = nlp.run(text) | |||
| print(result) | |||
| print("FastNLP finished!") | |||
| nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES) | |||
| nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test") | |||
| text = ["这是最好的基于深度学习的中文分词系统。", | |||
| "大王叫我来巡山。", | |||
| "我党多年来致力于改善人民生活水平。"] | |||
| results = nlp.run(text) | |||
| print(results) | |||
| for example in results: | |||
| words, labels = [], [] | |||
| for res in example: | |||
| words.append(res[0]) | |||
| labels.append(res[1]) | |||
| print(interpret_word_seg_results(words, labels)) | |||
| def text_class(): | |||
| @@ -19,5 +30,14 @@ def text_class(): | |||
| print("FastNLP finished!") | |||
| def test_word_seg_interpret(): | |||
| foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), | |||
| ('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), | |||
| ('。', 'S')]] | |||
| chars = [x[0] for x in foo[0]] | |||
| labels = [x[1] for x in foo[0]] | |||
| print(interpret_word_seg_results(chars, labels)) | |||
| if __name__ == "__main__": | |||
| text_class() | |||
| word_seg() | |||