diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index 65daafed..dc5640f1 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -11,8 +11,24 @@ class DatasetLoader(BaseLoader): class POSDatasetLoader(DatasetLoader): - """loader for pos data sets""" - + """Dataset Loader for POS Tag datasets. + + In these datasets, each line are divided by '\t' + while the first Col is the vocabulary and the second + Col is the label. + Different sentence are divided by an empty line. + e.g: + Tom label1 + and label2 + Jerry label1 + . label3 + Hello label4 + world label5 + ! label3 + In this file, there are two sentence "Tom and Jerry ." + and "Hello world !". Each word has its own label from label1 + to label5. + """ def __init__(self, data_name, data_path): super(POSDatasetLoader, self).__init__(data_name, data_path) @@ -23,10 +39,42 @@ class POSDatasetLoader(DatasetLoader): return line def load_lines(self): - assert (os.path.exists(self.data_path)) + """ + :return data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + """ with open(self.data_path, "r", encoding="utf-8") as f: lines = f.readlines() - return lines + return self.parse(lines) + + @staticmethod + def parse(lines): + data = [] + sentence = [] + for line in lines: + line = line.strip() + if len(line) > 1: + sentence.append(line.split('\t')) + else: + words = [] + labels = [] + for tokens in sentence: + words.append(tokens[0]) + labels.append(tokens[1]) + data.append([words, labels]) + sentence = [] + if len(sentence) != 0: + words = [] + labels = [] + for tokens in sentence: + words.append(tokens[0]) + labels.append(tokens[1]) + data.append([words, labels]) + return data class ClassDatasetLoader(DatasetLoader): @@ -112,3 +160,10 @@ class LMDatasetLoader(DatasetLoader): with open(self.data_path, "r", encoding="utf=8") as f: text = " ".join(f.readlines()) return text.strip().split() + + +if __name__ == "__main__": + data = POSDatasetLoader("xxx", "../../test/data_for_tests/people.txt").load_lines() + for example in data: + for w, l in zip(example[0], example[1]): + print(w, l) diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index ec70db88..7cd91f9c 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -28,33 +28,24 @@ class BasePreprocess(object): class POSPreprocess(BasePreprocess): """ This class are used to preprocess the pos datasets. - In these datasets, each line are divided by '\t' - while the first Col is the vocabulary and the second - Col is the label. - Different sentence are divided by an empty line. - e.g: - Tom label1 - and label2 - Jerry label1 - . label3 - Hello label4 - world label5 - ! label3 - In this file, there are two sentence "Tom and Jerry ." - and "Hello world !". Each word has its own label from label1 - to label5. + """ - def __init__(self, data, pickle_path, train_dev_split=0): + def __init__(self, data, pickle_path="./", train_dev_split=0): """ Preprocess pipeline, including building mapping from words to index, from index to words, from labels/classes to index, from index to labels/classes. - :param data: - :param pickle_path: + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :param pickle_path: str, the directory to the pickle files. Default: "./" :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. To do: - 1. use @contextmanager to handle pickle dumps and loads + 1. simplify __init__ """ super(POSPreprocess, self).__init__(data, pickle_path) @@ -75,6 +66,7 @@ class POSPreprocess(BasePreprocess): else: with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f: _pickle.dump(self.label2index, f) + #something will be wrong if word2id.pkl is found but class2id.pkl is not found if not self.pickle_exist("id2word.pkl"): index2word = self.build_reverse_dict(self.word2index) @@ -98,25 +90,23 @@ class POSPreprocess(BasePreprocess): def build_dict(self, data): """ Add new words with indices into self.word_dict, new labels with indices into self.label_dict. - :param data: list of list [word, label] - :return word2index: dict of (str, int) - label2index: dict of (str, int) + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :return word2index: dict of {str, int} + label2index: dict of {str, int} """ label2index = {} word2index = DEFAULT_WORD_TO_INDEX - for line in data: - line = line.strip() - if len(line) <= 1: - continue - tokens = line.split('\t') - - if tokens[0] not in word2index: - # add (word, index) into the dict - word2index[tokens[0]] = len(word2index) - - # for label in tokens[1: ]: - if tokens[1] not in label2index: - label2index[tokens[1]] = len(label2index) + for example in data: + for word, label in zip(example[0], example[1]): + if word not in word2index: + word2index[word] = len(word2index) + if label not in label2index: + label2index[label] = len(label2index) return word2index, label2index def pickle_exist(self, pickle_name): @@ -139,24 +129,31 @@ class POSPreprocess(BasePreprocess): def to_index(self, data): """ Convert word strings and label strings into indices. - :param data: list of str. Each string is a line, described above. - :return data_index: list of tuple (word index, label index) + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :return data_index: the shape of data, but each string is replaced by its corresponding index """ - data_train = [] - sentence = [] - for w in data: - w = w.strip() - if len(w) <= 1: - wid = [] - lid = [] - for i in range(len(sentence)): - wid.append(self.word2index[sentence[i][0]]) - lid.append(self.label2index[sentence[i][1]]) - data_train.append((wid, lid)) - sentence = [] - continue - sentence.append(w.split('\t')) - return data_train + data_index = [] + for example in data: + word_list = [] + label_list = [] + for word, label in zip(example[0], example[1]): + word_list.append(self.word2index[word]) + label_list.append(self.label2index[label]) + data_index.append([word_list, label_list]) + return data_index + + @property + def vocab_size(self): + return len(self.word2index) + + @property + def num_classes(self): + return len(self.label2index) class ClassPreprocess(BasePreprocess): diff --git a/test/test_loader.py b/test/test_loader.py index 58e5dfe5..b18a2fcf 100644 --- a/test/test_loader.py +++ b/test/test_loader.py @@ -1,9 +1,23 @@ import unittest +from fastNLP.loader.dataset_loader import POSDatasetLoader -class MyTestCase(unittest.TestCase): - def test_something(self): - self.assertEqual(True, False) + +class TestPreprocess(unittest.TestCase): + def test_case_1(self): + data = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], + ["Hello", "world", "!"], ["T", "F", "F"]] + pickle_path = "./data_for_tests/" + # POSPreprocess(data, pickle_path) + + +class TestDatasetLoader(unittest.TestCase): + def test_case_1(self): + data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" + lines = data.split("\n") + answer = POSDatasetLoader.parse(lines) + truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] + self.assertListEqual(answer, truth, "POS Dataset Loader") if __name__ == '__main__':