- add character vocab in preprocessor - add dataset loader for language model dataset - other minor adjustments - preserve only a little example data for language modeltags/v0.1.0^2
| @@ -33,6 +33,10 @@ class Loss(object): | |||||
| """Given a name of a loss function, return it from PyTorch. | """Given a name of a loss function, return it from PyTorch. | ||||
| :param loss_name: str, the name of a loss function | :param loss_name: str, the name of a loss function | ||||
| - cross_entropy: combines log softmax and nll loss in a single function. | |||||
| - nll: negative log likelihood | |||||
| :return loss: a PyTorch loss | :return loss: a PyTorch loss | ||||
| """ | """ | ||||
| if loss_name == "cross_entropy": | if loss_name == "cross_entropy": | ||||
| @@ -66,14 +66,26 @@ class Preprocessor(object): | |||||
| Preprocessors will check if those files are already in the directory and will reuse them in future calls. | Preprocessors will check if those files are already in the directory and will reuse them in future calls. | ||||
| """ | """ | ||||
| def __init__(self, label_is_seq=False): | |||||
| def __init__(self, label_is_seq=False, share_vocab=False, add_char_field=False): | |||||
| """ | """ | ||||
| :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | ||||
| several special tokens for sequence processing. | several special tokens for sequence processing. | ||||
| :param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this | |||||
| is only available when label_is_seq is True. Default: False. | |||||
| :param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | |||||
| """ | """ | ||||
| self.data_vocab = Vocabulary() | self.data_vocab = Vocabulary() | ||||
| self.label_vocab = Vocabulary(need_default=label_is_seq) | |||||
| if label_is_seq is True: | |||||
| if share_vocab is True: | |||||
| self.label_vocab = self.data_vocab | |||||
| else: | |||||
| self.label_vocab = Vocabulary() | |||||
| else: | |||||
| self.label_vocab = Vocabulary(need_default=False) | |||||
| self.character_vocab = Vocabulary(need_default=False) | |||||
| self.add_char_field = add_char_field | |||||
| @property | @property | ||||
| def vocab_size(self): | def vocab_size(self): | ||||
| @@ -83,6 +95,12 @@ class Preprocessor(object): | |||||
| def num_classes(self): | def num_classes(self): | ||||
| return len(self.label_vocab) | return len(self.label_vocab) | ||||
| @property | |||||
| def char_vocab_size(self): | |||||
| if self.character_vocab is None: | |||||
| self.build_char_dict() | |||||
| return len(self.character_vocab) | |||||
| def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | ||||
| """Main pre-processing pipeline. | """Main pre-processing pipeline. | ||||
| @@ -176,6 +194,16 @@ class Preprocessor(object): | |||||
| self.label_vocab.update(label) | self.label_vocab.update(label) | ||||
| return self.data_vocab, self.label_vocab | return self.data_vocab, self.label_vocab | ||||
| def build_char_dict(self): | |||||
| char_collection = set() | |||||
| for word in self.data_vocab.word2idx: | |||||
| if len(word) == 0: | |||||
| continue | |||||
| for ch in word: | |||||
| if ch not in char_collection: | |||||
| char_collection.add(ch) | |||||
| self.character_vocab.update(list(char_collection)) | |||||
| def build_reverse_dict(self): | def build_reverse_dict(self): | ||||
| self.data_vocab.build_reverse_vocab() | self.data_vocab.build_reverse_vocab() | ||||
| self.label_vocab.build_reverse_vocab() | self.label_vocab.build_reverse_vocab() | ||||
| @@ -231,7 +231,7 @@ class Trainer(object): | |||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| if self._task == "seq_label": | if self._task == "seq_label": | ||||
| y = network(x["word_seq"], x["word_seq_origin_len"]) | y = network(x["word_seq"], x["word_seq_origin_len"]) | ||||
| elif self._task == "text_classify": | |||||
| elif self._task == "text_classify" or self._task == "language_model": | |||||
| y = network(x["word_seq"]) | y = network(x["word_seq"]) | ||||
| else: | else: | ||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | raise NotImplementedError("Unknown task type {}.".format(self._task)) | ||||
| @@ -239,7 +239,7 @@ class Trainer(object): | |||||
| if not self._graph_summaried: | if not self._graph_summaried: | ||||
| if self._task == "seq_label": | if self._task == "seq_label": | ||||
| self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | ||||
| elif self._task == "text_classify": | |||||
| elif self._task == "text_classify" or self._task == "language_model": | |||||
| self._summary_writer.add_graph(network, x["word_seq"], verbose=False) | self._summary_writer.add_graph(network, x["word_seq"], verbose=False) | ||||
| self._graph_summaried = True | self._graph_summaried = True | ||||
| return y | return y | ||||
| @@ -10,13 +10,15 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | ||||
| DEFAULT_RESERVED_LABEL[2]: 4} | DEFAULT_RESERVED_LABEL[2]: 4} | ||||
| def isiterable(p_object): | def isiterable(p_object): | ||||
| try: | try: | ||||
| it = iter(p_object) | it = iter(p_object) | ||||
| except TypeError: | |||||
| except TypeError: | |||||
| return False | return False | ||||
| return True | return True | ||||
| class Vocabulary(object): | class Vocabulary(object): | ||||
| """Use for word and index one to one mapping | """Use for word and index one to one mapping | ||||
| @@ -28,9 +30,11 @@ class Vocabulary(object): | |||||
| vocab["word"] | vocab["word"] | ||||
| vocab.to_word(5) | vocab.to_word(5) | ||||
| """ | """ | ||||
| def __init__(self, need_default=True): | def __init__(self, need_default=True): | ||||
| """ | """ | ||||
| :param bool need_default: set if the Vocabulary has default labels reserved. | |||||
| :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||||
| """ | """ | ||||
| if need_default: | if need_default: | ||||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | ||||
| @@ -53,17 +57,16 @@ class Vocabulary(object): | |||||
| :param word: a list of str or str | :param word: a list of str or str | ||||
| """ | """ | ||||
| if not isinstance(word, str) and isiterable(word): | if not isinstance(word, str) and isiterable(word): | ||||
| # it's a nested list | |||||
| # it's a nested list | |||||
| for w in word: | for w in word: | ||||
| self.update(w) | self.update(w) | ||||
| else: | else: | ||||
| # it's a word to be added | |||||
| # it's a word to be added | |||||
| if word not in self.word2idx: | if word not in self.word2idx: | ||||
| self.word2idx[word] = len(self) | self.word2idx[word] = len(self) | ||||
| if self.idx2word is not None: | if self.idx2word is not None: | ||||
| self.idx2word = None | self.idx2word = None | ||||
| def __getitem__(self, w): | def __getitem__(self, w): | ||||
| """To support usage like:: | """To support usage like:: | ||||
| @@ -81,12 +84,12 @@ class Vocabulary(object): | |||||
| :param str w: | :param str w: | ||||
| """ | """ | ||||
| return self[w] | return self[w] | ||||
| def unknown_idx(self): | def unknown_idx(self): | ||||
| if self.unknown_label is None: | |||||
| if self.unknown_label is None: | |||||
| return None | return None | ||||
| return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
| def padding_idx(self): | def padding_idx(self): | ||||
| if self.padding_label is None: | if self.padding_label is None: | ||||
| return None | return None | ||||
| @@ -95,8 +98,8 @@ class Vocabulary(object): | |||||
| def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
| """build 'index to word' dict based on 'word to index' dict | """build 'index to word' dict based on 'word to index' dict | ||||
| """ | """ | ||||
| self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||||
| self.idx2word = {self.word2idx[w]: w for w in self.word2idx} | |||||
| def to_word(self, idx): | def to_word(self, idx): | ||||
| """given a word's index, return the word itself | """given a word's index, return the word itself | ||||
| @@ -105,7 +108,7 @@ class Vocabulary(object): | |||||
| if self.idx2word is None: | if self.idx2word is None: | ||||
| self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
| return self.idx2word[idx] | return self.idx2word[idx] | ||||
| def __getstate__(self): | def __getstate__(self): | ||||
| """use to prepare data for pickle | """use to prepare data for pickle | ||||
| """ | """ | ||||
| @@ -113,12 +116,9 @@ class Vocabulary(object): | |||||
| # no need to pickle idx2word as it can be constructed from word2idx | # no need to pickle idx2word as it can be constructed from word2idx | ||||
| del state['idx2word'] | del state['idx2word'] | ||||
| return state | return state | ||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| """use to restore state from pickle | """use to restore state from pickle | ||||
| """ | """ | ||||
| self.__dict__.update(state) | self.__dict__.update(state) | ||||
| self.idx2word = None | self.idx2word = None | ||||
| @@ -21,7 +21,7 @@ class BaseLoader(object): | |||||
| class ToyLoader0(BaseLoader): | class ToyLoader0(BaseLoader): | ||||
| """ | """ | ||||
| For charLM | |||||
| For CharLM | |||||
| """ | """ | ||||
| def __init__(self, data_path): | def __init__(self, data_path): | ||||
| @@ -208,6 +208,12 @@ class ConllLoader(DatasetLoader): | |||||
| class LMDatasetLoader(DatasetLoader): | class LMDatasetLoader(DatasetLoader): | ||||
| """Language Model Dataset Loader | |||||
| This loader produces data for language model training in a supervised way. | |||||
| That means it has X and Y. | |||||
| """ | |||||
| def __init__(self, data_path): | def __init__(self, data_path): | ||||
| super(LMDatasetLoader, self).__init__(data_path) | super(LMDatasetLoader, self).__init__(data_path) | ||||
| @@ -216,8 +222,20 @@ class LMDatasetLoader(DatasetLoader): | |||||
| raise FileNotFoundError("file {} not found.".format(self.data_path)) | raise FileNotFoundError("file {} not found.".format(self.data_path)) | ||||
| with open(self.data_path, "r", encoding="utf=8") as f: | with open(self.data_path, "r", encoding="utf=8") as f: | ||||
| text = " ".join(f.readlines()) | text = " ".join(f.readlines()) | ||||
| return text.strip().split() | |||||
| tokens = text.strip().split() | |||||
| return self.sentence_cut(tokens) | |||||
| def sentence_cut(self, tokens, sentence_length=15): | |||||
| start_idx = 0 | |||||
| data_set = [] | |||||
| for idx in range(len(tokens) // sentence_length): | |||||
| x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||||
| y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||||
| if start_idx * idx + sentence_length + 1 >= len(tokens): | |||||
| # ad hoc | |||||
| y.extend(["<unk>"]) | |||||
| data_set.append([x, y]) | |||||
| return data_set | |||||
| class PeopleDailyCorpusLoader(DatasetLoader): | class PeopleDailyCorpusLoader(DatasetLoader): | ||||
| """ | """ | ||||
| @@ -1,215 +1,8 @@ | |||||
| import os | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import torch.optim as optim | |||||
| from torch.autograd import Variable | |||||
| from fastNLP.models.base_model import BaseModel | |||||
| USE_GPU = True | |||||
| """ | |||||
| To be deprecated. | |||||
| """ | |||||
| class CharLM(BaseModel): | |||||
| """ | |||||
| Controller of the Character-level Neural Language Model | |||||
| """ | |||||
| def __init__(self, lstm_batch_size, lstm_seq_len): | |||||
| super(CharLM, self).__init__() | |||||
| """ | |||||
| Settings: should come from config loader or pre-processing | |||||
| """ | |||||
| self.word_embed_dim = 300 | |||||
| self.char_embedding_dim = 15 | |||||
| self.cnn_batch_size = lstm_batch_size * lstm_seq_len | |||||
| self.lstm_seq_len = lstm_seq_len | |||||
| self.lstm_batch_size = lstm_batch_size | |||||
| self.num_epoch = 10 | |||||
| self.old_PPL = 100000 | |||||
| self.best_PPL = 100000 | |||||
| """ | |||||
| These parameters are set by pre-processing. | |||||
| """ | |||||
| self.max_word_len = None | |||||
| self.num_char = None | |||||
| self.vocab_size = None | |||||
| self.preprocess("./data_for_tests/charlm.txt") | |||||
| self.data = None # named tuple to store all data set | |||||
| self.data_ready = False | |||||
| self.criterion = nn.CrossEntropyLoss() | |||||
| self._loss = None | |||||
| self.use_gpu = USE_GPU | |||||
| # word_emb_dim == hidden_size / num of hidden units | |||||
| self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)), | |||||
| to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim))) | |||||
| self.model = charLM(self.char_embedding_dim, | |||||
| self.word_embed_dim, | |||||
| self.vocab_size, | |||||
| self.num_char, | |||||
| use_gpu=self.use_gpu) | |||||
| for param in self.model.parameters(): | |||||
| nn.init.uniform(param.data, -0.05, 0.05) | |||||
| self.learning_rate = 0.1 | |||||
| self.optimizer = None | |||||
| def prepare_input(self, raw_text): | |||||
| """ | |||||
| :param raw_text: raw input text consisting of words | |||||
| :return: torch.Tensor, torch.Tensor | |||||
| feature matrix, label vector | |||||
| This function is only called once in Trainer.train, but may called multiple times in Tester.test | |||||
| So Tester will save test input for frequent calls. | |||||
| """ | |||||
| if os.path.exists("cache/prep.pt") is False: | |||||
| self.preprocess("./data_for_tests/charlm.txt") # To do: This is not good. Need to fix.. | |||||
| objects = torch.load("cache/prep.pt") | |||||
| word_dict = objects["word_dict"] | |||||
| char_dict = objects["char_dict"] | |||||
| max_word_len = self.max_word_len | |||||
| print("word/char dictionary built. Start making inputs.") | |||||
| words = raw_text | |||||
| input_vec = np.array(text2vec(words, char_dict, max_word_len)) | |||||
| # Labels are next-word index in word_dict with the same length as inputs | |||||
| input_label = np.array([word_dict[w] for w in words[1:]] + [word_dict[words[-1]]]) | |||||
| feature_input = torch.from_numpy(input_vec) | |||||
| label_input = torch.from_numpy(input_label) | |||||
| return feature_input, label_input | |||||
| def mode(self, test=False): | |||||
| if test: | |||||
| self.model.eval() | |||||
| else: | |||||
| self.model.train() | |||||
| def data_forward(self, x): | |||||
| """ | |||||
| :param x: Tensor of size [lstm_batch_size, lstm_seq_len, max_word_len+2] | |||||
| :return: Tensor of size [num_words, ?] | |||||
| """ | |||||
| # additional processing of inputs after batching | |||||
| num_seq = x.size()[0] // self.lstm_seq_len | |||||
| x = x[:num_seq * self.lstm_seq_len, :] | |||||
| x = x.view(-1, self.lstm_seq_len, self.max_word_len + 2) | |||||
| # detach hidden state of LSTM from last batch | |||||
| hidden = [state.detach() for state in self.hidden] | |||||
| output, self.hidden = self.model(to_var(x), hidden) | |||||
| return output | |||||
| def grad_backward(self): | |||||
| self.model.zero_grad() | |||||
| self._loss.backward() | |||||
| torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2) | |||||
| self.optimizer.step() | |||||
| def get_loss(self, predict, truth): | |||||
| self._loss = self.criterion(predict, to_var(truth)) | |||||
| return self._loss.data # No pytorch data structure exposed outsides | |||||
| def define_optimizer(self): | |||||
| # redefine optimizer for every new epoch | |||||
| self.optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.85) | |||||
| def save(self): | |||||
| print("network saved") | |||||
| # torch.save(self.models, "cache/models.pkl") | |||||
| def preprocess(self, all_text_files): | |||||
| word_dict, char_dict = create_word_char_dict(all_text_files) | |||||
| num_char = len(char_dict) | |||||
| self.vocab_size = len(word_dict) | |||||
| char_dict["BOW"] = num_char + 1 | |||||
| char_dict["EOW"] = num_char + 2 | |||||
| char_dict["PAD"] = 0 | |||||
| self.num_char = num_char + 3 | |||||
| # char_dict is a dict of (int, string), int counting from 0 to 47 | |||||
| reverse_word_dict = {value: key for key, value in word_dict.items()} | |||||
| self.max_word_len = max([len(word) for word in word_dict]) | |||||
| objects = { | |||||
| "word_dict": word_dict, | |||||
| "char_dict": char_dict, | |||||
| "reverse_word_dict": reverse_word_dict, | |||||
| } | |||||
| if not os.path.exists("cache"): | |||||
| os.mkdir("cache") | |||||
| torch.save(objects, "cache/prep.pt") | |||||
| print("Preprocess done.") | |||||
| """ | |||||
| Global Functions | |||||
| """ | |||||
| def batch_generator(x, batch_size): | |||||
| # x: [num_words, in_channel, height, width] | |||||
| # partitions x into batches | |||||
| num_step = x.size()[0] // batch_size | |||||
| for t in range(num_step): | |||||
| yield x[t * batch_size:(t + 1) * batch_size] | |||||
| def text2vec(words, char_dict, max_word_len): | |||||
| """ Return list of list of int """ | |||||
| word_vec = [] | |||||
| for word in words: | |||||
| vec = [char_dict[ch] for ch in word] | |||||
| if len(vec) < max_word_len: | |||||
| vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))] | |||||
| vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]] | |||||
| word_vec.append(vec) | |||||
| return word_vec | |||||
| def read_data(file_name): | |||||
| with open(file_name, 'r') as f: | |||||
| corpus = f.read().lower() | |||||
| import re | |||||
| corpus = re.sub(r"<unk>", "unk", corpus) | |||||
| return corpus.split() | |||||
| def get_char_dict(vocabulary): | |||||
| char_dict = dict() | |||||
| count = 1 | |||||
| for word in vocabulary: | |||||
| for ch in word: | |||||
| if ch not in char_dict: | |||||
| char_dict[ch] = count | |||||
| count += 1 | |||||
| return char_dict | |||||
| def create_word_char_dict(*file_name): | |||||
| text = [] | |||||
| for file in file_name: | |||||
| text += read_data(file) | |||||
| word_dict = {word: ix for ix, word in enumerate(set(text))} | |||||
| char_dict = get_char_dict(word_dict) | |||||
| return word_dict, char_dict | |||||
| def to_var(x): | |||||
| if torch.cuda.is_available() and USE_GPU: | |||||
| x = x.cuda() | |||||
| return Variable(x) | |||||
| """ | |||||
| Neural Network | |||||
| """ | |||||
| from fastNLP.modules.encoder.lstm import LSTM | |||||
| class Highway(nn.Module): | class Highway(nn.Module): | ||||
| @@ -225,9 +18,8 @@ class Highway(nn.Module): | |||||
| return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x) | return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x) | ||||
| class charLM(nn.Module): | |||||
| """Character-level Neural Language Model | |||||
| CNN + highway network + LSTM | |||||
| class CharLM(nn.Module): | |||||
| """CNN + highway network + LSTM | |||||
| # Input: | # Input: | ||||
| 4D tensor with shape [batch_size, in_channel, height, width] | 4D tensor with shape [batch_size, in_channel, height, width] | ||||
| # Output: | # Output: | ||||
| @@ -241,8 +33,8 @@ class charLM(nn.Module): | |||||
| """ | """ | ||||
| def __init__(self, char_emb_dim, word_emb_dim, | def __init__(self, char_emb_dim, word_emb_dim, | ||||
| vocab_size, num_char, use_gpu): | |||||
| super(charLM, self).__init__() | |||||
| vocab_size, num_char): | |||||
| super(CharLM, self).__init__() | |||||
| self.char_emb_dim = char_emb_dim | self.char_emb_dim = char_emb_dim | ||||
| self.word_emb_dim = word_emb_dim | self.word_emb_dim = word_emb_dim | ||||
| self.vocab_size = vocab_size | self.vocab_size = vocab_size | ||||
| @@ -254,8 +46,7 @@ class charLM(nn.Module): | |||||
| self.convolutions = [] | self.convolutions = [] | ||||
| # list of tuples: (the number of filter, width) | # list of tuples: (the number of filter, width) | ||||
| # self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)] | |||||
| self.filter_num_width = [(25, 1), (50, 2), (75, 3)] | |||||
| self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)] | |||||
| for out_channel, filter_width in self.filter_num_width: | for out_channel, filter_width in self.filter_num_width: | ||||
| self.convolutions.append( | self.convolutions.append( | ||||
| @@ -278,29 +69,13 @@ class charLM(nn.Module): | |||||
| # LSTM | # LSTM | ||||
| self.lstm_num_layers = 2 | self.lstm_num_layers = 2 | ||||
| self.lstm = nn.LSTM(input_size=self.highway_input_dim, | |||||
| hidden_size=self.word_emb_dim, | |||||
| num_layers=self.lstm_num_layers, | |||||
| bias=True, | |||||
| dropout=0.5, | |||||
| batch_first=True) | |||||
| self.lstm = LSTM(self.highway_input_dim, hidden_size=self.word_emb_dim, num_layers=self.lstm_num_layers, | |||||
| dropout=0.5) | |||||
| # output layer | # output layer | ||||
| self.dropout = nn.Dropout(p=0.5) | self.dropout = nn.Dropout(p=0.5) | ||||
| self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | ||||
| if use_gpu is True: | |||||
| for x in range(len(self.convolutions)): | |||||
| self.convolutions[x] = self.convolutions[x].cuda() | |||||
| self.highway1 = self.highway1.cuda() | |||||
| self.highway2 = self.highway2.cuda() | |||||
| self.lstm = self.lstm.cuda() | |||||
| self.dropout = self.dropout.cuda() | |||||
| self.char_embed = self.char_embed.cuda() | |||||
| self.linear = self.linear.cuda() | |||||
| self.batch_norm = self.batch_norm.cuda() | |||||
| def forward(self, x, hidden): | |||||
| def forward(self, x): | |||||
| # Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2] | # Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2] | ||||
| # Return: Variable of Tensor with shape [num_words, len(word_dict)] | # Return: Variable of Tensor with shape [num_words, len(word_dict)] | ||||
| lstm_batch_size = x.size()[0] | lstm_batch_size = x.size()[0] | ||||
| @@ -313,7 +88,7 @@ class charLM(nn.Module): | |||||
| # [num_seq*seq_len, max_word_len+2, char_emb_dim] | # [num_seq*seq_len, max_word_len+2, char_emb_dim] | ||||
| x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3) | x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3) | ||||
| # [num_seq*seq_len, 1, char_emb_dim, max_word_len+2] | |||||
| # [num_seq*seq_len, 1, max_word_len+2, char_emb_dim] | |||||
| x = self.conv_layers(x) | x = self.conv_layers(x) | ||||
| # [num_seq*seq_len, total_num_filters] | # [num_seq*seq_len, total_num_filters] | ||||
| @@ -328,7 +103,7 @@ class charLM(nn.Module): | |||||
| x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | ||||
| # [num_seq, seq_len, total_num_filters] | # [num_seq, seq_len, total_num_filters] | ||||
| x, hidden = self.lstm(x, hidden) | |||||
| x, hidden = self.lstm(x) | |||||
| # [seq_len, num_seq, hidden_size] | # [seq_len, num_seq, hidden_size] | ||||
| x = self.dropout(x) | x = self.dropout(x) | ||||
| @@ -339,7 +114,7 @@ class charLM(nn.Module): | |||||
| x = self.linear(x) | x = self.linear(x) | ||||
| # [num_seq*seq_len, vocab_size] | # [num_seq*seq_len, vocab_size] | ||||
| return x, hidden | |||||
| return x | |||||
| def conv_layers(self, x): | def conv_layers(self, x): | ||||
| chosen_list = list() | chosen_list = list() | ||||
| @@ -31,7 +31,7 @@ class SeqLabeling(BaseModel): | |||||
| num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | ||||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | |||||
| self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim) | |||||
| self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | ||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
| self.mask = None | self.mask = None | ||||
| @@ -97,7 +97,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
| num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||||
| self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||||
| self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
| self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | ||||
| self.relu = torch.nn.ReLU() | self.relu = torch.nn.ReLU() | ||||
| @@ -1,10 +1,10 @@ | |||||
| from .embedding import Embedding | |||||
| from .linear import Linear | |||||
| from .lstm import Lstm | |||||
| from .conv import Conv | from .conv import Conv | ||||
| from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
| from .embedding import Embedding | |||||
| from .linear import Linear | |||||
| from .lstm import LSTM | |||||
| __all__ = ["Lstm", | |||||
| __all__ = ["LSTM", | |||||
| "Embedding", | "Embedding", | ||||
| "Linear", | "Linear", | ||||
| "Conv", | "Conv", | ||||
| @@ -1,9 +1,10 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
| class Lstm(nn.Module): | |||||
| """ | |||||
| LSTM module | |||||
| class LSTM(nn.Module): | |||||
| """Long Short Term Memory | |||||
| Args: | Args: | ||||
| input_size : input size | input_size : input size | ||||
| @@ -13,13 +14,17 @@ class Lstm(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN. Default: False. | bidirectional : If True, becomes a bidirectional RNN. Default: False. | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None): | |||||
| super(Lstm, self).__init__() | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, | |||||
| initial_method=None): | |||||
| super(LSTM, self).__init__() | |||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | ||||
| dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
| initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
| def forward(self, x): | def forward(self, x): | ||||
| x, _ = self.lstm(x) | x, _ = self.lstm(x) | ||||
| return x | return x | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| lstm = Lstm(10) | |||||
| lstm = LSTM(10) | |||||
| @@ -0,0 +1,25 @@ | |||||
| from fastNLP.core.loss import Loss | |||||
| from fastNLP.core.preprocess import Preprocessor | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.loader.dataset_loader import LMDatasetLoader | |||||
| from fastNLP.models.char_language_model import CharLM | |||||
| PICKLE = "./save/" | |||||
| def train(): | |||||
| loader = LMDatasetLoader("./train.txt") | |||||
| train_data = loader.load() | |||||
| pre = Preprocessor(label_is_seq=True, share_vocab=True) | |||||
| train_set = pre.run(train_data, pickle_path=PICKLE) | |||||
| model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size) | |||||
| trainer = Trainer(task="language_model", loss=Loss("cross_entropy")) | |||||
| trainer.train(model, train_set) | |||||
| if __name__ == "__main__": | |||||
| train() | |||||
| @@ -9,7 +9,7 @@ from fastNLP.models.base_model import BaseModel | |||||
| from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
| from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
| from fastNLP.modules.encoder.embedding import Embedding as Embedding | from fastNLP.modules.encoder.embedding import Embedding as Embedding | ||||
| from fastNLP.modules.encoder.lstm import Lstm | |||||
| from fastNLP.modules.encoder.lstm import LSTM | |||||
| train_data_path = 'small_train_data.txt' | train_data_path = 'small_train_data.txt' | ||||
| dev_data_path = 'small_dev_data.txt' | dev_data_path = 'small_dev_data.txt' | ||||
| @@ -43,7 +43,7 @@ class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel): | |||||
| def __init__(self, args=None): | def __init__(self, args=None): | ||||
| super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() | super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() | ||||
| self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) | self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) | ||||
| self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True) | |||||
| self.lstm = LSTM(input_size=embeding_size, hidden_size=lstm_hidden_size, bidirectional=True) | |||||
| self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) | self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) | ||||
| self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ]) | self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ]) | ||||
| def forward(self,x): | def forward(self,x): | ||||