|
|
@@ -6,16 +6,7 @@ import numpy as np |
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.core.field import TextField, LabelField |
|
|
|
from fastNLP.core.instance import Instance |
|
|
|
|
|
|
|
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 |
|
|
|
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 |
|
|
|
DEFAULT_RESERVED_LABEL = ['<reserved-2>', |
|
|
|
'<reserved-3>', |
|
|
|
'<reserved-4>'] # dict index = 2~4 |
|
|
|
|
|
|
|
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, |
|
|
|
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, |
|
|
|
DEFAULT_RESERVED_LABEL[2]: 4} |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
|
|
|
|
|
|
# the first vocab in dict with the index = 5 |
|
|
@@ -68,24 +59,22 @@ class BasePreprocess(object): |
|
|
|
|
|
|
|
- "word2id.pkl", a mapping from words(tokens) to indices |
|
|
|
- "id2word.pkl", a reversed dictionary |
|
|
|
- "label2id.pkl", a dictionary on labels |
|
|
|
- "id2label.pkl", a reversed dictionary on labels |
|
|
|
|
|
|
|
These four pickle files are expected to be saved in the given pickle directory once they are constructed. |
|
|
|
Preprocessors will check if those files are already in the directory and will reuse them in future calls. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self.word2index = None |
|
|
|
self.label2index = None |
|
|
|
self.data_vocab = Vocabulary() |
|
|
|
self.label_vocab = Vocabulary() |
|
|
|
|
|
|
|
@property |
|
|
|
def vocab_size(self): |
|
|
|
return len(self.word2index) |
|
|
|
return len(self.data_vocab) |
|
|
|
|
|
|
|
@property |
|
|
|
def num_classes(self): |
|
|
|
return len(self.label2index) |
|
|
|
return len(self.label_vocab) |
|
|
|
|
|
|
|
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): |
|
|
|
"""Main pre-processing pipeline. |
|
|
@@ -102,20 +91,14 @@ class BasePreprocess(object): |
|
|
|
""" |
|
|
|
|
|
|
|
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): |
|
|
|
self.word2index = load_pickle(pickle_path, "word2id.pkl") |
|
|
|
self.label2index = load_pickle(pickle_path, "class2id.pkl") |
|
|
|
self.data_vocab = load_pickle(pickle_path, "word2id.pkl") |
|
|
|
self.label_vocab = load_pickle(pickle_path, "class2id.pkl") |
|
|
|
else: |
|
|
|
self.word2index, self.label2index = self.build_dict(train_dev_data) |
|
|
|
save_pickle(self.word2index, pickle_path, "word2id.pkl") |
|
|
|
save_pickle(self.label2index, pickle_path, "class2id.pkl") |
|
|
|
|
|
|
|
if not pickle_exist(pickle_path, "id2word.pkl"): |
|
|
|
index2word = self.build_reverse_dict(self.word2index) |
|
|
|
save_pickle(index2word, pickle_path, "id2word.pkl") |
|
|
|
self.data_vocab, self.label_vocab = self.build_dict(train_dev_data) |
|
|
|
save_pickle(self.data_vocab, pickle_path, "word2id.pkl") |
|
|
|
save_pickle(self.label_vocab, pickle_path, "class2id.pkl") |
|
|
|
|
|
|
|
if not pickle_exist(pickle_path, "id2class.pkl"): |
|
|
|
index2label = self.build_reverse_dict(self.label2index) |
|
|
|
save_pickle(index2label, pickle_path, "id2class.pkl") |
|
|
|
self.build_reverse_dict() |
|
|
|
|
|
|
|
train_set = [] |
|
|
|
dev_set = [] |
|
|
@@ -125,13 +108,13 @@ class BasePreprocess(object): |
|
|
|
split = int(len(train_dev_data) * train_dev_split) |
|
|
|
data_dev = train_dev_data[: split] |
|
|
|
data_train = train_dev_data[split:] |
|
|
|
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) |
|
|
|
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) |
|
|
|
train_set = self.convert_to_dataset(data_train, self.data_vocab, self.label_vocab) |
|
|
|
dev_set = self.convert_to_dataset(data_dev, self.data_vocab, self.label_vocab) |
|
|
|
|
|
|
|
save_pickle(dev_set, pickle_path, "data_dev.pkl") |
|
|
|
print("{} of the training data is split for validation. ".format(train_dev_split)) |
|
|
|
else: |
|
|
|
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) |
|
|
|
train_set = self.convert_to_dataset(train_dev_data, self.data_vocab, self.label_vocab) |
|
|
|
save_pickle(train_set, pickle_path, "data_train.pkl") |
|
|
|
else: |
|
|
|
train_set = load_pickle(pickle_path, "data_train.pkl") |
|
|
@@ -143,8 +126,8 @@ class BasePreprocess(object): |
|
|
|
# cross validation |
|
|
|
data_cv = self.cv_split(train_dev_data, n_fold) |
|
|
|
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): |
|
|
|
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) |
|
|
|
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) |
|
|
|
data_train_cv = self.convert_to_dataset(data_train_cv, self.data_vocab, self.label_vocab) |
|
|
|
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.data_vocab, self.label_vocab) |
|
|
|
save_pickle( |
|
|
|
data_train_cv, pickle_path, |
|
|
|
"data_train_{}.pkl".format(i)) |
|
|
@@ -165,7 +148,7 @@ class BasePreprocess(object): |
|
|
|
test_set = [] |
|
|
|
if test_data is not None: |
|
|
|
if not pickle_exist(pickle_path, "data_test.pkl"): |
|
|
|
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) |
|
|
|
test_set = self.convert_to_dataset(test_data, self.data_vocab, self.label_vocab) |
|
|
|
save_pickle(test_set, pickle_path, "data_test.pkl") |
|
|
|
|
|
|
|
# return preprocessed results |
|
|
@@ -180,28 +163,15 @@ class BasePreprocess(object): |
|
|
|
return tuple(results) |
|
|
|
|
|
|
|
def build_dict(self, data): |
|
|
|
label2index = DEFAULT_WORD_TO_INDEX.copy() |
|
|
|
word2index = DEFAULT_WORD_TO_INDEX.copy() |
|
|
|
for example in data: |
|
|
|
for word in example[0]: |
|
|
|
if word not in word2index: |
|
|
|
word2index[word] = len(word2index) |
|
|
|
label = example[1] |
|
|
|
if isinstance(label, str): |
|
|
|
# label is a string |
|
|
|
if label not in label2index: |
|
|
|
label2index[label] = len(label2index) |
|
|
|
elif isinstance(label, list): |
|
|
|
# label is a list of strings |
|
|
|
for single_label in label: |
|
|
|
if single_label not in label2index: |
|
|
|
label2index[single_label] = len(label2index) |
|
|
|
return word2index, label2index |
|
|
|
|
|
|
|
|
|
|
|
def build_reverse_dict(self, word_dict): |
|
|
|
id2word = {word_dict[w]: w for w in word_dict} |
|
|
|
return id2word |
|
|
|
word, label = example |
|
|
|
self.data_vocab.update(word) |
|
|
|
self.label_vocab.update(label) |
|
|
|
return self.data_vocab, self.label_vocab |
|
|
|
|
|
|
|
def build_reverse_dict(self): |
|
|
|
self.data_vocab.build_reverse_vocab() |
|
|
|
self.label_vocab.build_reverse_vocab() |
|
|
|
|
|
|
|
def data_split(self, data, train_dev_split): |
|
|
|
"""Split data into train and dev set.""" |
|
|
|