|
|
@@ -59,7 +59,6 @@ class BasePreprocess(object): |
|
|
|
|
|
|
|
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): |
|
|
|
"""Main preprocessing pipeline. |
|
|
|
|
|
|
|
:param train_dev_data: three-level list, with either single label or multiple labels in a sample. |
|
|
|
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional) |
|
|
|
:param pickle_path: str, the path to save the pickle files. |
|
|
@@ -98,6 +97,7 @@ class BasePreprocess(object): |
|
|
|
save_pickle(data_train, pickle_path, "data_train.pkl") |
|
|
|
else: |
|
|
|
data_train = load_pickle(pickle_path, "data_train.pkl") |
|
|
|
data_dev = load_pickle(pickle_path, "data_dev.pkl") |
|
|
|
else: |
|
|
|
# cross_val is True |
|
|
|
if not pickle_exist(pickle_path, "data_train_0.pkl"): |
|
|
@@ -307,4 +307,4 @@ def infer_preprocess(pickle_path, data): |
|
|
|
data_index = [] |
|
|
|
for example in data: |
|
|
|
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) |
|
|
|
return data_index |
|
|
|
return data_index |