|
|
@@ -1,25 +1,57 @@ |
|
|
|
import pickle
|
|
|
|
import _pickle
|
|
|
|
import os
|
|
|
|
|
|
|
|
from fastNLP.loader.base_preprocess import BasePreprocess
|
|
|
|
|
|
|
|
DEFAULT_PADDING_LABEL = '<pad>' #dict index = 0
|
|
|
|
DEFAULT_UNKNOWN_LABEL = '<unk>' #dict index = 1
|
|
|
|
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
|
|
|
|
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
|
|
|
|
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
|
|
|
|
'<reserved-3>',
|
|
|
|
'<reserved-4>'] #dict index = 2~4
|
|
|
|
#the first vocab in dict with the index = 5
|
|
|
|
'<reserved-4>'] # dict index = 2~4
|
|
|
|
|
|
|
|
|
|
|
|
# the first vocab in dict with the index = 5
|
|
|
|
|
|
|
|
|
|
|
|
class BasePreprocess(object):
|
|
|
|
|
|
|
|
def __init__(self, data, pickle_path):
|
|
|
|
super(BasePreprocess, self).__init__()
|
|
|
|
self.data = data
|
|
|
|
self.pickle_path = pickle_path
|
|
|
|
if not self.pickle_path.endswith('/'):
|
|
|
|
self.pickle_path = self.pickle_path + '/'
|
|
|
|
|
|
|
|
def word2id(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def id2word(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def class2id(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def id2class(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def embedding(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def data_train(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def data_dev(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
def data_test(self):
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class POSPreprocess(BasePreprocess):
|
|
|
|
|
|
|
|
"""
|
|
|
|
This class are used to preprocess the pos datasets.
|
|
|
|
In these datasets, each line are divided by '\t'
|
|
|
|
while the first Col is the vocabulary and the second
|
|
|
|
Col is the label.
|
|
|
|
In these datasets, each line is divided by '\t'
|
|
|
|
The first Col is the vocabulary.
|
|
|
|
The second Col is the labels.
|
|
|
|
Different sentence are divided by an empty line.
|
|
|
|
e.g:
|
|
|
|
Tom label1
|
|
|
@@ -36,7 +68,9 @@ class POSPreprocess(BasePreprocess): |
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data, pickle_path):
|
|
|
|
super(POSPreprocess, self).__init(data, pickle_path)
|
|
|
|
super(POSPreprocess, self).__init__(data, pickle_path)
|
|
|
|
self.word_dict = None
|
|
|
|
self.label_dict = None
|
|
|
|
self.build_dict()
|
|
|
|
self.word2id()
|
|
|
|
self.id2word()
|
|
|
@@ -46,8 +80,6 @@ class POSPreprocess(BasePreprocess): |
|
|
|
self.data_train()
|
|
|
|
self.data_dev()
|
|
|
|
self.data_test()
|
|
|
|
#...
|
|
|
|
|
|
|
|
|
|
|
|
def build_dict(self):
|
|
|
|
self.word_dict = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
|
|
|
@@ -68,7 +100,6 @@ class POSPreprocess(BasePreprocess): |
|
|
|
index = len(self.label_dict)
|
|
|
|
self.label_dict[label] = index
|
|
|
|
|
|
|
|
|
|
|
|
def pickle_exist(self, pickle_name):
|
|
|
|
"""
|
|
|
|
:param pickle_name: the filename of target pickle file
|
|
|
@@ -82,7 +113,6 @@ class POSPreprocess(BasePreprocess): |
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def word2id(self):
|
|
|
|
if self.pickle_exist("word2id.pkl"):
|
|
|
|
return
|
|
|
@@ -92,11 +122,10 @@ class POSPreprocess(BasePreprocess): |
|
|
|
with open(file_name, "wb", encoding='utf-8') as f:
|
|
|
|
_pickle.dump(self.word_dict, f)
|
|
|
|
|
|
|
|
|
|
|
|
def id2word(self):
|
|
|
|
if self.pickle_exist("id2word.pkl"):
|
|
|
|
return
|
|
|
|
#nothing will be done if id2word.pkl exists
|
|
|
|
# nothing will be done if id2word.pkl exists
|
|
|
|
|
|
|
|
id2word_dict = {}
|
|
|
|
for word in self.word_dict:
|
|
|
@@ -105,7 +134,6 @@ class POSPreprocess(BasePreprocess): |
|
|
|
with open(file_name, "wb", encoding='utf-8') as f:
|
|
|
|
_pickle.dump(id2word_dict, f)
|
|
|
|
|
|
|
|
|
|
|
|
def class2id(self):
|
|
|
|
if self.pickle_exist("class2id.pkl"):
|
|
|
|
return
|
|
|
@@ -115,11 +143,10 @@ class POSPreprocess(BasePreprocess): |
|
|
|
with open(file_name, "wb", encoding='utf-8') as f:
|
|
|
|
_pickle.dump(self.label_dict, f)
|
|
|
|
|
|
|
|
|
|
|
|
def id2class(self):
|
|
|
|
if self.pickle_exist("id2class.pkl"):
|
|
|
|
return
|
|
|
|
#nothing will be done if id2class.pkl exists
|
|
|
|
# nothing will be done if id2class.pkl exists
|
|
|
|
|
|
|
|
id2class_dict = {}
|
|
|
|
for label in self.label_dict:
|
|
|
@@ -128,17 +155,15 @@ class POSPreprocess(BasePreprocess): |
|
|
|
with open(file_name, "wb", encoding='utf-8') as f:
|
|
|
|
_pickle.dump(id2class_dict, f)
|
|
|
|
|
|
|
|
|
|
|
|
def embedding(self):
|
|
|
|
if self.pickle_exist("embedding.pkl"):
|
|
|
|
return
|
|
|
|
#nothing will be done if embedding.pkl exists
|
|
|
|
|
|
|
|
# nothing will be done if embedding.pkl exists
|
|
|
|
|
|
|
|
def data_train(self):
|
|
|
|
if self.pickle_exist("data_train.pkl"):
|
|
|
|
return
|
|
|
|
#nothing will be done if data_train.pkl exists
|
|
|
|
# nothing will be done if data_train.pkl exists
|
|
|
|
|
|
|
|
data_train = []
|
|
|
|
sentence = []
|
|
|
|