| @@ -1,58 +1,92 @@ | |||||
| # FastNLP | # FastNLP | ||||
| ``` | ``` | ||||
| FastNLP | FastNLP | ||||
| │ LICENSE | |||||
| │ README.md | |||||
| │ requirements.txt | |||||
| │ setup.py | |||||
| ├── docs | |||||
| │ └── quick_tutorial.md | |||||
| ├── fastNLP | |||||
| │ ├── action | |||||
| │ │ ├── action.py | |||||
| │ │ ├── inference.py | |||||
| │ │ ├── __init__.py | |||||
| │ │ ├── metrics.py | |||||
| │ │ ├── optimizer.py | |||||
| │ │ ├── README.md | |||||
| │ │ ├── tester.py | |||||
| │ │ └── trainer.py | |||||
| │ ├── fastnlp.py | |||||
| │ ├── __init__.py | |||||
| │ ├── loader | |||||
| │ │ ├── base_loader.py | |||||
| │ │ ├── config_loader.py | |||||
| │ │ ├── dataset_loader.py | |||||
| │ │ ├── embed_loader.py | |||||
| │ │ ├── __init__.py | |||||
| │ │ ├── model_loader.py | |||||
| │ │ └── preprocess.py | |||||
| │ ├── models | |||||
| │ │ ├── base_model.py | |||||
| │ │ ├── char_language_model.py | |||||
| │ │ ├── cnn_text_classification.py | |||||
| │ │ ├── __init__.py | |||||
| │ │ └── sequence_modeling.py | |||||
| │ ├── modules | |||||
| │ │ ├── aggregation | |||||
| │ │ │ ├── attention.py | |||||
| │ │ │ ├── avg_pool.py | |||||
| │ │ │ ├── __init__.py | |||||
| │ │ │ ├── kmax_pool.py | |||||
| │ │ │ ├── max_pool.py | |||||
| │ │ │ └── self_attention.py | |||||
| │ │ ├── decoder | |||||
| │ │ │ ├── CRF.py | |||||
| │ │ │ └── __init__.py | |||||
| │ │ ├── encoder | |||||
| │ │ │ ├── char_embedding.py | |||||
| │ │ │ ├── conv_maxpool.py | |||||
| │ │ │ ├── conv.py | |||||
| │ │ │ ├── embedding.py | |||||
| │ │ │ ├── __init__.py | |||||
| │ │ │ ├── linear.py | |||||
| │ │ │ ├── lstm.py | |||||
| │ │ │ ├── masked_rnn.py | |||||
| │ │ │ └── variational_rnn.py | |||||
| │ │ ├── __init__.py | |||||
| │ │ ├── interaction | |||||
| │ │ │ └── __init__.py | |||||
| │ │ ├── other_modules.py | |||||
| │ │ └── utils.py | |||||
| │ └── saver | |||||
| │ ├── base_saver.py | |||||
| │ ├── __init__.py | |||||
| │ ├── logger.py | |||||
| │ └── model_saver.py | |||||
| ├── LICENSE | |||||
| ├── README.md | |||||
| ├── reproduction | |||||
| │ ├── Char-aware_NLM | |||||
| │ │ | |||||
| │ ├── CNN-sentence_classification | |||||
| │ │ | |||||
| │ ├── HAN-document_classification | |||||
| │ │ | |||||
| │ └── LSTM+self_attention_sentiment_analysis | |||||
| | | | | ||||
| ├─docs (documentation) | |||||
| | | |||||
| └─tests (unit tests, intergrating tests, system tests) | |||||
| | │ test_charlm.py | |||||
| | │ test_loader.py | |||||
| | │ test_trainer.py | |||||
| | │ test_word_seg.py | |||||
| | │ | |||||
| | └─data_for_tests (test data used by models) | |||||
| | charlm.txt | |||||
| | cws_test | |||||
| | cws_train | |||||
| | | |||||
| └─fastNLP | |||||
| ├─action (model independent process) | |||||
| │ │ action.py (base class) | |||||
| │ │ README.md | |||||
| │ │ tester.py (model testing, for deployment and validation) | |||||
| │ │ trainer.py (main logic for model training) | |||||
| │ │ __init__.py | |||||
| │ │ | |||||
| | | |||||
| │ | |||||
| ├─loader (file loader for all loading operations) | |||||
| │ | base_loader.py (base class) | |||||
| │ | config_loader.py (model-specific configuration/parameter loader) | |||||
| │ | dataset_loader.py (data set loader, base class) | |||||
| │ | embed_loader.py (embedding loader, base class) | |||||
| │ | __init__.py | |||||
| │ | |||||
| ├─model (definitions of PyTorch models) | |||||
| │ │ base_model.py (base class, abstract) | |||||
| │ │ char_language_model.py (derived class, to implement abstract methods) | |||||
| │ │ word_seg_model.py | |||||
| │ │ __init__.py | |||||
| │ │ | |||||
| │ | |||||
| ├─reproduction (code library for paper reproduction) | |||||
| │ ├─Char-aware_NLM | |||||
| │ │ | |||||
| │ ├─CNN-sentence_classification | |||||
| │ │ | |||||
| │ └─HAN-document_classification | |||||
| │ | |||||
| ├─saver (file saver for all saving operations) | |||||
| │ base_saver.py | |||||
| │ logger.py | |||||
| │ model_saver.py | |||||
| │ | |||||
| ├── requirements.txt | |||||
| ├── setup.py | |||||
| └── test | |||||
| ├── data_for_tests | |||||
| │ ├── charlm.txt | |||||
| │ ├── config | |||||
| │ ├── cws_test | |||||
| │ ├── cws_train | |||||
| │ ├── people_infer.txt | |||||
| │ └── people.txt | |||||
| ├── test_charlm.py | |||||
| ├── test_cws.py | |||||
| ├── test_fastNLP.py | |||||
| ├── test_loader.py | |||||
| ├── test_seq_labeling.py | |||||
| ├── test_tester.py | |||||
| └── test_trainer.py | |||||
| ``` | ``` | ||||
| @@ -1,71 +0,0 @@ | |||||
| import numpy as np | |||||
| class Action(object): | |||||
| """ | |||||
| base class for Trainer and Tester | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Action, self).__init__() | |||||
| class BaseSampler(object): | |||||
| """ | |||||
| Base class for all samplers. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| self.data_set_length = len(data_set) | |||||
| def __len__(self): | |||||
| return self.data_set_length | |||||
| def __iter__(self): | |||||
| raise NotImplementedError | |||||
| class SequentialSampler(BaseSampler): | |||||
| """ | |||||
| Sample data in the original order. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| super(SequentialSampler, self).__init__(data_set) | |||||
| def __iter__(self): | |||||
| return iter(range(self.data_set_length)) | |||||
| class RandomSampler(BaseSampler): | |||||
| """ | |||||
| Sample data in random permutation order. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| super(RandomSampler, self).__init__(data_set) | |||||
| def __iter__(self): | |||||
| return iter(np.random.permutation(self.data_set_length)) | |||||
| class Batchifier(object): | |||||
| """ | |||||
| Wrap random or sequential sampler to generate a mini-batch. | |||||
| """ | |||||
| def __init__(self, sampler, batch_size, drop_last=True): | |||||
| super(Batchifier, self).__init__() | |||||
| self.sampler = sampler | |||||
| self.batch_size = batch_size | |||||
| self.drop_last = drop_last | |||||
| def __iter__(self): | |||||
| batch = [] | |||||
| for idx in self.sampler: | |||||
| batch.append(idx) | |||||
| if len(batch) == self.batch_size: | |||||
| yield batch | |||||
| batch = [] | |||||
| if 0 < len(batch) < self.batch_size and self.drop_last is False: | |||||
| yield batch | |||||
| @@ -1,26 +0,0 @@ | |||||
| class Inference(object): | |||||
| """ | |||||
| This is an interface focusing on predicting output based on trained models. | |||||
| It does not care about evaluations of the model. | |||||
| """ | |||||
| def __init__(self): | |||||
| pass | |||||
| def predict(self, model, data): | |||||
| """ | |||||
| this is actually a forward pass. shall be shared by Trainer/Tester | |||||
| :param model: | |||||
| :param data: | |||||
| :return result: the output results | |||||
| """ | |||||
| raise NotImplementedError | |||||
| def prepare_input(self, data_path): | |||||
| """ | |||||
| This can also be shared. | |||||
| :param data_path: | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| @@ -0,0 +1,150 @@ | |||||
| from collections import Counter | |||||
| import numpy as np | |||||
| class Action(object): | |||||
| """ | |||||
| base class for Trainer and Tester | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Action, self).__init__() | |||||
| def k_means_1d(x, k, max_iter=100): | |||||
| """ | |||||
| Perform k-means on 1-D data. | |||||
| :param x: list of int, representing points in 1-D. | |||||
| :param k: the number of clusters required. | |||||
| :param max_iter: maximum iteration | |||||
| :return centroids: numpy array, centroids of the k clusters | |||||
| assignment: numpy array, 1-D, the bucket id assigned to each example. | |||||
| """ | |||||
| sorted_x = sorted(list(set(x))) | |||||
| if len(sorted_x) < k: | |||||
| raise ValueError("too few buckets") | |||||
| gap = len(sorted_x) / k | |||||
| centroids = np.array([sorted_x[int(x * gap)] for x in range(k)]) | |||||
| assign = None | |||||
| for i in range(max_iter): | |||||
| # Cluster Assignment step | |||||
| assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x]) | |||||
| # Move centroids step | |||||
| new_centroids = np.array([x[assign == k].mean() for k in range(k)]) | |||||
| if (new_centroids == centroids).all(): | |||||
| centroids = new_centroids | |||||
| break | |||||
| centroids = new_centroids | |||||
| return np.array(centroids), assign | |||||
| def k_means_bucketing(all_inst, buckets): | |||||
| """ | |||||
| :param all_inst: 3-level list | |||||
| [ | |||||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
| ... | |||||
| ] | |||||
| :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||||
| threshold for each bucket (This is usually None.). | |||||
| :return data: 2-level list | |||||
| [ | |||||
| [index_11, index_12, ...], # bucket 1 | |||||
| [index_21, index_22, ...], # bucket 2 | |||||
| ... | |||||
| ] | |||||
| """ | |||||
| bucket_data = [[] for _ in buckets] | |||||
| num_buckets = len(buckets) | |||||
| lengths = np.array([len(inst[0]) for inst in all_inst]) | |||||
| _, assignments = k_means_1d(lengths, num_buckets) | |||||
| for idx, bucket_id in enumerate(assignments): | |||||
| if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: | |||||
| bucket_data[bucket_id].append(idx) | |||||
| return bucket_data | |||||
| class BaseSampler(object): | |||||
| """ | |||||
| Base class for all samplers. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| self.data_set_length = len(data_set) | |||||
| def __len__(self): | |||||
| return self.data_set_length | |||||
| def __iter__(self): | |||||
| raise NotImplementedError | |||||
| class SequentialSampler(BaseSampler): | |||||
| """ | |||||
| Sample data in the original order. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| super(SequentialSampler, self).__init__(data_set) | |||||
| def __iter__(self): | |||||
| return iter(range(self.data_set_length)) | |||||
| class RandomSampler(BaseSampler): | |||||
| """ | |||||
| Sample data in random permutation order. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| super(RandomSampler, self).__init__(data_set) | |||||
| def __iter__(self): | |||||
| return iter(np.random.permutation(self.data_set_length)) | |||||
| class BucketSampler(BaseSampler): | |||||
| """ | |||||
| Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||||
| In sampling, first random choose a bucket. Then sample data from it. | |||||
| The number of buckets is decided dynamically by the variance of sentence lengths. | |||||
| """ | |||||
| def __init__(self, data_set): | |||||
| super(BucketSampler, self).__init__(data_set) | |||||
| BUCKETS = ([None] * 20) | |||||
| self.length_freq = dict(Counter([len(example) for example in data_set])) | |||||
| self.buckets = k_means_bucketing(data_set, BUCKETS) | |||||
| def __iter__(self): | |||||
| bucket_samples = self.buckets[np.random.randint(0, len(self.buckets))] | |||||
| np.random.shuffle(bucket_samples) | |||||
| return iter(bucket_samples) | |||||
| class Batchifier(object): | |||||
| """ | |||||
| Wrap random or sequential sampler to generate a mini-batch. | |||||
| """ | |||||
| def __init__(self, sampler, batch_size, drop_last=True): | |||||
| super(Batchifier, self).__init__() | |||||
| self.sampler = sampler | |||||
| self.batch_size = batch_size | |||||
| self.drop_last = drop_last | |||||
| def __iter__(self): | |||||
| batch = [] | |||||
| while True: | |||||
| for idx in self.sampler: | |||||
| batch.append(idx) | |||||
| if len(batch) == self.batch_size: | |||||
| yield batch | |||||
| batch = [] | |||||
| if 0 < len(batch) < self.batch_size and self.drop_last is False: | |||||
| yield batch | |||||
| @@ -0,0 +1,118 @@ | |||||
| import torch | |||||
| from fastNLP.core.action import Batchifier, SequentialSampler | |||||
| from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||||
| class Inference(object): | |||||
| """ | |||||
| This is an interface focusing on predicting output based on trained models. | |||||
| It does not care about evaluations of the model, which is different from Tester. | |||||
| This is a high-level model wrapper to be called by FastNLP. | |||||
| """ | |||||
| def __init__(self, pickle_path): | |||||
| self.batch_size = 1 | |||||
| self.batch_output = [] | |||||
| self.iterator = None | |||||
| self.pickle_path = pickle_path | |||||
| self.index2label = load_pickle(self.pickle_path, "id2class.pkl") | |||||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||||
| def predict(self, network, data): | |||||
| """ | |||||
| Perform inference. | |||||
| :param network: | |||||
| :param data: multi-level lists of strings | |||||
| :return result: the model outputs | |||||
| """ | |||||
| # transform strings into indices | |||||
| data = self.prepare_input(data) | |||||
| # turn on the testing mode; clean up the history | |||||
| self.mode(network, test=True) | |||||
| self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||||
| num_iter = len(data) // self.batch_size | |||||
| for step in range(num_iter): | |||||
| batch_x = self.make_batch(data) | |||||
| prediction = self.data_forward(network, batch_x) | |||||
| self.batch_output.append(prediction) | |||||
| return self.prepare_output(self.batch_output) | |||||
| def mode(self, network, test=True): | |||||
| if test: | |||||
| network.eval() | |||||
| else: | |||||
| network.train() | |||||
| self.batch_output.clear() | |||||
| def data_forward(self, network, x): | |||||
| """ | |||||
| This is only for sequence labeling with CRF decoder. TODO: more general ? | |||||
| :param network: | |||||
| :param x: | |||||
| :return: | |||||
| """ | |||||
| seq_len = [len(seq) for seq in x] | |||||
| x = torch.Tensor(x).long() | |||||
| y = network(x) | |||||
| prediction = network.prediction(y, seq_len) | |||||
| # To do: hide framework | |||||
| results = torch.Tensor(prediction).view(-1, ) | |||||
| return list(results.data) | |||||
| def make_batch(self, data): | |||||
| indices = next(self.iterator) | |||||
| batch_x = [data[idx] for idx in indices] | |||||
| if self.batch_size > 1: | |||||
| batch_x = self.pad(batch_x) | |||||
| return batch_x | |||||
| @staticmethod | |||||
| def pad(batch, fill=0): | |||||
| """ | |||||
| Pad a batch of samples to maximum length. | |||||
| :param batch: list of list | |||||
| :param fill: word index to pad, default 0. | |||||
| :return: a padded batch | |||||
| """ | |||||
| max_length = max([len(x) for x in batch]) | |||||
| for idx, sample in enumerate(batch): | |||||
| if len(sample) < max_length: | |||||
| batch[idx] = sample + [fill * (max_length - len(sample))] | |||||
| return batch | |||||
| def prepare_input(self, data): | |||||
| """ | |||||
| Transform three-level list of strings into that of index. | |||||
| :param data: | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| """ | |||||
| assert isinstance(data, list) | |||||
| data_index = [] | |||||
| default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL] | |||||
| for example in data: | |||||
| data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) | |||||
| return data_index | |||||
| def prepare_output(self, batch_outputs): | |||||
| """ | |||||
| Transform list of batch outputs into strings. | |||||
| :param batch_outputs: list of list, of shape [num_batch, tag_seq_length]. Element type is Tensor. | |||||
| :return: | |||||
| """ | |||||
| results = [] | |||||
| for batch in batch_outputs: | |||||
| results.append([self.index2label[int(x.data)] for x in batch]) | |||||
| return results | |||||
| @@ -4,9 +4,8 @@ import os | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from fastNLP.action.action import Action | |||||
| from fastNLP.action.action import RandomSampler, Batchifier | |||||
| from fastNLP.modules.utils import seq_mask | |||||
| from fastNLP.core.action import Action | |||||
| from fastNLP.core.action import RandomSampler, Batchifier | |||||
| class BaseTester(Action): | class BaseTester(Action): | ||||
| @@ -26,14 +25,17 @@ class BaseTester(Action): | |||||
| self.batch_size = test_args["batch_size"] | self.batch_size = test_args["batch_size"] | ||||
| self.pickle_path = test_args["pickle_path"] | self.pickle_path = test_args["pickle_path"] | ||||
| self.iterator = None | self.iterator = None | ||||
| self.use_cuda = test_args["use_cuda"] | |||||
| self.model = None | self.model = None | ||||
| self.eval_history = [] | self.eval_history = [] | ||||
| self.batch_output = [] | self.batch_output = [] | ||||
| def test(self, network): | def test(self, network): | ||||
| # print("--------------testing----------------") | |||||
| self.model = network | |||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| self.model = network.cuda() | |||||
| else: | |||||
| self.model = network | |||||
| # turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
| self.mode(network, test=True) | self.mode(network, test=True) | ||||
| @@ -45,7 +47,7 @@ class BaseTester(Action): | |||||
| num_iter = len(dev_data) // self.batch_size | num_iter = len(dev_data) // self.batch_size | ||||
| for step in range(num_iter): | for step in range(num_iter): | ||||
| batch_x, batch_y = self.batchify(dev_data) | |||||
| batch_x, batch_y = self.make_batch(dev_data) | |||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| eval_results = self.evaluate(prediction, batch_y) | eval_results = self.evaluate(prediction, batch_y) | ||||
| @@ -66,7 +68,7 @@ class BaseTester(Action): | |||||
| self.save_dev_data = data_dev | self.save_dev_data = data_dev | ||||
| return self.save_dev_data | return self.save_dev_data | ||||
| def batchify(self, data): | |||||
| def make_batch(self, data, output_length=True): | |||||
| """ | """ | ||||
| 1. Perform batching from data and produce a batch of training data. | 1. Perform batching from data and produce a batch of training data. | ||||
| 2. Add padding. | 2. Add padding. | ||||
| @@ -84,8 +86,13 @@ class BaseTester(Action): | |||||
| batch = [data[idx] for idx in indices] | batch = [data[idx] for idx in indices] | ||||
| batch_x = [sample[0] for sample in batch] | batch_x = [sample[0] for sample in batch] | ||||
| batch_y = [sample[1] for sample in batch] | batch_y = [sample[1] for sample in batch] | ||||
| batch_x = self.pad(batch_x) | |||||
| return batch_x, batch_y | |||||
| batch_x_pad = self.pad(batch_x) | |||||
| batch_y_pad = self.pad(batch_y) | |||||
| if output_length: | |||||
| seq_len = [len(x) for x in batch_x] | |||||
| return (batch_x_pad, seq_len), batch_y_pad | |||||
| else: | |||||
| return batch_x_pad, batch_y_pad | |||||
| @staticmethod | @staticmethod | ||||
| def pad(batch, fill=0): | def pad(batch, fill=0): | ||||
| @@ -98,7 +105,7 @@ class BaseTester(Action): | |||||
| max_length = max([len(x) for x in batch]) | max_length = max([len(x) for x in batch]) | ||||
| for idx, sample in enumerate(batch): | for idx, sample in enumerate(batch): | ||||
| if len(sample) < max_length: | if len(sample) < max_length: | ||||
| batch[idx] = sample + [fill * (max_length - len(sample))] | |||||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
| return batch | return batch | ||||
| def data_forward(self, network, data): | def data_forward(self, network, data): | ||||
| @@ -112,7 +119,7 @@ class BaseTester(Action): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def mode(self, model, test=True): | def mode(self, model, test=True): | ||||
| """To do: combine this function with Trainer ?? """ | |||||
| """TODO: combine this function with Trainer ?? """ | |||||
| if test: | if test: | ||||
| model.eval() | model.eval() | ||||
| else: | else: | ||||
| @@ -141,25 +148,37 @@ class POSTester(BaseTester): | |||||
| self.mask = None | self.mask = None | ||||
| self.batch_result = None | self.batch_result = None | ||||
| def data_forward(self, network, x): | |||||
| """To Do: combine with Trainer | |||||
| def data_forward(self, network, inputs): | |||||
| """TODO: combine with Trainer | |||||
| :param network: the PyTorch model | :param network: the PyTorch model | ||||
| :param x: list of list, [batch_size, max_len] | :param x: list of list, [batch_size, max_len] | ||||
| :return y: [batch_size, num_classes] | :return y: [batch_size, num_classes] | ||||
| """ | """ | ||||
| seq_len = [len(seq) for seq in x] | |||||
| # unpack the returned value from make_batch | |||||
| if isinstance(inputs, tuple): | |||||
| x = inputs[0] | |||||
| self.seq_len = inputs[1] | |||||
| else: | |||||
| x = inputs | |||||
| x = torch.Tensor(x).long() | x = torch.Tensor(x).long() | ||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| x = x.cuda() | |||||
| self.batch_size = x.size(0) | self.batch_size = x.size(0) | ||||
| self.max_len = x.size(1) | self.max_len = x.size(1) | ||||
| self.mask = seq_mask(seq_len, self.max_len) | |||||
| y = network(x) | y = network(x) | ||||
| return y | return y | ||||
| def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
| truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
| loss, prediction = self.model.loss(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
| results = torch.Tensor(prediction[0][0]).view((-1,)) | |||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| truth = truth.cuda() | |||||
| loss = self.model.loss(predict, truth, self.seq_len) / self.batch_size | |||||
| prediction = self.model.prediction(predict, self.seq_len) | |||||
| results = torch.Tensor(prediction).view(-1,) | |||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| results = results.cuda() | |||||
| accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | ||||
| return [loss.data, accuracy] | return [loss.data, accuracy] | ||||
| @@ -256,7 +275,7 @@ class ClassTester(BaseTester): | |||||
| n_batches = len(data_test) // self.batch_size | n_batches = len(data_test) // self.batch_size | ||||
| n_print = n_batches // 10 | n_print = n_batches // 10 | ||||
| step = 0 | step = 0 | ||||
| for batch_x, batch_y in self.batchify(data_test, max_len=self.max_len): | |||||
| for batch_x, batch_y in self.make_batch(data_test, max_len=self.max_len): | |||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| eval_results = self.evaluate(prediction, batch_y) | eval_results = self.evaluate(prediction, batch_y) | ||||
| @@ -277,7 +296,7 @@ class ClassTester(BaseTester): | |||||
| data = _pickle.load(f) | data = _pickle.load(f) | ||||
| return data | return data | ||||
| def batchify(self, data, max_len=None): | |||||
| def make_batch(self, data, max_len=None): | |||||
| """Batch and pad data.""" | """Batch and pad data.""" | ||||
| for indices in self.iterator: | for indices in self.iterator: | ||||
| # generate batch and pad | # generate batch and pad | ||||
| @@ -319,7 +338,7 @@ class ClassTester(BaseTester): | |||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | ||||
| def mode(self, model, test=True): | def mode(self, model, test=True): | ||||
| """To do: combine this function with Trainer ?? """ | |||||
| """TODO: combine this function with Trainer ?? """ | |||||
| if test: | if test: | ||||
| model.eval() | model.eval() | ||||
| else: | else: | ||||
| @@ -7,10 +7,9 @@ import numpy as np | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.action.action import Action | |||||
| from fastNLP.action.action import RandomSampler, Batchifier | |||||
| from fastNLP.action.tester import POSTester | |||||
| from fastNLP.modules.utils import seq_mask | |||||
| from fastNLP.core.action import Action | |||||
| from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler | |||||
| from fastNLP.core.tester import POSTester | |||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| @@ -45,6 +44,7 @@ class BaseTrainer(Action): | |||||
| self.validate = train_args["validate"] | self.validate = train_args["validate"] | ||||
| self.save_best_dev = train_args["save_best_dev"] | self.save_best_dev = train_args["save_best_dev"] | ||||
| self.model_saved_path = train_args["model_saved_path"] | self.model_saved_path = train_args["model_saved_path"] | ||||
| self.use_cuda = train_args["use_cuda"] | |||||
| self.model = None | self.model = None | ||||
| self.iterator = None | self.iterator = None | ||||
| @@ -66,13 +66,19 @@ class BaseTrainer(Action): | |||||
| - update | - update | ||||
| Subclasses must implement these methods with a specific framework. | Subclasses must implement these methods with a specific framework. | ||||
| """ | """ | ||||
| # prepare model and data | |||||
| self.model = network | |||||
| # prepare model and data, transfer model to gpu if available | |||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| self.model = network.cuda() | |||||
| else: | |||||
| self.model = network | |||||
| data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | ||||
| # define tester over dev data | # define tester over dev data | ||||
| # TODO: more flexible | |||||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
| "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path} | |||||
| "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
| "use_cuda": self.use_cuda} | |||||
| validator = POSTester(valid_args) | validator = POSTester(valid_args) | ||||
| # main training epochs | # main training epochs | ||||
| @@ -83,11 +89,11 @@ class BaseTrainer(Action): | |||||
| # turn on network training mode; define optimizer; prepare batch iterator | # turn on network training mode; define optimizer; prepare batch iterator | ||||
| self.mode(test=False) | self.mode(test=False) | ||||
| self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) | |||||
| self.iterator = iter(Batchifier(BucketSampler(data_train), self.batch_size, drop_last=True)) | |||||
| # training iterations in one epoch | # training iterations in one epoch | ||||
| for step in range(iterations): | for step in range(iterations): | ||||
| batch_x, batch_y = self.batchify(data_train) # pad ? | |||||
| batch_x, batch_y = self.make_batch(data_train) | |||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| @@ -95,6 +101,9 @@ class BaseTrainer(Action): | |||||
| self.grad_backward(loss) | self.grad_backward(loss) | ||||
| self.update() | self.update() | ||||
| if step % 10 == 0: | |||||
| print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data)) | |||||
| if self.validate: | if self.validate: | ||||
| if data_dev is None: | if data_dev is None: | ||||
| raise RuntimeError("No validation data provided.") | raise RuntimeError("No validation data provided.") | ||||
| @@ -110,9 +119,6 @@ class BaseTrainer(Action): | |||||
| # finish training | # finish training | ||||
| def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
| """ | |||||
| To do: Load pkl files of train/dev/test and embedding | |||||
| """ | |||||
| data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | ||||
| data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | ||||
| data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) | data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) | ||||
| @@ -181,7 +187,7 @@ class BaseTrainer(Action): | |||||
| """ | """ | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| def batchify(self, data, output_length=True): | |||||
| def make_batch(self, data, output_length=True): | |||||
| """ | """ | ||||
| 1. Perform batching from data and produce a batch of training data. | 1. Perform batching from data and produce a batch of training data. | ||||
| 2. Add padding. | 2. Add padding. | ||||
| @@ -192,20 +198,24 @@ class BaseTrainer(Action): | |||||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | ||||
| ... | ... | ||||
| ] | ] | ||||
| :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
| :return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | ||||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
| return batch_x and batch_y, if output_length is False | |||||
| """ | """ | ||||
| indices = next(self.iterator) | indices = next(self.iterator) | ||||
| batch = [data[idx] for idx in indices] | batch = [data[idx] for idx in indices] | ||||
| batch_x = [sample[0] for sample in batch] | batch_x = [sample[0] for sample in batch] | ||||
| batch_y = [sample[1] for sample in batch] | batch_y = [sample[1] for sample in batch] | ||||
| batch_x_pad = self.pad(batch_x) | batch_x_pad = self.pad(batch_x) | ||||
| batch_y_pad = self.pad(batch_y) | |||||
| if output_length: | if output_length: | ||||
| seq_len = [len(x) for x in batch_x] | seq_len = [len(x) for x in batch_x] | ||||
| return batch_x_pad, batch_y, seq_len | |||||
| return (batch_x_pad, seq_len), batch_y_pad | |||||
| else: | else: | ||||
| return batch_x_pad, batch_y | |||||
| return batch_x_pad, batch_y_pad | |||||
| @staticmethod | @staticmethod | ||||
| def pad(batch, fill=0): | def pad(batch, fill=0): | ||||
| @@ -286,24 +296,30 @@ class POSTrainer(BaseTrainer): | |||||
| self.best_accuracy = 0.0 | self.best_accuracy = 0.0 | ||||
| def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
| """ | |||||
| To do: Load pkl files of train/dev/test and embedding | |||||
| """ | |||||
| data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | ||||
| data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | ||||
| return data_train, data_dev, 0, 1 | return data_train, data_dev, 0, 1 | ||||
| def data_forward(self, network, x): | |||||
| def data_forward(self, network, inputs): | |||||
| """ | """ | ||||
| :param network: the PyTorch model | :param network: the PyTorch model | ||||
| :param x: list of list, [batch_size, max_len] | |||||
| :return y: [batch_size, num_classes] | |||||
| """ | |||||
| seq_len = [len(seq) for seq in x] | |||||
| :param inputs: list of list, [batch_size, max_len], | |||||
| or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len] | |||||
| :return y: [batch_size, max_len, tag_size] | |||||
| """ | |||||
| # unpack the returned value from make_batch | |||||
| if isinstance(inputs, tuple): | |||||
| x = inputs[0] | |||||
| self.seq_len = inputs[1] | |||||
| else: | |||||
| x = inputs | |||||
| x = torch.Tensor(x).long() | x = torch.Tensor(x).long() | ||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| x = x.cuda() | |||||
| self.batch_size = x.size(0) | self.batch_size = x.size(0) | ||||
| self.max_len = x.size(1) | self.max_len = x.size(1) | ||||
| self.mask = seq_mask(seq_len, self.max_len) | |||||
| y = network(x) | y = network(x) | ||||
| return y | return y | ||||
| @@ -326,17 +342,20 @@ class POSTrainer(BaseTrainer): | |||||
| def get_loss(self, predict, truth): | def get_loss(self, predict, truth): | ||||
| """ | """ | ||||
| Compute loss given prediction and ground truth. | Compute loss given prediction and ground truth. | ||||
| :param predict: prediction label vector, [batch_size, num_classes] | |||||
| :param predict: prediction label vector, [batch_size, max_len, tag_size] | |||||
| :param truth: ground truth label vector, [batch_size, max_len] | :param truth: ground truth label vector, [batch_size, max_len] | ||||
| :return: a scalar | :return: a scalar | ||||
| """ | """ | ||||
| truth = torch.Tensor(truth) | truth = torch.Tensor(truth) | ||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| truth = truth.cuda() | |||||
| assert truth.shape == (self.batch_size, self.max_len) | |||||
| if self.loss_func is None: | if self.loss_func is None: | ||||
| if hasattr(self.model, "loss"): | if hasattr(self.model, "loss"): | ||||
| self.loss_func = self.model.loss | self.loss_func = self.model.loss | ||||
| else: | else: | ||||
| self.define_loss() | self.define_loss() | ||||
| loss, prediction = self.loss_func(predict, truth, self.mask, self.batch_size, self.max_len) | |||||
| loss = self.loss_func(predict, truth, self.seq_len) | |||||
| # print("loss={:.2f}".format(loss.data)) | # print("loss={:.2f}".format(loss.data)) | ||||
| return loss | return loss | ||||
| @@ -348,6 +367,36 @@ class POSTrainer(BaseTrainer): | |||||
| else: | else: | ||||
| return False | return False | ||||
| def make_batch(self, data, output_length=True): | |||||
| """ | |||||
| 1. Perform batching from data and produce a batch of training data. | |||||
| 2. Add padding. | |||||
| :param data: list. Each entry is a sample, which is also a list of features and label(s). | |||||
| E.g. | |||||
| [ | |||||
| [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 | |||||
| [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 | |||||
| ... | |||||
| ] | |||||
| :return (batch_x, seq_len): tuple of two elements, if output_length is true. | |||||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
| return batch_x and batch_y, if output_length is False | |||||
| """ | |||||
| indices = next(self.iterator) | |||||
| batch = [data[idx] for idx in indices] | |||||
| batch_x = [sample[0] for sample in batch] | |||||
| batch_y = [sample[1] for sample in batch] | |||||
| batch_x_pad = self.pad(batch_x) | |||||
| batch_y_pad = self.pad(batch_y) | |||||
| if output_length: | |||||
| seq_len = [len(x) for x in batch_x] | |||||
| return (batch_x_pad, seq_len), batch_y_pad | |||||
| else: | |||||
| return batch_x_pad, batch_y_pad | |||||
| class LanguageModelTrainer(BaseTrainer): | class LanguageModelTrainer(BaseTrainer): | ||||
| """ | """ | ||||
| @@ -439,7 +488,7 @@ class ClassTrainer(BaseTrainer): | |||||
| # training iterations in one epoch | # training iterations in one epoch | ||||
| step = 0 | step = 0 | ||||
| for batch_x, batch_y in self.batchify(data_train): | |||||
| for batch_x, batch_y in self.make_batch(data_train): | |||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| loss = self.get_loss(prediction, batch_y) | loss = self.get_loss(prediction, batch_y) | ||||
| @@ -466,9 +515,6 @@ class ClassTrainer(BaseTrainer): | |||||
| # finish training | # finish training | ||||
| def prepare_input(self, data_path): | def prepare_input(self, data_path): | ||||
| """ | |||||
| To do: Load pkl files of train/dev/test and embedding | |||||
| """ | |||||
| names = [ | names = [ | ||||
| "data_train.pkl", "data_dev.pkl", | "data_train.pkl", "data_dev.pkl", | ||||
| @@ -534,7 +580,7 @@ class ClassTrainer(BaseTrainer): | |||||
| """Apply gradient.""" | """Apply gradient.""" | ||||
| self.optimizer.step() | self.optimizer.step() | ||||
| def batchify(self, data): | |||||
| def make_batch(self, data): | |||||
| """Batch and pad data.""" | """Batch and pad data.""" | ||||
| for indices in self.iterator: | for indices in self.iterator: | ||||
| batch = [data[idx] for idx in indices] | batch = [data[idx] for idx in indices] | ||||
| @@ -560,4 +606,4 @@ if __name__ == "__name__": | |||||
| train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | ||||
| trainer = BaseTrainer(train_args) | trainer = BaseTrainer(train_args) | ||||
| data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] | data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] | ||||
| trainer.batchify(data=data_train) | |||||
| trainer.make_batch(data=data_train) | |||||
| @@ -0,0 +1,173 @@ | |||||
| from fastNLP.core.inference import Inference | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| """ | |||||
| mapping from model name to [URL, file_name.class_name, model_pickle_name] | |||||
| Notice that the class of the model should be in "models" directory. | |||||
| Example: | |||||
| "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] | |||||
| """ | |||||
| FastNLP_MODEL_COLLECTION = { | |||||
| "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"] | |||||
| } | |||||
| class FastNLP(object): | |||||
| """ | |||||
| High-level interface for direct model inference. | |||||
| Usage: | |||||
| fastnlp = FastNLP() | |||||
| fastnlp.load("zh_pos_tag_model") | |||||
| text = "这是最好的基于深度学习的中文分词系统。" | |||||
| result = fastnlp.run(text) | |||||
| print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"] | |||||
| """ | |||||
| def __init__(self, model_dir="./"): | |||||
| """ | |||||
| :param model_dir: this directory should contain the following files: | |||||
| 1. a pre-trained model | |||||
| 2. a config file | |||||
| 3. "id2class.pkl" | |||||
| 4. "word2id.pkl" | |||||
| """ | |||||
| self.model_dir = model_dir | |||||
| self.model = None | |||||
| def load(self, model_name): | |||||
| """ | |||||
| Load a pre-trained FastNLP model together with additional data. | |||||
| :param model_name: str, the name of a FastNLP model. | |||||
| """ | |||||
| assert type(model_name) is str | |||||
| if model_name not in FastNLP_MODEL_COLLECTION: | |||||
| raise ValueError("No FastNLP model named {}.".format(model_name)) | |||||
| if not self.model_exist(model_dir=self.model_dir): | |||||
| self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0]) | |||||
| model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1]) | |||||
| model_args = ConfigSection() | |||||
| # To do: customized config file for model init parameters | |||||
| ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args}) | |||||
| # Construct the model | |||||
| model = model_class(model_args) | |||||
| # To do: framework independent | |||||
| ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2]) | |||||
| self.model = model | |||||
| print("Model loaded. ") | |||||
| def run(self, raw_input): | |||||
| """ | |||||
| Perform inference over given input using the loaded model. | |||||
| :param raw_input: str, raw text | |||||
| :return results: | |||||
| """ | |||||
| infer = Inference(self.model_dir) | |||||
| infer_input = self.string_to_list(raw_input) | |||||
| results = infer.predict(self.model, infer_input) | |||||
| outputs = self.make_output(results) | |||||
| return outputs | |||||
| @staticmethod | |||||
| def _get_model_class(file_class_name): | |||||
| """ | |||||
| Feature the class specified by <file_class_name> | |||||
| :param file_class_name: str, contains the name of the Python module followed by the name of the class. | |||||
| Example: "sequence_modeling.SeqLabeling" | |||||
| :return module: the model class | |||||
| """ | |||||
| import_prefix = "fastNLP.models." | |||||
| parts = (import_prefix + file_class_name).split(".") | |||||
| from_module = ".".join(parts[:-1]) | |||||
| module = __import__(from_module) | |||||
| for sub in parts[1:]: | |||||
| module = getattr(module, sub) | |||||
| return module | |||||
| def _load(self, model_dir, model_name): | |||||
| # To do | |||||
| return 0 | |||||
| def _download(self, model_name, url): | |||||
| """ | |||||
| Download the model weights from <url> and save in <self.model_dir>. | |||||
| :param model_name: | |||||
| :param url: | |||||
| """ | |||||
| print("Downloading {} from {}".format(model_name, url)) | |||||
| # To do | |||||
| def model_exist(self, model_dir): | |||||
| """ | |||||
| Check whether the desired model is already in the directory. | |||||
| :param model_dir: | |||||
| """ | |||||
| return True | |||||
| def string_to_list(self, text, delimiter="\n"): | |||||
| """ | |||||
| For word seg only, currently. | |||||
| This function is used to transform raw input to lists, which is done by DatasetLoader in training. | |||||
| Split text string into three-level lists. | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| :param text: string | |||||
| :param delimiter: str, character used to split text into sentences. | |||||
| :return data: three-level lists | |||||
| """ | |||||
| data = [] | |||||
| sents = text.strip().split(delimiter) | |||||
| for sent in sents: | |||||
| characters = [] | |||||
| for ch in sent: | |||||
| characters.append(ch) | |||||
| data.append(characters) | |||||
| # To refactor: this is used in make_output | |||||
| self.data = data | |||||
| return data | |||||
| def make_output(self, results): | |||||
| """ | |||||
| Transform model output into user-friendly contents. | |||||
| Example: In CWS, convert <BMES> labeling into segmented text. | |||||
| :param results: | |||||
| :return: | |||||
| """ | |||||
| outputs = [] | |||||
| for sent_char, sent_label in zip(self.data, results): | |||||
| words = [] | |||||
| word = "" | |||||
| for char, label in zip(sent_char, sent_label): | |||||
| if label[0] == "B": | |||||
| if word != "": | |||||
| words.append(word) | |||||
| word = char | |||||
| elif label[0] == "M": | |||||
| word += char | |||||
| elif label[0] == "E": | |||||
| word += char | |||||
| words.append(word) | |||||
| word = "" | |||||
| elif label[0] == "S": | |||||
| if word != "": | |||||
| words.append(word) | |||||
| word = "" | |||||
| words.append(char) | |||||
| else: | |||||
| raise ValueError("invalid label") | |||||
| outputs.append(" ".join(words)) | |||||
| return outputs | |||||
| @@ -17,7 +17,7 @@ class BaseLoader(object): | |||||
| def load_lines(self): | def load_lines(self): | ||||
| with open(self.data_path, "r", encoding="utf=8") as f: | with open(self.data_path, "r", encoding="utf=8") as f: | ||||
| text = f.readlines() | text = f.readlines() | ||||
| return text | |||||
| return [line.strip() for line in text] | |||||
| class ToyLoader0(BaseLoader): | class ToyLoader0(BaseLoader): | ||||
| @@ -20,9 +20,13 @@ class ConfigLoader(BaseLoader): | |||||
| def load_config(file_path, sections): | def load_config(file_path, sections): | ||||
| """ | """ | ||||
| :param file_path: the path of config file | :param file_path: the path of config file | ||||
| :param sections: the dict of sections | |||||
| :return: | |||||
| :param sections: the dict of {section_name(string): Section instance} | |||||
| Example: | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| :return: return nothing, but the value of attributes are saved in sessions | |||||
| """ | """ | ||||
| assert isinstance(sections, dict) | |||||
| cfg = configparser.ConfigParser() | cfg = configparser.ConfigParser() | ||||
| if not os.path.exists(file_path): | if not os.path.exists(file_path): | ||||
| raise FileNotFoundError("config file {} not found. ".format(file_path)) | raise FileNotFoundError("config file {} not found. ".format(file_path)) | ||||
| @@ -22,6 +22,7 @@ class POSDatasetLoader(DatasetLoader): | |||||
| and label2 | and label2 | ||||
| Jerry label1 | Jerry label1 | ||||
| . label3 | . label3 | ||||
| (separated by an empty line) | |||||
| Hello label4 | Hello label4 | ||||
| world label5 | world label5 | ||||
| ! label3 | ! label3 | ||||
| @@ -29,6 +30,7 @@ class POSDatasetLoader(DatasetLoader): | |||||
| and "Hello world !". Each word has its own label from label1 | and "Hello world !". Each word has its own label from label1 | ||||
| to label5. | to label5. | ||||
| """ | """ | ||||
| def __init__(self, data_name, data_path): | def __init__(self, data_name, data_path): | ||||
| super(POSDatasetLoader, self).__init__(data_name, data_path) | super(POSDatasetLoader, self).__init__(data_name, data_path) | ||||
| @@ -77,6 +79,62 @@ class POSDatasetLoader(DatasetLoader): | |||||
| return data | return data | ||||
| class TokenizeDatasetLoader(DatasetLoader): | |||||
| """ | |||||
| Data set loader for tokenization data sets | |||||
| """ | |||||
| def __init__(self, data_name, data_path): | |||||
| super(TokenizeDatasetLoader, self).__init__(data_name, data_path) | |||||
| def load_pku(self, max_seq_len=32): | |||||
| """ | |||||
| load pku dataset for Chinese word segmentation | |||||
| CWS (Chinese Word Segmentation) pku training dataset format: | |||||
| 1. Each line is a sentence. | |||||
| 2. Each word in a sentence is separated by space. | |||||
| This function convert the pku dataset into three-level lists with labels <BMES>. | |||||
| B: beginning of a word | |||||
| M: middle of a word | |||||
| E: ending of a word | |||||
| S: single character | |||||
| :param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into | |||||
| several sequences. | |||||
| :return: three-level lists | |||||
| """ | |||||
| assert isinstance(max_seq_len, int) and max_seq_len > 0 | |||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| sentences = f.readlines() | |||||
| data = [] | |||||
| for sent in sentences: | |||||
| tokens = sent.strip().split() | |||||
| words = [] | |||||
| labels = [] | |||||
| for token in tokens: | |||||
| if len(token) == 1: | |||||
| words.append(token) | |||||
| labels.append("S") | |||||
| else: | |||||
| words.append(token[0]) | |||||
| labels.append("B") | |||||
| for idx in range(1, len(token) - 1): | |||||
| words.append(token[idx]) | |||||
| labels.append("M") | |||||
| words.append(token[-1]) | |||||
| labels.append("E") | |||||
| num_samples = len(words) // max_seq_len | |||||
| if len(words) % max_seq_len != 0: | |||||
| num_samples += 1 | |||||
| for sample_idx in range(num_samples): | |||||
| start = sample_idx * max_seq_len | |||||
| end = (sample_idx + 1) * max_seq_len | |||||
| seq_words = words[start:end] | |||||
| seq_labels = labels[start:end] | |||||
| data.append([seq_words, seq_labels]) | |||||
| return data | |||||
| class ClassDatasetLoader(DatasetLoader): | class ClassDatasetLoader(DatasetLoader): | ||||
| """Loader for classification data sets""" | """Loader for classification data sets""" | ||||
| @@ -163,7 +221,12 @@ class LMDatasetLoader(DatasetLoader): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| """ | |||||
| data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines() | data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines() | ||||
| for example in data: | for example in data: | ||||
| for w, l in zip(example[0], example[1]): | for w, l in zip(example[0], example[1]): | ||||
| print(w, l) | print(w, l) | ||||
| """ | |||||
| ans = TokenizeDatasetLoader("xxx", "/home/zyfeng/Desktop/data/icwb2-data/training/test").load_pku() | |||||
| print(ans) | |||||
| @@ -11,9 +11,11 @@ class ModelLoader(BaseLoader): | |||||
| def __init__(self, data_name, data_path): | def __init__(self, data_name, data_path): | ||||
| super(ModelLoader, self).__init__(data_name, data_path) | super(ModelLoader, self).__init__(data_name, data_path) | ||||
| def load_pytorch(self, empty_model): | |||||
| @staticmethod | |||||
| def load_pytorch(empty_model, model_path): | |||||
| """ | """ | ||||
| Load model parameters from .pkl files into the empty PyTorch model. | Load model parameters from .pkl files into the empty PyTorch model. | ||||
| :param empty_model: a PyTorch model with initialized parameters. | :param empty_model: a PyTorch model with initialized parameters. | ||||
| :param model_path: str, the path to the saved model. | |||||
| """ | """ | ||||
| empty_model.load_state_dict(torch.load(self.data_path)) | |||||
| empty_model.load_state_dict(torch.load(model_path)) | |||||
| @@ -1,346 +1,363 @@ | |||||
| import _pickle | |||||
| import os | |||||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
| DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
| '<reserved-3>', | |||||
| '<reserved-4>'] # dict index = 2~4 | |||||
| DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||||
| # the first vocab in dict with the index = 5 | |||||
| class BasePreprocess(object): | |||||
| def __init__(self, data, pickle_path): | |||||
| super(BasePreprocess, self).__init__() | |||||
| self.data = data | |||||
| self.pickle_path = pickle_path | |||||
| if not self.pickle_path.endswith('/'): | |||||
| self.pickle_path = self.pickle_path + '/' | |||||
| class POSPreprocess(BasePreprocess): | |||||
| """ | |||||
| This class are used to preprocess the pos datasets. | |||||
| """ | |||||
| def __init__(self, data, pickle_path="./", train_dev_split=0): | |||||
| """ | |||||
| Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
| from labels/classes to index, from index to labels/classes. | |||||
| :param data: three-level list | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :param pickle_path: str, the directory to the pickle files. Default: "./" | |||||
| :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. | |||||
| To do: | |||||
| 1. simplify __init__ | |||||
| """ | |||||
| super(POSPreprocess, self).__init__(data, pickle_path) | |||||
| self.pickle_path = pickle_path | |||||
| if self.pickle_exist("word2id.pkl"): | |||||
| # load word2index because the construction of the following objects needs it | |||||
| with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f: | |||||
| self.word2index = _pickle.load(f) | |||||
| else: | |||||
| self.word2index, self.label2index = self.build_dict(data) | |||||
| with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: | |||||
| _pickle.dump(self.word2index, f) | |||||
| if self.pickle_exist("class2id.pkl"): | |||||
| with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f: | |||||
| self.label2index = _pickle.load(f) | |||||
| else: | |||||
| with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f: | |||||
| _pickle.dump(self.label2index, f) | |||||
| #something will be wrong if word2id.pkl is found but class2id.pkl is not found | |||||
| if not self.pickle_exist("id2word.pkl"): | |||||
| index2word = self.build_reverse_dict(self.word2index) | |||||
| with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f: | |||||
| _pickle.dump(index2word, f) | |||||
| if not self.pickle_exist("id2class.pkl"): | |||||
| index2label = self.build_reverse_dict(self.label2index) | |||||
| with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: | |||||
| _pickle.dump(index2label, f) | |||||
| if not self.pickle_exist("data_train.pkl"): | |||||
| data_train = self.to_index(data) | |||||
| if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"): | |||||
| data_dev = data_train[: int(len(data_train) * train_dev_split)] | |||||
| with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f: | |||||
| _pickle.dump(data_dev, f) | |||||
| with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f: | |||||
| _pickle.dump(data_train, f) | |||||
| def build_dict(self, data): | |||||
| """ | |||||
| Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||||
| :param data: three-level list | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return word2index: dict of {str, int} | |||||
| label2index: dict of {str, int} | |||||
| """ | |||||
| label2index = {} | |||||
| word2index = DEFAULT_WORD_TO_INDEX | |||||
| for example in data: | |||||
| for word, label in zip(example[0], example[1]): | |||||
| if word not in word2index: | |||||
| word2index[word] = len(word2index) | |||||
| if label not in label2index: | |||||
| label2index[label] = len(label2index) | |||||
| return word2index, label2index | |||||
| def pickle_exist(self, pickle_name): | |||||
| """ | |||||
| :param pickle_name: the filename of target pickle file | |||||
| :return: True if file exists else False | |||||
| """ | |||||
| if not os.path.exists(self.pickle_path): | |||||
| os.makedirs(self.pickle_path) | |||||
| file_name = os.path.join(self.pickle_path, pickle_name) | |||||
| if os.path.exists(file_name): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def build_reverse_dict(self, word_dict): | |||||
| id2word = {word_dict[w]: w for w in word_dict} | |||||
| return id2word | |||||
| def to_index(self, data): | |||||
| """ | |||||
| Convert word strings and label strings into indices. | |||||
| :param data: three-level list | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return data_index: the shape of data, but each string is replaced by its corresponding index | |||||
| """ | |||||
| data_index = [] | |||||
| for example in data: | |||||
| word_list = [] | |||||
| label_list = [] | |||||
| for word, label in zip(example[0], example[1]): | |||||
| word_list.append(self.word2index[word]) | |||||
| label_list.append(self.label2index[label]) | |||||
| data_index.append([word_list, label_list]) | |||||
| return data_index | |||||
| @property | |||||
| def vocab_size(self): | |||||
| return len(self.word2index) | |||||
| @property | |||||
| def num_classes(self): | |||||
| return len(self.label2index) | |||||
| class ClassPreprocess(BasePreprocess): | |||||
| """ | |||||
| Pre-process the classification datasets. | |||||
| Params: | |||||
| pickle_path - directory to save result of pre-processing | |||||
| Saves: | |||||
| word2id.pkl | |||||
| id2word.pkl | |||||
| class2id.pkl | |||||
| id2class.pkl | |||||
| embedding.pkl | |||||
| data_train.pkl | |||||
| data_dev.pkl | |||||
| data_test.pkl | |||||
| """ | |||||
| def __init__(self, pickle_path): | |||||
| # super(ClassPreprocess, self).__init__(data, pickle_path) | |||||
| self.word_dict = None | |||||
| self.label_dict = None | |||||
| self.pickle_path = pickle_path # save directory | |||||
| def process(self, data, save_name): | |||||
| """ | |||||
| Process data. | |||||
| Params: | |||||
| data - nested list, data = [sample1, sample2, ...], | |||||
| sample = [sentence, label], sentence = [word1, word2, ...] | |||||
| save_name - name of processed data, such as data_train.pkl | |||||
| Returns: | |||||
| vocab_size - vocabulary size | |||||
| n_classes - number of classes | |||||
| """ | |||||
| self.build_dict(data) | |||||
| self.word2id() | |||||
| vocab_size = self.id2word() | |||||
| self.class2id() | |||||
| num_classes = self.id2class() | |||||
| self.embedding() | |||||
| self.data_generate(data, save_name) | |||||
| return vocab_size, num_classes | |||||
| def build_dict(self, data): | |||||
| """Build vocabulary.""" | |||||
| # just read if word2id.pkl and class2id.pkl exists | |||||
| if self.pickle_exist("word2id.pkl") and \ | |||||
| self.pickle_exist("class2id.pkl"): | |||||
| file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
| with open(file_name, 'rb') as f: | |||||
| self.word_dict = _pickle.load(f) | |||||
| file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
| with open(file_name, 'rb') as f: | |||||
| self.label_dict = _pickle.load(f) | |||||
| return | |||||
| # build vocabulary from scratch if nothing exists | |||||
| self.word_dict = { | |||||
| DEFAULT_PADDING_LABEL: 0, | |||||
| DEFAULT_UNKNOWN_LABEL: 1, | |||||
| DEFAULT_RESERVED_LABEL[0]: 2, | |||||
| DEFAULT_RESERVED_LABEL[1]: 3, | |||||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||||
| self.label_dict = {} | |||||
| # collect every word and label | |||||
| for sent, label in data: | |||||
| if len(sent) <= 1: | |||||
| continue | |||||
| if label not in self.label_dict: | |||||
| index = len(self.label_dict) | |||||
| self.label_dict[label] = index | |||||
| for word in sent: | |||||
| if word not in self.word_dict: | |||||
| index = len(self.word_dict) | |||||
| self.word_dict[word[0]] = index | |||||
| def pickle_exist(self, pickle_name): | |||||
| """ | |||||
| Check whether a pickle file exists. | |||||
| Params | |||||
| pickle_name: the filename of target pickle file | |||||
| Return | |||||
| True if file exists else False | |||||
| """ | |||||
| if not os.path.exists(self.pickle_path): | |||||
| os.makedirs(self.pickle_path) | |||||
| file_name = os.path.join(self.pickle_path, pickle_name) | |||||
| if os.path.exists(file_name): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def word2id(self): | |||||
| """Save vocabulary of {word:id} mapping format.""" | |||||
| # nothing will be done if word2id.pkl exists | |||||
| if self.pickle_exist("word2id.pkl"): | |||||
| return | |||||
| file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(self.word_dict, f) | |||||
| def id2word(self): | |||||
| """Save vocabulary of {id:word} mapping format.""" | |||||
| # nothing will be done if id2word.pkl exists | |||||
| if self.pickle_exist("id2word.pkl"): | |||||
| file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
| with open(file_name, 'rb') as f: | |||||
| id2word_dict = _pickle.load(f) | |||||
| return len(id2word_dict) | |||||
| id2word_dict = {self.word_dict[w]: w for w in self.word_dict} | |||||
| file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(id2word_dict, f) | |||||
| return len(id2word_dict) | |||||
| def class2id(self): | |||||
| """Save mapping of {class:id}.""" | |||||
| # nothing will be done if class2id.pkl exists | |||||
| if self.pickle_exist("class2id.pkl"): | |||||
| return | |||||
| file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(self.label_dict, f) | |||||
| def id2class(self): | |||||
| """Save mapping of {id:class}.""" | |||||
| # nothing will be done if id2class.pkl exists | |||||
| if self.pickle_exist("id2class.pkl"): | |||||
| file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
| with open(file_name, "rb") as f: | |||||
| id2class_dict = _pickle.load(f) | |||||
| return len(id2class_dict) | |||||
| id2class_dict = {self.label_dict[c]: c for c in self.label_dict} | |||||
| file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(id2class_dict, f) | |||||
| return len(id2class_dict) | |||||
| def embedding(self): | |||||
| """Save embedding lookup table corresponding to vocabulary.""" | |||||
| # nothing will be done if embedding.pkl exists | |||||
| if self.pickle_exist("embedding.pkl"): | |||||
| return | |||||
| # retrieve vocabulary from pre-trained embedding (not implemented) | |||||
| def data_generate(self, data_src, save_name): | |||||
| """Convert dataset from text to digit.""" | |||||
| # nothing will be done if file exists | |||||
| save_path = os.path.join(self.pickle_path, save_name) | |||||
| if os.path.exists(save_path): | |||||
| return | |||||
| data = [] | |||||
| # for every sample | |||||
| for sent, label in data_src: | |||||
| if len(sent) <= 1: | |||||
| continue | |||||
| label_id = self.label_dict[label] # label id | |||||
| sent_id = [] # sentence ids | |||||
| for word in sent: | |||||
| if word in self.word_dict: | |||||
| sent_id.append(self.word_dict[word]) | |||||
| else: | |||||
| sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) | |||||
| data.append([sent_id, label_id]) | |||||
| # save data | |||||
| with open(save_path, "wb") as f: | |||||
| _pickle.dump(data, f) | |||||
| class LMPreprocess(BasePreprocess): | |||||
| def __init__(self, data, pickle_path): | |||||
| super(LMPreprocess, self).__init__(data, pickle_path) | |||||
| import _pickle | |||||
| import os | |||||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
| DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
| '<reserved-3>', | |||||
| '<reserved-4>'] # dict index = 2~4 | |||||
| DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||||
| # the first vocab in dict with the index = 5 | |||||
| def save_pickle(obj, pickle_path, file_name): | |||||
| with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
| _pickle.dump(obj, f) | |||||
| print("{} saved. ".format(file_name)) | |||||
| def load_pickle(pickle_path, file_name): | |||||
| with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
| obj = _pickle.load(f) | |||||
| print("{} loaded. ".format(file_name)) | |||||
| return obj | |||||
| def pickle_exist(pickle_path, pickle_name): | |||||
| """ | |||||
| :param pickle_path: the directory of target pickle file | |||||
| :param pickle_name: the filename of target pickle file | |||||
| :return: True if file exists else False | |||||
| """ | |||||
| if not os.path.exists(pickle_path): | |||||
| os.makedirs(pickle_path) | |||||
| file_name = os.path.join(pickle_path, pickle_name) | |||||
| if os.path.exists(file_name): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| class BasePreprocess(object): | |||||
| def __init__(self, data, pickle_path): | |||||
| super(BasePreprocess, self).__init__() | |||||
| # self.data = data | |||||
| self.pickle_path = pickle_path | |||||
| if not self.pickle_path.endswith('/'): | |||||
| self.pickle_path = self.pickle_path + '/' | |||||
| class POSPreprocess(BasePreprocess): | |||||
| """ | |||||
| This class are used to preprocess the POS Tag datasets. | |||||
| """ | |||||
| def __init__(self, data, pickle_path="./", train_dev_split=0): | |||||
| """ | |||||
| Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
| from labels/classes to index, from index to labels/classes. | |||||
| :param data: three-level list | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :param pickle_path: str, the directory to the pickle files. Default: "./" | |||||
| :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. | |||||
| """ | |||||
| super(POSPreprocess, self).__init__(data, pickle_path) | |||||
| self.pickle_path = pickle_path | |||||
| if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||||
| self.label2index = load_pickle(self.pickle_path, "class2id.pkl") | |||||
| else: | |||||
| self.word2index, self.label2index = self.build_dict(data) | |||||
| save_pickle(self.word2index, self.pickle_path, "word2id.pkl") | |||||
| save_pickle(self.label2index, self.pickle_path, "class2id.pkl") | |||||
| if not pickle_exist(pickle_path, "id2word.pkl"): | |||||
| index2word = self.build_reverse_dict(self.word2index) | |||||
| save_pickle(index2word, self.pickle_path, "id2word.pkl") | |||||
| if not pickle_exist(pickle_path, "id2class.pkl"): | |||||
| index2label = self.build_reverse_dict(self.label2index) | |||||
| save_pickle(index2label, self.pickle_path, "id2class.pkl") | |||||
| if not pickle_exist(pickle_path, "data_train.pkl"): | |||||
| data_train = self.to_index(data) | |||||
| if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | |||||
| data_dev = data_train[: int(len(data_train) * train_dev_split)] | |||||
| save_pickle(data_dev, self.pickle_path, "data_dev.pkl") | |||||
| save_pickle(data_train, self.pickle_path, "data_train.pkl") | |||||
| def build_dict(self, data): | |||||
| """ | |||||
| Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||||
| :param data: three-level list | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return word2index: dict of {str, int} | |||||
| label2index: dict of {str, int} | |||||
| """ | |||||
| # In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. | |||||
| label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| for example in data: | |||||
| for word, label in zip(example[0], example[1]): | |||||
| if word not in word2index: | |||||
| word2index[word] = len(word2index) | |||||
| if label not in label2index: | |||||
| label2index[label] = len(label2index) | |||||
| return word2index, label2index | |||||
| def build_reverse_dict(self, word_dict): | |||||
| id2word = {word_dict[w]: w for w in word_dict} | |||||
| return id2word | |||||
| def to_index(self, data): | |||||
| """ | |||||
| Convert word strings and label strings into indices. | |||||
| :param data: three-level list | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return data_index: the shape of data, but each string is replaced by its corresponding index | |||||
| """ | |||||
| data_index = [] | |||||
| for example in data: | |||||
| word_list = [] | |||||
| label_list = [] | |||||
| for word, label in zip(example[0], example[1]): | |||||
| word_list.append(self.word2index[word]) | |||||
| label_list.append(self.label2index[label]) | |||||
| data_index.append([word_list, label_list]) | |||||
| return data_index | |||||
| @property | |||||
| def vocab_size(self): | |||||
| return len(self.word2index) | |||||
| @property | |||||
| def num_classes(self): | |||||
| return len(self.label2index) | |||||
| class ClassPreprocess(BasePreprocess): | |||||
| """ | |||||
| Pre-process the classification datasets. | |||||
| Params: | |||||
| pickle_path - directory to save result of pre-processing | |||||
| Saves: | |||||
| word2id.pkl | |||||
| id2word.pkl | |||||
| class2id.pkl | |||||
| id2class.pkl | |||||
| embedding.pkl | |||||
| data_train.pkl | |||||
| data_dev.pkl | |||||
| data_test.pkl | |||||
| """ | |||||
| def __init__(self, pickle_path): | |||||
| # super(ClassPreprocess, self).__init__(data, pickle_path) | |||||
| self.word_dict = None | |||||
| self.label_dict = None | |||||
| self.pickle_path = pickle_path # save directory | |||||
| def process(self, data, save_name): | |||||
| """ | |||||
| Process data. | |||||
| Params: | |||||
| data - nested list, data = [sample1, sample2, ...], | |||||
| sample = [sentence, label], sentence = [word1, word2, ...] | |||||
| save_name - name of processed data, such as data_train.pkl | |||||
| Returns: | |||||
| vocab_size - vocabulary size | |||||
| n_classes - number of classes | |||||
| """ | |||||
| self.build_dict(data) | |||||
| self.word2id() | |||||
| vocab_size = self.id2word() | |||||
| self.class2id() | |||||
| num_classes = self.id2class() | |||||
| self.embedding() | |||||
| self.data_generate(data, save_name) | |||||
| return vocab_size, num_classes | |||||
| def build_dict(self, data): | |||||
| """Build vocabulary.""" | |||||
| # just read if word2id.pkl and class2id.pkl exists | |||||
| if self.pickle_exist("word2id.pkl") and \ | |||||
| self.pickle_exist("class2id.pkl"): | |||||
| file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
| with open(file_name, 'rb') as f: | |||||
| self.word_dict = _pickle.load(f) | |||||
| file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
| with open(file_name, 'rb') as f: | |||||
| self.label_dict = _pickle.load(f) | |||||
| return | |||||
| # build vocabulary from scratch if nothing exists | |||||
| self.word_dict = { | |||||
| DEFAULT_PADDING_LABEL: 0, | |||||
| DEFAULT_UNKNOWN_LABEL: 1, | |||||
| DEFAULT_RESERVED_LABEL[0]: 2, | |||||
| DEFAULT_RESERVED_LABEL[1]: 3, | |||||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||||
| self.label_dict = {} | |||||
| # collect every word and label | |||||
| for sent, label in data: | |||||
| if len(sent) <= 1: | |||||
| continue | |||||
| if label not in self.label_dict: | |||||
| index = len(self.label_dict) | |||||
| self.label_dict[label] = index | |||||
| for word in sent: | |||||
| if word not in self.word_dict: | |||||
| index = len(self.word_dict) | |||||
| self.word_dict[word[0]] = index | |||||
| def pickle_exist(self, pickle_name): | |||||
| """ | |||||
| Check whether a pickle file exists. | |||||
| Params | |||||
| pickle_name: the filename of target pickle file | |||||
| Return | |||||
| True if file exists else False | |||||
| """ | |||||
| if not os.path.exists(self.pickle_path): | |||||
| os.makedirs(self.pickle_path) | |||||
| file_name = os.path.join(self.pickle_path, pickle_name) | |||||
| if os.path.exists(file_name): | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def word2id(self): | |||||
| """Save vocabulary of {word:id} mapping format.""" | |||||
| # nothing will be done if word2id.pkl exists | |||||
| if self.pickle_exist("word2id.pkl"): | |||||
| return | |||||
| file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(self.word_dict, f) | |||||
| def id2word(self): | |||||
| """Save vocabulary of {id:word} mapping format.""" | |||||
| # nothing will be done if id2word.pkl exists | |||||
| if self.pickle_exist("id2word.pkl"): | |||||
| file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
| with open(file_name, 'rb') as f: | |||||
| id2word_dict = _pickle.load(f) | |||||
| return len(id2word_dict) | |||||
| id2word_dict = {self.word_dict[w]: w for w in self.word_dict} | |||||
| file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(id2word_dict, f) | |||||
| return len(id2word_dict) | |||||
| def class2id(self): | |||||
| """Save mapping of {class:id}.""" | |||||
| # nothing will be done if class2id.pkl exists | |||||
| if self.pickle_exist("class2id.pkl"): | |||||
| return | |||||
| file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(self.label_dict, f) | |||||
| def id2class(self): | |||||
| """Save mapping of {id:class}.""" | |||||
| # nothing will be done if id2class.pkl exists | |||||
| if self.pickle_exist("id2class.pkl"): | |||||
| file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
| with open(file_name, "rb") as f: | |||||
| id2class_dict = _pickle.load(f) | |||||
| return len(id2class_dict) | |||||
| id2class_dict = {self.label_dict[c]: c for c in self.label_dict} | |||||
| file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
| with open(file_name, "wb") as f: | |||||
| _pickle.dump(id2class_dict, f) | |||||
| return len(id2class_dict) | |||||
| def embedding(self): | |||||
| """Save embedding lookup table corresponding to vocabulary.""" | |||||
| # nothing will be done if embedding.pkl exists | |||||
| if self.pickle_exist("embedding.pkl"): | |||||
| return | |||||
| # retrieve vocabulary from pre-trained embedding (not implemented) | |||||
| def data_generate(self, data_src, save_name): | |||||
| """Convert dataset from text to digit.""" | |||||
| # nothing will be done if file exists | |||||
| save_path = os.path.join(self.pickle_path, save_name) | |||||
| if os.path.exists(save_path): | |||||
| return | |||||
| data = [] | |||||
| # for every sample | |||||
| for sent, label in data_src: | |||||
| if len(sent) <= 1: | |||||
| continue | |||||
| label_id = self.label_dict[label] # label id | |||||
| sent_id = [] # sentence ids | |||||
| for word in sent: | |||||
| if word in self.word_dict: | |||||
| sent_id.append(self.word_dict[word]) | |||||
| else: | |||||
| sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) | |||||
| data.append([sent_id, label_id]) | |||||
| # save data | |||||
| with open(save_path, "wb") as f: | |||||
| _pickle.dump(data, f) | |||||
| class LMPreprocess(BasePreprocess): | |||||
| def __init__(self, data, pickle_path): | |||||
| super(LMPreprocess, self).__init__(data, pickle_path) | |||||
| def infer_preprocess(pickle_path, data): | |||||
| """ | |||||
| Preprocess over inference data. | |||||
| Transform three-level list of strings into that of index. | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| """ | |||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| data_index = [] | |||||
| for example in data: | |||||
| data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) | |||||
| return data_index | |||||
| @@ -3,7 +3,6 @@ import torch | |||||
| class BaseModel(torch.nn.Module): | class BaseModel(torch.nn.Module): | ||||
| """Base PyTorch model for all models. | """Base PyTorch model for all models. | ||||
| To do: add some useful common features | |||||
| """ | """ | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -19,8 +19,6 @@ USE_GPU = True | |||||
| class CharLM(BaseModel): | class CharLM(BaseModel): | ||||
| """ | """ | ||||
| Controller of the Character-level Neural Language Model | Controller of the Character-level Neural Language Model | ||||
| To do: | |||||
| - where the data goes, call data savers. | |||||
| """ | """ | ||||
| def __init__(self, lstm_batch_size, lstm_seq_len): | def __init__(self, lstm_batch_size, lstm_seq_len): | ||||
| super(CharLM, self).__init__() | super(CharLM, self).__init__() | ||||
| @@ -1,9 +1,7 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | |||||
| from torch.nn import functional as F | |||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.modules.decoder.CRF import ContionalRandomField | |||||
| from fastNLP.modules import decoder, encoder, utils | |||||
| class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
| @@ -11,87 +9,71 @@ class SeqLabeling(BaseModel): | |||||
| PyTorch Network for sequence labeling | PyTorch Network for sequence labeling | ||||
| """ | """ | ||||
| def __init__(self, hidden_dim, | |||||
| rnn_num_layer, | |||||
| num_classes, | |||||
| vocab_size, | |||||
| word_emb_dim=100, | |||||
| init_emb=None, | |||||
| rnn_mode="gru", | |||||
| bi_direction=False, | |||||
| dropout=0.5, | |||||
| use_crf=True): | |||||
| def __init__(self, args): | |||||
| super(SeqLabeling, self).__init__() | super(SeqLabeling, self).__init__() | ||||
| vocab_size = args["vocab_size"] | |||||
| word_emb_dim = args["word_emb_dim"] | |||||
| hidden_dim = args["rnn_hidden_units"] | |||||
| num_classes = args["num_classes"] | |||||
| self.Emb = nn.Embedding(vocab_size, word_emb_dim) | |||||
| if init_emb: | |||||
| self.Emb.weight = nn.Parameter(init_emb) | |||||
| self.num_classes = num_classes | |||||
| self.input_dim = word_emb_dim | |||||
| self.layers = rnn_num_layer | |||||
| self.hidden_dim = hidden_dim | |||||
| self.bi_direction = bi_direction | |||||
| self.dropout = dropout | |||||
| self.mode = rnn_mode | |||||
| if self.mode == "lstm": | |||||
| self.rnn = nn.LSTM(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
| bidirectional=self.bi_direction, dropout=self.dropout) | |||||
| elif self.mode == "gru": | |||||
| self.rnn = nn.GRU(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
| bidirectional=self.bi_direction, dropout=self.dropout) | |||||
| elif self.mode == "rnn": | |||||
| self.rnn = nn.RNN(self.input_dim, self.hidden_dim, self.layers, batch_first=True, | |||||
| bidirectional=self.bi_direction, dropout=self.dropout) | |||||
| else: | |||||
| raise Exception | |||||
| if bi_direction: | |||||
| self.linear = nn.Linear(self.hidden_dim * 2, self.num_classes) | |||||
| else: | |||||
| self.linear = nn.Linear(self.hidden_dim, self.num_classes) | |||||
| self.use_crf = use_crf | |||||
| if self.use_crf: | |||||
| self.crf = ContionalRandomField(num_classes) | |||||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | |||||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | |||||
| self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | |||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| """ | """ | ||||
| :param x: LongTensor, [batch_size, mex_len] | :param x: LongTensor, [batch_size, mex_len] | ||||
| :return y: [batch_size, tag_size, tag_size] | |||||
| :return y: [batch_size, mex_len, tag_size] | |||||
| """ | """ | ||||
| x = self.Emb(x) | |||||
| x = self.Embedding(x) | |||||
| # [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
| x, hidden = self.rnn(x) | |||||
| x = self.Rnn(x) | |||||
| # [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
| y = self.linear(x) | |||||
| x = self.Linear(x) | |||||
| # [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
| return y | |||||
| return x | |||||
| def loss(self, x, y, mask, batch_size, max_len): | |||||
| def loss(self, x, y, seq_length): | |||||
| """ | """ | ||||
| Negative log likelihood loss. | Negative log likelihood loss. | ||||
| :param x: FloatTensor, [batch_size, tag_size, tag_size] | |||||
| :param x: FloatTensor, [batch_size, max_len, tag_size] | |||||
| :param y: LongTensor, [batch_size, max_len] | :param y: LongTensor, [batch_size, max_len] | ||||
| :param mask: ByteTensor, [batch_size, max_len] | |||||
| :param batch_size: int | |||||
| :param max_len: int | |||||
| :param seq_length: list of int. [batch_size] | |||||
| :return loss: a scalar Tensor | :return loss: a scalar Tensor | ||||
| prediction: list of tuple of (decode path(list), best score) | |||||
| """ | """ | ||||
| x = x.float() | x = x.float() | ||||
| y = y.long() | y = y.long() | ||||
| mask = mask.byte() | |||||
| # print(x.shape, y.shape, mask.shape) | |||||
| if self.use_crf: | |||||
| total_loss = self.crf(x, y, mask) | |||||
| tag_seq = self.crf.viterbi_decode(x, mask) | |||||
| else: | |||||
| # error | |||||
| loss_function = nn.NLLLoss(ignore_index=0, size_average=False) | |||||
| x = x.view(batch_size * max_len, -1) | |||||
| score = F.log_softmax(x) | |||||
| total_loss = loss_function(score, y.view(batch_size * max_len)) | |||||
| _, tag_seq = torch.max(score) | |||||
| tag_seq = tag_seq.view(batch_size, max_len) | |||||
| return torch.mean(total_loss), tag_seq | |||||
| batch_size = x.size(0) | |||||
| max_len = x.size(1) | |||||
| mask = utils.seq_mask(seq_length, max_len) | |||||
| mask = mask.byte().view(batch_size, max_len) | |||||
| # TODO: remove | |||||
| if torch.cuda.is_available(): | |||||
| mask = mask.cuda() | |||||
| # mask = x.new(batch_size, max_len) | |||||
| total_loss = self.Crf(x, y, mask) | |||||
| return torch.mean(total_loss) | |||||
| def prediction(self, x, seq_length): | |||||
| """ | |||||
| :param x: FloatTensor, [batch_size, max_len, tag_size] | |||||
| :param seq_length: int | |||||
| :return prediction: list of tuple of (decode path(list), best score) | |||||
| """ | |||||
| x = x.float() | |||||
| max_len = x.size(1) | |||||
| mask = utils.seq_mask(seq_length, max_len) | |||||
| # hack: make sure mask has the same device as x | |||||
| mask = mask.to(x).byte() | |||||
| tag_seq = self.Crf.viterbi_decode(x, mask) | |||||
| return tag_seq | |||||
| @@ -0,0 +1,11 @@ | |||||
| from . import aggregation | |||||
| from . import decoder | |||||
| from . import encoder | |||||
| from . import interaction | |||||
| __version__ = '0.0.0' | |||||
| __all__ = ['encoder', | |||||
| 'decoder', | |||||
| 'aggregation', | |||||
| 'interaction'] | |||||
| @@ -18,13 +18,13 @@ def seq_len_to_byte_mask(seq_lens): | |||||
| return mask | return mask | ||||
| class ContionalRandomField(nn.Module): | |||||
| class ConditionalRandomField(nn.Module): | |||||
| def __init__(self, tag_size, include_start_end_trans=True): | def __init__(self, tag_size, include_start_end_trans=True): | ||||
| """ | """ | ||||
| :param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
| :param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
| """ | """ | ||||
| super(ContionalRandomField, self).__init__() | |||||
| super(ConditionalRandomField, self).__init__() | |||||
| self.include_start_end_trans = include_start_end_trans | self.include_start_end_trans = include_start_end_trans | ||||
| self.tag_size = tag_size | self.tag_size = tag_size | ||||
| @@ -47,7 +47,6 @@ class ContionalRandomField(nn.Module): | |||||
| """ | """ | ||||
| Computes the (batch_size,) denominator term for the log-likelihood, which is the | Computes the (batch_size,) denominator term for the log-likelihood, which is the | ||||
| sum of the likelihoods across all possible state sequences. | sum of the likelihoods across all possible state sequences. | ||||
| :param feats:FloatTensor, batch_size x max_len x tag_size | :param feats:FloatTensor, batch_size x max_len x tag_size | ||||
| :param masks:ByteTensor, batch_size x max_len | :param masks:ByteTensor, batch_size x max_len | ||||
| :return:FloatTensor, batch_size | :return:FloatTensor, batch_size | ||||
| @@ -128,7 +127,7 @@ class ContionalRandomField(nn.Module): | |||||
| return all_path_score - gold_path_score | return all_path_score - gold_path_score | ||||
| def viterbi_decode(self, feats, masks): | |||||
| def viterbi_decode(self, feats, masks, get_score=False): | |||||
| """ | """ | ||||
| Given a feats matrix, return best decode path and best score. | Given a feats matrix, return best decode path and best score. | ||||
| :param feats: | :param feats: | ||||
| @@ -147,28 +146,28 @@ class ContionalRandomField(nn.Module): | |||||
| for t in range(self.tag_size): | for t in range(self.tag_size): | ||||
| pre_scores = self.transition_m[:, t].view( | pre_scores = self.transition_m[:, t].view( | ||||
| 1, self.tag_size) + alpha | 1, self.tag_size) + alpha | ||||
| max_scroe, indice = pre_scores.max(dim=1) | |||||
| new_alpha[:, t] = max_scroe + feats[:, i, t] | |||||
| paths[:, i - 1, t] = indice | |||||
| alpha = new_alpha * \ | |||||
| masks[:, i:i + 1].float() + alpha * \ | |||||
| (1 - masks[:, i:i + 1].float()) | |||||
| max_score, indices = pre_scores.max(dim=1) | |||||
| new_alpha[:, t] = max_score + feats[:, i, t] | |||||
| paths[:, i - 1, t] = indices | |||||
| alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) | |||||
| if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
| alpha += self.end_scores.view(1, -1) | alpha += self.end_scores.view(1, -1) | ||||
| max_scroes, indice = alpha.max(dim=1) | |||||
| indice = indice.cpu().numpy() | |||||
| max_scores, indices = alpha.max(dim=1) | |||||
| indices = indices.cpu().numpy() | |||||
| final_paths = [] | final_paths = [] | ||||
| paths = paths.cpu().numpy().astype(int) | paths = paths.cpu().numpy().astype(int) | ||||
| seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] | ||||
| for b in range(batch_size): | for b in range(batch_size): | ||||
| path = [indice[b]] | |||||
| path = [indices[b]] | |||||
| for i in range(seq_lens[b] - 2, -1, -1): | for i in range(seq_lens[b] - 2, -1, -1): | ||||
| index = paths[b, i, path[-1]] | index = paths[b, i, path[-1]] | ||||
| path.append(index) | path.append(index) | ||||
| final_paths.append(path[::-1]) | final_paths.append(path[::-1]) | ||||
| return list(zip(final_paths, max_scroes.detach().cpu().numpy())) | |||||
| if get_score: | |||||
| return list(zip(final_paths, max_scores.detach().cpu().numpy())) | |||||
| else: | |||||
| return final_paths | |||||
| @@ -0,0 +1,3 @@ | |||||
| from .CRF import ConditionalRandomField | |||||
| __all__ = ["ConditionalRandomField"] | |||||
| @@ -0,0 +1,7 @@ | |||||
| from .embedding import Embedding | |||||
| from .linear import Linear | |||||
| from .lstm import Lstm | |||||
| __all__ = ["Lstm", | |||||
| "Embedding", | |||||
| "Linear"] | |||||
| @@ -1,25 +1,24 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| class Lookuptable(nn.Module): | |||||
| class Embedding(nn.Module): | |||||
| """ | """ | ||||
| A simple lookup table | A simple lookup table | ||||
| Args: | Args: | ||||
| nums : the size of the lookup table | nums : the size of the lookup table | ||||
| dims : the size of each vector. Default: 50. | |||||
| dims : the size of each vector | |||||
| padding_idx : pads the tensor with zeros whenever it encounters this index | padding_idx : pads the tensor with zeros whenever it encounters this index | ||||
| sparse : If True, gradient matrix will be a sparse tensor. In this case, | sparse : If True, gradient matrix will be a sparse tensor. In this case, | ||||
| only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used | only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used | ||||
| """ | """ | ||||
| def __init__(self, nums, dims=50, padding_idx=0, sparse=False): | |||||
| super(Lookuptable, self).__init__() | |||||
| def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0): | |||||
| super(Embedding, self).__init__() | |||||
| self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse) | ||||
| def forward(self, x): | |||||
| return self.embed(x) | |||||
| if init_emb: | |||||
| self.embed.weight = nn.Parameter(init_emb) | |||||
| self.dropout = nn.Dropout(dropout) | |||||
| if __name__ == "__main__": | |||||
| model = Lookuptable(10, 20) | |||||
| def forward(self, x): | |||||
| x = self.embed(x) | |||||
| return self.dropout(x) | |||||
| @@ -0,0 +1,21 @@ | |||||
| import torch.nn as nn | |||||
| class Linear(nn.Module): | |||||
| """ | |||||
| Linear module | |||||
| Args: | |||||
| input_size : input size | |||||
| hidden_size : hidden size | |||||
| num_layers : number of hidden layers | |||||
| dropout : dropout rate | |||||
| bidirectional : If True, becomes a bidirectional RNN | |||||
| """ | |||||
| def __init__(self, input_size, output_size, bias=True): | |||||
| super(Linear, self).__init__() | |||||
| self.linear = nn.Linear(input_size, output_size, bias) | |||||
| def forward(self, x): | |||||
| x = self.linear(x) | |||||
| return x | |||||
| @@ -13,7 +13,7 @@ class Lstm(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN. Default: False. | bidirectional : If True, becomes a bidirectional RNN. Default: False. | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.5, bidirectional=False): | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False): | |||||
| super(Lstm, self).__init__() | super(Lstm, self).__init__() | ||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | ||||
| dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
| @@ -18,7 +18,7 @@ MLP_HIDDEN = 2000 | |||||
| CLASSES_NUM = 5 | CLASSES_NUM = 5 | ||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.action.trainer import BaseTrainer | |||||
| from fastNLP.core.trainer import BaseTrainer | |||||
| class MyNet(BaseModel): | class MyNet(BaseModel): | ||||
| @@ -0,0 +1,29 @@ | |||||
| [train] | |||||
| epochs = 2 | |||||
| batch_size = 32 | |||||
| pickle_path = "./save/" | |||||
| validate = true | |||||
| save_best_dev = true | |||||
| model_saved_path = "./save/" | |||||
| rnn_hidden_units = 100 | |||||
| rnn_layers = 2 | |||||
| rnn_bi_direction = true | |||||
| word_emb_dim = 100 | |||||
| dropout = 0.5 | |||||
| use_crf = true | |||||
| use_cuda = true | |||||
| [test] | |||||
| save_output = true | |||||
| validate_in_training = true | |||||
| save_dev_input = false | |||||
| save_loss = true | |||||
| batch_size = 64 | |||||
| pickle_path = "./save/" | |||||
| rnn_hidden_units = 100 | |||||
| rnn_layers = 1 | |||||
| rnn_bi_direction = true | |||||
| word_emb_dim = 100 | |||||
| dropout = 0.5 | |||||
| use_crf = true | |||||
| use_cuda = true | |||||
| @@ -0,0 +1,110 @@ | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.core.trainer import POSTrainer | |||||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
| from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
| from fastNLP.saver.model_saver import ModelSaver | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.core.tester import POSTester | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.core.inference import Inference | |||||
| data_name = "pku_training.utf8" | |||||
| cws_data_path = "/home/zyfeng/data/pku_training.utf8" | |||||
| pickle_path = "./save/" | |||||
| data_infer_path = "data_for_tests/people_infer.txt" | |||||
| def infer(): | |||||
| # Load infer configuration, the same as test | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | |||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| test_args["vocab_size"] = len(word2index) | |||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| test_args["num_classes"] = len(index2label) | |||||
| # Define the same model | |||||
| model = SeqLabeling(test_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Data Loader | |||||
| raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | |||||
| # Inference interface | |||||
| infer = Inference(pickle_path) | |||||
| results = infer.predict(model, infer_data) | |||||
| print(results) | |||||
| print("Inference finished!") | |||||
| def train(): | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("good_name", "good_path").load_config("./cws.cfg", {"train": train_args, "test": test_args}) | |||||
| # Data Loader | |||||
| loader = TokenizeDatasetLoader(data_name, cws_data_path) | |||||
| train_data = loader.load_pku() | |||||
| # Preprocessor | |||||
| p = POSPreprocess(train_data, pickle_path) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| # Trainer | |||||
| trainer = POSTrainer(train_args) | |||||
| # Model | |||||
| model = SeqLabeling(train_args) | |||||
| # Start training | |||||
| trainer.train(model) | |||||
| print("Training finished!") | |||||
| # Saver | |||||
| saver = ModelSaver("./save/saved_model.pkl") | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| def test(): | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| # Define the same model | |||||
| model = SeqLabeling(train_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Load test configuration | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # Tester | |||||
| tester = POSTester(test_args) | |||||
| # Start testing | |||||
| tester.test(model) | |||||
| # print test results | |||||
| print(tester.show_matrices()) | |||||
| print("model tested!") | |||||
| if __name__ == "__main__": | |||||
| train() | |||||
| @@ -54,8 +54,8 @@ test = 5 | |||||
| new_attr = 40 | new_attr = 40 | ||||
| [POS] | [POS] | ||||
| epochs = 20 | |||||
| batch_size = 1 | |||||
| epochs = 1 | |||||
| batch_size = 32 | |||||
| pickle_path = "./data_for_tests/" | pickle_path = "./data_for_tests/" | ||||
| validate = true | validate = true | ||||
| save_best_dev = true | save_best_dev = true | ||||
| @@ -66,6 +66,7 @@ rnn_bi_direction = true | |||||
| word_emb_dim = 100 | word_emb_dim = 100 | ||||
| dropout = 0.5 | dropout = 0.5 | ||||
| use_crf = true | use_crf = true | ||||
| use_cuda = true | |||||
| [POS_test] | [POS_test] | ||||
| save_output = true | save_output = true | ||||
| @@ -74,3 +75,19 @@ save_dev_input = false | |||||
| save_loss = true | save_loss = true | ||||
| batch_size = 1 | batch_size = 1 | ||||
| pickle_path = "./data_for_tests/" | pickle_path = "./data_for_tests/" | ||||
| rnn_hidden_units = 100 | |||||
| rnn_layers = 1 | |||||
| rnn_bi_direction = true | |||||
| word_emb_dim = 100 | |||||
| dropout = 0.5 | |||||
| use_crf = true | |||||
| use_cuda = true | |||||
| [POS_infer] | |||||
| pickle_path = "./data_for_tests/" | |||||
| rnn_hidden_units = 100 | |||||
| rnn_layers = 1 | |||||
| rnn_bi_direction = true | |||||
| word_emb_dim = 100 | |||||
| vocab_size = 52 | |||||
| num_classes = 22 | |||||
| @@ -0,0 +1,56 @@ | |||||
| 迈向 充满 希望 的 新 世纪 —— 一九九八年 新年 讲话 ( 附 图片 1 张 ) | |||||
| 中共中央 总书记 、 国家 主席 江 泽民 | |||||
| ( 一九九七年 十二月 三十一日 ) | |||||
| 12月 31日 , 中共中央 总书记 、 国家 主席 江 泽民 发表 1998年 新年 讲话 《 迈向 充满 希望 的 新 世纪 》 。 ( 新华社 记者 兰 红光 摄 ) | |||||
| 同胞 们 、 朋友 们 、 女士 们 、 先生 们 : | |||||
| 在 1998年 来临 之际 , 我 十分 高兴 地 通过 中央 人民 广播 电台 、 中国 国际 广播 电台 和 中央 电视台 , 向 全国 各族 人民 , 向 香港 特别 行政区 同胞 、 澳门 和 台湾 同胞 、 海外 侨胞 , 向 世界 各国 的 朋友 们 , 致以 诚挚 的 问候 和 良好 的 祝愿 ! | |||||
| 1997年 , 是 中国 发展 历史 上 非常 重要 的 很 不 平凡 的 一 年 。 中国 人民 决心 继承 邓 小平 同志 的 遗志 , 继续 把 建设 有 中国 特色 社会主义 事业 推向 前进 。 中国 政府 顺利 恢复 对 香港 行使 主权 , 并 按照 “ 一国两制 ” 、 “ 港人治港 ” 、 高度 自治 的 方针 保持 香港 的 繁荣 稳定 。 中国 共产党 成功 地 召开 了 第十五 次 全国 代表大会 , 高举 邓小平理论 伟大 旗帜 , 总结 百年 历史 , 展望 新 的 世纪 , 制定 了 中国 跨 世纪 发展 的 行动 纲领 。 | |||||
| 在 这 一 年 中 , 中国 的 改革 开放 和 现代化 建设 继续 向前 迈进 。 国民经济 保持 了 “ 高 增长 、 低 通胀 ” 的 良好 发展 态势 。 农业 生产 再次 获得 好 的 收成 , 企业 改革 继续 深化 , 人民 生活 进一步 改善 。 对外 经济 技术 合作 与 交流 不断 扩大 。 民主 法制 建设 、 精神文明 建设 和 其他 各项 事业 都 有 新 的 进展 。 我们 十分 关注 最近 一个 时期 一些 国家 和 地区 发生 的 金融 风波 , 我们 相信 通过 这些 国家 和 地区 的 努力 以及 有关 的 国际 合作 , 情况 会 逐步 得到 缓解 。 总的来说 , 中国 改革 和 发展 的 全局 继续 保持 了 稳定 。 | |||||
| 在 这 一 年 中 , 中国 的 外交 工作 取得 了 重要 成果 。 通过 高层 互访 , 中国 与 美国 、 俄罗斯 、 法国 、 日本 等 大国 确定 了 双方 关系 未来 发展 的 目标 和 指导 方针 。 中国 与 周边 国家 和 广大 发展中国家 的 友好 合作 进一步 加强 。 中国 积极 参与 亚太经合 组织 的 活动 , 参加 了 东盟 — 中 日 韩 和 中国 — 东盟 首脑 非正式 会晤 。 这些 外交 活动 , 符合 和平 与 发展 的 时代 主题 , 顺应 世界 走向 多极化 的 趋势 , 对于 促进 国际 社会 的 友好 合作 和 共同 发展 作出 了 积极 的 贡献 。 | |||||
| 1998年 , 中国 人民 将 满怀信心 地 开创 新 的 业绩 。 尽管 我们 在 经济社会 发展 中 还 面临 不少 困难 , 但 我们 有 邓小平理论 的 指引 , 有 改革 开放 近 20 年 来 取得 的 伟大 成就 和 积累 的 丰富 经验 , 还有 其他 的 各种 有利 条件 , 我们 一定 能够 克服 这些 困难 , 继续 稳步前进 。 只要 我们 进一步 解放思想 , 实事求是 , 抓住 机遇 , 开拓进取 , 建设 有 中国 特色 社会主义 的 道路 就 会 越 走 越 宽广 。 | |||||
| 实现 祖国 的 完全 统一 , 是 海内外 全体 中国 人 的 共同 心愿 。 通过 中 葡 双方 的 合作 和 努力 , 按照 “ 一国两制 ” 方针 和 澳门 《 基本法 》 , 1999年 12月 澳门 的 回归 一定 能够 顺利 实现 。 | |||||
| 台湾 是 中国 领土 不可分割 的 一 部分 。 完成 祖国 统一 , 是 大势所趋 , 民心所向 。 任何 企图 制造 “ 两 个 中国 ” 、 “ 一中一台 ” 、 “ 台湾 独立 ” 的 图谋 , 都 注定 要 更 失败 。 希望 台湾 当局 以 民族 大义 为重 , 拿 出 诚意 , 采取 实际 的 行动 , 推动 两岸 经济 文化 交流 和 人员 往来 , 促进 两岸 直接 通邮 、 通航 、 通商 的 早日 实现 , 并 尽早 回应 我们 发出 的 在 一个 中国 的 原则 下 两岸 进行 谈判 的 郑重 呼吁 。 | |||||
| 环顾 全球 , 日益 密切 的 世界 经济 联系 , 日新月异 的 科技 进步 , 正在 为 各国 经济 的 发展 提供 历史 机遇 。 但是 , 世界 还 不 安宁 。 南北 之间 的 贫富 差距 继续 扩大 ; 局部 冲突 时有发生 ; 不 公正 不 合理 的 旧 的 国际 政治经济 秩序 还 没有 根本 改变 ; 发展中国家 在 激烈 的 国际 经济 竞争 中 仍 处于 弱势 地位 ; 人类 的 生存 与 发展 还 面临 种种 威胁 和 挑战 。 和平 与 发展 的 前景 是 光明 的 , 21 世纪 将 是 充满 希望 的 世纪 。 但 前进 的 道路 不 会 也 不 可能 一帆风顺 , 关键 是 世界 各国 人民 要 进一步 团结 起来 , 共同 推动 早日 建立 公正 合理 的 国际 政治经济 新 秩序 。 | |||||
| 中国 政府 将 继续 坚持 奉行 独立自主 的 和平 外交 政策 , 在 和平共处 五 项 原则 的 基础 上 努力 发展 同 世界 各国 的 友好 关系 。 中国 愿意 加强 同 联合国 和 其他 国际 组织 的 协调 , 促进 在 扩大 经贸 科技 交流 、 保护 环境 、 消除 贫困 、 打击 国际 犯罪 等 方面 的 国际 合作 。 中国 永远 是 维护 世界 和平 与 稳定 的 重要 力量 。 中国 人民 愿 与 世界 各国 人民 一道 , 为 开创 持久 和平 、 共同 发展 的 新 世纪 而 不懈努力 ! | |||||
| 在 这 辞旧迎新 的 美好 时刻 , 我 祝 大家 新年 快乐 , 家庭 幸福 ! | |||||
| 谢谢 ! ( 新华社 北京 12月 31日 电 ) | |||||
| 在 十五大 精神 指引 下 胜利 前进 —— 元旦 献辞 | |||||
| 我们 即将 以 丰收 的 喜悦 送 走 牛年 , 以 昂扬 的 斗志 迎来 虎年 。 我们 伟大 祖国 在 新 的 一 年 , 将 是 充满 生机 、 充满 希望 的 一 年 。 | |||||
| 刚刚 过去 的 一 年 , 大气磅礴 , 波澜壮阔 。 在 这 一 年 , 以 江 泽民 同志 为 核心 的 党中央 , 继承 邓 小平 同志 的 遗志 , 高举 邓小平理论 的 伟大 旗帜 , 领导 全党 和 全国 各族 人民 坚定不移 地 沿着 建设 有 中国 特色 社会主义 道路 阔步 前进 , 写 下 了 改革 开放 和 社会主义 现代化 建设 的 辉煌 篇章 。 顺利 地 恢复 对 香港 行使 主权 , 胜利 地 召开 党 的 第十五 次 全国 代表大会 ——— 两 件 大事 办 得 圆满 成功 。 国民经济 稳中求进 , 国家 经济 实力 进一步 增强 , 人民 生活 继续 改善 , 对外 经济 技术 交流 日益 扩大 。 在 国际 金融 危机 的 风浪 波及 许多 国家 的 情况 下 , 我国 保持 了 金融 形势 和 整个 经济 形势 的 稳定 发展 。 社会主义 精神文明 建设 和 民主 法制 建设 取得 新 的 成绩 , 各项 社会 事业 全面 进步 。 外交 工作 取得 可喜 的 突破 , 我国 的 国际 地位 和 国际 威望 进一步 提高 。 实践 使 亿万 人民 对 邓小平理论 更加 信仰 , 对 以 江 泽民 同志 为 核心 的 党中央 更加 信赖 , 对 伟大 祖国 的 光辉 前景 更加 充满 信心 。 | |||||
| 1998年 , 是 全面 贯彻 落实 党 的 十五大 提 出 的 任务 的 第一 年 , 各 条 战线 改革 和 发展 的 任务 都 十分 繁重 , 有 许多 深 层次 的 矛盾 和 问题 有待 克服 和 解决 , 特别 是 国有 企业 改革 已经 进入 攻坚 阶段 。 我们 必须 进一步 深入 学习 和 掌握 党 的 十五大 精神 , 统揽全局 , 精心 部署 , 狠抓 落实 , 团结 一致 , 艰苦奋斗 , 开拓 前进 , 为 夺取 今年 改革 开放 和 社会主义 现代化 建设 的 新 胜利 而 奋斗 。 | |||||
| 今年 是 党 的 十一 届 三中全会 召开 20 周年 , 是 我们 党 和 国家 实现 伟大 的 历史 转折 、 进入 改革 开放 历史 新 时期 的 20 周年 。 在 新 的 一 年 里 , 大力 发扬 十一 届 三中全会 以来 我们 党 所 恢复 的 优良 传统 和 在 新 的 历史 条件 下 形成 的 优良 作风 , 对于 完成 好 今年 的 各项 任务 具有 十分 重要 的 意义 。 | |||||
| 我们 要 更 好 地 坚持 解放思想 、 实事求是 的 思想 路线 。 解放思想 、 实事求是 , 是 邓小平理论 的 精髓 。 实践 证明 , 只有 解放思想 、 实事求是 , 才 能 冲破 各种 不 切合 实际 的 或者 过时 的 观念 的 束缚 , 真正 做到 尊重 、 认识 和 掌握 客观 规律 , 勇于 突破 , 勇于 创新 , 不断 开创 社会主义 现代化 建设 的 新 局面 。 党 的 十五大 是 我们 党 解放思想 、 实事求是 的 新 的 里程碑 。 进一步 认真 学习 和 掌握 十五大 精神 , 解放思想 、 实事求是 , 我们 的 各项 事业 就 能 结 出 更加 丰硕 的 成果 。 | |||||
| 我们 要 更 好 地 坚持 以 经济 建设 为 中心 。 各项 工作 必须 以 经济 建设 为 中心 , 是 邓小平理论 的 基本 观点 , 是 党 的 基本 路线 的 核心 内容 , 近 20 年 来 的 实践 证明 , 坚持 这个 中心 , 是 完全 正确 的 。 今后 , 我们 能否 把 建设 有 中国 特色 社会主义 伟大 事业 全面 推向 21 世纪 , 关键 仍然 要 看 能否 把 经济 工作 搞 上去 。 各级 领导 干部 要 切实 把 精力 集中 到 贯彻 落实 好 中央 关于 今年 经济 工作 的 总体 要求 和 各项 重要 任务 上 来 , 不断 提高 领导 经济 建设 的 能力 和 水平 。 | |||||
| 我们 要 更 好 地 坚持 “ 两手抓 、 两手 都 要 硬 ” 的 方针 。 在 坚持 以 经济 建设 为 中心 的 同时 , 积极 推进 社会主义 精神文明 建设 和 民主 法制 建设 , 是 建设 富强 、 民主 、 文明 的 社会主义 现代化 国家 的 重要 内容 。 实践 证明 , 经济 建设 的 顺利 进行 , 离 不 开 精神文明 建设 和 民主 法制 建设 的 保证 。 党 的 十五大 依据 邓小平理论 和 党 的 基本 路线 提 出 的 党 在 社会主义 初级阶段 经济 、 政治 、 文化 的 基本 纲领 , 为 “ 两手抓 、 两手 都 要 硬 ” 提供 了 新 的 理论 根据 , 提 出 了 更 高 要求 , 现在 的 关键 是 认真 抓好 落实 。 | |||||
| 我们 要 更 好 地 发扬 求真务实 、 密切 联系 群众 的 作风 。 这 是 把 党 的 方针 、 政策 落到实处 , 使 改革 和 建设 取得 胜利 的 重要 保证 。 在 当前 改革 进一步 深化 , 经济 不断 发展 , 同时 又 出现 一些 新 情况 、 新 问题 和 新 困难 的 形势 下 , 更 要 发扬 这样 的 好 作风 。 要 尊重 群众 的 意愿 , 重视 群众 的 首创 精神 , 关心 群众 的 生活 疾苦 。 江 泽民 同志 最近 强调 指出 , 要 大力 倡导 说实话 、 办 实事 、 鼓 实劲 、 讲 实效 的 作风 , 坚决 制止 追求 表面文章 , 搞 花架子 等 形式主义 , 坚决 杜绝 脱离 群众 、 脱离 实际 、 浮躁 虚夸 等 官僚主义 。 这 是 非常 重要 的 。 因此 , 各级 领导 干部 务必 牢记 全心全意 为 人民 服务 的 宗旨 , 在 勤政廉政 、 艰苦奋斗 方面 以身作则 , 当 好 表率 。 | |||||
| 1998 , 瞩目 中华 。 新 的 机遇 和 挑战 , 催 人 进取 ; 新 的 目标 和 征途 , 催 人 奋发 。 英雄 的 中国 人民 在 以 江 泽民 同志 为 核心 的 党中央 坚强 领导 和 党 的 十五大 精神 指引 下 , 更 高 地 举起 邓小平理论 的 伟大 旗帜 , 团结 一致 , 扎实 工作 , 奋勇前进 , 一定 能够 创造 出 更加 辉煌 的 业绩 ! | |||||
| 北京 举行 新年 音乐会 | |||||
| 江 泽民 李 鹏 乔 石 朱 镕基 李 瑞环 刘 华清 尉 健行 李 岚清 与 万 名 首都 各界 群众 和 劳动模范 代表 一起 辞旧迎新 ( 附 图片 1 张 ) | |||||
| 党 和 国家 领导人 江 泽民 、 李 鹏 、 乔 石 、 朱 镕基 、 李 瑞环 、 刘 华清 、 尉 健行 、 李 岚清 等 与 万 名 首都 各界 群众 和 劳动模范 代表 一起 欣赏 了 ’98 北京 新年 音乐会 的 精彩 节目 。 这 是 江 泽民 等 在 演出 结束 后 同 演出 人员 合影 。 | |||||
| ( 新华社 记者 樊 如钧 摄 ) | |||||
| 本报 北京 12月 31日 讯 新华社 记者 陈 雁 、 本报 记者 何 加正 报道 : 在 度过 了 非凡 而 辉煌 的 1997年 , 迈向 充满 希望 的 1998年 之际 , ’98 北京 新年 音乐会 今晚 在 人民 大会堂 举行 。 党 和 国家 领导人 江 泽民 、 李 鹏 、 乔 石 、 朱 镕基 、 李 瑞环 、 刘 华清 、 尉 健行 、 李 岚清 与 万 名 首都 各界 群众 和 劳动模范 代表 一起 , 在 激昂 奋进 的 音乐声 中 辞旧迎新 。 | |||||
| 今晚 的 长安街 流光溢彩 , 火树银花 ; 人民 大会堂 里 灯火辉煌 , 充满 欢乐 祥和 的 喜庆 气氛 。 在 这 场 由 中共 北京 市委 宣传部 、 市政府 办公厅 等 单位 主办 的 题 为 “ 世纪 携手 、 共 奏 华章 ” 的 新年 音乐会 上 , 中国 三 个 著名 交响乐团 ——— 中国 交响乐团 、 上海 交响乐团 、 北京 交响乐团 首 次 联袂 演出 。 著名 指挥家 陈 佐湟 、 陈 燮阳 、 谭 利华 分别 指挥 演奏 了 一 批 中外 名曲 , 京 沪 两地 200 多 位 音乐家 组成 的 大型 乐队 以 饱满 的 激情 和 精湛 的 技艺 为 观众 奉献 了 一 台 高 水准 的 交响音乐会 。 | |||||
| 音乐会 在 雄壮 的 管弦乐 《 红旗 颂 》 中 拉开 帷幕 , 舒展 、 优美 的 乐曲声 使 人们 仿佛 看到 : 五星红旗 在 天安门 城楼 上 冉冉 升起 ; 仿佛 听到 : 在 红旗 的 指引 下 中国 人民 向 现代化 新 征程 迈进 的 脚步声 。 钢琴 与 管弦乐队 作品 《 东方 之 珠 》 , 把 广大 听众 耳熟能详 的 歌曲 改编 为 器乐曲 , 以 其 优美 感人 的 旋律 抒发 了 洗雪 百年 耻辱 的 香港 明天 会 更 好 的 情感 。 专程 回国 参加 音乐会 的 著名 女高音 歌唱家 迪里拜尔 演唱 的 《 春 之 声 》 , 把 人们 带 到 了 万象更新 的 田野 和 山谷 ; 享誉 国际 乐坛 的 男高音 歌唱家 莫 华伦 演唱 了 著名 歌剧 《 图兰朵 》 选段 “ 今夜 无 人 入睡 ” , 把 人们 带入 迷人 的 艺术 境地 。 音乐会 上 还 演奏 了 小提琴 协奏曲 《 梁 山伯 与 祝 英台 》 、 柴可夫斯基 的 《 第四 交响曲 ——— 第四 乐章 》 、 交响诗 《 罗马 的 松树 》 等 中外 著名 交响曲 。 | |||||
| 万 人 大会堂 今晚 座无虚席 , 观众 被 艺术家 们 精湛 的 表演 深深 打动 , 不断 报 以 经久不息 的 热烈 掌声 。 艺术家 们 频频 谢幕 , 指挥家 依次 指挥 演出 返 场 曲目 , 最后 音乐会 在 《 红色 娘子军 》 选曲 、 《 白毛女 》 选曲 、 《 北京 喜讯 到 边寨 》 等 乐曲声 中 达到 高潮 。 | |||||
| 演出 结束 后 , 江 泽民 等 党 和 国家 领导人 走 上 舞台 , 亲切 会见 了 参加 演出 的 全体 人员 , 祝贺 演出 成功 , 并 与 他们 合影 留念 。 | |||||
| 李 铁映 、 贾 庆林 、 曾 庆红 等 领导 同志 也 出席 了 今晚 音乐会 。 | |||||
| 李 鹏 在 北京 考察 企业 | |||||
| 向 广大 职工 祝贺 新年 , 对 节日 坚守 岗位 的 同志 们 表示 慰问 | |||||
| 新华社 北京 十二月 三十一日 电 ( 中央 人民 广播 电台 记者 刘 振英 、 新华社 记者 张 宿堂 ) 今天 是 一九九七年 的 最后 一 天 。 辞旧迎新 之际 , 国务院 总理 李 鹏 今天 上午 来到 北京 石景山 发电 总厂 考察 , 向 广大 企业 职工 表示 节日 的 祝贺 , 向 将要 在 节日 期间 坚守 工作 岗位 的 同志 们 表示 慰问 。 | |||||
| 上午 九时 二十分 , 李 鹏 总理 在 北京 市委 书记 、 市长 贾 庆林 的 陪同 下 , 来到 位于 北京 西郊 的 北京 石景山 发电 总厂 。 始建 于 一九一九年 的 北京 石景山 发电 总厂 是 华北 电力 集团公司 骨干 发电 企业 , 承担 着 向 首都 供电 、 供热 任务 , 装机 总 容量 一百一十六点六万 千瓦 。 总厂 年发电量 四十五亿 千瓦时 , 供热 能力 八百 百万大卡/小时 , 现 供热 面积 已 达 八百 多 万 平方米 。 早 在 担任 华北 电管局 领导 时 , 李 鹏 就 曾 多次 到 发电 总厂 检查 指导 工作 。 | |||||
| 在 总厂 所 属 的 石景山 热电厂 , 李 鹏 首先 向 华北 电管局 、 电厂 负责人 详细 询问 了 目前 电厂 生产 、 职工 生活 和 华北 电网 向 首都 供电 、 供热 的 有关 情况 。 随后 , 他 又 实地 察看 了 发电机组 的 运行 情况 和 电厂 一号机 、 二号机 控制室 。 在 控制室 , 李 鹏 与 职工 们 一一 握手 , 向 大家 表示 慰问 。 他 说 , 在 一九九八年 即将 到来之际 , 有 机会 再次 回到 石景山 发电 总厂 , 感到 十分 高兴 。 李 鹏 亲切 地 说 : 『 今天 我 看到 了 许多 新 的 、 年轻 的 面孔 , 这 说明 在 老 同志 们 作出 贡献 退 下来 后 , 新 一代 的 年轻人 成长 起来 了 、 成熟 起来 了 , 我 感到 十分 欣慰 。 』 | |||||
| ( A 、 B ) | |||||
| 李 鹏 说 : “ 作为 首都 的 电力 工作者 , 你们 为 首都 的 各项 重大 活动 的 顺利 进行 , 为 保障 人民 群众 的 工作 、 生活 和 学习 , 为 促进 首都 经济 的 发展 作出 了 自己 的 贡献 。 明天 就 是 元旦 , 你们 还有 许多 同志 要 坚守 岗位 , 我 向 你们 、 向 全体 电力 工作者 表示 感谢 。 现在 , 我们 的 首都 已经 结束 了 拉 闸 限 电 的 历史 , 希望 依靠 大家 , 使 拉 闸 限 电 的 历史 永远 不再 重演 。 同时 , 也 希望 你们 安全 生产 、 经济 调度 , 实现 经济 增长 方式 的 转变 。 ” 李 鹏 最后 向 电业 职工 , 向 全 北京市 的 人民 拜年 , 向 大家 致以 新春 的 问候 , 祝愿 电力 事业 取得 新 的 成绩 , 祝愿 北京市 在 改革 、 发展 和 稳定 的 各项 工作 中 取得 新 的 成就 。 | |||||
| 参观 工厂 结束 后 , 李 鹏 又 来到 工厂 退休 职工 郭 树范 和 闫 戌麟 家 看望 慰问 , 向 他们 拜年 。 曾经 是 高级 工程师 的 郭 树范 退休 前 一直 在 发电厂 从事 土建工程 建设 , 退休 后 , 与 老伴 一起 抚养 着 身体 欠佳 的 孙子 。 李 鹏 对 他们 倾心 照顾 下 一 代 表示 肯定 。 他 说 : “ 人 老 了 , 照顾 照顾 后代 也 是 一 件 可以 带来 快乐 的 事 , 当然 , 对 孩子 们 不 能 溺爱 , 要 让 他们 健康 成长 。 ” 在 老工人 闫 戌麟 家 , 当 李 鹏 了解 到 老闫 退休 前 一直 都 是 厂里 的 先进 工作者 、 曾经 被 评为 北京市 “ 五好 职工 ” , 退休 后 仍然 为 改善 职工 的 住房 而 奔波 时 , 十分 高兴 , 对 他 为 工厂 建设 作出 的 贡献 表示 感谢 。 在 郭 家 和 闫 家 , 李 鹏 都 具体 地 了解 了 他们 退休 后 的 生活 保障 问题 , 并 与 一些 老 职工 一起 回忆 起 了 当年 建设 电厂 的 情景 。 李 鹏 说 : “ 当年 搞 建设 , 条件 比 现在 差 多 了 , 大家 也 很 少 计较 什么 , 只是 一心 想 着 把 电厂 建 好 。 现在 条件 好 了 , 但 艰苦奋斗 、 无私奉献 的 精神 可 不 能 丢 。 ” 李 鹏 最后 祝 他们 新春 快乐 , 身体 健康 , 家庭 幸福 。 | |||||
| 陪同 考察 企业 并 看望 慰问 职工 的 国务院 有关 部门 和 北京市 负责人 还有 : 史 大桢 、 高 严 、 石 秀诗 、 阳 安江 等 。 | |||||
| 挂 起 红灯 迎 新年 ( 图片 ) | |||||
| 元旦 来临 , 安徽省 合肥市 长江路 悬挂 起 3300 盏 大 红灯笼 , 为 节日 营造 出 “ 千 盏 灯笼 凌空 舞 , 十 里 长街 别样 红 ” 的 欢乐 祥和 气氛 。 ( 新华社 记者 戴 浩 摄 ) | |||||
| ( 传真 照片 ) | |||||
| 全总 致 全国 各族 职工 慰问信 | |||||
| 勉励 广大 职工 发挥 工人阶级 主力军 作用 , 为 企业 改革 发展 建功立业 | |||||
| 本报 北京 1月 1日 讯 中华 全国 总工会 今日 发出 《 致 全国 各族 职工 慰问信 》 , 向 全国 各族 职工 祝贺 新年 。 | |||||
| 慰问信 说 , 实现 党 的 十五大 提 出 的 宏伟 目标 , 必须 依靠 工人阶级 和 全体 人民 的 长期 奋斗 。 工人阶级 是 我们 国家 的 领导 阶级 , 是 先进 生产力 和 生产关系 的 代表 , 是 两 个 文明 建设 的 主力军 , 是 维护 社会 安定团结 的 中坚 力量 。 党 的 十五大 再次 强调 要 坚持 全心全意 依靠 工人阶级 的 方针 , 具有 重大 的 意义 。 广大 职工 要 以 邓小平理论 和 党 的 基本 路线 为 指导 , 坚持 党 的 基本 纲领 和 各项 方针 政策 , 积极 投身 于 改革 和 建设 事业 。 要 坚持 站 在 改革 的 前列 , 转变 思想 观念 , 增强 市场 意识 、 竞争 意识 和 效益 意识 , 以 实际 行动 促进 改革 的 不断 深化 。 要 发扬 工人阶级 的 首创 精神 , 不断 为 企业 转机建制 、 调整 结构 、 加强 管理 、 提高 效益 献计献策 。 要 大力 开展 劳动 竞赛 、 合理化 建议 、 技术 革新 、 技术 协作 和 发明 创造 等 活动 , 努力 提高 产品 质量 和 经济效益 , 推动 企业 加快 技术 进步 , 实现 增长 方式 的 根本 转变 , 再 创 国有 企业 的 辉煌 。 要 正确 对待 企业 改革 和 发展 中 的 困难 和 问题 , 树立 起 战胜 困难 的 勇气 和 信心 , 锲而不舍 , 迎难而上 , 为 企业 的 改革 和 发展 建功立业 。 | |||||
| 慰问信 指出 , 广大 职工 要 以 主人翁 的 姿态 , 积极 行使 当家作主 的 权利 。 要 不断 提高 自身 素质 , 发扬 爱国 奉献 、 爱厂如家 、 爱岗敬业 的 精神 , 学习 掌握 先进 科学 文化 知识 , 成为 本职工作 的 行家里手 , 迎接 新 世纪 面临 的 挑战 。 | |||||
| 慰问信 最后 说 , 让 我们 在 邓小平理论 和 党 的 基本 路线 指导 下 , 更加 紧密 地 团结 在 以 江 泽民 同志 为 核心 的 党中央 周围 , 统揽全局 , 精心 部署 , 狠抓 落实 , 团结 一致 , 艰苦奋斗 , 开拓 前进 , 在 两 个 文明 建设 中 充分 发挥 工人阶级 主力军 作用 , 为 实现 跨 世纪 宏伟 目标 作出 新 的 更 大 的 贡献 。 | |||||
| 忠诚 的 共产主义 战士 , 久经考验 的 无产阶级 革命家 刘 澜涛 同志 逝世 | |||||
| ( 附 图片 1 张 ) | |||||
| @@ -0,0 +1,2 @@ | |||||
| 迈向充满希望的新世纪——一九九八年新年讲话 | |||||
| (附图片1张) | |||||
| @@ -1,99 +0,0 @@ | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.action.trainer import POSTrainer | |||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader | |||||
| from fastNLP.loader.preprocess import POSPreprocess | |||||
| from fastNLP.saver.model_saver import ModelSaver | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.action.tester import POSTester | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.action.inference import Inference | |||||
| data_name = "people.txt" | |||||
| data_path = "data_for_tests/people.txt" | |||||
| pickle_path = "data_for_tests" | |||||
| def test_infer(): | |||||
| # Define the same model | |||||
| model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], | |||||
| num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], | |||||
| word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], | |||||
| rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model) | |||||
| print("model loaded!") | |||||
| # Data Loader | |||||
| pos_loader = POSDatasetLoader(data_name, data_path) | |||||
| infer_data = pos_loader.load_lines() | |||||
| # Preprocessor | |||||
| POSPreprocess(infer_data, pickle_path) | |||||
| # Inference interface | |||||
| infer = Inference() | |||||
| results = infer.predict(model, infer_data) | |||||
| if __name__ == "__main__": | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| # Data Loader | |||||
| pos_loader = POSDatasetLoader(data_name, data_path) | |||||
| train_data = pos_loader.load_lines() | |||||
| # Preprocessor | |||||
| p = POSPreprocess(train_data, pickle_path) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| # Trainer | |||||
| trainer = POSTrainer(train_args) | |||||
| # Model | |||||
| model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], | |||||
| num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], | |||||
| word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], | |||||
| rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) | |||||
| # Start training | |||||
| trainer.train(model) | |||||
| print("Training finished!") | |||||
| # Saver | |||||
| saver = ModelSaver("./saved_model.pkl") | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| del model, trainer, pos_loader | |||||
| # Define the same model | |||||
| model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], | |||||
| num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], | |||||
| word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], | |||||
| rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model) | |||||
| print("model loaded!") | |||||
| # Load test configuration | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # Tester | |||||
| tester = POSTester(test_args) | |||||
| # Start testing | |||||
| tester.test(model) | |||||
| # print test results | |||||
| print(tester.show_matrices()) | |||||
| print("model tested!") | |||||
| @@ -1,31 +1,7 @@ | |||||
| from loader.base_loader import ToyLoader0 | |||||
| from model.char_language_model import CharLM | |||||
| from fastNLP.action import Tester | |||||
| from fastNLP.action.trainer import Trainer | |||||
| def test_charlm(): | def test_charlm(): | ||||
| train_config = Trainer.TrainConfig(epochs=1, validate=True, save_when_better=True, | |||||
| log_per_step=10, log_validation=True, batch_size=160) | |||||
| trainer = Trainer(train_config) | |||||
| model = CharLM(lstm_batch_size=16, lstm_seq_len=10) | |||||
| train_data = ToyLoader0("load_train", "./data_for_tests/charlm.txt").load() | |||||
| valid_data = ToyLoader0("load_valid", "./data_for_tests/charlm.txt").load() | |||||
| trainer.train(model, train_data, valid_data) | |||||
| trainer.save_model(model) | |||||
| test_config = Tester.TestConfig(save_output=True, validate_in_training=True, | |||||
| save_dev_input=True, save_loss=True, batch_size=160) | |||||
| tester = Tester(test_config) | |||||
| test_data = ToyLoader0("load_test", "./data_for_tests/charlm.txt").load() | |||||
| tester.test(model, test_data) | |||||
| pass | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| @@ -0,0 +1,116 @@ | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.core.trainer import POSTrainer | |||||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
| from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
| from fastNLP.saver.model_saver import ModelSaver | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.core.tester import POSTester | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.core.inference import Inference | |||||
| data_name = "pku_training.utf8" | |||||
| # cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | |||||
| cws_data_path = "data_for_tests/cws_pku_utf_8" | |||||
| pickle_path = "data_for_tests" | |||||
| data_infer_path = "data_for_tests/people_infer.txt" | |||||
| def infer(): | |||||
| # Load infer configuration, the same as test | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | |||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| test_args["vocab_size"] = len(word2index) | |||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| test_args["num_classes"] = len(index2label) | |||||
| # Define the same model | |||||
| model = SeqLabeling(test_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Data Loader | |||||
| raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | |||||
| """ | |||||
| Transform strings into list of list of strings. | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||||
| """ | |||||
| # Inference interface | |||||
| infer = Inference(pickle_path) | |||||
| results = infer.predict(model, infer_data) | |||||
| print(results) | |||||
| print("Inference finished!") | |||||
| def train_test(): | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| # Data Loader | |||||
| loader = TokenizeDatasetLoader(data_name, cws_data_path) | |||||
| train_data = loader.load_pku() | |||||
| # Preprocessor | |||||
| p = POSPreprocess(train_data, pickle_path) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| # Trainer | |||||
| trainer = POSTrainer(train_args) | |||||
| # Model | |||||
| model = SeqLabeling(train_args) | |||||
| # Start training | |||||
| trainer.train(model) | |||||
| print("Training finished!") | |||||
| # Saver | |||||
| saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| del model, trainer, loader | |||||
| # Define the same model | |||||
| model = SeqLabeling(train_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Load test configuration | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # Tester | |||||
| tester = POSTester(test_args) | |||||
| # Start testing | |||||
| tester.test(model) | |||||
| # print test results | |||||
| print(tester.show_matrices()) | |||||
| print("model tested!") | |||||
| if __name__ == "__main__": | |||||
| train_test() | |||||
| infer() | |||||
| @@ -0,0 +1,14 @@ | |||||
| from fastNLP.fastnlp import FastNLP | |||||
| def foo(): | |||||
| nlp = FastNLP("./data_for_tests/") | |||||
| nlp.load("zh_pos_tag_model") | |||||
| text = "这是最好的基于深度学习的中文分词系统。" | |||||
| result = nlp.run(text) | |||||
| print(result) | |||||
| print("FastNLP finished!") | |||||
| if __name__ == "__main__": | |||||
| foo() | |||||
| @@ -1,28 +0,0 @@ | |||||
| import aggregation | |||||
| import decoder | |||||
| import encoder | |||||
| class Input(object): | |||||
| def __init__(self): | |||||
| pass | |||||
| class Trainer(object): | |||||
| def __init__(self, input, target, truth): | |||||
| pass | |||||
| def train(self): | |||||
| pass | |||||
| def test_keras_like(): | |||||
| data_train, label_train = dataLoader("./data_path") | |||||
| x = Input() | |||||
| x = encoder.LSTM(input=x) | |||||
| x = aggregation.max_pool(input=x) | |||||
| y = decoder.CRF(input=x) | |||||
| trainer = Trainer(input=data_train, target=y, truth=label_train) | |||||
| trainer.train() | |||||
| @@ -0,0 +1,115 @@ | |||||
| import sys | |||||
| sys.path.append("..") | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.core.trainer import POSTrainer | |||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||||
| from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
| from fastNLP.saver.model_saver import ModelSaver | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.core.tester import POSTester | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.core.inference import Inference | |||||
| data_name = "people.txt" | |||||
| data_path = "data_for_tests/people.txt" | |||||
| pickle_path = "data_for_tests" | |||||
| data_infer_path = "data_for_tests/people_infer.txt" | |||||
| def infer(): | |||||
| # Load infer configuration, the same as test | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | |||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| test_args["vocab_size"] = len(word2index) | |||||
| index2label = load_pickle(pickle_path, "id2class.pkl") | |||||
| test_args["num_classes"] = len(index2label) | |||||
| # Define the same model | |||||
| model = SeqLabeling(test_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Data Loader | |||||
| raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | |||||
| """ | |||||
| Transform strings into list of list of strings. | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. | |||||
| """ | |||||
| # Inference interface | |||||
| infer = Inference(pickle_path) | |||||
| results = infer.predict(model, infer_data) | |||||
| print(results) | |||||
| print("Inference finished!") | |||||
| def train_test(): | |||||
| # Config Loader | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| # Data Loader | |||||
| pos_loader = POSDatasetLoader(data_name, data_path) | |||||
| train_data = pos_loader.load_lines() | |||||
| # Preprocessor | |||||
| p = POSPreprocess(train_data, pickle_path) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| # Trainer | |||||
| trainer = POSTrainer(train_args) | |||||
| # Model | |||||
| model = SeqLabeling(train_args) | |||||
| # Start training | |||||
| trainer.train(model) | |||||
| print("Training finished!") | |||||
| # Saver | |||||
| saver = ModelSaver("./data_for_tests/saved_model.pkl") | |||||
| saver.save_pytorch(model) | |||||
| print("Model saved!") | |||||
| del model, trainer, pos_loader | |||||
| # Define the same model | |||||
| model = SeqLabeling(train_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, "./data_for_tests/saved_model.pkl") | |||||
| print("model loaded!") | |||||
| # Load test configuration | |||||
| test_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
| # Tester | |||||
| tester = POSTester(test_args) | |||||
| # Start testing | |||||
| tester.test(model) | |||||
| # print test results | |||||
| print(tester.show_matrices()) | |||||
| print("model tested!") | |||||
| if __name__ == "__main__": | |||||
| train_test() | |||||
| # infer() | |||||
| @@ -0,0 +1,37 @@ | |||||
| from fastNLP.core.tester import POSTester | |||||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | |||||
| from fastNLP.loader.preprocess import POSPreprocess | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| data_name = "pku_training.utf8" | |||||
| cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | |||||
| pickle_path = "data_for_tests" | |||||
| def foo(): | |||||
| loader = TokenizeDatasetLoader(data_name, "./data_for_tests/cws_pku_utf_8") | |||||
| train_data = loader.load_pku() | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| # Preprocessor | |||||
| p = POSPreprocess(train_data, pickle_path) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| model = SeqLabeling(train_args) | |||||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
| "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | |||||
| "use_cuda": True} | |||||
| validator = POSTester(valid_args) | |||||
| print("start validation.") | |||||
| validator.test(model) | |||||
| print(validator.show_matrices()) | |||||
| if __name__ == "__main__": | |||||
| foo() | |||||
| @@ -1,12 +1,5 @@ | |||||
| def test_trainer(): | def test_trainer(): | ||||
| Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | |||||
| train_config = Config(epochs=5, validate=True, save_when_better=True) | |||||
| trainer = Trainer(train_config) | |||||
| net = ToyModel() | |||||
| data = np.random.rand(20, 6) | |||||
| dev_data = np.random.rand(20, 6) | |||||
| trainer.train(net, data, dev_data) | |||||
| pass | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||