From 9c7f3cf26125234a785d9383839df1ee6779905f Mon Sep 17 00:00:00 2001 From: yunfan Date: Tue, 18 Sep 2018 16:43:56 +0800 Subject: [PATCH] add vocabulary into preprocessor --- fastNLP/core/preprocess.py | 80 ++++++++++++-------------------------- fastNLP/core/vocabulary.py | 29 +++++++++++++- test/core/test_field.py | 69 -------------------------------- test/core/test_vocab.py | 6 +-- 4 files changed, 53 insertions(+), 131 deletions(-) delete mode 100644 test/core/test_field.py diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py index b5d348e6..2671b4f4 100644 --- a/fastNLP/core/preprocess.py +++ b/fastNLP/core/preprocess.py @@ -6,16 +6,7 @@ import numpy as np from fastNLP.core.dataset import DataSet from fastNLP.core.field import TextField, LabelField from fastNLP.core.instance import Instance - -DEFAULT_PADDING_LABEL = '' # dict index = 0 -DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 -DEFAULT_RESERVED_LABEL = ['', - '', - ''] # dict index = 2~4 - -DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} +from fastNLP.core.vocabulary import Vocabulary # the first vocab in dict with the index = 5 @@ -68,24 +59,22 @@ class BasePreprocess(object): - "word2id.pkl", a mapping from words(tokens) to indices - "id2word.pkl", a reversed dictionary - - "label2id.pkl", a dictionary on labels - - "id2label.pkl", a reversed dictionary on labels These four pickle files are expected to be saved in the given pickle directory once they are constructed. Preprocessors will check if those files are already in the directory and will reuse them in future calls. """ def __init__(self): - self.word2index = None - self.label2index = None + self.data_vocab = Vocabulary() + self.label_vocab = Vocabulary() @property def vocab_size(self): - return len(self.word2index) + return len(self.data_vocab) @property def num_classes(self): - return len(self.label2index) + return len(self.label_vocab) def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): """Main pre-processing pipeline. @@ -102,20 +91,14 @@ class BasePreprocess(object): """ if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): - self.word2index = load_pickle(pickle_path, "word2id.pkl") - self.label2index = load_pickle(pickle_path, "class2id.pkl") + self.data_vocab = load_pickle(pickle_path, "word2id.pkl") + self.label_vocab = load_pickle(pickle_path, "class2id.pkl") else: - self.word2index, self.label2index = self.build_dict(train_dev_data) - save_pickle(self.word2index, pickle_path, "word2id.pkl") - save_pickle(self.label2index, pickle_path, "class2id.pkl") - - if not pickle_exist(pickle_path, "id2word.pkl"): - index2word = self.build_reverse_dict(self.word2index) - save_pickle(index2word, pickle_path, "id2word.pkl") + self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) + save_pickle(self.data_vocab, pickle_path, "word2id.pkl") + save_pickle(self.label_vocab, pickle_path, "class2id.pkl") - if not pickle_exist(pickle_path, "id2class.pkl"): - index2label = self.build_reverse_dict(self.label2index) - save_pickle(index2label, pickle_path, "id2class.pkl") + self.build_reverse_dict() train_set = [] dev_set = [] @@ -125,13 +108,13 @@ class BasePreprocess(object): split = int(len(train_dev_data) * train_dev_split) data_dev = train_dev_data[: split] data_train = train_dev_data[split:] - train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) - dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) + train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) + dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) save_pickle(dev_set, pickle_path, "data_dev.pkl") print("{} of the training data is split for validation. ".format(train_dev_split)) else: - train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) + train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) save_pickle(train_set, pickle_path, "data_train.pkl") else: train_set = load_pickle(pickle_path, "data_train.pkl") @@ -143,8 +126,8 @@ class BasePreprocess(object): # cross validation data_cv = self.cv_split(train_dev_data, n_fold) for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): - data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) - data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) + data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) + data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) save_pickle( data_train_cv, pickle_path, "data_train_{}.pkl".format(i)) @@ -165,7 +148,7 @@ class BasePreprocess(object): test_set = [] if test_data is not None: if not pickle_exist(pickle_path, "data_test.pkl"): - test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) + test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) save_pickle(test_set, pickle_path, "data_test.pkl") # return preprocessed results @@ -180,28 +163,15 @@ class BasePreprocess(object): return tuple(results) def build_dict(self, data): - label2index = DEFAULT_WORD_TO_INDEX.copy() - word2index = DEFAULT_WORD_TO_INDEX.copy() for example in data: - for word in example[0]: - if word not in word2index: - word2index[word] = len(word2index) - label = example[1] - if isinstance(label, str): - # label is a string - if label not in label2index: - label2index[label] = len(label2index) - elif isinstance(label, list): - # label is a list of strings - for single_label in label: - if single_label not in label2index: - label2index[single_label] = len(label2index) - return word2index, label2index - - - def build_reverse_dict(self, word_dict): - id2word = {word_dict[w]: w for w in word_dict} - return id2word + word, label = example + self.data_vocab.update(word) + self.label_vocab.update(label) + return self.data_vocab, self.label_vocab + + def build_reverse_dict(self): + self.data_vocab.build_reverse_vocab() + self.label_vocab.build_reverse_vocab() def data_split(self, data, train_dev_split): """Split data into train and dev set.""" diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index baae3753..79b70939 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -18,6 +18,16 @@ def isiterable(p_object): return True class Vocabulary(object): + """Use for word and index one to one mapping + + Example:: + + vocab = Vocabulary() + word_list = "this is a word list".split() + vocab.update(word_list) + vocab["word"] + vocab.to_word(5) + """ def __init__(self): self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) self.padding_label = DEFAULT_PADDING_LABEL @@ -29,6 +39,8 @@ class Vocabulary(object): def update(self, word): """add word or list of words into Vocabulary + + :param word: a list of str or str """ if not isinstance(word, str) and isiterable(word): # it's a nested list @@ -43,13 +55,22 @@ class Vocabulary(object): def __getitem__(self, w): - """ like to_index(w) function, turn a word to the index - if w is not in Vocabulary, return the unknown label + """To support usage like:: + + vocab[w] """ if w in self.word2idx: return self.word2idx[w] else: return self.word2idx[DEFAULT_UNKNOWN_LABEL] + + def to_index(self, w): + """ like to_index(w) function, turn a word to the index + if w is not in Vocabulary, return the unknown label + + :param str w: + """ + return self[w] def unknown_idx(self): return self.word2idx[self.unknown_label] @@ -58,10 +79,14 @@ class Vocabulary(object): return self.word2idx[self.padding_label] def build_reverse_vocab(self): + """build 'index to word' dict based on 'word to index' dict + """ self.idx2word = {self.word2idx[w] : w for w in self.word2idx} def to_word(self, idx): """given a word's index, return the word itself + + :param int idx: """ if self.idx2word is None: self.build_reverse_vocab() diff --git a/test/core/test_field.py b/test/core/test_field.py deleted file mode 100644 index 7c1b6343..00000000 --- a/test/core/test_field.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - -import unittest -import torch -from fastNLP.data.field import TextField, LabelField -from fastNLP.data.instance import Instance -from fastNLP.data.dataset import DataSet -from fastNLP.data.batch import Batch - - - -class TestField(unittest.TestCase): - def check_batched_data_equal(self, data1, data2): - self.assertEqual(len(data1), len(data2)) - for i in range(len(data1)): - self.assertTrue(data1[i].keys(), data2[i].keys()) - for i in range(len(data1)): - for t1, t2 in zip(data1[i].values(), data2[i].values()): - self.assertTrue(torch.equal(t1, t2)) - - def test_batchiter(self): - texts = [ - "i am a cat", - "this is a test of new batch", - "haha" - ] - labels = [0, 1, 0] - - # prepare vocabulary - vocab = {} - for text in texts: - for tokens in text.split(): - if tokens not in vocab: - vocab[tokens] = len(vocab) - - # prepare input dataset - data = DataSet() - for text, label in zip(texts, labels): - x = TextField(text.split(), False) - y = LabelField(label, is_target=True) - ins = Instance(text=x, label=y) - data.append(ins) - - # use vocabulary to index data - data.index_field("text", vocab) - - # define naive sampler for batch class - class SeqSampler: - def __call__(self, dataset): - return list(range(len(dataset))) - - # use bacth to iterate dataset - batcher = Batch(data, SeqSampler(), 2) - TRUE_X = [{'text': torch.tensor([[0, 1, 2, 3, 0, 0, 0], [4, 5, 2, 6, 7, 8, 9]])}, {'text': torch.tensor([[10]])}] - TRUE_Y = [{'label': torch.tensor([[0], [1]])}, {'label': torch.tensor([[0]])}] - for epoch in range(3): - test_x, test_y = [], [] - for batch_x, batch_y in batcher: - test_x.append(batch_x) - test_y.append(batch_y) - self.check_batched_data_equal(TRUE_X, test_x) - self.check_batched_data_equal(TRUE_Y, test_y) - - -if __name__ == "__main__": - unittest.main() - \ No newline at end of file diff --git a/test/core/test_vocab.py b/test/core/test_vocab.py index dd51c197..89b0691a 100644 --- a/test/core/test_vocab.py +++ b/test/core/test_vocab.py @@ -1,9 +1,5 @@ -import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - import unittest -from fastNLP.data.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX +from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX class TestVocabulary(unittest.TestCase): def test_vocab(self):