| @@ -5,7 +5,7 @@ from fastNLP.core.preprocess import ClassPreprocess | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||
| from fastNLP.models.base_model import BaseModel | |||
| from fastNLP.modules import aggregation | |||
| from fastNLP.modules import aggregator | |||
| from fastNLP.modules import decoder | |||
| from fastNLP.modules import encoder | |||
| @@ -21,7 +21,7 @@ class ClassificationModel(BaseModel): | |||
| self.emb = encoder.Embedding(nums=vocab_size, dims=300) | |||
| self.enc = encoder.Conv( | |||
| in_channels=300, out_channels=100, kernel_size=3) | |||
| self.agg = aggregation.MaxPool() | |||
| self.agg = aggregator.MaxPool() | |||
| self.dec = decoder.MLP(size_layer=[100, num_classes]) | |||
| def forward(self, x): | |||
| @@ -2,10 +2,6 @@ from collections import defaultdict | |||
| import torch | |||
| from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.field import TextField, LabelField | |||
| from fastNLP.core.instance import Instance | |||
| class Batch(object): | |||
| """Batch is an iterable object which iterates over mini-batches. | |||
| @@ -16,6 +12,14 @@ class Batch(object): | |||
| """ | |||
| def __init__(self, dataset, batch_size, sampler, use_cuda): | |||
| """ | |||
| :param dataset: a DataSet object | |||
| :param batch_size: int, the size of the batch | |||
| :param sampler: a Sampler object | |||
| :param use_cuda: bool, whetjher to use GPU | |||
| """ | |||
| self.dataset = dataset | |||
| self.batch_size = batch_size | |||
| self.sampler = sampler | |||
| @@ -81,46 +85,3 @@ class Batch(object): | |||
| self.curidx += endidx | |||
| return batch_x, batch_y | |||
| if __name__ == "__main__": | |||
| """simple running example | |||
| """ | |||
| texts = ["i am a cat", | |||
| "this is a test of new batch", | |||
| "haha" | |||
| ] | |||
| labels = [0, 1, 0] | |||
| # prepare vocabulary | |||
| vocab = {} | |||
| for text in texts: | |||
| for tokens in text.split(): | |||
| if tokens not in vocab: | |||
| vocab[tokens] = len(vocab) | |||
| print("vocabulary: ", vocab) | |||
| # prepare input dataset | |||
| data = DataSet() | |||
| for text, label in zip(texts, labels): | |||
| x = TextField(text.split(), False) | |||
| y = LabelField(label, is_target=True) | |||
| ins = Instance(text=x, label=y) | |||
| data.append(ins) | |||
| # use vocabulary to index data | |||
| data.index_field("text", vocab) | |||
| # define naive sampler for batch class | |||
| class SeqSampler: | |||
| def __call__(self, dataset): | |||
| return list(range(len(dataset))) | |||
| # use batch to iterate dataset | |||
| data_iterator = Batch(data, 2, SeqSampler(), False) | |||
| for epoch in range(1): | |||
| for batch_x, batch_y in data_iterator: | |||
| print(batch_x) | |||
| print(batch_y) | |||
| # do stuff | |||
| @@ -1,10 +1,10 @@ | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.action import SequentialSampler | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.dataset import create_dataset_from_lists | |||
| from fastNLP.core.preprocess import load_pickle | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| class Predictor(object): | |||
| @@ -62,9 +62,13 @@ class Predictor(object): | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| y = network(**x) | |||
| if self._task == "seq_label": | |||
| y = network(x["word_seq"], x["word_seq_origin_len"]) | |||
| y = network.prediction(y) | |||
| elif self._task == "text_classify": | |||
| y = network(x["word_seq"]) | |||
| else: | |||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||
| return y | |||
| def prepare_input(self, data): | |||
| @@ -52,21 +52,28 @@ def pickle_exist(pickle_path, pickle_name): | |||
| return False | |||
| class BasePreprocess(object): | |||
| """Base class of all preprocessors. | |||
| Preprocessors are responsible for converting data of strings into data of indices. | |||
| class Preprocessor(object): | |||
| """Preprocessors are responsible for converting data of strings into data of indices. | |||
| During the pre-processing, the following pickle files will be built: | |||
| - "word2id.pkl", a mapping from words(tokens) to indices | |||
| - "id2word.pkl", a reversed dictionary | |||
| - "word2id.pkl", a Vocabulary object, mapping words to indices. | |||
| - "class2id.pkl", a Vocabulary object, mapping labels to indices. | |||
| - "data_train.pkl", a DataSet object for training | |||
| - "data_dev.pkl", a DataSet object for validation, if train_dev_split > 0. | |||
| - "data_test.pkl", a DataSet object for testing, if test_data is not None. | |||
| These four pickle files are expected to be saved in the given pickle directory once they are constructed. | |||
| Preprocessors will check if those files are already in the directory and will reuse them in future calls. | |||
| """ | |||
| def __init__(self): | |||
| def __init__(self, label_is_seq=False): | |||
| """ | |||
| :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | |||
| several special tokens for sequence processing. | |||
| """ | |||
| self.data_vocab = Vocabulary() | |||
| self.label_vocab = Vocabulary() | |||
| self.label_vocab = Vocabulary(need_default=label_is_seq) | |||
| @property | |||
| def vocab_size(self): | |||
| @@ -259,20 +266,20 @@ class BasePreprocess(object): | |||
| return data_set | |||
| class SeqLabelPreprocess(BasePreprocess): | |||
| class SeqLabelPreprocess(Preprocessor): | |||
| def __init__(self): | |||
| print("[FastNLP warning] SeqLabelPreprocess is about to deprecate. Please use Preprocess directly.") | |||
| super(SeqLabelPreprocess, self).__init__() | |||
| class ClassPreprocess(BasePreprocess): | |||
| class ClassPreprocess(Preprocessor): | |||
| def __init__(self): | |||
| print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | |||
| super(ClassPreprocess, self).__init__() | |||
| if __name__ == "__main__": | |||
| p = BasePreprocess() | |||
| p = Preprocessor() | |||
| train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||
| [["You", "are", "pretty", "."], "1"] | |||
| ] | |||
| @@ -1,5 +1,3 @@ | |||
| from collections import Counter | |||
| import numpy as np | |||
| import torch | |||
| @@ -17,6 +15,56 @@ def convert_to_torch_tensor(data_list, use_cuda): | |||
| return data_list | |||
| class BaseSampler(object): | |||
| """The base class of all samplers. | |||
| Sub-classes must implement the __call__ method. | |||
| __call__ takes a DataSet object and returns a list of int - the sampling indices. | |||
| """ | |||
| def __call__(self, *args, **kwargs): | |||
| raise NotImplementedError | |||
| class SequentialSampler(BaseSampler): | |||
| """Sample data in the original order. | |||
| """ | |||
| def __call__(self, data_set): | |||
| return list(range(len(data_set))) | |||
| class RandomSampler(BaseSampler): | |||
| """Sample data in random permutation order. | |||
| """ | |||
| def __call__(self, data_set): | |||
| return list(np.random.permutation(len(data_set))) | |||
| def simple_sort_bucketing(lengths): | |||
| """ | |||
| :param lengths: list of int, the lengths of all examples. | |||
| :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||
| threshold for each bucket (This is usually None.). | |||
| :return data: 2-level list | |||
| :: | |||
| [ | |||
| [index_11, index_12, ...], # bucket 1 | |||
| [index_21, index_22, ...], # bucket 2 | |||
| ... | |||
| ] | |||
| """ | |||
| lengths_mapping = [(idx, length) for idx, length in enumerate(lengths)] | |||
| sorted_lengths = sorted(lengths_mapping, key=lambda x: x[1]) | |||
| # TODO: need to return buckets | |||
| return [idx for idx, _ in sorted_lengths] | |||
| def k_means_1d(x, k, max_iter=100): | |||
| """Perform k-means on 1-D data. | |||
| @@ -46,18 +94,10 @@ def k_means_1d(x, k, max_iter=100): | |||
| return np.array(centroids), assign | |||
| def k_means_bucketing(all_inst, buckets): | |||
| def k_means_bucketing(lengths, buckets): | |||
| """Assign all instances into possible buckets using k-means, such that instances in the same bucket have similar lengths. | |||
| :param all_inst: 3-level list | |||
| E.g. :: | |||
| [ | |||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||
| ... | |||
| ] | |||
| :param lengths: list of int, the length of all samples. | |||
| :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||
| threshold for each bucket (This is usually None.). | |||
| :return data: 2-level list | |||
| @@ -72,7 +112,6 @@ def k_means_bucketing(all_inst, buckets): | |||
| """ | |||
| bucket_data = [[] for _ in buckets] | |||
| num_buckets = len(buckets) | |||
| lengths = np.array([len(inst[0]) for inst in all_inst]) | |||
| _, assignments = k_means_1d(lengths, num_buckets) | |||
| for idx, bucket_id in enumerate(assignments): | |||
| @@ -81,102 +120,33 @@ def k_means_bucketing(all_inst, buckets): | |||
| return bucket_data | |||
| class BaseSampler(object): | |||
| """The base class of all samplers. | |||
| """ | |||
| def __call__(self, *args, **kwargs): | |||
| raise NotImplementedError | |||
| class SequentialSampler(BaseSampler): | |||
| """Sample data in the original order. | |||
| """ | |||
| def __call__(self, data_set): | |||
| return list(range(len(data_set))) | |||
| class RandomSampler(BaseSampler): | |||
| """Sample data in random permutation order. | |||
| """ | |||
| def __call__(self, data_set): | |||
| return list(np.random.permutation(len(data_set))) | |||
| class Batchifier(object): | |||
| """Wrap random or sequential sampler to generate a mini-batch. | |||
| """ | |||
| def __init__(self, sampler, batch_size, drop_last=True): | |||
| """ | |||
| :param sampler: a Sampler object | |||
| :param batch_size: int, the size of the mini-batch | |||
| :param drop_last: bool, whether to drop the last examples that are not enough to make a mini-batch. | |||
| """ | |||
| super(Batchifier, self).__init__() | |||
| self.sampler = sampler | |||
| self.batch_size = batch_size | |||
| self.drop_last = drop_last | |||
| def __iter__(self): | |||
| batch = [] | |||
| for example in self.sampler: | |||
| batch.append(example) | |||
| if len(batch) == self.batch_size: | |||
| yield batch | |||
| batch = [] | |||
| if 0 < len(batch) < self.batch_size and self.drop_last is False: | |||
| yield batch | |||
| class BucketBatchifier(Batchifier): | |||
| class BucketSampler(BaseSampler): | |||
| """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||
| In sampling, first random choose a bucket. Then sample data from it. | |||
| The number of buckets is decided dynamically by the variance of sentence lengths. | |||
| TODO: merge it into Batch | |||
| """ | |||
| def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): | |||
| def __call__(self, data_set, batch_size, num_buckets): | |||
| return self._process(data_set, batch_size, num_buckets) | |||
| def _process(self, data_set, batch_size, num_buckets, use_kmeans=False): | |||
| """ | |||
| :param data_set: three-level list, shape [num_samples, 2] | |||
| :param data_set: a DataSet object | |||
| :param batch_size: int | |||
| :param num_buckets: int, number of buckets for grouping these sequences. | |||
| :param drop_last: bool, useless currently. | |||
| :param sampler: Sampler, useless currently. | |||
| :param use_kmeans: bool, whether to use k-means to create buckets. | |||
| """ | |||
| super(BucketBatchifier, self).__init__(sampler, batch_size, drop_last) | |||
| buckets = ([None] * num_buckets) | |||
| self.data = data_set | |||
| self.batch_size = batch_size | |||
| self.length_freq = dict(Counter([len(example) for example in data_set])) | |||
| self.buckets = k_means_bucketing(data_set, buckets) | |||
| def __iter__(self): | |||
| """Make a min-batch of data.""" | |||
| for _ in range(len(self.data) // self.batch_size): | |||
| bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] | |||
| np.random.shuffle(bucket_samples) | |||
| yield [self.data[idx] for idx in bucket_samples[:batch_size]] | |||
| if __name__ == "__main__": | |||
| import random | |||
| data = [[[y] * random.randint(0, 50), [y]] for y in range(500)] | |||
| batch_size = 8 | |||
| iterator = iter(BucketBatchifier(data, batch_size, num_buckets=5)) | |||
| for d in iterator: | |||
| print("\nbatch:") | |||
| for dd in d: | |||
| print(len(dd[0]), end=" ") | |||
| if use_kmeans is True: | |||
| buckets = k_means_bucketing(data_set, buckets) | |||
| else: | |||
| buckets = simple_sort_bucketing(data_set) | |||
| index_list = [] | |||
| for _ in range(len(data_set) // batch_size): | |||
| chosen_bucket = buckets[np.random.randint(0, len(buckets))] | |||
| np.random.shuffle(chosen_bucket) | |||
| index_list += [idx for idx in chosen_bucket[:batch_size]] | |||
| return index_list | |||
| @@ -1,32 +1,32 @@ | |||
| import numpy as np | |||
| import torch | |||
| from fastNLP.core.action import RandomSampler | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.sampler import RandomSampler | |||
| from fastNLP.saver.logger import create_logger | |||
| logger = create_logger(__name__, "./train_test.log") | |||
| class BaseTester(object): | |||
| class Tester(object): | |||
| """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ | |||
| def __init__(self, **kwargs): | |||
| """ | |||
| :param kwargs: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
| """ | |||
| super(BaseTester, self).__init__() | |||
| super(Tester, self).__init__() | |||
| """ | |||
| "default_args" provides default value for important settings. | |||
| The initialization arguments "kwargs" with the same key (name) will override the default value. | |||
| "kwargs" must have the same type as "default_args" on corresponding keys. | |||
| Otherwise, error will raise. | |||
| """ | |||
| default_args = {"save_output": False, # collect outputs of validation set | |||
| "save_loss": False, # collect losses in validation | |||
| default_args = {"save_output": True, # collect outputs of validation set | |||
| "save_loss": True, # collect losses in validation | |||
| "save_best_dev": False, # save best model during validation | |||
| "batch_size": 8, | |||
| "use_cuda": True, | |||
| "use_cuda": False, | |||
| "pickle_path": "./save/", | |||
| "model_name": "dev_best_model.pkl", | |||
| "print_every_step": 1, | |||
| @@ -55,7 +55,7 @@ class BaseTester(object): | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| else: | |||
| # BaseTester doesn't care about extra arguments | |||
| # Tester doesn't care about extra arguments | |||
| pass | |||
| print(default_args) | |||
| @@ -208,7 +208,7 @@ class BaseTester(object): | |||
| return self.show_metrics() | |||
| class SeqLabelTester(BaseTester): | |||
| class SeqLabelTester(Tester): | |||
| def __init__(self, **test_args): | |||
| test_args.update({"task": "seq_label"}) | |||
| print( | |||
| @@ -216,9 +216,9 @@ class SeqLabelTester(BaseTester): | |||
| super(SeqLabelTester, self).__init__(**test_args) | |||
| class ClassificationTester(BaseTester): | |||
| class ClassificationTester(Tester): | |||
| def __init__(self, **test_args): | |||
| test_args.update({"task": "seq_label"}) | |||
| test_args.update({"task": "text_classify"}) | |||
| print( | |||
| "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.") | |||
| super(ClassificationTester, self).__init__(**test_args) | |||
| @@ -6,10 +6,10 @@ from datetime import timedelta | |||
| import torch | |||
| from tensorboardX import SummaryWriter | |||
| from fastNLP.core.action import RandomSampler | |||
| from fastNLP.core.batch import Batch | |||
| from fastNLP.core.loss import Loss | |||
| from fastNLP.core.optimizer import Optimizer | |||
| from fastNLP.core.sampler import RandomSampler | |||
| from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||
| from fastNLP.saver.logger import create_logger | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| @@ -17,7 +17,7 @@ from fastNLP.saver.model_saver import ModelSaver | |||
| logger = create_logger(__name__, "./train_test.log") | |||
| class BaseTrainer(object): | |||
| class Trainer(object): | |||
| """Operations of training a model, including data loading, gradient descent, and validation. | |||
| """ | |||
| @@ -32,7 +32,7 @@ class BaseTrainer(object): | |||
| - batch_size: int | |||
| - pickle_path: str, the path to pickle files for pre-processing | |||
| """ | |||
| super(BaseTrainer, self).__init__() | |||
| super(Trainer, self).__init__() | |||
| """ | |||
| "default_args" provides default value for important settings. | |||
| @@ -40,8 +40,8 @@ class BaseTrainer(object): | |||
| "kwargs" must have the same type as "default_args" on corresponding keys. | |||
| Otherwise, error will raise. | |||
| """ | |||
| default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | |||
| "save_best_dev": True, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||
| default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||
| "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | |||
| "loss": Loss(None), # used to pass type check | |||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | |||
| } | |||
| @@ -69,7 +69,7 @@ class BaseTrainer(object): | |||
| logger.error(msg) | |||
| raise ValueError(msg) | |||
| else: | |||
| # BaseTrainer doesn't care about extra arguments | |||
| # Trainer doesn't care about extra arguments | |||
| pass | |||
| print(default_args) | |||
| @@ -136,6 +136,9 @@ class BaseTrainer(object): | |||
| # validation | |||
| if self.validate: | |||
| if dev_data is None: | |||
| raise RuntimeError( | |||
| "self.validate is True in trainer, but dev_data is None. Please provide the validation data.") | |||
| logger.info("validation started") | |||
| validator.test(network, dev_data) | |||
| @@ -314,7 +317,7 @@ class BaseTrainer(object): | |||
| raise NotImplementedError | |||
| class SeqLabelTrainer(BaseTrainer): | |||
| class SeqLabelTrainer(Trainer): | |||
| """Trainer for Sequence Labeling | |||
| """ | |||
| @@ -328,7 +331,7 @@ class SeqLabelTrainer(BaseTrainer): | |||
| return SeqLabelTester(**valid_args) | |||
| class ClassificationTrainer(BaseTrainer): | |||
| class ClassificationTrainer(Trainer): | |||
| """Trainer for text classification.""" | |||
| def __init__(self, **train_args): | |||
| @@ -31,7 +31,7 @@ FastNLP_MODEL_COLLECTION = { | |||
| "class": "sequence_modeling.AdvSeqLabel", | |||
| "pickle": "cws_basic_model_v_0.pkl", | |||
| "type": "seq_label", | |||
| "config_file_name": "config", | |||
| "config_file_name": "cws.cfg", | |||
| "config_section_name": "text_class_model" | |||
| }, | |||
| "pos_tag_model": { | |||
| @@ -39,7 +39,7 @@ FastNLP_MODEL_COLLECTION = { | |||
| "class": "sequence_modeling.AdvSeqLabel", | |||
| "pickle": "pos_tag_model_v_0.pkl", | |||
| "type": "seq_label", | |||
| "config_file_name": "pos_tag.config", | |||
| "config_file_name": "pos_tag.cfg", | |||
| "config_section_name": "pos_tag_model" | |||
| }, | |||
| "text_classify_model": { | |||
| @@ -56,21 +56,22 @@ FastNLP_MODEL_COLLECTION = { | |||
| class FastNLP(object): | |||
| """ | |||
| High-level interface for direct model inference. | |||
| Example Usage: | |||
| Example Usage | |||
| :: | |||
| fastnlp = FastNLP() | |||
| fastnlp.load("zh_pos_tag_model") | |||
| text = "这是最好的基于深度学习的中文分词系统。" | |||
| result = fastnlp.run(text) | |||
| print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"] | |||
| """ | |||
| def __init__(self, model_dir="./"): | |||
| """ | |||
| :param model_dir: this directory should contain the following files: | |||
| 1. a pre-trained model | |||
| 2. a config file | |||
| 3. "class2id.pkl" | |||
| 4. "word2id.pkl" | |||
| 1. a trained model | |||
| 2. a config file, which is a fastNLP's configuration. | |||
| 3. a Vocab file, which is a pickle object of a Vocab instance. | |||
| """ | |||
| self.model_dir = model_dir | |||
| self.model = None | |||
| @@ -172,9 +172,8 @@ class ClassDatasetLoader(DatasetLoader): | |||
| class ConllLoader(DatasetLoader): | |||
| """loader for conll format files""" | |||
| def __int__(self, data_name, data_path): | |||
| def __int__(self, data_path): | |||
| """ | |||
| :param str data_name: the name of the conll data set | |||
| :param str data_path: the path to the conll data set | |||
| """ | |||
| super(ConllLoader, self).__init__(data_path) | |||
| @@ -269,8 +268,3 @@ class PeopleDailyCorpusLoader(DatasetLoader): | |||
| ner_examples.append([sent_words, sent_ner]) | |||
| return pos_tag_examples, ner_examples | |||
| if __name__ == "__main__": | |||
| loader = PeopleDailyCorpusLoader("./") | |||
| pos, ner = loader.load() | |||
| print(pos[:10]) | |||
| print(ner[:10]) | |||
| @@ -1,11 +1,11 @@ | |||
| from . import aggregation | |||
| from . import aggregator | |||
| from . import decoder | |||
| from . import encoder | |||
| from . import interaction | |||
| from . import interactor | |||
| __version__ = '0.0.0' | |||
| __all__ = ['encoder', | |||
| 'decoder', | |||
| 'aggregation', | |||
| 'interaction'] | |||
| 'aggregator', | |||
| 'interactor'] | |||
| @@ -1,8 +1,7 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch.autograd import Variable | |||
| import torch.nn.functional as F | |||
| from torch.autograd import Variable | |||
| from fastNLP.modules.utils import initial_parameter | |||
| @@ -1,19 +1,10 @@ | |||
| """ | |||
| This is borrowed from FudanParser. Not stable. Do not use !!! | |||
| """ | |||
| import numpy | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import torch.utils.data | |||
| from torch import optim | |||
| from torch.autograd import Function, Variable | |||
| from torch.nn import Parameter | |||
| from .utils import orthogonal | |||
| class GroupNorm(nn.Module): | |||
| def __init__(self, num_features, num_groups=20, eps=1e-5): | |||
| @@ -59,15 +50,6 @@ class LayerNormalization(nn.Module): | |||
| return ln_out | |||
| class OrthEmbedding(nn.Embedding): | |||
| def __init__(self, *args, **kwargs): | |||
| super(OrthEmbedding, self).__init__(*args, **kwargs) | |||
| def reset_parameters(self): | |||
| self.weight = orthogonal(self.weight) | |||
| nn.init.constant_(self.bias, 0.) | |||
| class BiLinear(nn.Module): | |||
| def __init__(self, n_left, n_right, n_out, bias=True): | |||
| """ | |||
| @@ -241,250 +223,3 @@ class WordDropout(nn.Module): | |||
| drop_mask = drop_mask.long() | |||
| output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx | |||
| return output | |||
| class WlossLayer(torch.nn.Module): | |||
| def __init__(self, lam=100, sinkhorn_iter=50): | |||
| super(WlossLayer, self).__init__() | |||
| # cost = matrix M = distance matrix | |||
| # lam = lambda of type float > 0 | |||
| # sinkhorn_iter > 0 | |||
| # diagonal cost should be 0 | |||
| self.lam = lam | |||
| self.sinkhorn_iter = sinkhorn_iter | |||
| # self.register_buffer("K", torch.exp(-self.cost / self.lam).double()) | |||
| # self.register_buffer("KM", (self.cost * self.K).double()) | |||
| def forward(self, pred, target, cost): | |||
| return WassersteinLossStab.apply(pred, target, | |||
| cost, self.lam, self.sinkhorn_iter) | |||
| class WassersteinLossStab(Function): | |||
| @staticmethod | |||
| def forward(ctx, pred, target, cost, lam=1e-3, sinkhorn_iter=4): | |||
| """pred: Batch * K: K = # mass points | |||
| target: Batch * L: L = # mass points""" | |||
| # import pdb | |||
| # pdb.set_trace() | |||
| eps = 1e-8 | |||
| # pred = pred.gather(dim=1, index=) | |||
| na = pred.size(1) | |||
| nb = target.size(1) | |||
| cost = cost.double() | |||
| pred = pred.double() | |||
| target = target.double() | |||
| cost = cost[:na, :nb].double() | |||
| K = torch.exp(-cost / lam).double() | |||
| KM = (cost * K).double() | |||
| batch_size = pred.size(0) | |||
| # pdb.set_trace() | |||
| log_a, log_b = torch.log(pred + eps), torch.log(target + eps) | |||
| log_u = cost.new(batch_size, na).fill_(-numpy.log(na)) | |||
| log_v = cost.new(batch_size, nb).fill_(-numpy.log(nb)) | |||
| # import pdb | |||
| # pdb.set_trace() | |||
| for i in range(int(sinkhorn_iter)): | |||
| log_u_max = torch.max(log_u, dim=1)[0] | |||
| u_stab = torch.exp(log_u - log_u_max.unsqueeze(1) + eps) | |||
| log_v = log_b - torch.log(torch.mm(K.t(), u_stab.t()).t()) - log_u_max.unsqueeze(1) | |||
| log_v_max = torch.max(log_v, dim=1)[0] | |||
| v_stab = torch.exp(log_v - log_v_max.unsqueeze(1)) | |||
| tmp = log_u | |||
| log_u = log_a - torch.log(torch.mm(K, v_stab.t()).t() + eps) - log_v_max.unsqueeze(1) | |||
| # print(log_u.sum()) | |||
| if torch.norm(tmp - log_u) / torch.norm(log_u) < eps: | |||
| break | |||
| log_v_max = torch.max(log_v, dim=1)[0] | |||
| v_stab = torch.exp(log_v - log_v_max.unsqueeze(1)) | |||
| logcostpart1 = torch.log(torch.mm(KM, v_stab.t()).t() + eps) + log_v_max.unsqueeze(1) | |||
| wnorm = torch.exp(log_u + logcostpart1).mean(0).sum() # sum(1) for per item pair loss... | |||
| grad_input = log_u * lam | |||
| # print("log_u", log_u) | |||
| grad_input = grad_input - torch.mean(grad_input, dim=1).unsqueeze(1) | |||
| grad_input = grad_input - torch.mean(grad_input, dim=1).unsqueeze(1) | |||
| grad_input = grad_input / batch_size | |||
| ctx.save_for_backward(grad_input) | |||
| # print("grad type", type(grad_input)) | |||
| return pred.new((wnorm,)), grad_input | |||
| @staticmethod | |||
| def backward(ctx, grad_output, _): | |||
| grad_input = ctx.saved_variables | |||
| # print(grad) | |||
| res = grad_output.clone() | |||
| res.data.resize_(grad_input[0].size()).copy_(grad_input[0].data) | |||
| res = res.mul_(grad_output[0]).float() | |||
| # print("in backward func:\n\n", res) | |||
| return res, None, None, None, None, None, None | |||
| class Sinkhorn(Function): | |||
| def __init__(self): | |||
| super(Sinkhorn, self).__init__() | |||
| def forward(ctx, a, b, M, reg, tau, warmstart, numItermax, stop): | |||
| a = a.double() | |||
| b = b.double() | |||
| M = M.double() | |||
| nbb = b.size(1) | |||
| # init data | |||
| na = len(a) | |||
| nb = len(b) | |||
| cpt = 0 | |||
| # we assume that no distances are null except those of the diagonal of | |||
| # distances | |||
| if warmstart is None: | |||
| alpha, beta = np.zeros(na), np.zeros(nb) | |||
| else: | |||
| alpha, beta = warmstart | |||
| if nbb: | |||
| u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb | |||
| else: | |||
| u, v = np.ones(na) / na, np.ones(nb) / nb | |||
| def get_K(alpha, beta): | |||
| """log space computation""" | |||
| return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) | |||
| def get_Gamma(alpha, beta, u, v): | |||
| """log space gamma computation""" | |||
| return np.exp( | |||
| -(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log( | |||
| v.reshape((1, nb)))) | |||
| # print(np.min(K)) | |||
| K = get_K(alpha, beta) | |||
| transp = K | |||
| cpt = 0 | |||
| err = 1 | |||
| while 1: | |||
| uprev = u | |||
| vprev = v | |||
| # sinkhorn update | |||
| v = b / (np.dot(K.T, u) + 1e-16) | |||
| u = a / (np.dot(K, v) + 1e-16) | |||
| # remove numerical problems and store them in K | |||
| if np.abs(u).max() > tau or np.abs(v).max() > tau: | |||
| if nbb: | |||
| alpha, beta = alpha + reg * \ | |||
| np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) | |||
| else: | |||
| alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) | |||
| if nbb: | |||
| u, v = np.ones((na, nbb)) / na, np.ones((nb, nbb)) / nb | |||
| else: | |||
| u, v = np.ones(na) / na, np.ones(nb) / nb | |||
| K = get_K(alpha, beta) | |||
| if cpt % print_period == 0: | |||
| # we can speed up the process by checking for the error only all | |||
| # the 10th iterations | |||
| if nbb: | |||
| err = np.sum((u - uprev) ** 2) / np.sum((u) ** 2) + \ | |||
| np.sum((v - vprev) ** 2) / np.sum((v) ** 2) | |||
| else: | |||
| transp = get_Gamma(alpha, beta, u, v) | |||
| err = np.linalg.norm((np.sum(transp, axis=0) - b)) ** 2 | |||
| if log: | |||
| log['err'].append(err) | |||
| if verbose: | |||
| if cpt % (print_period * 20) == 0: | |||
| print( | |||
| '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) | |||
| print('{:5d}|{:8e}|'.format(cpt, err)) | |||
| if err <= stopThr: | |||
| loop = False | |||
| if cpt >= numItermax: | |||
| loop = False | |||
| if np.any(np.isnan(u)) or np.any(np.isnan(v)): | |||
| # we have reached the machine precision | |||
| # come back to previous solution and quit loop | |||
| print('Warning: numerical errors at iteration', cpt) | |||
| u = uprev | |||
| v = vprev | |||
| break | |||
| cpt = cpt + 1 | |||
| # print('err=',err,' cpt=',cpt) | |||
| if log: | |||
| log['logu'] = alpha / reg + np.log(u) | |||
| log['logv'] = beta / reg + np.log(v) | |||
| log['alpha'] = alpha + reg * np.log(u) | |||
| log['beta'] = beta + reg * np.log(v) | |||
| log['warmstart'] = (log['alpha'], log['beta']) | |||
| if nbb: | |||
| res = np.zeros((nbb)) | |||
| for i in range(nbb): | |||
| res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) | |||
| return res, log | |||
| else: | |||
| return get_Gamma(alpha, beta, u, v), log | |||
| else: | |||
| if nbb: | |||
| res = np.zeros((nbb)) | |||
| for i in range(nbb): | |||
| res[i] = np.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) | |||
| return res | |||
| else: | |||
| return get_Gamma(alpha, beta, u, v) | |||
| if __name__ == "__main__": | |||
| cost = (torch.Tensor(2, 2).fill_(1) - torch.diag(torch.Tensor(2).fill_(1))) # .cuda() | |||
| mylayer = WlossLayer(cost) # .cuda() | |||
| inp = Variable(torch.Tensor([[1, 0], [0.5, 0.5]]), requires_grad=True) # .cuda() | |||
| ground_true = Variable(torch.Tensor([[0, 1], [0.5, 0.5]])) # .cuda() | |||
| res, _ = mylayer(inp, ground_true) | |||
| # print(inp.requires_grad, res.requires_grad) | |||
| # print(res, inp) | |||
| mylayer.zero_grad() | |||
| res.backward() | |||
| print("inp's gradient is good:") | |||
| print(inp.grad) | |||
| print("convert to gpu:\n", inp.cuda().grad) | |||
| print("==============================================" | |||
| "\n However, this does not work on pytorch when GPU is enabled") | |||
| cost = (torch.Tensor(2, 2).fill_(1) - torch.diag(torch.Tensor(2).fill_(1))).cuda() | |||
| mylayer = WlossLayer(cost).cuda() | |||
| inp = Variable(torch.Tensor([[1, 0], [0.5, 0.5]]), requires_grad=True).cuda() | |||
| ground_true = Variable(torch.Tensor([[0, 1], [0.5, 0.5]])).cuda() | |||
| opt = optim.SGD([ | |||
| {'params': mylayer.parameters()}, | |||
| ], lr=1e-2, momentum=0.9) | |||
| res, _ = mylayer(inp, ground_true) | |||
| # print(inp.requires_grad, res.requires_grad) | |||
| # print(res, inp) | |||
| mylayer.zero_grad() | |||
| res.backward() | |||
| print("input's gradient is None!!!!!!!!!!!!!!!!") | |||
| print(inp.grad) | |||
| @@ -1,9 +1,8 @@ | |||
| from collections import defaultdict | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.init as init | |||
| import torch.nn as nn | |||
| import torch.nn.init as init | |||
| def mask_softmax(matrix, mask): | |||
| if mask is None: | |||
| result = torch.nn.functional.softmax(matrix, dim=-1) | |||
| @@ -11,13 +10,28 @@ def mask_softmax(matrix, mask): | |||
| raise NotImplementedError | |||
| return result | |||
| def initial_parameter(net ,initial_method =None): | |||
| def initial_parameter(net, initial_method=None): | |||
| """A method used to initialize the weights of PyTorch models. | |||
| :param net: a PyTorch model | |||
| :param initial_method: str, one of the following initializations | |||
| - xavier_uniform | |||
| - xavier_normal (default) | |||
| - kaiming_normal, or msra | |||
| - kaiming_uniform | |||
| - orthogonal | |||
| - sparse | |||
| - normal | |||
| - uniform | |||
| """ | |||
| if initial_method == 'xavier_uniform': | |||
| init_method = init.xavier_uniform_ | |||
| elif initial_method=='xavier_normal': | |||
| elif initial_method == 'xavier_normal': | |||
| init_method = init.xavier_normal_ | |||
| elif initial_method == 'kaiming_normal' or initial_method =='msra': | |||
| elif initial_method == 'kaiming_normal' or initial_method == 'msra': | |||
| init_method = init.kaiming_normal | |||
| elif initial_method == 'kaiming_uniform': | |||
| init_method = init.kaiming_normal | |||
| @@ -25,263 +39,49 @@ def initial_parameter(net ,initial_method =None): | |||
| init_method = init.orthogonal_ | |||
| elif initial_method == 'sparse': | |||
| init_method = init.sparse_ | |||
| elif initial_method =='normal': | |||
| elif initial_method == 'normal': | |||
| init_method = init.normal_ | |||
| elif initial_method =='uniform': | |||
| elif initial_method == 'uniform': | |||
| initial_method = init.uniform_ | |||
| else: | |||
| init_method = init.xavier_normal_ | |||
| def weights_init(m): | |||
| # classname = m.__class__.__name__ | |||
| if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn | |||
| if initial_method != None: | |||
| if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn | |||
| if initial_method is not None: | |||
| init_method(m.weight.data) | |||
| else: | |||
| init.xavier_normal_(m.weight.data) | |||
| init.normal_(m.bias.data) | |||
| elif isinstance(m, nn.LSTM): | |||
| for w in m.parameters(): | |||
| if len(w.data.size())>1: | |||
| if len(w.data.size()) > 1: | |||
| init_method(w.data) # weight | |||
| else: | |||
| init.normal_(w.data) # bias | |||
| elif hasattr(m, 'weight') and m.weight.requires_grad: | |||
| init_method(m.weight.data) | |||
| else: | |||
| for w in m.parameters() : | |||
| if w.requires_grad: | |||
| if len(w.data.size())>1: | |||
| for w in m.parameters(): | |||
| if w.requires_grad: | |||
| if len(w.data.size()) > 1: | |||
| init_method(w.data) # weight | |||
| else: | |||
| init.normal_(w.data) # bias | |||
| # print("init else") | |||
| net.apply(weights_init) | |||
| def seq_mask(seq_len, max_len): | |||
| mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||
| mask = torch.stack(mask, 1) | |||
| return mask | |||
| """ | |||
| Codes from FudanParser. Not tested. Do not use !!! | |||
| """ | |||
| def expand_gt(gt): | |||
| """expand_gt: Expand ground truth to matrix | |||
| Arguments: | |||
| gt: tensor of (n, l) | |||
| Return: | |||
| f: ground truth matrix of (n, l), $gt[i][j] = k$ leads to $f[i][j][k] = 1$. | |||
| """ | |||
| n, l = gt.shape | |||
| ret = torch.zeros(n, l, l).long() | |||
| for i in range(n): | |||
| ret[i][torch.arange(l).long(), gt[i]] = 1 | |||
| return ret | |||
| def greedy_decoding(arc_f): | |||
| """greedy_decoding | |||
| Arguments: | |||
| arc_f: a tensor in shape of (n, l+1, l+1) | |||
| length of the sentence is l and index 0 is <root> | |||
| Output: | |||
| arc_pred: a tensor in shape of (n, l), indicating the head words | |||
| """ | |||
| f_arc = arc_f[:, 1:, :] # ignore the root | |||
| _, arc_pred = torch.max(f_arc.data, dim=-1, keepdim=False) | |||
| return arc_pred | |||
| def mst_decoding(arc_f): | |||
| batch_size = arc_f.shape[0] | |||
| length = arc_f.shape[1] | |||
| arc_score = arc_f.data.cpu() | |||
| pred_collection = [] | |||
| for i in range(batch_size): | |||
| head = mst(arc_score[i].numpy()) | |||
| pred_collection.append(head[1:].reshape((1, length - 1))) | |||
| arc_pred = torch.LongTensor(np.concatenate(pred_collection, axis=0)).type_as(arc_f).long() | |||
| return arc_pred | |||
| def outer_product(features): | |||
| """InterProduct: Get inter sequence product of features | |||
| Arguments: | |||
| features: feature vectors of sequence in the shape of (n, l, h) | |||
| Return: | |||
| f: product result in (n, l, l, h) shape | |||
| """ | |||
| n, l, c = features.shape | |||
| features = features.contiguous() | |||
| x = features.view(n, l, 1, c) | |||
| x = x.expand(n, l, l, c) | |||
| y = features.view(n, 1, l, c).contiguous() | |||
| y = y.expand(n, l, l, c) | |||
| return x * y | |||
| def outer_concat(features): | |||
| """InterProduct: Get inter sequence concatenation of features | |||
| Arguments: | |||
| features: feature vectors of sequence in the shape of (n, l, h) | |||
| Return: | |||
| f: product result in (n, l, l, h) shape | |||
| """ | |||
| n, l, c = features.shape | |||
| x = features.contiguous().view(n, l, 1, c) | |||
| x = x.expand(n, l, l, c) | |||
| y = features.view(n, 1, l, c) | |||
| y = y.expand(n, l, l, c) | |||
| return torch.cat((x, y), dim=3) | |||
| def mst(scores): | |||
| """ | |||
| https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 # NOQA | |||
| """ | |||
| length = scores.shape[0] | |||
| min_score = scores.min() - 1 | |||
| eye = np.eye(length) | |||
| scores = scores * (1 - eye) + min_score * eye | |||
| heads = np.argmax(scores, axis=1) | |||
| heads[0] = 0 | |||
| tokens = np.arange(1, length) | |||
| roots = np.where(heads[tokens] == 0)[0] + 1 | |||
| if len(roots) < 1: | |||
| root_scores = scores[tokens, 0] | |||
| head_scores = scores[tokens, heads[tokens]] | |||
| new_root = tokens[np.argmax(root_scores / head_scores)] | |||
| heads[new_root] = 0 | |||
| elif len(roots) > 1: | |||
| root_scores = scores[roots, 0] | |||
| scores[roots, 0] = 0 | |||
| new_heads = np.argmax(scores[roots][:, tokens], axis=1) + 1 | |||
| new_root = roots[np.argmin( | |||
| scores[roots, new_heads] / root_scores)] | |||
| heads[roots] = new_heads | |||
| heads[new_root] = 0 | |||
| edges = defaultdict(set) | |||
| vertices = set((0,)) | |||
| for dep, head in enumerate(heads[tokens]): | |||
| vertices.add(dep + 1) | |||
| edges[head].add(dep + 1) | |||
| for cycle in _find_cycle(vertices, edges): | |||
| dependents = set() | |||
| to_visit = set(cycle) | |||
| while len(to_visit) > 0: | |||
| node = to_visit.pop() | |||
| if node not in dependents: | |||
| dependents.add(node) | |||
| to_visit.update(edges[node]) | |||
| cycle = np.array(list(cycle)) | |||
| old_heads = heads[cycle] | |||
| old_scores = scores[cycle, old_heads] | |||
| non_heads = np.array(list(dependents)) | |||
| scores[np.repeat(cycle, len(non_heads)), | |||
| np.repeat([non_heads], len(cycle), axis=0).flatten()] = min_score | |||
| new_heads = np.argmax(scores[cycle][:, tokens], axis=1) + 1 | |||
| new_scores = scores[cycle, new_heads] / old_scores | |||
| change = np.argmax(new_scores) | |||
| changed_cycle = cycle[change] | |||
| old_head = old_heads[change] | |||
| new_head = new_heads[change] | |||
| heads[changed_cycle] = new_head | |||
| edges[new_head].add(changed_cycle) | |||
| edges[old_head].remove(changed_cycle) | |||
| return heads | |||
| def _find_cycle(vertices, edges): | |||
| """ | |||
| https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm # NOQA | |||
| https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/etc/tarjan.py # NOQA | |||
| """ | |||
| _index = 0 | |||
| _stack = [] | |||
| _indices = {} | |||
| _lowlinks = {} | |||
| _onstack = defaultdict(lambda: False) | |||
| _SCCs = [] | |||
| def _strongconnect(v): | |||
| nonlocal _index | |||
| _indices[v] = _index | |||
| _lowlinks[v] = _index | |||
| _index += 1 | |||
| _stack.append(v) | |||
| _onstack[v] = True | |||
| for w in edges[v]: | |||
| if w not in _indices: | |||
| _strongconnect(w) | |||
| _lowlinks[v] = min(_lowlinks[v], _lowlinks[w]) | |||
| elif _onstack[w]: | |||
| _lowlinks[v] = min(_lowlinks[v], _indices[w]) | |||
| if _lowlinks[v] == _indices[v]: | |||
| SCC = set() | |||
| while True: | |||
| w = _stack.pop() | |||
| _onstack[w] = False | |||
| SCC.add(w) | |||
| if not (w != v): | |||
| break | |||
| _SCCs.append(SCC) | |||
| net.apply(weights_init) | |||
| for v in vertices: | |||
| if v not in _indices: | |||
| _strongconnect(v) | |||
| return [SCC for SCC in _SCCs if len(SCC) > 1] | |||
| def seq_mask(seq_len, max_len): | |||
| """Create sequence mask. | |||
| :param seq_len: list of int, the lengths of sequences in a batch. | |||
| :param max_len: int, the maximum sequence length in a batch. | |||
| :return mask: torch.LongTensor, [batch_size, max_len] | |||
| # https://github.com/alykhantejani/nninit/blob/master/nninit.py | |||
| def orthogonal(tensor, gain=1): | |||
| """Fills the input Tensor or Variable with a (semi) orthogonal matrix. The input tensor must have at least 2 dimensions, | |||
| and for tensors with more than 2 dimensions the trailing dimensions are flattened. viewed as 2D representation with | |||
| rows equal to the first dimension and columns equal to the product of as a sparse matrix, where the non-zero elements | |||
| will be drawn from a normal distribution with mean=0 and std=`std`. | |||
| Reference: "Exact solutions to the nonlinear dynamics of learning in deep linear neural networks" - Saxe, A. et al. | |||
| Args: | |||
| tensor: a n-dimension torch.Tensor, where n >= 2 | |||
| gain: optional gain to be applied | |||
| Examples: | |||
| >>> w = torch.Tensor(3, 5) | |||
| >>> nninit.orthogonal(w) | |||
| """ | |||
| if tensor.ndimension() < 2: | |||
| raise ValueError("Only tensors with 2 or more dimensions are supported.") | |||
| flattened_shape = (tensor.size(0), int(np.prod(tensor.detach().numpy().shape[1:]))) | |||
| flattened = torch.Tensor(flattened_shape[0], flattened_shape[1]).normal_(0, 1) | |||
| u, s, v = np.linalg.svd(flattened.numpy(), full_matrices=False) | |||
| if u.shape == flattened.detach().numpy().shape: | |||
| tensor.view_as(flattened).copy_(torch.from_numpy(u)) | |||
| else: | |||
| tensor.view_as(flattened).copy_(torch.from_numpy(v)) | |||
| tensor.mul_(gain) | |||
| with torch.no_grad(): | |||
| return tensor | |||
| def generate_step_dropout(masks, hidden_dim, step_dropout, training=False): | |||
| # assume batch first | |||
| # import pdb | |||
| # pdb.set_trace() | |||
| batch, length = masks.size() | |||
| if not training: | |||
| return torch.ones(batch, length, hidden_dim).fill_(1 - step_dropout).cuda(masks.device) * masks.view(batch, | |||
| length, 1) | |||
| masked = torch.zeros(batch, 1, hidden_dim).fill_(step_dropout) | |||
| masked = torch.bernoulli(masked).repeat(1, length, 1) | |||
| masked = masked.cuda(masks.device) * masks.view(batch, length, 1) | |||
| return masked | |||
| mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | |||
| mask = torch.stack(mask, 1) | |||
| return mask | |||
| @@ -2,16 +2,23 @@ import torch | |||
| class ModelSaver(object): | |||
| """Save a models""" | |||
| """Save a model | |||
| Example:: | |||
| saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
| saver.save_pytorch(model) | |||
| """ | |||
| def __init__(self, save_path): | |||
| """ | |||
| :param save_path: str, the path to the saving directory. | |||
| """ | |||
| self.save_path = save_path | |||
| # TODO: check whether the path exist, if not exist, create it. | |||
| def save_pytorch(self, model): | |||
| """ | |||
| Save a pytorch model into .pkl file. | |||
| """Save a pytorch model into .pkl file. | |||
| :param model: a PyTorch model | |||
| :return: | |||
| """ | |||
| torch.save(model.state_dict(), self.save_path) | |||
| @@ -1,23 +1,15 @@ | |||
| import os | |||
| import torch.nn.functional as F | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader | |||
| from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader | |||
| from fastNLP.loader.config_loader import ConfigSection | |||
| from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.loader.config_loader import ConfigLoader | |||
| from fastNLP.loader.config_loader import ConfigSection | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader | |||
| from fastNLP.models.base_model import BaseModel | |||
| from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.modules.aggregator.self_attention import SelfAttention | |||
| from fastNLP.modules.decoder.MLP import MLP | |||
| from fastNLP.modules.encoder.embedding import Embedding as Embedding | |||
| from fastNLP.modules.encoder.lstm import Lstm | |||
| from fastNLP.modules.aggregation.self_attention import SelfAttention | |||
| from fastNLP.modules.decoder.MLP import MLP | |||
| train_data_path = 'small_train_data.txt' | |||
| dev_data_path = 'small_dev_data.txt' | |||
| @@ -0,0 +1,30 @@ | |||
| import torch | |||
| from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler | |||
| def test_convert_to_torch_tensor(): | |||
| data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||
| ans = convert_to_torch_tensor(data, False) | |||
| assert isinstance(ans, torch.Tensor) | |||
| assert tuple(ans.shape) == (3, 5) | |||
| def test_sequential_sampler(): | |||
| sampler = SequentialSampler() | |||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
| for idx, i in enumerate(sampler(data)): | |||
| assert idx == i | |||
| def test_random_sampler(): | |||
| sampler = RandomSampler() | |||
| data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
| ans = [data[i] for i in sampler(data)] | |||
| assert len(ans) == len(data) | |||
| for d in ans: | |||
| assert d in data | |||
| if __name__ == "__main__": | |||
| test_sequential_sampler() | |||
| @@ -0,0 +1,15 @@ | |||
| 1 I _ PRP PRP _ 2 SUB | |||
| 2 solved _ VBD VBD _ 0 ROOT | |||
| 3 the _ DT DT _ 4 NMOD | |||
| 4 problem _ NN NN _ 2 OBJ | |||
| 5 with _ IN IN _ 2 VMOD | |||
| 6 statistics _ NNS NNS _ 5 PMOD | |||
| 7 . _ . . _ 2 P | |||
| 1 I _ PRP PRP _ 2 SUB | |||
| 2 solved _ VBD VBD _ 0 ROOT | |||
| 3 the _ DT DT _ 4 NMOD | |||
| 4 problem _ NN NN _ 2 OBJ | |||
| 5 with _ IN IN _ 2 VMOD | |||
| 6 statistics _ NNS NNS _ 5 PMOD | |||
| 7 . _ . . _ 2 P | |||
| @@ -0,0 +1,27 @@ | |||
| 19980101-01-001-001/m 迈向/v 充满/v 希望/n 的/u 新/a 世纪/n ——/w 一九九八年/t 新年/t 讲话/n (/w 附/v 图片/n 1/m 张/q )/w | |||
| 19980101-01-001-002/m 中共中央/nt 总书记/n 、/w 国家/n 主席/n 江/nr 泽民/nr | |||
| 19980101-01-001-003/m (/w 一九九七年/t 十二月/t 三十一日/t )/w | |||
| 19980101-01-001-004/m 12月/t 31日/t ,/w 中共中央/nt 总书记/n 、/w 国家/n 主席/n 江/nr 泽民/nr 发表/v 1998年/t 新年/t 讲话/n 《/w 迈向/v 充满/v 希望/n 的/u 新/a 世纪/n 》/w 。/w (/w 新华社/nt 记者/n 兰/nr 红光/nr 摄/Vg )/w | |||
| 19980101-01-001-005/m 同胞/n 们/k 、/w 朋友/n 们/k 、/w 女士/n 们/k 、/w 先生/n 们/k :/w | |||
| 19980101-01-001-006/m 在/p 1998年/t 来临/v 之际/f ,/w 我/r 十分/m 高兴/a 地/u 通过/p [中央/n 人民/n 广播/vn 电台/n]nt 、/w [中国/ns 国际/n 广播/vn 电台/n]nt 和/c [中央/n 电视台/n]nt ,/w 向/p 全国/n 各族/r 人民/n ,/w 向/p [香港/ns 特别/a 行政区/n]ns 同胞/n 、/w 澳门/ns 和/c 台湾/ns 同胞/n 、/w 海外/s 侨胞/n ,/w 向/p 世界/n 各国/r 的/u 朋友/n 们/k ,/w 致以/v 诚挚/a 的/u 问候/vn 和/c 良好/a 的/u 祝愿/vn !/w | |||
| 19980101-01-001-007/m 1997年/t ,/w 是/v 中国/ns 发展/vn 历史/n 上/f 非常/d 重要/a 的/u 很/d 不/d 平凡/a 的/u 一/m 年/q 。/w 中国/ns 人民/n 决心/d 继承/v 邓/nr 小平/nr 同志/n 的/u 遗志/n ,/w 继续/v 把/p 建设/v 有/v 中国/ns 特色/n 社会主义/n 事业/n 推向/v 前进/v 。/w [中国/ns 政府/n]nt 顺利/ad 恢复/v 对/p 香港/ns 行使/v 主权/n ,/w 并/c 按照/p “/w 一国两制/j ”/w 、/w “/w 港人治港/l ”/w 、/w 高度/d 自治/v 的/u 方针/n 保持/v 香港/ns 的/u 繁荣/an 稳定/an 。/w [中国/ns 共产党/n]nt 成功/a 地/u 召开/v 了/u 第十五/m 次/q 全国/n 代表大会/n ,/w 高举/v 邓小平理论/n 伟大/a 旗帜/n ,/w 总结/v 百年/m 历史/n ,/w 展望/v 新/a 的/u 世纪/n ,/w 制定/v 了/u 中国/ns 跨/v 世纪/n 发展/v 的/u 行动/vn 纲领/n 。/w | |||
| 19980101-01-001-008/m 在/p 这/r 一/m 年/q 中/f ,/w 中国/ns 的/u 改革/vn 开放/vn 和/c 现代化/vn 建设/vn 继续/v 向前/v 迈进/v 。/w 国民经济/n 保持/v 了/u “/w 高/a 增长/vn 、/w 低/a 通胀/j ”/w 的/u 良好/a 发展/vn 态势/n 。/w 农业/n 生产/vn 再次/d 获得/v 好/a 的/u 收成/n ,/w 企业/n 改革/vn 继续/v 深化/v ,/w 人民/n 生活/vn 进一步/d 改善/v 。/w 对外/vn 经济/n 技术/n 合作/vn 与/c 交流/vn 不断/d 扩大/v 。/w 民主/a 法制/n 建设/vn 、/w 精神文明/n 建设/vn 和/c 其他/r 各项/r 事业/n 都/d 有/v 新/a 的/u 进展/vn 。/w 我们/r 十分/m 关注/v 最近/t 一个/m 时期/n 一些/m 国家/n 和/c 地区/n 发生/v 的/u 金融/n 风波/n ,/w 我们/r 相信/v 通过/p 这些/r 国家/n 和/c 地区/n 的/u 努力/an 以及/c 有关/v 的/u 国际/n 合作/vn ,/w 情况/n 会/v 逐步/d 得到/v 缓解/vn 。/w 总的来说/c ,/w 中国/ns 改革/v 和/c 发展/v 的/u 全局/n 继续/v 保持/v 了/u 稳定/an 。/w | |||
| 19980101-01-001-009/m 在/p 这/r 一/m 年/q 中/f ,/w 中国/ns 的/u 外交/n 工作/vn 取得/v 了/u 重要/a 成果/n 。/w 通过/p 高层/n 互访/v ,/w 中国/ns 与/p 美国/ns 、/w 俄罗斯/ns 、/w 法国/ns 、/w 日本/ns 等/u 大国/n 确定/v 了/u 双方/n 关系/n 未来/t 发展/v 的/u 目标/n 和/c 指导/vn 方针/n 。/w 中国/ns 与/p 周边/n 国家/n 和/c 广大/b 发展中国家/l 的/u 友好/a 合作/vn 进一步/d 加强/v 。/w 中国/ns 积极/ad 参与/v [亚/j 太/j 经合/j 组织/n]nt 的/u 活动/vn ,/w 参加/v 了/u 东盟/ns —/w 中/j 日/j 韩/j 和/c 中国/ns —/w 东盟/ns 首脑/n 非正式/b 会晤/vn 。/w 这些/r 外交/n 活动/vn ,/w 符合/v 和平/n 与/c 发展/v 的/u 时代/n 主题/n ,/w 顺应/v 世界/n 走向/v 多极化/v 的/u 趋势/n ,/w 对于/p 促进/v 国际/n 社会/n 的/u 友好/a 合作/vn 和/c 共同/b 发展/vn 作出/v 了/u 积极/a 的/u 贡献/n 。/w | |||
| 19980101-01-001-010/m 1998年/t ,/w 中国/ns 人民/n 将/d 满怀信心/l 地/u 开创/v 新/a 的/u 业绩/n 。/w 尽管/c 我们/r 在/p 经济/n 社会/n 发展/v 中/f 还/d 面临/v 不少/m 困难/an ,/w 但/c 我们/r 有/v 邓小平理论/n 的/u 指引/vn ,/w 有/v 改革/v 开放/v 近/a 20/m 年/q 来/f 取得/v 的/u 伟大/a 成就/n 和/c 积累/v 的/u 丰富/a 经验/n ,/w 还/d 有/v 其他/r 的/u 各种/r 有利/a 条件/n ,/w 我们/r 一定/d 能够/v 克服/v 这些/r 困难/an ,/w 继续/v 稳步前进/l 。/w 只要/c 我们/r 进一步/d 解放思想/i ,/w 实事求是/i ,/w 抓住/v 机遇/n ,/w 开拓进取/l ,/w 建设/v 有/v 中国/ns 特色/n 社会主义/n 的/u 道路/n 就/c 会/v 越/d 走/v 越/d 宽广/a 。/w | |||
| 19980101-01-001-011/m 实现/v 祖国/n 的/u 完全/a 统一/vn ,/w 是/v 海内外/s 全体/n 中国/ns 人/n 的/u 共同/b 心愿/n 。/w 通过/p 中/j 葡/j 双方/n 的/u 合作/vn 和/c 努力/an ,/w 按照/p “/w 一国两制/j ”/w 方针/n 和/c 澳门/ns 《/w 基本法/n 》/w ,/w 1999年/t 12月/t 澳门/ns 的/u 回归/vn 一定/d 能够/v 顺利/ad 实现/v 。/w | |||
| 19980101-01-001-012/m 台湾/ns 是/v 中国/ns 领土/n 不可分割/l 的/u 一/m 部分/n 。/w 完成/v 祖国/n 统一/vn ,/w 是/v 大势所趋/i ,/w 民心所向/l 。/w 任何/r 企图/v 制造/v “/w 两/m 个/q 中国/ns ”/w 、/w “/w 一中一台/j ”/w 、/w “/w 台湾/ns 独立/v ”/w 的/u 图谋/n ,/w 都/d 注定/v 要/v 失败/v 。/w 希望/v 台湾/ns 当局/n 以/p 民族/n 大义/n 为重/v ,/w 拿/v 出/v 诚意/n ,/w 采取/v 实际/a 的/u 行动/vn ,/w 推动/v 两岸/n 经济/n 文化/n 交流/vn 和/c 人员/n 往来/vn ,/w 促进/v 两岸/n 直接/ad 通邮/v 、/w 通航/v 、/w 通商/v 的/u 早日/d 实现/v ,/w 并/c 尽早/d 回应/v 我们/r 发出/v 的/u 在/p 一个/m 中国/ns 的/u 原则/n 下/f 两岸/n 进行/v 谈判/vn 的/u 郑重/a 呼吁/vn 。/w | |||
| 19980101-01-001-013/m 环顾/v 全球/n ,/w 日益/d 密切/a 的/u 世界/n 经济/n 联系/vn ,/w 日新月异/i 的/u 科技/n 进步/vn ,/w 正在/d 为/p 各国/r 经济/n 的/u 发展/vn 提供/v 历史/n 机遇/n 。/w 但是/c ,/w 世界/n 还/d 不/d 安宁/a 。/w 南北/f 之间/f 的/u 贫富/n 差距/n 继续/v 扩大/v ;/w 局部/n 冲突/vn 时有发生/l ;/w 不/d 公正/a 不/d 合理/a 的/u 旧/a 的/u 国际/n 政治/n 经济/n 秩序/n 还/d 没有/v 根本/a 改变/vn ;/w 发展中国家/l 在/p 激烈/a 的/u 国际/n 经济/n 竞争/vn 中/f 仍/d 处于/v 弱势/n 地位/n ;/w 人类/n 的/u 生存/vn 与/c 发展/vn 还/d 面临/v 种种/q 威胁/vn 和/c 挑战/vn 。/w 和平/n 与/c 发展/vn 的/u 前景/n 是/v 光明/a 的/u ,/w 21/m 世纪/n 将/d 是/v 充满/v 希望/n 的/u 世纪/n 。/w 但/c 前进/v 的/u 道路/n 不/d 会/v 也/d 不/d 可能/v 一帆风顺/i ,/w 关键/n 是/v 世界/n 各国/r 人民/n 要/v 进一步/d 团结/a 起来/v ,/w 共同/d 推动/v 早日/d 建立/v 公正/a 合理/a 的/u 国际/n 政治/n 经济/n 新/a 秩序/n 。/w | |||
| 19980101-01-001-014/m [中国/ns 政府/n]nt 将/d 继续/v 坚持/v 奉行/v 独立自主/i 的/u 和平/n 外交/n 政策/n ,/w 在/p 和平共处/l 五/m 项/q 原则/n 的/u 基础/n 上/f 努力/ad 发展/v 同/p 世界/n 各国/r 的/u 友好/a 关系/n 。/w 中国/ns 愿意/v 加强/v 同/p 联合国/nt 和/c 其他/r 国际/n 组织/n 的/u 协调/vn ,/w 促进/v 在/p 扩大/v 经贸/j 科技/n 交流/vn 、/w 保护/v 环境/n 、/w 消除/v 贫困/an 、/w 打击/v 国际/n 犯罪/vn 等/u 方面/n 的/u 国际/n 合作/vn 。/w 中国/ns 永远/d 是/v 维护/v 世界/n 和平/n 与/c 稳定/an 的/u 重要/a 力量/n 。/w 中国/ns 人民/n 愿/v 与/p 世界/n 各国/r 人民/n 一道/d ,/w 为/p 开创/v 持久/a 和平/n 、/w 共同/d 发展/v 的/u 新/a 世纪/n 而/c 不懈努力/l !/w | |||
| 19980101-01-001-015/m 在/p 这/r 辞旧迎新/l 的/u 美好/a 时刻/n ,/w 我/r 祝/v 大家/r 新年/t 快乐/a ,/w 家庭/n 幸福/a !/w | |||
| 19980101-01-001-016/m 谢谢/v !/w (/w 新华社/nt 北京/ns 12月/t 31日/t 电/n )/w | |||
| 19980101-01-002-001/m 在/p 十五大/j 精神/n 指引/vn 下/f 胜利/vd 前进/v ——/w 元旦/t 献辞/n | |||
| 19980101-01-002-002/m 我们/r 即将/d 以/p 丰收/vn 的/u 喜悦/an 送/v 走/v 牛年/t ,/w 以/p 昂扬/a 的/u 斗志/n 迎来/v 虎年/t 。/w 我们/r 伟大/a 祖国/n 在/p 新/a 的/u 一/m 年/q ,/w 将/d 是/v 充满/v 生机/n 、/w 充满/v 希望/n 的/u 一/m 年/q 。/w | |||
| 19980101-01-002-003/m 刚刚/d 过去/v 的/u 一/m 年/q ,/w 大气磅礴/i ,/w 波澜壮阔/i 。/w 在/p 这/r 一/m 年/q ,/w 以/p 江/nr 泽民/nr 同志/n 为/v 核心/n 的/u 党中央/nt ,/w 继承/v 邓/nr 小平/nr 同志/n 的/u 遗志/n ,/w 高举/v 邓小平理论/n 的/u 伟大/a 旗帜/n ,/w 领导/v 全党/n 和/c 全国/n 各族/r 人民/n 坚定不移/i 地/u 沿着/p 建设/v 有/v 中国/ns 特色/n 社会主义/n 道路/n 阔步/d 前进/v ,/w 写/v 下/v 了/u 改革/v 开放/v 和/c 社会主义/n 现代化/vn 建设/vn 的/u 辉煌/a 篇章/n 。/w 顺利/a 地/u 恢复/v 对/p 香港/ns 行使/v 主权/n ,/w 胜利/v 地/u 召开/v 党/n 的/u 第十五/m 次/q 全国/n 代表大会/n ———/w 两/m 件/q 大事/n 办/v 得/u 圆满/a 成功/a 。/w 国民经济/n 稳中求进/l ,/w 国家/n 经济/n 实力/n 进一步/d 增强/v ,/w 人民/n 生活/vn 继续/v 改善/v ,/w 对外/vn 经济/n 技术/n 交流/vn 日益/d 扩大/v 。/w 在/p 国际/n 金融/n 危机/n 的/u 风浪/n 波及/v 许多/m 国家/n 的/u 情况/n 下/f ,/w 我国/n 保持/v 了/u 金融/n 形势/n 和/c 整个/b 经济/n 形势/n 的/u 稳定/a 发展/vn 。/w 社会主义/n 精神文明/n 建设/vn 和/c 民主/a 法制/n 建设/vn 取得/v 新/a 的/u 成绩/n ,/w 各项/r 社会/n 事业/n 全面/ad 进步/v 。/w 外交/n 工作/vn 取得/v 可喜/a 的/u 突破/vn ,/w 我国/n 的/u 国际/n 地位/n 和/c 国际/n 威望/n 进一步/d 提高/v 。/w 实践/v 使/v 亿万/m 人民/n 对/p 邓小平理论/n 更加/d 信仰/v ,/w 对/p 以/p 江/nr 泽民/nr 同志/n 为/v 核心/n 的/u 党中央/nt 更加/d 信赖/v ,/w 对/p 伟大/a 祖国/n 的/u 光辉/n 前景/n 更加/d 充满/v 信心/n 。/w | |||
| 19980101-01-002-004/m 1998年/t ,/w 是/v 全面/ad 贯彻/v 落实/v 党/n 的/u 十五大/j 提出/v 的/u 任务/n 的/u 第一/m 年/q ,/w 各/r 条/q 战线/n 改革/v 和/c 发展/v 的/u 任务/n 都/d 十分/m 繁重/a ,/w 有/v 许多/m 深/a 层次/n 的/u 矛盾/an 和/c 问题/n 有待/v 克服/v 和/c 解决/v ,/w 特别/d 是/v 国有/vn 企业/n 改革/vn 已经/d 进入/v 攻坚/vn 阶段/n 。/w 我们/r 必须/d 进一步/d 深入/ad 学习/v 和/c 掌握/v 党/n 的/u 十五大/j 精神/n ,/w 统揽全局/l ,/w 精心/ad 部署/v ,/w 狠抓/v 落实/v ,/w 团结/a 一致/a ,/w 艰苦奋斗/i ,/w 开拓/v 前进/v ,/w 为/p 夺取/v 今年/t 改革/v 开放/v 和/c 社会主义/n 现代化/vn 建设/vn 的/u 新/a 胜利/vn 而/c 奋斗/v 。/w | |||
| 19980101-01-002-005/m 今年/t 是/v 党/n 的/u 十一/m 届/q 三中全会/j 召开/v 20/m 周年/q ,/w 是/v 我们/r 党/n 和/c 国家/n 实现/v 伟大/a 的/u 历史/n 转折/vn 、/w 进入/v 改革/vn 开放/vn 历史/n 新/a 时期/n 的/u 20/m 周年/q 。/w 在/p 新/a 的/u 一/m 年/q 里/f ,/w 大力/d 发扬/v 十一/m 届/q 三中全会/j 以来/f 我们/r 党/n 所/u 恢复/v 的/u 优良/z 传统/n 和/c 在/p 新/a 的/u 历史/n 条件/n 下/f 形成/v 的/u 优良/z 作风/n ,/w 对于/p 完成/v 好/a 今年/t 的/u 各项/r 任务/n 具有/v 十分/m 重要/a 的/u 意义/n 。/w | |||
| 19980101-01-002-006/m 我们/r 要/v 更/d 好/a 地/u 坚持/v 解放思想/i 、/w 实事求是/i 的/u 思想/n 路线/n 。/w 解放思想/i 、/w 实事求是/i ,/w 是/v 邓小平理论/n 的/u 精髓/n 。/w 实践/v 证明/v ,/w 只有/c 解放思想/i 、/w 实事求是/i ,/w 才/c 能/v 冲破/v 各种/r 不/d 切合/v 实际/n 的/u 或者/c 过时/a 的/u 观念/n 的/u 束缚/vn ,/w 真正/d 做到/v 尊重/v 、/w 认识/v 和/c 掌握/v 客观/a 规律/n ,/w 勇于/v 突破/v ,/w 勇于/v 创新/v ,/w 不断/d 开创/v 社会主义/n 现代化/vn 建设/vn 的/u 新/a 局面/n 。/w 党/n 的/u 十五大/j 是/v 我们/r 党/n 解放思想/i 、/w 实事求是/i 的/u 新/a 的/u 里程碑/n 。/w 进一步/d 认真/ad 学习/v 和/c 掌握/v 十五大/j 精神/n ,/w 解放思想/i 、/w 实事求是/i ,/w 我们/r 的/u 各项/r 事业/n 就/d 能/v 结/v 出/v 更加/d 丰硕/a 的/u 成果/n 。/w | |||
| 19980101-01-002-007/m 我们/r 要/v 更/d 好/a 地/u 坚持/v 以/p 经济/n 建设/vn 为/v 中心/n 。/w 各项/r 工作/vn 必须/d 以/p 经济/n 建设/vn 为/v 中心/n ,/w 是/v 邓小平理论/n 的/u 基本/a 观点/n ,/w 是/v 党/n 的/u 基本/a 路线/n 的/u 核心/n 内容/n ,/w 近/a 20/m 年/q 来/f 的/u 实践/vn 证明/v ,/w 坚持/v 这个/r 中心/n ,/w 是/v 完全/ad 正确/a 的/u 。/w 今后/t ,/w 我们/r 能否/v 把/p 建设/v 有/v 中国/ns 特色/n 社会主义/n 伟大/a 事业/n 全面/ad 推向/v 21/m 世纪/n ,/w 关键/n 仍然/d 要/v 看/v 能否/v 把/p 经济/n 工作/vn 搞/v 上去/v 。/w 各级/r 领导/n 干部/n 要/v 切实/ad 把/p 精力/n 集中/v 到/v 贯彻/v 落实/v 好/a 中央/n 关于/p 今年/t 经济/n 工作/vn 的/u 总体/n 要求/n 和/c 各项/r 重要/a 任务/n 上/f 来/v ,/w 不断/d 提高/v 领导/v 经济/n 建设/vn 的/u 能力/n 和/c 水平/n 。/w | |||
| 19980101-01-002-008/m 我们/r 要/v 更/d 好/a 地/u 坚持/v “/w 两手抓/l 、/w 两手/m 都/d 要/v 硬/a ”/w 的/u 方针/n 。/w 在/p 坚持/v 以/p 经济/n 建设/vn 为/v 中心/n 的/u 同时/n ,/w 积极/ad 推进/v 社会主义/n 精神文明/n 建设/vn 和/c 民主/a 法制/n 建设/vn ,/w 是/v 建设/v 富强/a 、/w 民主/a 、/w 文明/a 的/u 社会主义/n 现代化/vn 国家/n 的/u 重要/a 内容/n 。/w 实践/v 证明/v ,/w 经济/n 建设/vn 的/u 顺利/a 进行/vn ,/w 离/v 不/d 开/v 精神文明/n 建设/vn 和/c 民主/a 法制/n 建设/vn 的/u 保证/vn 。/w 党/n 的/u 十五大/j 依据/p 邓小平理论/n 和/c 党/n 的/u 基本/a 路线/n 提出/v 的/u 党/n 在/p 社会主义/n 初级/b 阶段/n 经济/n 、/w 政治/n 、/w 文化/n 的/u 基本/a 纲领/n ,/w 为/p “/w 两手抓/l 、/w 两手/m 都/d 要/v 硬/a ”/w 提供/v 了/u 新/a 的/u 理论/n 根据/n ,/w 提出/v 了/u 更/d 高/a 要求/n ,/w 现在/t 的/u 关键/n 是/v 认真/ad 抓好/v 落实/v 。/w | |||
| 19980101-01-002-009/m 我们/r 要/v 更/d 好/a 地/u 发扬/v 求真务实/l 、/w 密切/ad 联系/v 群众/n 的/u 作风/n 。/w 这/r 是/v 把/p 党/n 的/u 方针/n 、/w 政策/n 落到实处/l ,/w 使/v 改革/v 和/c 建设/v 取得/v 胜利/vn 的/u 重要/a 保证/vn 。/w 在/p 当前/t 改革/v 进一步/d 深化/v ,/w 经济/n 不断/d 发展/v ,/w 同时/c 又/d 出现/v 一些/m 新/a 情况/n 、/w 新/a 问题/n 和/c 新/a 困难/an 的/u 形势/n 下/f ,/w 更/d 要/v 发扬/v 这样/r 的/u 好/a 作风/n 。/w 要/v 尊重/v 群众/n 的/u 意愿/n ,/w 重视/v 群众/n 的/u 首创/vn 精神/n ,/w 关心/v 群众/n 的/u 生活/vn 疾苦/n 。/w 江/nr 泽民/nr 同志/n 最近/t 强调/vd 指出/v ,/w 要/v 大力/d 倡导/v 说实话/l 、/w 办/v 实事/n 、/w 鼓/v 实劲/n 、/w 讲/v 实效/n 的/u 作风/n ,/w 坚决/ad 制止/v 追求/v 表面文章/i ,/w 搞/v 花架子/n 等/u 形式主义/n ,/w 坚决/ad 杜绝/v 脱离/v 群众/n 、/w 脱离/v 实际/n 、/w 浮躁/a 虚夸/v 等/u 官僚主义/n 。/w 这/r 是/v 非常/d 重要/a 的/u 。/w 因此/c ,/w 各级/r 领导/n 干部/n 务必/d 牢记/v 全心全意/i 为/p 人民/n 服务/v 的/u 宗旨/n ,/w 在/p 勤政廉政/l 、/w 艰苦奋斗/i 方面/n 以身作则/i ,/w 当/v 好/a 表率/n 。/w | |||
| 19980101-01-002-010/m 1998/m ,/w 瞩目/v 中华/nz 。/w 新/a 的/u 机遇/n 和/c 挑战/vn ,/w 催/v 人/n 进取/v ;/w 新/a 的/u 目标/n 和/c 征途/n ,/w 催/v 人/n 奋发/v 。/w 英雄/n 的/u 中国/ns 人民/n 在/p 以/p 江/nr 泽民/nr 同志/n 为/v 核心/n 的/u 党中央/nt 坚强/a 领导/vn 和/c 党/n 的/u 十五大/j 精神/n 指引/v 下/f ,/w 更/d 高/a 地/u 举起/v 邓小平理论/n 的/u 伟大/a 旗帜/n ,/w 团结/a 一致/a ,/w 扎实/ad 工作/v ,/w 奋勇/d 前进/v ,/w 一定/d 能够/v 创造/v 出/v 更加/d 辉煌/a 的/u 业绩/n !/w | |||
| @@ -4,7 +4,6 @@ import os | |||
| import unittest | |||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader | |||
| class TestConfigLoader(unittest.TestCase): | |||
| @@ -52,21 +51,3 @@ class TestConfigLoader(unittest.TestCase): | |||
| print("pass config test!") | |||
| class TestDatasetLoader(unittest.TestCase): | |||
| def test_case_TokenizeDatasetLoader(self): | |||
| loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||
| data = loader.load_pku(max_seq_len=32) | |||
| print("pass TokenizeDatasetLoader test!") | |||
| def test_case_POSDatasetLoader(self): | |||
| loader = POSDatasetLoader("./test/data_for_tests/people.txt") | |||
| data = loader.load() | |||
| datas = loader.load_lines() | |||
| print("pass POSDatasetLoader test!") | |||
| def test_case_LMDatasetLoader(self): | |||
| loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||
| data = loader.load() | |||
| datas = loader.load_lines() | |||
| print("pass TokenizeDatasetLoader test!") | |||
| @@ -0,0 +1,42 @@ | |||
| import unittest | |||
| from fastNLP.loader.dataset_loader import POSDatasetLoader, LMDatasetLoader, TokenizeDatasetLoader, \ | |||
| PeopleDailyCorpusLoader, ConllLoader | |||
| class TestDatasetLoader(unittest.TestCase): | |||
| def test_case_1(self): | |||
| data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||
| lines = data.split("\n") | |||
| answer = POSDatasetLoader.parse(lines) | |||
| truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||
| self.assertListEqual(answer, truth, "POS Dataset Loader") | |||
| def test_case_TokenizeDatasetLoader(self): | |||
| loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||
| data = loader.load_pku(max_seq_len=32) | |||
| print("pass TokenizeDatasetLoader test!") | |||
| def test_case_POSDatasetLoader(self): | |||
| loader = POSDatasetLoader("./test/data_for_tests/people.txt") | |||
| data = loader.load() | |||
| datas = loader.load_lines() | |||
| print("pass POSDatasetLoader test!") | |||
| def test_case_LMDatasetLoader(self): | |||
| loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||
| data = loader.load() | |||
| datas = loader.load_lines() | |||
| print("pass TokenizeDatasetLoader test!") | |||
| def test_PeopleDailyCorpusLoader(self): | |||
| loader = PeopleDailyCorpusLoader("./test/data_for_tests/people_daily_raw.txt") | |||
| _, _ = loader.load() | |||
| def test_ConllLoader(self): | |||
| loader = ConllLoader("./test/data_for_tests/conll_example.txt") | |||
| _ = loader.load() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,24 +0,0 @@ | |||
| import unittest | |||
| from fastNLP.loader.dataset_loader import POSDatasetLoader | |||
| class TestPreprocess(unittest.TestCase): | |||
| def test_case_1(self): | |||
| data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], | |||
| ["Hello", "world", "!"], ["T", "F", "F"]] | |||
| pickle_path = "./data_for_tests/" | |||
| # POSPreprocess(data, pickle_path) | |||
| class TestDatasetLoader(unittest.TestCase): | |||
| def test_case_1(self): | |||
| data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | |||
| lines = data.split("\n") | |||
| answer = POSDatasetLoader.parse(lines) | |||
| truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | |||
| self.assertListEqual(answer, truth, "POS Dataset Loader") | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -1,28 +1,25 @@ | |||
| import sys | |||
| import os | |||
| sys.path.append("..") | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.core.preprocess import Preprocessor, load_pickle | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.core.trainer import SeqLabelTrainer | |||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||
| from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| from fastNLP.loader.model_loader import ModelLoader | |||
| from fastNLP.core.tester import SeqLabelTester | |||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||
| from fastNLP.core.predictor import Predictor | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| data_name = "pku_training.utf8" | |||
| # cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | |||
| cws_data_path = "data_for_tests/cws_pku_utf_8" | |||
| pickle_path = "data_for_tests" | |||
| data_infer_path = "data_for_tests/people_infer.txt" | |||
| cws_data_path = "test/data_for_tests/cws_pku_utf_8" | |||
| pickle_path = "./save/" | |||
| data_infer_path = "test/data_for_tests/people_infer.txt" | |||
| config_path = "test/data_for_tests/config" | |||
| def infer(): | |||
| # Load infer configuration, the same as test | |||
| test_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
| ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args}) | |||
| # fetch dictionary size and number of labels from pickle files | |||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||
| @@ -34,41 +31,31 @@ def infer(): | |||
| model = SeqLabeling(test_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| print("model loaded!") | |||
| # Data Loader | |||
| raw_data_loader = BaseLoader(data_infer_path) | |||
| infer_data = raw_data_loader.load_lines() | |||
| """ | |||
| Transform strings into list of list of strings. | |||
| [ | |||
| [word_11, word_12, ...], | |||
| [word_21, word_22, ...], | |||
| ... | |||
| ] | |||
| In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||
| """ | |||
| # Inference interface | |||
| infer = Predictor(pickle_path) | |||
| infer = Predictor(pickle_path, "seq_label") | |||
| results = infer.predict(model, infer_data) | |||
| print(results) | |||
| print("Inference finished!") | |||
| def train_test(): | |||
| # Config Loader | |||
| train_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args}) | |||
| ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": train_args}) | |||
| # Data Loader | |||
| loader = TokenizeDatasetLoader(cws_data_path) | |||
| train_data = loader.load_pku() | |||
| # Preprocessor | |||
| p = SeqLabelPreprocess() | |||
| p = Preprocessor(label_is_seq=True) | |||
| data_train = p.run(train_data, pickle_path=pickle_path) | |||
| train_args["vocab_size"] = p.vocab_size | |||
| train_args["num_classes"] = p.num_classes | |||
| @@ -81,12 +68,10 @@ def train_test(): | |||
| # Start training | |||
| trainer.train(model, data_train) | |||
| print("Training finished!") | |||
| # Saver | |||
| saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||
| saver = ModelSaver("./save/saved_model.pkl") | |||
| saver.save_pytorch(model) | |||
| print("Model saved!") | |||
| del model, trainer, loader | |||
| @@ -94,12 +79,11 @@ def train_test(): | |||
| model = SeqLabeling(train_args) | |||
| # Dump trained parameters into the model | |||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||
| print("model loaded!") | |||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||
| # Load test configuration | |||
| test_args = ConfigSection() | |||
| ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
| ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args}) | |||
| # Tester | |||
| tester = SeqLabelTester(**test_args.data) | |||
| @@ -109,7 +93,13 @@ def train_test(): | |||
| # print test results | |||
| print(tester.show_metrics()) | |||
| print("model tested!") | |||
| def test(): | |||
| os.makedirs("save", exist_ok=True) | |||
| train_test() | |||
| infer() | |||
| os.system("rm -rf save") | |||
| if __name__ == "__main__": | |||
| @@ -1,7 +1,6 @@ | |||
| import unittest | |||
| import torch | |||
| import unittest | |||
| from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear | |||
| @@ -1,18 +1,9 @@ | |||
| import torch | |||
| import numpy as np | |||
| import unittest | |||
| import fastNLP.modules.utils as utils | |||
| class TestUtils(unittest.TestCase): | |||
| def test_case_1(self): | |||
| a = torch.tensor([ | |||
| [1, 2, 3, 4, 5], [2, 3, 4, 5, 6] | |||
| ]) | |||
| utils.orthogonal(a) | |||
| pass | |||
| def test_case_2(self): | |||
| a = np.random.rand(100, 100) | |||
| utils.mst(a) | |||
| pass | |||
| @@ -1,16 +1,32 @@ | |||
| import sys | |||
| # encoding: utf-8 | |||
| import os | |||
| sys.path.append("..") | |||
| from fastNLP.core.preprocess import save_pickle | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.fastnlp import FastNLP | |||
| from fastNLP.fastnlp import interpret_word_seg_results, interpret_cws_pos_results | |||
| from fastNLP.models.cnn_text_classification import CNNText | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| from fastNLP.saver.model_saver import ModelSaver | |||
| PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/" | |||
| PATH_TO_POS_TAG_PICKLE_FILES = "/home/zyfeng/data/crf_seg/" | |||
| PATH_TO_TEXT_CLASSIFICATION_PICKLE_FILES = "/home/zyfeng/data/text_classify/" | |||
| def word_seg(): | |||
| nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES) | |||
| nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test") | |||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
| DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||
| '<reserved-3>', | |||
| '<reserved-4>'] # dict index = 2~4 | |||
| DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||
| def word_seg(model_dir, config, section): | |||
| nlp = FastNLP(model_dir=model_dir) | |||
| nlp.load("cws_basic_model", config_file=config, section_name=section) | |||
| text = ["这是最好的基于深度学习的中文分词系统。", | |||
| "大王叫我来巡山。", | |||
| "我党多年来致力于改善人民生活水平。"] | |||
| @@ -24,38 +40,52 @@ def word_seg(): | |||
| print(interpret_word_seg_results(words, labels)) | |||
| def text_class(): | |||
| nlp = FastNLP("./data_for_tests/") | |||
| nlp.load("text_class_model") | |||
| text = "这是最好的基于深度学习的中文分词系统。" | |||
| result = nlp.run(text) | |||
| print(result) | |||
| print("FastNLP finished!") | |||
| def mock_cws(): | |||
| os.makedirs("mock", exist_ok=True) | |||
| text = ["这是最好的基于深度学习的中文分词系统。", | |||
| "大王叫我来巡山。", | |||
| "我党多年来致力于改善人民生活水平。"] | |||
| word2id = Vocabulary() | |||
| word_list = [ch for ch in "".join(text)] | |||
| word2id.update(word_list) | |||
| save_pickle(word2id, "./mock/", "word2id.pkl") | |||
| def test_word_seg_interpret(): | |||
| foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), | |||
| ('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), | |||
| ('。', 'S')]] | |||
| chars = [x[0] for x in foo[0]] | |||
| labels = [x[1] for x in foo[0]] | |||
| print(interpret_word_seg_results(chars, labels)) | |||
| class2id = Vocabulary(need_default=False) | |||
| label_list = ['B', 'M', 'E', 'S'] | |||
| class2id.update(label_list) | |||
| save_pickle(class2id, "./mock/", "class2id.pkl") | |||
| model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} | |||
| config_file = """ | |||
| [test_section] | |||
| vocab_size = {} | |||
| word_emb_dim = 50 | |||
| rnn_hidden_units = 50 | |||
| num_classes = {} | |||
| """.format(len(word2id), len(class2id)) | |||
| with open("mock/test.cfg", "w", encoding="utf-8") as f: | |||
| f.write(config_file) | |||
| def test_interpret_cws_pos_results(): | |||
| foo = [ | |||
| [('这', 'S-r'), ('是', 'S-v'), ('最', 'S-d'), ('好', 'S-a'), ('的', 'S-u'), ('基', 'B-p'), ('于', 'E-p'), ('深', 'B-d'), | |||
| ('度', 'E-d'), ('学', 'B-v'), ('习', 'E-v'), ('的', 'S-u'), ('中', 'B-nz'), ('文', 'E-nz'), ('分', 'B-vn'), | |||
| ('词', 'E-vn'), ('系', 'B-n'), ('统', 'E-n'), ('。', 'S-w')] | |||
| ] | |||
| chars = [x[0] for x in foo[0]] | |||
| labels = [x[1] for x in foo[0]] | |||
| print(interpret_cws_pos_results(chars, labels)) | |||
| model = AdvSeqLabel(model_args) | |||
| ModelSaver("mock/cws_basic_model_v_0.pkl").save_pytorch(model) | |||
| def test_word_seg(): | |||
| # fake the model and pickles | |||
| print("start mocking") | |||
| mock_cws() | |||
| # run the inference codes | |||
| print("start testing") | |||
| word_seg("./mock/", "test.cfg", "test_section") | |||
| # clean up environments | |||
| print("clean up") | |||
| os.system("rm -rf mock") | |||
| def pos_tag(): | |||
| nlp = FastNLP(model_dir=PATH_TO_POS_TAG_PICKLE_FILES) | |||
| nlp.load("pos_tag_model", config_file="pos_tag.config", section_name="pos_tag_model") | |||
| def pos_tag(model_dir, config, section): | |||
| nlp = FastNLP(model_dir=model_dir) | |||
| nlp.load("pos_tag_model", config_file=config, section_name=section) | |||
| text = ["这是最好的基于深度学习的中文分词系统。", | |||
| "大王叫我来巡山。", | |||
| "我党多年来致力于改善人民生活水平。"] | |||
| @@ -65,21 +95,119 @@ def pos_tag(): | |||
| for res in example: | |||
| words.append(res[0]) | |||
| labels.append(res[1]) | |||
| print(interpret_cws_pos_results(words, labels)) | |||
| try: | |||
| print(interpret_cws_pos_results(words, labels)) | |||
| except RuntimeError: | |||
| print("inconsistent pos tags. this is for test only.") | |||
| def mock_pos_tag(): | |||
| os.makedirs("mock", exist_ok=True) | |||
| text = ["这是最好的基于深度学习的中文分词系统。", | |||
| "大王叫我来巡山。", | |||
| "我党多年来致力于改善人民生活水平。"] | |||
| vocab = Vocabulary() | |||
| word_list = [ch for ch in "".join(text)] | |||
| vocab.update(word_list) | |||
| save_pickle(vocab, "./mock/", "word2id.pkl") | |||
| idx2label = Vocabulary(need_default=False) | |||
| label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] | |||
| idx2label.update(label_list) | |||
| save_pickle(idx2label, "./mock/", "class2id.pkl") | |||
| model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | |||
| config_file = """ | |||
| [test_section] | |||
| vocab_size = {} | |||
| word_emb_dim = 50 | |||
| rnn_hidden_units = 50 | |||
| num_classes = {} | |||
| """.format(len(vocab), len(idx2label)) | |||
| with open("mock/test.cfg", "w", encoding="utf-8") as f: | |||
| f.write(config_file) | |||
| def text_classify(): | |||
| nlp = FastNLP(model_dir=PATH_TO_TEXT_CLASSIFICATION_PICKLE_FILES) | |||
| nlp.load("text_classify_model", config_file="text_classify.cfg", section_name="model") | |||
| model = AdvSeqLabel(model_args) | |||
| ModelSaver("mock/pos_tag_model_v_0.pkl").save_pytorch(model) | |||
| def test_pos_tag(): | |||
| mock_pos_tag() | |||
| pos_tag("./mock/", "test.cfg", "test_section") | |||
| os.system("rm -rf mock") | |||
| def text_classify(model_dir, config, section): | |||
| nlp = FastNLP(model_dir=model_dir) | |||
| nlp.load("text_classify_model", config_file=config, section_name=section) | |||
| text = [ | |||
| "世界物联网大会明日在京召开龙头股启动在即", | |||
| "乌鲁木齐市新增一处城市中心旅游目的地", | |||
| "朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"] | |||
| results = nlp.run(text) | |||
| print(results) | |||
| """ | |||
| ['finance', 'travel', 'history'] | |||
| """ | |||
| def mock_text_classify(): | |||
| os.makedirs("mock", exist_ok=True) | |||
| text = ["世界物联网大会明日在京召开龙头股启动在即", | |||
| "乌鲁木齐市新增一处城市中心旅游目的地", | |||
| "朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”" | |||
| ] | |||
| vocab = Vocabulary() | |||
| word_list = [ch for ch in "".join(text)] | |||
| vocab.update(word_list) | |||
| save_pickle(vocab, "./mock/", "word2id.pkl") | |||
| idx2label = Vocabulary(need_default=False) | |||
| label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] | |||
| idx2label.update(label_list) | |||
| save_pickle(idx2label, "./mock/", "class2id.pkl") | |||
| model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | |||
| config_file = """ | |||
| [test_section] | |||
| vocab_size = {} | |||
| word_emb_dim = 50 | |||
| rnn_hidden_units = 50 | |||
| num_classes = {} | |||
| """.format(len(vocab), len(idx2label)) | |||
| with open("mock/test.cfg", "w", encoding="utf-8") as f: | |||
| f.write(config_file) | |||
| model = CNNText(model_args) | |||
| ModelSaver("mock/text_class_model_v0.pkl").save_pytorch(model) | |||
| def test_text_classify(): | |||
| mock_text_classify() | |||
| text_classify("./mock/", "test.cfg", "test_section") | |||
| os.system("rm -rf mock") | |||
| def test_word_seg_interpret(): | |||
| foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'), | |||
| ('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'), | |||
| ('。', 'S')]] | |||
| chars = [x[0] for x in foo[0]] | |||
| labels = [x[1] for x in foo[0]] | |||
| print(interpret_word_seg_results(chars, labels)) | |||
| def test_interpret_cws_pos_results(): | |||
| foo = [ | |||
| [('这', 'S-r'), ('是', 'S-v'), ('最', 'S-d'), ('好', 'S-a'), ('的', 'S-u'), ('基', 'B-p'), ('于', 'E-p'), ('深', 'B-d'), | |||
| ('度', 'E-d'), ('学', 'B-v'), ('习', 'E-v'), ('的', 'S-u'), ('中', 'B-nz'), ('文', 'E-nz'), ('分', 'B-vn'), | |||
| ('词', 'E-vn'), ('系', 'B-n'), ('统', 'E-n'), ('。', 'S-w')] | |||
| ] | |||
| chars = [x[0] for x in foo[0]] | |||
| labels = [x[1] for x in foo[0]] | |||
| print(interpret_cws_pos_results(chars, labels)) | |||
| if __name__ == "__main__": | |||
| text_classify() | |||
| test_word_seg() | |||
| test_pos_tag() | |||
| test_text_classify() | |||
| test_word_seg_interpret() | |||
| test_interpret_cws_pos_results() | |||