| @@ -2,6 +2,7 @@ import _pickle | |||
| import numpy as np | |||
| import torch | |||
| import os | |||
| from fastNLP.action.action import Action | |||
| from fastNLP.action.action import RandomSampler, Batchifier | |||
| @@ -174,3 +175,153 @@ class POSTester(BaseTester): | |||
| """ | |||
| loss, accuracy = self.matrices() | |||
| return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||
| class ClassTester(BaseTester): | |||
| """Tester for classification.""" | |||
| def __init__(self, test_args): | |||
| """ | |||
| :param test_args: a dict-like object that has __getitem__ method, \ | |||
| can be accessed by "test_args["key_str"]" | |||
| """ | |||
| # super(ClassTester, self).__init__() | |||
| self.pickle_path = test_args["pickle_path"] | |||
| self.save_dev_data = None | |||
| self.output = None | |||
| self.mean_loss = None | |||
| self.iterator = None | |||
| if "test_name" in test_args: | |||
| self.test_name = test_args["test_name"] | |||
| else: | |||
| self.test_name = "data_test.pkl" | |||
| if "validate_in_training" in test_args: | |||
| self.validate_in_training = test_args["validate_in_training"] | |||
| else: | |||
| self.validate_in_training = False | |||
| if "save_output" in test_args: | |||
| self.save_output = test_args["save_output"] | |||
| else: | |||
| self.save_output = False | |||
| if "save_loss" in test_args: | |||
| self.save_loss = test_args["save_loss"] | |||
| else: | |||
| self.save_loss = True | |||
| if "batch_size" in test_args: | |||
| self.batch_size = test_args["batch_size"] | |||
| else: | |||
| self.batch_size = 50 | |||
| if "use_cuda" in test_args: | |||
| self.use_cuda = test_args["use_cuda"] | |||
| else: | |||
| self.use_cuda = True | |||
| if "max_len" in test_args: | |||
| self.max_len = test_args["max_len"] | |||
| else: | |||
| self.max_len = None | |||
| self.model = None | |||
| self.eval_history = [] | |||
| self.batch_output = [] | |||
| def test(self, network): | |||
| # prepare model | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| self.model = network.cuda() | |||
| else: | |||
| self.model = network | |||
| # no backward setting for model | |||
| for param in self.model.parameters(): | |||
| param.requires_grad = False | |||
| # turn on the testing mode; clean up the history | |||
| self.mode(network, test=True) | |||
| # prepare test data | |||
| data_test = self.prepare_input(self.pickle_path, self.test_name) | |||
| # data generator | |||
| self.iterator = iter(Batchifier( | |||
| RandomSampler(data_test), self.batch_size, drop_last=False)) | |||
| # test | |||
| n_batches = len(data_test) // self.batch_size | |||
| n_print = n_batches // 10 | |||
| step = 0 | |||
| for batch_x, batch_y in self.batchify(data_test, max_len=self.max_len): | |||
| prediction = self.data_forward(network, batch_x) | |||
| eval_results = self.evaluate(prediction, batch_y) | |||
| if self.save_output: | |||
| self.batch_output.append(prediction) | |||
| if self.save_loss: | |||
| self.eval_history.append(eval_results) | |||
| if step % n_print == 0: | |||
| print("step: {:>5}".format(step)) | |||
| step += 1 | |||
| def prepare_input(self, data_dir, file_name): | |||
| """Prepare data.""" | |||
| file_path = os.path.join(data_dir, file_name) | |||
| with open(file_path, 'rb') as f: | |||
| data = _pickle.load(f) | |||
| return data | |||
| def batchify(self, data, max_len=None): | |||
| """Batch and pad data.""" | |||
| for indices in self.iterator: | |||
| # generate batch and pad | |||
| batch = [data[idx] for idx in indices] | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x = self.pad(batch_x) | |||
| # convert to tensor | |||
| batch_x = torch.tensor(batch_x, dtype=torch.long) | |||
| batch_y = torch.tensor(batch_y, dtype=torch.long) | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| batch_x = batch_x.cuda() | |||
| batch_y = batch_y.cuda() | |||
| # trim data to max_len | |||
| if max_len is not None and batch_x.size(1) > max_len: | |||
| batch_x = batch_x[:, :max_len] | |||
| yield batch_x, batch_y | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| logits = network(x) | |||
| return logits | |||
| def evaluate(self, y_logit, y_true): | |||
| """Return y_pred and y_true.""" | |||
| y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||
| return [y_prob, y_true] | |||
| def matrices(self): | |||
| """Compute accuracy.""" | |||
| y_prob, y_true = zip(*self.eval_history) | |||
| y_prob = torch.cat(y_prob, dim=0) | |||
| y_pred = torch.argmax(y_prob, dim=-1) | |||
| y_true = torch.cat(y_true, dim=0) | |||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||
| def mode(self, model, test=True): | |||
| """To do: combine this function with Trainer ?? """ | |||
| if test: | |||
| model.eval() | |||
| else: | |||
| model.train() | |||
| self.eval_history.clear() | |||
| @@ -2,6 +2,10 @@ import _pickle | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import os | |||
| from time import time | |||
| from datetime import timedelta | |||
| from fastNLP.action.action import Action | |||
| from fastNLP.action.action import RandomSampler, Batchifier | |||
| @@ -348,6 +352,201 @@ class LanguageModelTrainer(BaseTrainer): | |||
| pass | |||
| class ClassTrainer(BaseTrainer): | |||
| """Trainer for classification.""" | |||
| def __init__(self, train_args): | |||
| # super(ClassTrainer, self).__init__(train_args) | |||
| self.n_epochs = train_args["epochs"] | |||
| self.batch_size = train_args["batch_size"] | |||
| self.pickle_path = train_args["pickle_path"] | |||
| if "validate" in train_args: | |||
| self.validate = train_args["validate"] | |||
| else: | |||
| self.validate = False | |||
| if "learn_rate" in train_args: | |||
| self.learn_rate = train_args["learn_rate"] | |||
| else: | |||
| self.learn_rate = 1e-3 | |||
| if "momentum" in train_args: | |||
| self.momentum = train_args["momentum"] | |||
| else: | |||
| self.momentum = 0.9 | |||
| if "use_cuda" in train_args: | |||
| self.use_cuda = train_args["use_cuda"] | |||
| else: | |||
| self.use_cuda = True | |||
| self.model = None | |||
| self.iterator = None | |||
| self.loss_func = None | |||
| self.optimizer = None | |||
| def train(self, network): | |||
| """General Training Steps | |||
| :param network: a model | |||
| The method is framework independent. | |||
| Work by calling the following methods: | |||
| - prepare_input | |||
| - mode | |||
| - define_optimizer | |||
| - data_forward | |||
| - get_loss | |||
| - grad_backward | |||
| - update | |||
| Subclasses must implement these methods with a specific framework. | |||
| """ | |||
| # prepare model and data, transfer model to gpu if available | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| self.model = network.cuda() | |||
| else: | |||
| self.model = network | |||
| data_train, data_dev, data_test, embedding = self.prepare_input( | |||
| self.pickle_path) | |||
| # define tester over dev data | |||
| # valid_args = { | |||
| # "save_output": True, "validate_in_training": True, | |||
| # "save_dev_input": True, "save_loss": True, | |||
| # "batch_size": self.batch_size, "pickle_path": self.pickle_path} | |||
| # validator = POSTester(valid_args) | |||
| # urn on network training mode, define loss and optimizer | |||
| self.define_loss() | |||
| self.define_optimizer() | |||
| self.mode(test=False) | |||
| # main training epochs | |||
| start = time() | |||
| n_samples = len(data_train) | |||
| n_batches = n_samples // self.batch_size | |||
| n_print = n_batches // 10 | |||
| for epoch in range(self.n_epochs): | |||
| # prepare batch iterator | |||
| self.iterator = iter(Batchifier( | |||
| RandomSampler(data_train), self.batch_size, drop_last=False)) | |||
| # training iterations in one epoch | |||
| step = 0 | |||
| for batch_x, batch_y in self.batchify(data_train): | |||
| prediction = self.data_forward(network, batch_x) | |||
| loss = self.get_loss(prediction, batch_y) | |||
| self.grad_backward(loss) | |||
| self.update() | |||
| if step % n_print == 0: | |||
| acc = self.get_acc(prediction, batch_y) | |||
| end = time() | |||
| diff = timedelta(seconds=round(end - start)) | |||
| print("epoch: {:>3} step: {:>4} loss: {:>4.2}" | |||
| " train acc: {:>5.1%} time: {}".format( | |||
| epoch, step, loss, acc, diff)) | |||
| step += 1 | |||
| # if self.validate: | |||
| # if data_dev is None: | |||
| # raise RuntimeError("No validation data provided.") | |||
| # validator.test(network) | |||
| # print("[epoch {}]".format(epoch), end=" ") | |||
| # print(validator.show_matrices()) | |||
| # finish training | |||
| def prepare_input(self, data_path): | |||
| """ | |||
| To do: Load pkl files of train/dev/test and embedding | |||
| """ | |||
| names = [ | |||
| "data_train.pkl", "data_dev.pkl", | |||
| "data_test.pkl", "embedding.pkl"] | |||
| files = [] | |||
| for name in names: | |||
| file_path = os.path.join(data_path, name) | |||
| if os.path.exists(file_path): | |||
| with open(file_path, 'rb') as f: | |||
| data = _pickle.load(f) | |||
| else: | |||
| data = [] | |||
| files.append(data) | |||
| return tuple(files) | |||
| def mode(self, test=False): | |||
| """ | |||
| Tell the network to be trained or not. | |||
| :param test: bool | |||
| """ | |||
| if test: | |||
| self.model.eval() | |||
| else: | |||
| self.model.train() | |||
| def define_loss(self): | |||
| """ | |||
| Assign an instance of loss function to self.loss_func | |||
| E.g. self.loss_func = nn.CrossEntropyLoss() | |||
| """ | |||
| if self.loss_func is None: | |||
| if hasattr(self.model, "loss"): | |||
| self.loss_func = self.model.loss | |||
| else: | |||
| self.loss_func = nn.CrossEntropyLoss() | |||
| def define_optimizer(self): | |||
| """ | |||
| Define framework-specific optimizer specified by the models. | |||
| """ | |||
| self.optimizer = torch.optim.SGD( | |||
| self.model.parameters(), | |||
| lr=self.learn_rate, | |||
| momentum=self.momentum) | |||
| def data_forward(self, network, x): | |||
| """Forward through network.""" | |||
| logits = network(x) | |||
| return logits | |||
| def get_loss(self, predict, truth): | |||
| """Calculate loss.""" | |||
| return self.loss_func(predict, truth) | |||
| def grad_backward(self, loss): | |||
| """Compute gradient backward.""" | |||
| self.model.zero_grad() | |||
| loss.backward() | |||
| def update(self): | |||
| """Apply gradient.""" | |||
| self.optimizer.step() | |||
| def batchify(self, data): | |||
| """Batch and pad data.""" | |||
| for indices in self.iterator: | |||
| batch = [data[idx] for idx in indices] | |||
| batch_x = [sample[0] for sample in batch] | |||
| batch_y = [sample[1] for sample in batch] | |||
| batch_x = self.pad(batch_x) | |||
| batch_x = torch.tensor(batch_x, dtype=torch.long) | |||
| batch_y = torch.tensor(batch_y, dtype=torch.long) | |||
| if torch.cuda.is_available() and self.use_cuda: | |||
| batch_x = batch_x.cuda() | |||
| batch_y = batch_y.cuda() | |||
| yield batch_x, batch_y | |||
| def get_acc(self, y_logit, y_true): | |||
| """Compute accuracy.""" | |||
| y_pred = torch.argmax(y_logit, dim=-1) | |||
| return int(torch.sum(y_true == y_pred)) / len(y_true) | |||
| if __name__ == "__name__": | |||
| train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} | |||
| trainer = BaseTrainer(train_args) | |||
| @@ -29,11 +29,11 @@ class POSDatasetLoader(DatasetLoader): | |||
| return lines | |||
| class ClassificationDatasetLoader(DatasetLoader): | |||
| """loader for classfication data sets""" | |||
| class ClassDatasetLoader(DatasetLoader): | |||
| """Loader for classification data sets""" | |||
| def __init__(self, data_name, data_path): | |||
| super(ClassificationDatasetLoader, data_name).__init__() | |||
| super(ClassDatasetLoader, self).__init__(data_name, data_path) | |||
| def load(self): | |||
| assert os.path.exists(self.data_path) | |||
| @@ -44,16 +44,21 @@ class ClassificationDatasetLoader(DatasetLoader): | |||
| @staticmethod | |||
| def parse(lines): | |||
| """ | |||
| :param lines: lines from dataset | |||
| :return: list(list(list())): the three level of lists are | |||
| Params | |||
| lines: lines from dataset | |||
| Return | |||
| list(list(list())): the three level of lists are | |||
| words, sentence, and dataset | |||
| """ | |||
| dataset = list() | |||
| for line in lines: | |||
| label = line.split(" ")[0] | |||
| words = line.split(" ")[1:] | |||
| word = list([w for w in words]) | |||
| sentence = list([word, label]) | |||
| line = line.strip().split() | |||
| label = line[0] | |||
| words = line[1:] | |||
| if len(words) <= 1: | |||
| continue | |||
| sentence = [words, label] | |||
| dataset.append(sentence) | |||
| return dataset | |||
| @@ -187,6 +187,191 @@ class POSPreprocess(BasePreprocess): | |||
| pass | |||
| class ClassPreprocess(BasePreprocess): | |||
| """ | |||
| Pre-process the classification datasets. | |||
| Params: | |||
| pickle_path - directory to save result of pre-processing | |||
| Saves: | |||
| word2id.pkl | |||
| id2word.pkl | |||
| class2id.pkl | |||
| id2class.pkl | |||
| embedding.pkl | |||
| data_train.pkl | |||
| data_dev.pkl | |||
| data_test.pkl | |||
| """ | |||
| def __init__(self, pickle_path): | |||
| # super(ClassPreprocess, self).__init__(data, pickle_path) | |||
| self.word_dict = None | |||
| self.label_dict = None | |||
| self.pickle_path = pickle_path # save directory | |||
| def process(self, data, save_name): | |||
| """ | |||
| Process data. | |||
| Params: | |||
| data - nested list, data = [sample1, sample2, ...], | |||
| sample = [sentence, label], sentence = [word1, word2, ...] | |||
| save_name - name of processed data, such as data_train.pkl | |||
| Returns: | |||
| vocab_size - vocabulary size | |||
| n_classes - number of classes | |||
| """ | |||
| self.build_dict(data) | |||
| self.word2id() | |||
| vocab_size = self.id2word() | |||
| self.class2id() | |||
| num_classes = self.id2class() | |||
| self.embedding() | |||
| self.data_generate(data, save_name) | |||
| return vocab_size, num_classes | |||
| def build_dict(self, data): | |||
| """Build vocabulary.""" | |||
| # just read if word2id.pkl and class2id.pkl exists | |||
| if self.pickle_exist("word2id.pkl") and \ | |||
| self.pickle_exist("class2id.pkl"): | |||
| file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||
| with open(file_name, 'rb') as f: | |||
| self.word_dict = _pickle.load(f) | |||
| file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||
| with open(file_name, 'rb') as f: | |||
| self.label_dict = _pickle.load(f) | |||
| return | |||
| # build vocabulary from scratch if nothing exists | |||
| self.word_dict = { | |||
| DEFAULT_PADDING_LABEL: 0, | |||
| DEFAULT_UNKNOWN_LABEL: 1, | |||
| DEFAULT_RESERVED_LABEL[0]: 2, | |||
| DEFAULT_RESERVED_LABEL[1]: 3, | |||
| DEFAULT_RESERVED_LABEL[2]: 4} | |||
| self.label_dict = {} | |||
| # collect every word and label | |||
| for sent, label in data: | |||
| if len(sent) <= 1: | |||
| continue | |||
| if label not in self.label_dict: | |||
| index = len(self.label_dict) | |||
| self.label_dict[label] = index | |||
| for word in sent: | |||
| if word not in self.word_dict: | |||
| index = len(self.word_dict) | |||
| self.word_dict[word[0]] = index | |||
| def pickle_exist(self, pickle_name): | |||
| """ | |||
| Check whether a pickle file exists. | |||
| Params | |||
| pickle_name: the filename of target pickle file | |||
| Return | |||
| True if file exists else False | |||
| """ | |||
| if not os.path.exists(self.pickle_path): | |||
| os.makedirs(self.pickle_path) | |||
| file_name = os.path.join(self.pickle_path, pickle_name) | |||
| if os.path.exists(file_name): | |||
| return True | |||
| else: | |||
| return False | |||
| def word2id(self): | |||
| """Save vocabulary of {word:id} mapping format.""" | |||
| # nothing will be done if word2id.pkl exists | |||
| if self.pickle_exist("word2id.pkl"): | |||
| return | |||
| file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||
| with open(file_name, "wb") as f: | |||
| _pickle.dump(self.word_dict, f) | |||
| def id2word(self): | |||
| """Save vocabulary of {id:word} mapping format.""" | |||
| # nothing will be done if id2word.pkl exists | |||
| if self.pickle_exist("id2word.pkl"): | |||
| file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||
| with open(file_name, 'rb') as f: | |||
| id2word_dict = _pickle.load(f) | |||
| return len(id2word_dict) | |||
| id2word_dict = {self.word_dict[w]: w for w in self.word_dict} | |||
| file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||
| with open(file_name, "wb") as f: | |||
| _pickle.dump(id2word_dict, f) | |||
| return len(id2word_dict) | |||
| def class2id(self): | |||
| """Save mapping of {class:id}.""" | |||
| # nothing will be done if class2id.pkl exists | |||
| if self.pickle_exist("class2id.pkl"): | |||
| return | |||
| file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||
| with open(file_name, "wb") as f: | |||
| _pickle.dump(self.label_dict, f) | |||
| def id2class(self): | |||
| """Save mapping of {id:class}.""" | |||
| # nothing will be done if id2class.pkl exists | |||
| if self.pickle_exist("id2class.pkl"): | |||
| file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||
| with open(file_name, "rb") as f: | |||
| id2class_dict = _pickle.load(f) | |||
| return len(id2class_dict) | |||
| id2class_dict = {self.label_dict[c]: c for c in self.label_dict} | |||
| file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||
| with open(file_name, "wb") as f: | |||
| _pickle.dump(id2class_dict, f) | |||
| return len(id2class_dict) | |||
| def embedding(self): | |||
| """Save embedding lookup table corresponding to vocabulary.""" | |||
| # nothing will be done if embedding.pkl exists | |||
| if self.pickle_exist("embedding.pkl"): | |||
| return | |||
| # retrieve vocabulary from pre-trained embedding (not implemented) | |||
| def data_generate(self, data_src, save_name): | |||
| """Convert dataset from text to digit.""" | |||
| # nothing will be done if file exists | |||
| save_path = os.path.join(self.pickle_path, save_name) | |||
| if os.path.exists(save_path): | |||
| return | |||
| data = [] | |||
| # for every sample | |||
| for sent, label in data_src: | |||
| if len(sent) <= 1: | |||
| continue | |||
| label_id = self.label_dict[label] # label id | |||
| sent_id = [] # sentence ids | |||
| for word in sent: | |||
| if word in self.word_dict: | |||
| sent_id.append(self.word_dict[word]) | |||
| else: | |||
| sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) | |||
| data.append([sent_id, label_id]) | |||
| # save data | |||
| with open(save_path, "wb") as f: | |||
| _pickle.dump(data, f) | |||
| class LMPreprocess(BasePreprocess): | |||
| def __init__(self, data, pickle_path): | |||
| super(LMPreprocess, self).__init__(data, pickle_path) | |||
| @@ -0,0 +1,37 @@ | |||
| # python: 3.6 | |||
| # encoding: utf-8 | |||
| import torch.nn as nn | |||
| # import torch.nn.functional as F | |||
| from fastNLP.models.base_model import BaseModel | |||
| from fastNLP.modules.encoder.conv_maxpool import ConvMaxpool | |||
| class CNNText(BaseModel): | |||
| """ | |||
| Text classification model by character CNN, the implementation of paper | |||
| 'Yoon Kim. 2014. Convolution Neural Networks for Sentence | |||
| Classification.' | |||
| """ | |||
| def __init__(self, class_num=9, | |||
| kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5], | |||
| embed_num=1000, embed_dim=300, pretrained_embed=None, | |||
| drop_prob=0.5): | |||
| super(CNNText, self).__init__() | |||
| # no support for pre-trained embedding currently | |||
| self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0) | |||
| self.conv_pool = ConvMaxpool( | |||
| in_channels=embed_dim, | |||
| out_channels=kernel_nums, | |||
| kernel_sizes=kernel_sizes) | |||
| self.dropout = nn.Dropout(drop_prob) | |||
| self.fc = nn.Linear(sum(kernel_nums), class_num) | |||
| def forward(self, x): | |||
| x = self.embed(x) # [N,L] -> [N,L,C] | |||
| x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
| x = self.dropout(x) | |||
| x = self.fc(x) # [N,C] -> [N, N_class] | |||
| return x | |||
| @@ -0,0 +1,53 @@ | |||
| # python: 3.6 | |||
| # encoding: utf-8 | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class ConvMaxpool(nn.Module): | |||
| """ | |||
| Convolution and max-pooling module with multiple kernel sizes. | |||
| """ | |||
| def __init__(self, in_channels, out_channels, kernel_sizes, | |||
| stride=1, padding=0, dilation=1, | |||
| groups=1, bias=True, activation='relu'): | |||
| super(ConvMaxpool, self).__init__() | |||
| # convolution | |||
| if isinstance(kernel_sizes, (list, tuple, int)): | |||
| if isinstance(kernel_sizes, int): | |||
| out_channels = [out_channels] | |||
| kernel_sizes = [kernel_sizes] | |||
| self.convs = nn.ModuleList([nn.Conv1d( | |||
| in_channels=in_channels, | |||
| out_channels=oc, | |||
| kernel_size=ks, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| bias=bias) | |||
| for oc, ks in zip(out_channels, kernel_sizes)]) | |||
| else: | |||
| raise Exception( | |||
| 'Incorrect kernel sizes: should be list, tuple or int') | |||
| # activation function | |||
| if activation == 'relu': | |||
| self.activation = F.relu | |||
| else: | |||
| raise Exception( | |||
| "Undefined activation function: choose from: relu") | |||
| def forward(self, x): | |||
| # [N,L,C] -> [N,C,L] | |||
| x = torch.transpose(x, 1, 2) | |||
| # convolution | |||
| xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L]] | |||
| # max-pooling | |||
| xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||
| for i in xs] # [[N, C]] | |||
| return torch.cat(xs, dim=-1) # [N,C] | |||