From 843b7c0e7ef3a425dd21b23169d317ecb5b6631a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Fri, 20 Jul 2018 20:21:40 +0800 Subject: [PATCH] To Do (save commit) --- fastNLP/loader/dataset_loader.py | 12 +++++++ fastNLP/loader/preprocess.py | 55 ++++++++++++++++++++------------ test/test_keras_like.py | 28 ++++++++++++++++ 3 files changed, 75 insertions(+), 20 deletions(-) create mode 100644 test/test_keras_like.py diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py index d57a48db..19f12ab8 100644 --- a/fastNLP/loader/dataset_loader.py +++ b/fastNLP/loader/dataset_loader.py @@ -95,3 +95,15 @@ class ConllLoader(DatasetLoader): continue tokens.append(line.split()) return sentences + + +class LMDatasetLoader(DatasetLoader): + def __init__(self, data_name, data_path): + super(LMDatasetLoader, self).__init__(data_name, data_path) + + def load(self): + if not os.path.exists(self.data_path): + raise FileNotFoundError("file {} not found.".format(self.data_path)) + with open(self.data_path, "r", encoding="utf=8") as f: + text = " ".join(f.readlines()) + return text.strip().split() diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index 106fe90f..0ffd94fa 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -41,14 +41,24 @@ class POSPreprocess(BasePreprocess): to label5. """ + def __init__(self, data, pickle_path): super(POSPreprocess, self).__init__(data, pickle_path) - self.word_dict = None + self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, + DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, + DEFAULT_RESERVED_LABEL[2]: 4} + self.label_dict = None self.data = data self.pickle_path = pickle_path - self.build_dict() - self.word2id() + + self.build_dict(data) + if not self.pickle_exist("word2id.pkl"): + self.word_dict.update(self.word2id(data)) + file_name = os.path.join(self.pickle_path, "word2id.pkl") + with open(file_name, "wb") as f: + _pickle.dump(self.word_dict, f) + self.vocab_size = self.id2word() self.class2id() self.num_classes = self.id2class() @@ -57,26 +67,26 @@ class POSPreprocess(BasePreprocess): self.data_dev() self.data_test() - def build_dict(self): - self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} + def build_dict(self, data): + """ + Add new words with indices into self.word_dict, new labels with indices into self.label_dict. + :param data: list of list [word, label] + """ + self.label_dict = {} - for w in self.data: - w = w.strip() - if len(w) <= 1: + for line in data: + line = line.strip() + if len(line) <= 1: continue - word = w.split('\t') + tokens = line.split('\t') - if word[0] not in self.word_dict: - index = len(self.word_dict) - self.word_dict[word[0]] = index + if tokens[0] not in self.word_dict: + # add (word, index) into the dict + self.word_dict[tokens[0]] = len(self.word_dict) - # for label in word[1: ]: - label = word[1] - if label not in self.label_dict: - index = len(self.label_dict) - self.label_dict[label] = index + # for label in tokens[1: ]: + if tokens[1] not in self.label_dict: + self.label_dict[tokens[1]] = len(self.label_dict) def pickle_exist(self, pickle_name): """ @@ -174,4 +184,9 @@ class POSPreprocess(BasePreprocess): pass def data_test(self): - pass \ No newline at end of file + pass + + +class LMPreprocess(BasePreprocess): + def __init__(self, data, pickle_path): + super(LMPreprocess, self).__init__(data, pickle_path) diff --git a/test/test_keras_like.py b/test/test_keras_like.py new file mode 100644 index 00000000..08f7d6ae --- /dev/null +++ b/test/test_keras_like.py @@ -0,0 +1,28 @@ +import aggregation +import decoder +import encoder + + +class Input(object): + def __init__(self): + pass + + +class Trainer(object): + def __init__(self, input, target, truth): + pass + + def train(self): + pass + + +def test_keras_like(): + data_train, label_train = dataLoader("./data_path") + + x = Input() + x = encoder.LSTM(input=x) + x = aggregation.max_pool(input=x) + y = decoder.CRF(input=x) + + trainer = Trainer(input=data_train, target=y, truth=label_train) + trainer.train()