diff --git a/fastNLP/action/README.md b/fastNLP/action/README.md new file mode 100644 index 00000000..af0e39c3 --- /dev/null +++ b/fastNLP/action/README.md @@ -0,0 +1,8 @@ +SpaCy "Doc" +https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/doc.pyx#L80 + +SpaCy "Vocab" +https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/vocab.pyx#L25 + +SpaCy "Token" +https://github.com/explosion/spaCy/blob/75d2a05c2938f412f0fae44748374e4de19cc2be/spacy/tokens/token.pyx#L27 diff --git a/fastNLP/action/__init__.py b/fastNLP/action/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/action/action.py b/fastNLP/action/action.py new file mode 100644 index 00000000..c85a74df --- /dev/null +++ b/fastNLP/action/action.py @@ -0,0 +1,46 @@ +from saver.logger import Logger + + +class Action(object): + """ + base class for Trainer and Tester + """ + + def __init__(self): + super(Action, self).__init__() + self.logger = Logger("logger_output.txt") + + def load_config(self, args): + raise NotImplementedError + + def load_dataset(self, args): + raise NotImplementedError + + def log(self, string): + self.logger.log(string) + + def batchify(self, batch_size, X, Y=None): + """ + :param batch_size: int + :param X: feature matrix of size [n_sample, m_feature] + :param Y: label vector of size [n_sample, 1] (optional) + :return iteration:int, the number of step in each epoch + generator:generator, to generate batch inputs + """ + n_samples = X.shape[0] + num_iter = n_samples // batch_size + if Y is None: + generator = self._batch_generate(batch_size, num_iter, X) + else: + generator = self._batch_generate(batch_size, num_iter, X, Y) + return num_iter, generator + + @staticmethod + def _batch_generate(batch_size, num_iter, *data): + for step in range(num_iter): + start = batch_size * step + end = batch_size * (step + 1) + yield tuple([x[start:end] for x in data]) + + def make_log(self, *args): + return "log" diff --git a/fastNLP/action/tester.py b/fastNLP/action/tester.py new file mode 100644 index 00000000..0be1b010 --- /dev/null +++ b/fastNLP/action/tester.py @@ -0,0 +1,87 @@ +from collections import namedtuple + +import numpy as np + +from fastNLP.action import Action + + +class Tester(Action): + """docstring for Tester""" + + TestConfig = namedtuple("config", ["validate_in_training", "save_dev_input", "save_output", + "save_loss", "batch_size"]) + + def __init__(self, test_args): + """ + :param test_args: named tuple + """ + super(Tester, self).__init__() + self.validate_in_training = test_args.validate_in_training + self.save_dev_input = test_args.save_dev_input + self.valid_x = None + self.valid_y = None + self.save_output = test_args.save_output + self.output = None + self.save_loss = test_args.save_loss + self.mean_loss = None + self.batch_size = test_args.batch_size + + def test(self, network, data): + print("testing") + network.mode(test=True) # turn on the testing mode + if self.save_dev_input: + if self.valid_x is None: + valid_x, valid_y = network.prepare_input(data) + self.valid_x = valid_x + self.valid_y = valid_y + else: + valid_x = self.valid_x + valid_y = self.valid_y + else: + valid_x, valid_y = network.prepare_input(data) + + # split into batches by self.batch_size + iterations, test_batch_generator = self.batchify(self.batch_size, valid_x, valid_y) + + batch_output = list() + loss_history = list() + # turn on the testing mode of the network + network.mode(test=True) + + for step in range(iterations): + batch_x, batch_y = test_batch_generator.__next__() + + # forward pass from test input to predicted output + prediction = network.data_forward(batch_x) + + loss = network.get_loss(prediction, batch_y) + + if self.save_output: + batch_output.append(prediction.data) + if self.save_loss: + loss_history.append(loss) + self.log(self.make_log(step, loss)) + + if self.save_loss: + self.mean_loss = np.mean(np.array(loss_history)) + if self.save_output: + self.output = self.make_output(batch_output) + + @property + def loss(self): + return self.mean_loss + + @property + def result(self): + return self.output + + @staticmethod + def make_output(batch_outputs): + # construct full prediction with batch outputs + return np.concatenate(batch_outputs, axis=0) + + def load_config(self, args): + raise NotImplementedError + + def load_dataset(self, args): + raise NotImplementedError diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py new file mode 100644 index 00000000..b3640ba2 --- /dev/null +++ b/fastNLP/action/trainer.py @@ -0,0 +1,93 @@ +from collections import namedtuple + +from .action import Action +from .tester import Tester + + +class Trainer(Action): + """ + Trainer is a common training pipeline shared among all models. + """ + TrainConfig = namedtuple("config", ["epochs", "validate", "save_when_better", + "log_per_step", "log_validation", "batch_size"]) + + def __init__(self, train_args): + """ + :param train_args: namedtuple + """ + super(Trainer, self).__init__() + self.n_epochs = train_args.epochs + self.validate = train_args.validate + self.save_when_better = train_args.save_when_better + self.log_per_step = train_args.log_per_step + self.log_validation = train_args.log_validation + self.batch_size = train_args.batch_size + + def train(self, network, train_data, dev_data=None): + """ + :param network: the models controller + :param train_data: raw data for training + :param dev_data: raw data for validation + This method will call all the base methods of network (implemented in models.base_model). + """ + train_x, train_y = network.prepare_input(train_data) + + iterations, train_batch_generator = self.batchify(self.batch_size, train_x, train_y) + + test_args = Tester.TestConfig(save_output=True, validate_in_training=True, + save_dev_input=True, save_loss=True, batch_size=self.batch_size) + evaluator = Tester(test_args) + + best_loss = 1e10 + loss_history = list() + + for epoch in range(self.n_epochs): + network.mode(test=False) # turn on the train mode + + network.define_optimizer() + for step in range(iterations): + batch_x, batch_y = train_batch_generator.__next__() + + prediction = network.data_forward(batch_x) + + loss = network.get_loss(prediction, batch_y) + network.grad_backward() + + if step % self.log_per_step == 0: + print("step ", step) + loss_history.append(loss) + self.log(self.make_log(epoch, step, loss)) + + #################### evaluate over dev set ################### + if self.validate: + if dev_data is None: + raise RuntimeError("No validation data provided.") + # give all controls to tester + evaluator.test(network, dev_data) + + if self.log_validation: + self.log(self.make_valid_log(epoch, evaluator.loss)) + if evaluator.loss < best_loss: + best_loss = evaluator.loss + if self.save_when_better: + self.save_model(network) + + # finish training + + def make_log(self, *args): + return "make a log" + + def make_valid_log(self, *args): + return "make a valid log" + + def save_model(self, model): + model.save() + + def load_data(self, data_name): + print("load data") + + def load_config(self, args): + raise NotImplementedError + + def load_dataset(self, args): + raise NotImplementedError diff --git a/fastNLP/loader/__init__.py b/fastNLP/loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/loader/base_loader.py b/fastNLP/loader/base_loader.py new file mode 100644 index 00000000..2863f01f --- /dev/null +++ b/fastNLP/loader/base_loader.py @@ -0,0 +1,36 @@ +class BaseLoader(object): + """docstring for BaseLoader""" + + def __init__(self, data_name, data_path): + super(BaseLoader, self).__init__() + self.data_name = data_name + self.data_path = data_path + + def load(self): + """ + :return: string + """ + with open(self.data_path, "r", encoding="utf-8") as f: + text = f.read() + return text + + def load_lines(self): + with open(self.data_path, "r", encoding="utf=8") as f: + text = f.readlines() + return text + + +class ToyLoader0(BaseLoader): + """ + For charLM + """ + + def __init__(self, name, path): + super(ToyLoader0, self).__init__(name, path) + + def load(self): + with open(self.data_path, 'r') as f: + corpus = f.read().lower() + import re + corpus = re.sub(r"", "unk", corpus) + return corpus.split() diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py new file mode 100644 index 00000000..fa1d446d --- /dev/null +++ b/fastNLP/loader/config_loader.py @@ -0,0 +1,13 @@ +from loader.base_loader import BaseLoader + + +class ConfigLoader(BaseLoader): + """loader for configuration files""" + + def __int__(self, data_name, data_path): + super(ConfigLoader, self).__init__(data_name, data_path) + self.config = self.parse(super(ConfigLoader, self).load()) + + @staticmethod + def parse(string): + raise NotImplementedError diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py new file mode 100644 index 00000000..624032b0 --- /dev/null +++ b/fastNLP/loader/dataset_loader.py @@ -0,0 +1,47 @@ +from loader.base_loader import BaseLoader + + +class DatasetLoader(BaseLoader): + """"loader for data sets""" + + def __init__(self, data_name, data_path): + super(DatasetLoader, self).__init__(data_name, data_path) + + +class ConllLoader(DatasetLoader): + """loader for conll format files""" + + def __int__(self, data_name, data_path): + """ + :param str data_name: the name of the conll data set + :param str data_path: the path to the conll data set + """ + super(ConllLoader, self).__init__(data_name, data_path) + self.data_set = self.parse(self.load()) + + def load(self): + """ + :return: list lines: all lines in a conll file + """ + with open(self.data_path, "r", encoding="utf-8") as f: + lines = f.readlines() + return lines + + @staticmethod + def parse(lines): + """ + :param list lines:a list containing all lines in a conll file. + :return: a 3D list + """ + sentences = list() + tokens = list() + for line in lines: + if line[0] == "#": + # skip the comments + continue + if line == "\n": + sentences.append(tokens) + tokens = [] + continue + tokens.append(line.split()) + return sentences diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py new file mode 100644 index 00000000..9610ca2d --- /dev/null +++ b/fastNLP/loader/embed_loader.py @@ -0,0 +1,8 @@ +from loader.base_loader import BaseLoader + + +class EmbedLoader(BaseLoader): + """docstring for EmbedLoader""" + + def __init__(self, data_name, data_path): + super(EmbedLoader, self).__init__(data_name, data_path) diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py new file mode 100644 index 00000000..f593c2b7 --- /dev/null +++ b/fastNLP/models/base_model.py @@ -0,0 +1,158 @@ +import numpy as np + + +class BaseModel(object): + """The base class of all models. + This class and its subclasses are actually "wrappers" of the PyTorch models. + They act as an interface between Trainer and the deep learning networks. + This interface provides the following methods to be called by Trainer. + - prepare_input + - mode + - define_optimizer + - data_forward + - grad_backward + - get_loss + """ + + def __init__(self): + pass + + def prepare_input(self, data): + """ + Perform data transformation from raw input to vector/matrix inputs. + :param data: raw inputs + :return (X, Y): tuple, input features and labels + """ + raise NotImplementedError + + def mode(self, test=False): + """ + Tell the network to be trained or not, required by PyTorch. + :param test: bool + """ + raise NotImplementedError + + def define_optimizer(self): + """ + Define PyTorch optimizer specified by the models. + """ + raise NotImplementedError + + def data_forward(self, *x): + """ + Forward pass of the data. + :param x: input feature matrix and label vector + :return: output by the models + """ + # required by PyTorch nn + raise NotImplementedError + + def grad_backward(self): + """ + Perform gradient descent to update the models parameters. + """ + raise NotImplementedError + + def get_loss(self, pred, truth): + """ + Compute loss given models prediction and ground truth. Loss function specified by the models. + :param pred: prediction label vector + :param truth: ground truth label vector + :return: a scalar + """ + raise NotImplementedError + + +class ToyModel(BaseModel): + """This is for code testing.""" + + def __init__(self): + super(ToyModel, self).__init__() + self.test_mode = False + self.weight = np.random.rand(5, 1) + self.bias = np.random.rand() + self._loss = 0 + + def prepare_input(self, data): + return data[:, :-1], data[:, -1] + + def mode(self, test=False): + self.test_mode = test + + def data_forward(self, x): + return np.matmul(x, self.weight) + self.bias + + def grad_backward(self): + print("loss gradient backward") + + def get_loss(self, pred, truth): + self._loss = np.mean(np.square(pred - truth)) + return self._loss + + def define_optimizer(self): + pass + + +class Vocabulary(object): + """A look-up table that allows you to access `Lexeme` objects. The `Vocab` + instance also provides access to the `StringStore`, and owns underlying + data that is shared between `Doc` objects. + """ + + def __init__(self): + """Create the vocabulary. + RETURNS (Vocab): The newly constructed object. + """ + self.data_frame = None + + +class Document(object): + """A sequence of Token objects. Access sentences and named entities, export + annotations to numpy arrays, losslessly serialize to compressed binary + strings. The `Doc` object holds an array of `Token` objects. The + Python-level `Token` and `Span` objects are views of this array, i.e. + they don't own the data themselves. -- spacy + """ + + def __init__(self, vocab, words=None, spaces=None): + """Create a Doc object. + vocab (Vocab): A vocabulary object, which must match any models you + want to use (e.g. tokenizer, parser, entity recognizer). + words (list or None): A list of unicode strings, to add to the document + as words. If `None`, defaults to empty list. + spaces (list or None): A list of boolean values, of the same length as + words. True means that the word is followed by a space, False means + it is not. If `None`, defaults to `[True]*len(words)` + user_data (dict or None): Optional extra data to attach to the Doc. + RETURNS (Doc): The newly constructed object. + """ + self.vocab = vocab + self.spaces = spaces + self.words = words + if spaces is None: + self.spaces = [True] * len(self.words) + elif len(spaces) != len(self.words): + raise ValueError("dismatch spaces and words") + + def get_chunker(self, vocab): + return None + + def push_back(self, vocab): + pass + + +class Token(object): + """An individual token – i.e. a word, punctuation symbol, whitespace, + etc. + """ + + def __init__(self, vocab, doc, offset): + """Construct a `Token` object. + vocab (Vocabulary): A storage container for lexical types. + doc (Document): The parent document. + offset (int): The index of the token within the document. + """ + self.vocab = vocab + self.doc = doc + self.token = doc[offset] + self.i = offset diff --git a/fastNLP/models/char_language_model.py b/fastNLP/models/char_language_model.py new file mode 100644 index 00000000..9a6997b9 --- /dev/null +++ b/fastNLP/models/char_language_model.py @@ -0,0 +1,354 @@ +import os +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from model.base_model import BaseModel +from torch.autograd import Variable + +USE_GPU = True + + +class CharLM(BaseModel): + """ + Controller of the Character-level Neural Language Model + To do: + - where the data goes, call data savers. + """ + DataTuple = namedtuple("DataTuple", ["feature", "label"]) + + def __init__(self, lstm_batch_size, lstm_seq_len): + super(CharLM, self).__init__() + """ + Settings: should come from config loader or pre-processing + """ + self.word_embed_dim = 300 + self.char_embedding_dim = 15 + self.cnn_batch_size = lstm_batch_size * lstm_seq_len + self.lstm_seq_len = lstm_seq_len + self.lstm_batch_size = lstm_batch_size + self.num_epoch = 10 + self.old_PPL = 100000 + self.best_PPL = 100000 + + """ + These parameters are set by pre-processing. + """ + self.max_word_len = None + self.num_char = None + self.vocab_size = None + self.preprocess("./data_for_tests/charlm.txt") + + self.data = None # named tuple to store all data set + self.data_ready = False + self.criterion = nn.CrossEntropyLoss() + self._loss = None + self.use_gpu = USE_GPU + + # word_emb_dim == hidden_size / num of hidden units + self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)), + to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim))) + + self.model = charLM(self.char_embedding_dim, + self.word_embed_dim, + self.vocab_size, + self.num_char, + use_gpu=self.use_gpu) + for param in self.model.parameters(): + nn.init.uniform(param.data, -0.05, 0.05) + + self.learning_rate = 0.1 + self.optimizer = None + + def prepare_input(self, raw_text): + """ + :param raw_text: raw input text consisting of words + :return: torch.Tensor, torch.Tensor + feature matrix, label vector + This function is only called once in Trainer.train, but may called multiple times in Tester.test + So Tester will save test input for frequent calls. + """ + if os.path.exists("cache/prep.pt") is False: + self.preprocess("./data_for_tests/charlm.txt") # To do: This is not good. Need to fix.. + objects = torch.load("cache/prep.pt") + word_dict = objects["word_dict"] + char_dict = objects["char_dict"] + max_word_len = self.max_word_len + print("word/char dictionary built. Start making inputs.") + + words = raw_text + input_vec = np.array(text2vec(words, char_dict, max_word_len)) + # Labels are next-word index in word_dict with the same length as inputs + input_label = np.array([word_dict[w] for w in words[1:]] + [word_dict[words[-1]]]) + feature_input = torch.from_numpy(input_vec) + label_input = torch.from_numpy(input_label) + return feature_input, label_input + + def mode(self, test=False): + if test: + self.model.eval() + else: + self.model.train() + + def data_forward(self, x): + """ + :param x: Tensor of size [lstm_batch_size, lstm_seq_len, max_word_len+2] + :return: Tensor of size [num_words, ?] + """ + # additional processing of inputs after batching + num_seq = x.size()[0] // self.lstm_seq_len + x = x[:num_seq * self.lstm_seq_len, :] + x = x.view(-1, self.lstm_seq_len, self.max_word_len + 2) + + # detach hidden state of LSTM from last batch + hidden = [state.detach() for state in self.hidden] + output, self.hidden = self.model(to_var(x), hidden) + return output + + def grad_backward(self): + self.model.zero_grad() + self._loss.backward() + torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2) + self.optimizer.step() + + def get_loss(self, predict, truth): + self._loss = self.criterion(predict, to_var(truth)) + return self._loss.data # No pytorch data structure exposed outsides + + def define_optimizer(self): + # redefine optimizer for every new epoch + self.optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.85) + + def save(self): + print("network saved") + # torch.save(self.models, "cache/models.pkl") + + def preprocess(self, all_text_files): + word_dict, char_dict = create_word_char_dict(all_text_files) + num_char = len(char_dict) + self.vocab_size = len(word_dict) + char_dict["BOW"] = num_char + 1 + char_dict["EOW"] = num_char + 2 + char_dict["PAD"] = 0 + self.num_char = num_char + 3 + # char_dict is a dict of (int, string), int counting from 0 to 47 + reverse_word_dict = {value: key for key, value in word_dict.items()} + self.max_word_len = max([len(word) for word in word_dict]) + objects = { + "word_dict": word_dict, + "char_dict": char_dict, + "reverse_word_dict": reverse_word_dict, + } + torch.save(objects, "cache/prep.pt") + print("Preprocess done.") + + +""" + Global Functions +""" + + +def batch_generator(x, batch_size): + # x: [num_words, in_channel, height, width] + # partitions x into batches + num_step = x.size()[0] // batch_size + for t in range(num_step): + yield x[t * batch_size:(t + 1) * batch_size] + + +def text2vec(words, char_dict, max_word_len): + """ Return list of list of int """ + word_vec = [] + for word in words: + vec = [char_dict[ch] for ch in word] + if len(vec) < max_word_len: + vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))] + vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]] + word_vec.append(vec) + return word_vec + + +def read_data(file_name): + with open(file_name, 'r') as f: + corpus = f.read().lower() + import re + corpus = re.sub(r"", "unk", corpus) + return corpus.split() + + +def get_char_dict(vocabulary): + char_dict = dict() + count = 1 + for word in vocabulary: + for ch in word: + if ch not in char_dict: + char_dict[ch] = count + count += 1 + return char_dict + + +def create_word_char_dict(*file_name): + text = [] + for file in file_name: + text += read_data(file) + word_dict = {word: ix for ix, word in enumerate(set(text))} + char_dict = get_char_dict(word_dict) + return word_dict, char_dict + + +def to_var(x): + if torch.cuda.is_available() and USE_GPU: + x = x.cuda() + return Variable(x) + + +""" + Neural Network +""" + + +class Highway(nn.Module): + """Highway network""" + + def __init__(self, input_size): + super(Highway, self).__init__() + self.fc1 = nn.Linear(input_size, input_size, bias=True) + self.fc2 = nn.Linear(input_size, input_size, bias=True) + + def forward(self, x): + t = F.sigmoid(self.fc1(x)) + return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x) + + +class charLM(nn.Module): + """Character-level Neural Language Model + CNN + highway network + LSTM + # Input: + 4D tensor with shape [batch_size, in_channel, height, width] + # Output: + 2D Tensor with shape [batch_size, vocab_size] + # Arguments: + char_emb_dim: the size of each character's attention + word_emb_dim: the size of each word's attention + vocab_size: num of unique words + num_char: num of characters + use_gpu: True or False + """ + + def __init__(self, char_emb_dim, word_emb_dim, + vocab_size, num_char, use_gpu): + super(charLM, self).__init__() + self.char_emb_dim = char_emb_dim + self.word_emb_dim = word_emb_dim + self.vocab_size = vocab_size + + # char attention layer + self.char_embed = nn.Embedding(num_char, char_emb_dim) + + # convolutions of filters with different sizes + self.convolutions = [] + + # list of tuples: (the number of filter, width) + # self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)] + self.filter_num_width = [(25, 1), (50, 2), (75, 3)] + + for out_channel, filter_width in self.filter_num_width: + self.convolutions.append( + nn.Conv2d( + 1, # in_channel + out_channel, # out_channel + kernel_size=(char_emb_dim, filter_width), # (height, width) + bias=True + ) + ) + + self.highway_input_dim = sum([x for x, y in self.filter_num_width]) + + self.batch_norm = nn.BatchNorm1d(self.highway_input_dim, affine=False) + + # highway net + self.highway1 = Highway(self.highway_input_dim) + self.highway2 = Highway(self.highway_input_dim) + + # LSTM + self.lstm_num_layers = 2 + + self.lstm = nn.LSTM(input_size=self.highway_input_dim, + hidden_size=self.word_emb_dim, + num_layers=self.lstm_num_layers, + bias=True, + dropout=0.5, + batch_first=True) + + # output layer + self.dropout = nn.Dropout(p=0.5) + self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) + + if use_gpu is True: + for x in range(len(self.convolutions)): + self.convolutions[x] = self.convolutions[x].cuda() + self.highway1 = self.highway1.cuda() + self.highway2 = self.highway2.cuda() + self.lstm = self.lstm.cuda() + self.dropout = self.dropout.cuda() + self.char_embed = self.char_embed.cuda() + self.linear = self.linear.cuda() + self.batch_norm = self.batch_norm.cuda() + + def forward(self, x, hidden): + # Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2] + # Return: Variable of Tensor with shape [num_words, len(word_dict)] + lstm_batch_size = x.size()[0] + lstm_seq_len = x.size()[1] + + x = x.contiguous().view(-1, x.size()[2]) + # [num_seq*seq_len, max_word_len+2] + + x = self.char_embed(x) + # [num_seq*seq_len, max_word_len+2, char_emb_dim] + + x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3) + # [num_seq*seq_len, 1, char_emb_dim, max_word_len+2] + + x = self.conv_layers(x) + # [num_seq*seq_len, total_num_filters] + + x = self.batch_norm(x) + # [num_seq*seq_len, total_num_filters] + + x = self.highway1(x) + x = self.highway2(x) + # [num_seq*seq_len, total_num_filters] + + x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) + # [num_seq, seq_len, total_num_filters] + + x, hidden = self.lstm(x, hidden) + # [seq_len, num_seq, hidden_size] + + x = self.dropout(x) + # [seq_len, num_seq, hidden_size] + + x = x.contiguous().view(lstm_batch_size * lstm_seq_len, -1) + # [num_seq*seq_len, hidden_size] + + x = self.linear(x) + # [num_seq*seq_len, vocab_size] + return x, hidden + + def conv_layers(self, x): + chosen_list = list() + for conv in self.convolutions: + feature_map = F.tanh(conv(x)) + # (batch_size, out_channel, 1, max_word_len-width+1) + chosen = torch.max(feature_map, 3)[0] + # (batch_size, out_channel, 1) + chosen = chosen.squeeze() + # (batch_size, out_channel) + chosen_list.append(chosen) + + # (batch_size, total_num_filers) + return torch.cat(chosen_list, 1) diff --git a/fastNLP/models/word_seg_model.py b/fastNLP/models/word_seg_model.py new file mode 100644 index 00000000..71f3b43e --- /dev/null +++ b/fastNLP/models/word_seg_model.py @@ -0,0 +1,134 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from model.base_model import BaseModel +from torch.autograd import Variable + +USE_GPU = True + + +def to_var(x): + if torch.cuda.is_available() and USE_GPU: + x = x.cuda() + return Variable(x) + + +class WordSegModel(BaseModel): + """ + Model controller for WordSeg + """ + + def __init__(self): + super(WordSegModel, self).__init__() + self.id2word = None + self.word2id = None + self.id2tag = None + self.tag2id = None + + self.lstm_batch_size = 8 + self.lstm_seq_len = 32 # Trainer batch_size == lstm_batch_size * lstm_seq_len + self.hidden_dim = 100 + self.lstm_num_layers = 2 + self.vocab_size = 100 + self.word_emb_dim = 100 + + self.model = WordSeg(self.hidden_dim, self.lstm_num_layers, self.vocab_size, self.word_emb_dim) + self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim)), + to_var(torch.zeros(2, self.lstm_batch_size, self.word_emb_dim))) + + self.optimizer = None + self._loss = None + + def prepare_input(self, data): + """ + perform word indices lookup to convert strings into indices + :param data: list of string, each string contains word + space + [B, M, E, S] + :return + """ + word_list = [] + tag_list = [] + for line in data: + if len(line) > 2: + tokens = line.split("#") + word_list.append(tokens[0]) + tag_list.append(tokens[2][0]) + self.id2word = list(set(word_list)) + self.word2id = {word: idx for idx, word in enumerate(self.id2word)} + self.id2tag = list(set(tag_list)) + self.tag2id = {tag: idx for idx, tag in enumerate(self.id2tag)} + words = np.array([self.word2id[w] for w in word_list]).reshape(-1, 1) + tags = np.array([self.tag2id[t] for t in tag_list]).reshape(-1, 1) + return words, tags + + def mode(self, test=False): + if test: + self.model.eval() + else: + self.model.train() + + def data_forward(self, x): + """ + :param x: sequence of length [batch_size], word indices + :return: + """ + x = x.reshape(self.lstm_batch_size, self.lstm_seq_len) + output, self.hidden = self.model(x, self.hidden) + return output + + def define_optimizer(self): + self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.85) + + def get_loss(self, pred, truth): + + self._loss = nn.CrossEntropyLoss(pred, truth) + return self._loss + + def grad_backward(self): + self.model.zero_grad() + self._loss.backward() + torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2) + self.optimizer.step() + + +class WordSeg(nn.Module): + """ + PyTorch Network for word segmentation + """ + + def __init__(self, hidden_dim, lstm_num_layers, vocab_size, word_emb_dim=100): + super(WordSeg, self).__init__() + + self.vocab_size = vocab_size + self.word_emb_dim = word_emb_dim + self.lstm_num_layers = lstm_num_layers + self.hidden_dim = hidden_dim + + self.word_emb = nn.Embedding(self.vocab_size, self.word_emb_dim) + + self.lstm = nn.LSTM(input_size=self.word_emb_dim, + hidden_size=self.word_emb_dim, + num_layers=self.lstm_num_layers, + bias=True, + dropout=0.5, + batch_first=True) + + self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) + + def forward(self, x, hidden): + """ + :param x: tensor of shape [batch_size, seq_len], vocabulary index + :param hidden: + :return x: probability of vocabulary entries + hidden: (memory cell, hidden state) from LSTM + """ + # [batch_size, seq_len] + x = self.word_emb(x) + # [batch_size, seq_len, word_emb_size] + x, hidden = self.lstm(x, hidden) + # [batch_size, seq_len, word_emb_size] + x = x.contiguous().view(x.shape[0] * x.shape[1], -1) + # [batch_size*seq_len, word_emb_size] + x = self.linear(x) + # [batch_size*seq_len, vocab_size] + return x, hidden