| @@ -30,8 +30,18 @@ class DataSet(list): | |||
| return self | |||
| def index_field(self, field_name, vocab): | |||
| for ins in self: | |||
| ins.index_field(field_name, vocab) | |||
| if isinstance(field_name, str) and isinstance(vocab, Vocabulary): | |||
| field_list = [field_name] | |||
| vocab_list = [vocab] | |||
| else: | |||
| classes = (list, tuple) | |||
| assert isinstance(field_name, classes) and isinstance(vocab, classes) and len(field_name) == len(vocab) | |||
| field_list = field_name | |||
| vocab_list = vocab | |||
| for name, vocabs in zip(field_list, vocab_list): | |||
| for ins in self: | |||
| ins.index_field(name, vocabs) | |||
| return self | |||
| def to_tensor(self, idx: int, padding_length: dict): | |||
| @@ -57,6 +57,20 @@ class SeqLabelEvaluator(Evaluator): | |||
| return {"accuracy": float(accuracy)} | |||
| class SNLIEvaluator(Evaluator): | |||
| def __init__(self): | |||
| super(SNLIEvaluator, self).__init__() | |||
| def __call__(self, predict, truth): | |||
| y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||
| y_prob = torch.cat(y_prob, dim=0) | |||
| y_pred = torch.argmax(y_prob, dim=-1) | |||
| truth = [t['truth'] for t in truth] | |||
| y_true = torch.cat(truth, dim=0).view(-1) | |||
| acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) | |||
| return {"accuracy": acc} | |||
| def _conver_numpy(x): | |||
| """convert input data to numpy array | |||
| @@ -83,6 +83,7 @@ class Tester(object): | |||
| truth_list.append(batch_y) | |||
| eval_results = self.evaluate(output_list, truth_list) | |||
| print("[tester] {}".format(self.print_eval_results(eval_results))) | |||
| logger.info("[tester] {}".format(self.print_eval_results(eval_results))) | |||
| def mode(self, model, is_test=False): | |||
| """Train mode or Test mode. This is for PyTorch currently. | |||
| @@ -131,3 +132,10 @@ class ClassificationTester(Tester): | |||
| print( | |||
| "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") | |||
| super(ClassificationTester, self).__init__(**test_args) | |||
| class SNLITester(Tester): | |||
| def __init__(self, **test_args): | |||
| print( | |||
| "[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.") | |||
| super(SNLITester, self).__init__(**test_args) | |||
| @@ -18,6 +18,7 @@ def isiterable(p_object): | |||
| return False | |||
| return True | |||
| def check_build_vocab(func): | |||
| def _wrapper(self, *args, **kwargs): | |||
| if self.word2idx is None: | |||
| @@ -28,6 +29,7 @@ def check_build_vocab(func): | |||
| return func(self, *args, **kwargs) | |||
| return _wrapper | |||
| class Vocabulary(object): | |||
| """Use for word and index one to one mapping | |||
| @@ -52,7 +54,6 @@ class Vocabulary(object): | |||
| self.word2idx = None | |||
| self.idx2word = None | |||
| def update(self, word): | |||
| """add word or list of words into Vocabulary | |||
| @@ -70,7 +71,6 @@ class Vocabulary(object): | |||
| self.word_count[word] += 1 | |||
| self.word2idx = None | |||
| def build_vocab(self): | |||
| """build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||
| """ | |||
| @@ -163,3 +163,11 @@ class Vocabulary(object): | |||
| """ | |||
| self.__dict__.update(state) | |||
| self.idx2word = None | |||
| def __contains__(self, item): | |||
| """Check if a word in vocabulary. | |||
| :param item: the word | |||
| :return: True or False | |||
| """ | |||
| return self.has_word(item) | |||
| @@ -5,6 +5,7 @@ from fastNLP.core.dataset import DataSet | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.field import * | |||
| def convert_seq_dataset(data): | |||
| """Create an DataSet instance that contains no labels. | |||
| @@ -23,6 +24,7 @@ def convert_seq_dataset(data): | |||
| dataset.append(Instance(word_seq=x)) | |||
| return dataset | |||
| def convert_seq2tag_dataset(data): | |||
| """Convert list of data into DataSet | |||
| @@ -45,6 +47,7 @@ def convert_seq2tag_dataset(data): | |||
| dataset.append(ins) | |||
| return dataset | |||
| def convert_seq2seq_dataset(data): | |||
| """Convert list of data into DataSet | |||
| @@ -84,6 +87,7 @@ class DataSetLoader(BaseLoader): | |||
| """ | |||
| raise NotImplementedError | |||
| class RawDataSetLoader(DataSetLoader): | |||
| def __init__(self): | |||
| super(RawDataSetLoader, self).__init__() | |||
| @@ -98,6 +102,7 @@ class RawDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| return convert_seq_dataset(data) | |||
| class POSDataSetLoader(DataSetLoader): | |||
| """Dataset Loader for POS Tag datasets. | |||
| @@ -166,6 +171,7 @@ class POSDataSetLoader(DataSetLoader): | |||
| """ | |||
| return convert_seq2seq_dataset(data) | |||
| class TokenizeDataSetLoader(DataSetLoader): | |||
| """ | |||
| Data set loader for tokenization data sets | |||
| @@ -339,6 +345,7 @@ class LMDataSetLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| pass | |||
| class PeopleDailyCorpusLoader(DataSetLoader): | |||
| """ | |||
| People Daily Corpus: Chinese word segmentation, POS tag, NER | |||
| @@ -390,3 +397,72 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
| def convert(self, data): | |||
| pass | |||
| class SNLIDataSetLoader(DataSetLoader): | |||
| """A data set loader for SNLI data set. | |||
| """ | |||
| def __init__(self): | |||
| super(SNLIDataSetLoader, self).__init__() | |||
| def load(self, path_list): | |||
| """ | |||
| :param path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||
| :return: data_set: A DataSet object. | |||
| """ | |||
| assert len(path_list) == 3 | |||
| line_set = [] | |||
| for file in path_list: | |||
| if not os.path.exists(file): | |||
| raise FileNotFoundError("file {} NOT found".format(file)) | |||
| with open(file, 'r', encoding='utf-8') as f: | |||
| lines = f.readlines() | |||
| line_set.append(lines) | |||
| premise_lines, hypothesis_lines, label_lines = line_set | |||
| assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines) | |||
| data_set = [] | |||
| for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines): | |||
| p = premise.strip().split() | |||
| h = hypothesis.strip().split() | |||
| l = label.strip() | |||
| data_set.append([p, h, l]) | |||
| return self.convert(data_set) | |||
| def convert(self, data): | |||
| """Convert a 3D list to a DataSet object. | |||
| :param data: A 3D tensor. | |||
| [ | |||
| [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||
| [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||
| ... | |||
| ] | |||
| :return: data_set: A DataSet object. | |||
| """ | |||
| data_set = DataSet() | |||
| for example in data: | |||
| p, h, l = example | |||
| # list, list, str | |||
| x1 = TextField(p, is_target=False) | |||
| x2 = TextField(h, is_target=False) | |||
| x1_len = TextField([1] * len(p), is_target=False) | |||
| x2_len = TextField([1] * len(h), is_target=False) | |||
| y = LabelField(l, is_target=True) | |||
| instance = Instance() | |||
| instance.add_field("premise", x1) | |||
| instance.add_field("hypothesis", x2) | |||
| instance.add_field("premise_len", x1_len) | |||
| instance.add_field("hypothesis_len", x2_len) | |||
| instance.add_field("truth", y) | |||
| data_set.append(instance) | |||
| return data_set | |||
| @@ -6,11 +6,12 @@ import torch | |||
| from fastNLP.loader.base_loader import BaseLoader | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| class EmbedLoader(BaseLoader): | |||
| """docstring for EmbedLoader""" | |||
| def __init__(self, data_path): | |||
| super(EmbedLoader, self).__init__(data_path) | |||
| def __init__(self): | |||
| super(EmbedLoader, self).__init__() | |||
| @staticmethod | |||
| def _load_glove(emb_file): | |||
| @@ -55,15 +56,15 @@ class EmbedLoader(BaseLoader): | |||
| :param emb_type: str, the pre-trained embedding format, support glove now | |||
| :param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
| :param emb_pkl: str, the embedding pickle file. | |||
| :return embedding_np: numpy array of shape (len(word_dict), emb_dim) | |||
| :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) | |||
| vocab: input vocab or vocab built by pre-train | |||
| TODO: fragile code | |||
| """ | |||
| # If the embedding pickle exists, load it and return. | |||
| if os.path.exists(emb_pkl): | |||
| with open(emb_pkl, "rb") as f: | |||
| embedding_np, vocab = _pickle.load(f) | |||
| return embedding_np, vocab | |||
| embedding_tensor, vocab = _pickle.load(f) | |||
| return embedding_tensor, vocab | |||
| # Otherwise, load the pre-trained embedding. | |||
| pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||
| if vocab is None: | |||
| @@ -71,14 +72,14 @@ class EmbedLoader(BaseLoader): | |||
| vocab = Vocabulary() | |||
| for w in pretrain.keys(): | |||
| vocab.update(w) | |||
| embedding_np = torch.randn(len(vocab), emb_dim) | |||
| embedding_tensor = torch.randn(len(vocab), emb_dim) | |||
| for w, v in pretrain.items(): | |||
| if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||
| raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) | |||
| if vocab.has_word(w): | |||
| embedding_np[vocab[w]] = v | |||
| embedding_tensor[vocab[w]] = v | |||
| # save and return the result | |||
| with open(emb_pkl, "wb") as f: | |||
| _pickle.dump((embedding_np, vocab), f) | |||
| return embedding_np, vocab | |||
| _pickle.dump((embedding_tensor, vocab), f) | |||
| return embedding_tensor, vocab | |||
| @@ -1,12 +1,15 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| from fastNLP.modules.utils import initial_parameter | |||
| class MLP(nn.Module): | |||
| def __init__(self, size_layer, activation='relu' , initial_method = None): | |||
| def __init__(self, size_layer, activation='relu', initial_method=None): | |||
| """Multilayer Perceptrons as a decoder | |||
| :param size_layer: list of int, define the size of MLP layers | |||
| :param activation: str or function, the activation function for hidden layers | |||
| :param size_layer: list of int, define the size of MLP layers. | |||
| :param activation: str or function, the activation function for hidden layers. | |||
| :param initial_method: str, the name of init method. | |||
| .. note:: | |||
| There is no activation function applying on output layer. | |||
| @@ -23,7 +26,7 @@ class MLP(nn.Module): | |||
| actives = { | |||
| 'relu': nn.ReLU(), | |||
| 'tanh': nn.Tanh() | |||
| 'tanh': nn.Tanh(), | |||
| } | |||
| if activation in actives: | |||
| self.hidden_active = actives[activation] | |||
| @@ -31,7 +34,7 @@ class MLP(nn.Module): | |||
| self.hidden_active = activation | |||
| else: | |||
| raise ValueError("should set activation correctly: {}".format(activation)) | |||
| initial_parameter(self, initial_method ) | |||
| initial_parameter(self, initial_method) | |||
| def forward(self, x): | |||
| for layer in self.hiddens: | |||
| @@ -40,13 +43,11 @@ class MLP(nn.Module): | |||
| return x | |||
| if __name__ == '__main__': | |||
| net1 = MLP([5,10,5]) | |||
| net2 = MLP([5,10,5], 'tanh') | |||
| net1 = MLP([5, 10, 5]) | |||
| net2 = MLP([5, 10, 5], 'tanh') | |||
| for net in [net1, net2]: | |||
| x = torch.randn(5, 5) | |||
| y = net(x) | |||
| print(x) | |||
| print(y) | |||
| @@ -1,6 +1,8 @@ | |||
| import torch.nn as nn | |||
| from fastNLP.modules.utils import initial_parameter | |||
| class Linear(nn.Module): | |||
| """ | |||
| Linear module | |||
| @@ -12,10 +14,11 @@ class Linear(nn.Module): | |||
| bidirectional : If True, becomes a bidirectional RNN | |||
| """ | |||
| def __init__(self, input_size, output_size, bias=True,initial_method = None ): | |||
| def __init__(self, input_size, output_size, bias=True, initial_method=None): | |||
| super(Linear, self).__init__() | |||
| self.linear = nn.Linear(input_size, output_size, bias) | |||
| initial_parameter(self, initial_method) | |||
| def forward(self, x): | |||
| x = self.linear(x) | |||
| return x | |||
| @@ -14,16 +14,23 @@ class LSTM(nn.Module): | |||
| bidirectional : If True, becomes a bidirectional RNN. Default: False. | |||
| """ | |||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, | |||
| initial_method=None): | |||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
| bidirectional=False, bias=True, initial_method=None, get_hidden=False): | |||
| super(LSTM, self).__init__() | |||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | |||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
| dropout=dropout, bidirectional=bidirectional) | |||
| self.get_hidden = get_hidden | |||
| initial_parameter(self, initial_method) | |||
| def forward(self, x): | |||
| x, _ = self.lstm(x) | |||
| return x | |||
| def forward(self, x, h0=None, c0=None): | |||
| if h0 is not None and c0 is not None: | |||
| x, (ht, ct) = self.lstm(x, (h0, c0)) | |||
| else: | |||
| x, (ht, ct) = self.lstm(x) | |||
| if self.get_hidden: | |||
| return x, (ht, ct) | |||
| else: | |||
| return x | |||
| if __name__ == "__main__": | |||
| @@ -45,3 +45,28 @@ use_cuda = true | |||
| learn_rate = 1e-3 | |||
| momentum = 0.9 | |||
| model_name = "class_model.pkl" | |||
| [snli_trainer] | |||
| epochs = 5 | |||
| batch_size = 32 | |||
| validate = true | |||
| save_best_dev = true | |||
| use_cuda = true | |||
| learn_rate = 1e-4 | |||
| loss = "cross_entropy" | |||
| print_every_step = 1000 | |||
| [snli_tester] | |||
| batch_size = 512 | |||
| use_cuda = true | |||
| [snli_model] | |||
| model_name = "snli_model.pkl" | |||
| embed_dim = 300 | |||
| hidden_size = 300 | |||
| batch_first = true | |||
| dropout = 0.5 | |||
| gpu = true | |||
| embed_file = "./../data_for_tests/glove.840B.300d.txt" | |||
| embed_pkl = "./snli/embed.pkl" | |||
| examples = 0 | |||