|
|
@@ -41,14 +41,24 @@ class POSPreprocess(BasePreprocess): |
|
|
|
to label5.
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, data, pickle_path):
|
|
|
|
super(POSPreprocess, self).__init__(data, pickle_path)
|
|
|
|
self.word_dict = None
|
|
|
|
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
|
|
|
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
|
|
|
DEFAULT_RESERVED_LABEL[2]: 4}
|
|
|
|
|
|
|
|
self.label_dict = None
|
|
|
|
self.data = data
|
|
|
|
self.pickle_path = pickle_path
|
|
|
|
self.build_dict()
|
|
|
|
self.word2id()
|
|
|
|
|
|
|
|
self.build_dict(data)
|
|
|
|
if not self.pickle_exist("word2id.pkl"):
|
|
|
|
self.word_dict.update(self.word2id(data))
|
|
|
|
file_name = os.path.join(self.pickle_path, "word2id.pkl")
|
|
|
|
with open(file_name, "wb") as f:
|
|
|
|
_pickle.dump(self.word_dict, f)
|
|
|
|
|
|
|
|
self.vocab_size = self.id2word()
|
|
|
|
self.class2id()
|
|
|
|
self.num_classes = self.id2class()
|
|
|
@@ -57,26 +67,26 @@ class POSPreprocess(BasePreprocess): |
|
|
|
self.data_dev()
|
|
|
|
self.data_test()
|
|
|
|
|
|
|
|
def build_dict(self):
|
|
|
|
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
|
|
|
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
|
|
|
|
DEFAULT_RESERVED_LABEL[2]: 4}
|
|
|
|
def build_dict(self, data):
|
|
|
|
"""
|
|
|
|
Add new words with indices into self.word_dict, new labels with indices into self.label_dict.
|
|
|
|
:param data: list of list [word, label]
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.label_dict = {}
|
|
|
|
for w in self.data:
|
|
|
|
w = w.strip()
|
|
|
|
if len(w) <= 1:
|
|
|
|
for line in data:
|
|
|
|
line = line.strip()
|
|
|
|
if len(line) <= 1:
|
|
|
|
continue
|
|
|
|
word = w.split('\t')
|
|
|
|
tokens = line.split('\t')
|
|
|
|
|
|
|
|
if word[0] not in self.word_dict:
|
|
|
|
index = len(self.word_dict)
|
|
|
|
self.word_dict[word[0]] = index
|
|
|
|
if tokens[0] not in self.word_dict:
|
|
|
|
# add (word, index) into the dict
|
|
|
|
self.word_dict[tokens[0]] = len(self.word_dict)
|
|
|
|
|
|
|
|
# for label in word[1: ]:
|
|
|
|
label = word[1]
|
|
|
|
if label not in self.label_dict:
|
|
|
|
index = len(self.label_dict)
|
|
|
|
self.label_dict[label] = index
|
|
|
|
# for label in tokens[1: ]:
|
|
|
|
if tokens[1] not in self.label_dict:
|
|
|
|
self.label_dict[tokens[1]] = len(self.label_dict)
|
|
|
|
|
|
|
|
def pickle_exist(self, pickle_name):
|
|
|
|
"""
|
|
|
@@ -174,4 +184,9 @@ class POSPreprocess(BasePreprocess): |
|
|
|
pass
|
|
|
|
|
|
|
|
def data_test(self):
|
|
|
|
pass |
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class LMPreprocess(BasePreprocess):
|
|
|
|
def __init__(self, data, pickle_path):
|
|
|
|
super(LMPreprocess, self).__init__(data, pickle_path)
|