- move preprocess.py from loader/ to core/ - changes to interface of preprocess: 1. add run method, to run the main processing 2. add cross validation split 3. add return value 4. merge subclasses - Trainer supports cross validation - add data as arguments in Trainer.train & Tester.test - add readme.example.py, to run the example program shown in README.md - other corresponding changestags/v0.1.0
@@ -3,7 +3,7 @@ import torch | |||||
from fastNLP.core.action import Batchifier, SequentialSampler | from fastNLP.core.action import Batchifier, SequentialSampler | ||||
from fastNLP.core.action import convert_to_torch_tensor | from fastNLP.core.action import convert_to_torch_tensor | ||||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||||
from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||||
from fastNLP.modules import utils | from fastNLP.modules import utils | ||||
@@ -0,0 +1,306 @@ | |||||
import _pickle | |||||
import os | |||||
import numpy as np | |||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
'<reserved-3>', | |||||
'<reserved-4>'] # dict index = 2~4 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
# the first vocab in dict with the index = 5 | |||||
def save_pickle(obj, pickle_path, file_name): | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved. ".format(file_name)) | |||||
def load_pickle(pickle_path, file_name): | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded. ".format(file_name)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
""" | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
class BasePreprocess(object): | |||||
def __init__(self): | |||||
self.word2index = None | |||||
self.label2index = None | |||||
@property | |||||
def vocab_size(self): | |||||
return len(self.word2index) | |||||
@property | |||||
def num_classes(self): | |||||
return len(self.label2index) | |||||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | |||||
"""Main preprocessing pipeline. | |||||
:param train_dev_data: three-level list, with either single label or multiple labels in a sample. | |||||
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional) | |||||
:param pickle_path: str, the path to save the pickle files. | |||||
:param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. | |||||
:param cross_val: bool, whether to do cross validation. | |||||
:param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. | |||||
:return results: a tuple of datasets after preprocessing. | |||||
""" | |||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||||
self.word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
self.label2index = load_pickle(pickle_path, "class2id.pkl") | |||||
else: | |||||
self.word2index, self.label2index = self.build_dict(train_dev_data) | |||||
save_pickle(self.word2index, pickle_path, "word2id.pkl") | |||||
save_pickle(self.label2index, pickle_path, "class2id.pkl") | |||||
if not pickle_exist(pickle_path, "id2word.pkl"): | |||||
index2word = self.build_reverse_dict(self.word2index) | |||||
save_pickle(index2word, pickle_path, "id2word.pkl") | |||||
if not pickle_exist(pickle_path, "id2class.pkl"): | |||||
index2label = self.build_reverse_dict(self.label2index) | |||||
save_pickle(index2label, pickle_path, "id2class.pkl") | |||||
data_train = [] | |||||
data_dev = [] | |||||
if not cross_val: | |||||
if not pickle_exist(pickle_path, "data_train.pkl"): | |||||
data_train.extend(self.to_index(train_dev_data)) | |||||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | |||||
split = int(len(data_train) * train_dev_split) | |||||
data_dev = data_train[: split] | |||||
data_train = data_train[split:] | |||||
save_pickle(data_dev, pickle_path, "data_dev.pkl") | |||||
print("{} of the training data is split for validation. ".format(train_dev_split)) | |||||
save_pickle(data_train, pickle_path, "data_train.pkl") | |||||
else: | |||||
data_train = load_pickle(pickle_path, "data_train.pkl") | |||||
else: | |||||
# cross_val is True | |||||
if not pickle_exist(pickle_path, "data_train_0.pkl"): | |||||
# cross validation | |||||
data_idx = self.to_index(train_dev_data) | |||||
data_cv = self.cv_split(data_idx, n_fold) | |||||
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | |||||
save_pickle( | |||||
data_train_cv, pickle_path, | |||||
"data_train_{}.pkl".format(i)) | |||||
save_pickle( | |||||
data_dev_cv, pickle_path, | |||||
"data_dev_{}.pkl".format(i)) | |||||
data_train.append(data_train_cv) | |||||
data_dev.append(data_dev_cv) | |||||
print("{}-fold cross validation.".format(n_fold)) | |||||
else: | |||||
for i in range(n_fold): | |||||
data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) | |||||
data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) | |||||
data_train.append(data_train_cv) | |||||
data_dev.append(data_dev_cv) | |||||
# prepare test data if provided | |||||
data_test = [] | |||||
if test_data is not None: | |||||
if not pickle_exist(pickle_path, "data_test.pkl"): | |||||
data_test = self.to_index(test_data) | |||||
save_pickle(data_test, pickle_path, "data_test.pkl") | |||||
# return preprocessed results | |||||
results = [data_train] | |||||
if cross_val or train_dev_split > 0: | |||||
results.append(data_dev) | |||||
if test_data: | |||||
results.append(data_test) | |||||
return tuple(results) | |||||
def build_dict(self, data): | |||||
raise NotImplementedError | |||||
def to_index(self, data): | |||||
raise NotImplementedError | |||||
def build_reverse_dict(self, word_dict): | |||||
id2word = {word_dict[w]: w for w in word_dict} | |||||
return id2word | |||||
def data_split(self, data, train_dev_split): | |||||
"""Split data into train and dev set.""" | |||||
split = int(len(data) * train_dev_split) | |||||
data_dev = data[: split] | |||||
data_train = data[split:] | |||||
return data_train, data_dev | |||||
def cv_split(self, data, n_fold): | |||||
"""Split data for cross validation.""" | |||||
data_copy = data.copy() | |||||
np.random.shuffle(data_copy) | |||||
fold_size = round(len(data_copy) / n_fold) | |||||
data_cv = [] | |||||
for i in range(n_fold - 1): | |||||
start = i * fold_size | |||||
end = (i + 1) * fold_size | |||||
data_dev = data_copy[start:end] | |||||
data_train = data_copy[:start] + data_copy[end:] | |||||
data_cv.append((data_train, data_dev)) | |||||
start = (n_fold - 1) * fold_size | |||||
data_dev = data_copy[start:] | |||||
data_train = data_copy[:start] | |||||
data_cv.append((data_train, data_dev)) | |||||
return data_cv | |||||
class SeqLabelPreprocess(BasePreprocess): | |||||
"""Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
from labels/classes to index, from index to labels/classes. | |||||
data of three-level list which have multiple labels in each sample. | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
""" | |||||
def __init__(self): | |||||
super(SeqLabelPreprocess, self).__init__() | |||||
def build_dict(self, data): | |||||
""" | |||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||||
:param data: three-level list | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
:return word2index: dict of {str, int} | |||||
label2index: dict of {str, int} | |||||
""" | |||||
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. | |||||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
for example in data: | |||||
for word, label in zip(example[0], example[1]): | |||||
if word not in word2index: | |||||
word2index[word] = len(word2index) | |||||
if label not in label2index: | |||||
label2index[label] = len(label2index) | |||||
return word2index, label2index | |||||
def to_index(self, data): | |||||
""" | |||||
Convert word strings and label strings into indices. | |||||
:param data: three-level list | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
:return data_index: the same shape as data, but each string is replaced by its corresponding index | |||||
""" | |||||
data_index = [] | |||||
for example in data: | |||||
word_list = [] | |||||
label_list = [] | |||||
for word, label in zip(example[0], example[1]): | |||||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||||
label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||||
data_index.append([word_list, label_list]) | |||||
return data_index | |||||
class ClassPreprocess(BasePreprocess): | |||||
""" Preprocess pipeline for classification datasets. | |||||
Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
from labels/classes to index, from index to labels/classes. | |||||
design for data of three-level list which has a single label in each sample. | |||||
[ | |||||
[ [word_11, word_12, ...], label_1 ], | |||||
[ [word_21, word_22, ...], label_2 ], | |||||
... | |||||
] | |||||
""" | |||||
def __init__(self): | |||||
super(ClassPreprocess, self).__init__() | |||||
def build_dict(self, data): | |||||
"""Build vocabulary.""" | |||||
# build vocabulary from scratch if nothing exists | |||||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
# collect every word and label | |||||
for sent, label in data: | |||||
if len(sent) <= 1: | |||||
continue | |||||
if label not in label2index: | |||||
label2index[label] = len(label2index) | |||||
for word in sent: | |||||
if word not in word2index: | |||||
word2index[word[0]] = len(word2index) | |||||
return word2index, label2index | |||||
def to_index(self, data): | |||||
""" | |||||
Convert word strings and label strings into indices. | |||||
:param data: three-level list | |||||
[ | |||||
[ [word_11, word_12, ...], label_1 ], | |||||
[ [word_21, word_22, ...], label_2 ], | |||||
... | |||||
] | |||||
:return data_index: the same shape as data, but each string is replaced by its corresponding index | |||||
""" | |||||
data_index = [] | |||||
for example in data: | |||||
word_list = [] | |||||
for word, label in zip(example[0]): | |||||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||||
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | |||||
data_index.append([word_list, label_index]) | |||||
return data_index | |||||
def infer_preprocess(pickle_path, data): | |||||
""" | |||||
Preprocess over inference data. | |||||
Transform three-level list of strings into that of index. | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
""" | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
data_index = [] | |||||
for example in data: | |||||
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) | |||||
return data_index |
@@ -34,7 +34,7 @@ class BaseTester(object): | |||||
self.eval_history = [] | self.eval_history = [] | ||||
self.batch_output = [] | self.batch_output = [] | ||||
def test(self, network): | |||||
def test(self, network, dev_data): | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self.model = network.cuda() | self.model = network.cuda() | ||||
else: | else: | ||||
@@ -45,8 +45,8 @@ class BaseTester(object): | |||||
self.eval_history.clear() | self.eval_history.clear() | ||||
self.batch_output.clear() | self.batch_output.clear() | ||||
dev_data = self.prepare_input(self.pickle_path) | |||||
logger.info("validation data loaded") | |||||
# dev_data = self.prepare_input(self.pickle_path) | |||||
# logger.info("validation data loaded") | |||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | ||||
n_batches = len(dev_data) // self.batch_size | n_batches = len(dev_data) // self.batch_size | ||||
@@ -1,4 +1,5 @@ | |||||
import _pickle | import _pickle | ||||
import copy | |||||
import os | import os | ||||
import time | import time | ||||
from datetime import timedelta | from datetime import timedelta | ||||
@@ -52,9 +53,11 @@ class BaseTrainer(object): | |||||
self.loss_func = None | self.loss_func = None | ||||
self.optimizer = None | self.optimizer = None | ||||
def train(self, network): | |||||
def train(self, network, train_data, dev_data=None): | |||||
"""General Training Steps | """General Training Steps | ||||
:param network: a model | :param network: a model | ||||
:param train_data: three-level list, the training set. | |||||
:param dev_data: three-level list, the validation data (optional) | |||||
The method is framework independent. | The method is framework independent. | ||||
Work by calling the following methods: | Work by calling the following methods: | ||||
@@ -73,8 +76,8 @@ class BaseTrainer(object): | |||||
else: | else: | ||||
self.model = network | self.model = network | ||||
data_train = self.load_train_data(self.pickle_path) | |||||
logger.info("training data loaded") | |||||
# train_data = self.load_train_data(self.pickle_path) | |||||
# logger.info("training data loaded") | |||||
# define tester over dev data | # define tester over dev data | ||||
if self.validate: | if self.validate: | ||||
@@ -88,8 +91,7 @@ class BaseTrainer(object): | |||||
logger.info("optimizer defined as {}".format(str(self.optimizer))) | logger.info("optimizer defined as {}".format(str(self.optimizer))) | ||||
# main training epochs | # main training epochs | ||||
n_samples = len(data_train) | |||||
n_samples = len(train_data) | |||||
n_batches = n_samples // self.batch_size | n_batches = n_samples // self.batch_size | ||||
n_print = 1 | n_print = 1 | ||||
start = time.time() | start = time.time() | ||||
@@ -101,14 +103,14 @@ class BaseTrainer(object): | |||||
# turn on network training mode | # turn on network training mode | ||||
self.mode(network, test=False) | self.mode(network, test=False) | ||||
# prepare mini-batch iterator | # prepare mini-batch iterator | ||||
data_iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) | |||||
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) | |||||
logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) | self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch) | ||||
if self.validate: | if self.validate: | ||||
logger.info("validation started") | logger.info("validation started") | ||||
validator.test(network) | |||||
validator.test(network, dev_data) | |||||
if self.save_best_dev and self.best_eval_result(validator): | if self.save_best_dev and self.best_eval_result(validator): | ||||
self.save_model(network) | self.save_model(network) | ||||
@@ -139,6 +141,26 @@ class BaseTrainer(object): | |||||
logger.info(print_output) | logger.info(print_output) | ||||
step += 1 | step += 1 | ||||
def cross_validate(self, network, train_data_cv, dev_data_cv): | |||||
"""Training with cross validation. | |||||
:param network: the model | |||||
:param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||||
:param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||||
""" | |||||
if len(train_data_cv) != len(dev_data_cv): | |||||
logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | |||||
len(dev_data_cv))) | |||||
raise RuntimeError("the number of folds in train and dev data unequals") | |||||
n_fold = len(train_data_cv) | |||||
logger.info("perform {} folds cross validation.".format(n_fold)) | |||||
for i in range(n_fold): | |||||
print("CV:", i) | |||||
logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold)) | |||||
network_copy = copy.deepcopy(network) | |||||
self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||||
def load_train_data(self, pickle_path): | def load_train_data(self, pickle_path): | ||||
""" | """ | ||||
For task-specific processing. | For task-specific processing. | ||||
@@ -1,366 +0,0 @@ | |||||
import _pickle | |||||
import os | |||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||||
'<reserved-3>', | |||||
'<reserved-4>'] # dict index = 2~4 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
# the first vocab in dict with the index = 5 | |||||
def save_pickle(obj, pickle_path, file_name): | |||||
with open(os.path.join(pickle_path, file_name), "wb") as f: | |||||
_pickle.dump(obj, f) | |||||
print("{} saved. ".format(file_name)) | |||||
def load_pickle(pickle_path, file_name): | |||||
with open(os.path.join(pickle_path, file_name), "rb") as f: | |||||
obj = _pickle.load(f) | |||||
print("{} loaded. ".format(file_name)) | |||||
return obj | |||||
def pickle_exist(pickle_path, pickle_name): | |||||
""" | |||||
:param pickle_path: the directory of target pickle file | |||||
:param pickle_name: the filename of target pickle file | |||||
:return: True if file exists else False | |||||
""" | |||||
if not os.path.exists(pickle_path): | |||||
os.makedirs(pickle_path) | |||||
file_name = os.path.join(pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
class BasePreprocess(object): | |||||
def __init__(self, data, pickle_path): | |||||
super(BasePreprocess, self).__init__() | |||||
# self.data = data | |||||
self.pickle_path = pickle_path | |||||
if not self.pickle_path.endswith('/'): | |||||
self.pickle_path = self.pickle_path + '/' | |||||
class POSPreprocess(BasePreprocess): | |||||
""" | |||||
This class are used to preprocess the POS Tag datasets. | |||||
""" | |||||
def __init__(self, data, pickle_path="./", train_dev_split=0): | |||||
""" | |||||
Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
from labels/classes to index, from index to labels/classes. | |||||
:param data: three-level list | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
:param pickle_path: str, the directory to the pickle files. Default: "./" | |||||
:param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. | |||||
""" | |||||
super(POSPreprocess, self).__init__(data, pickle_path) | |||||
self.pickle_path = pickle_path | |||||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||||
self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | |||||
self.label2index = load_pickle(self.pickle_path, "class2id.pkl") | |||||
else: | |||||
self.word2index, self.label2index = self.build_dict(data) | |||||
save_pickle(self.word2index, self.pickle_path, "word2id.pkl") | |||||
save_pickle(self.label2index, self.pickle_path, "class2id.pkl") | |||||
if not pickle_exist(pickle_path, "id2word.pkl"): | |||||
index2word = self.build_reverse_dict(self.word2index) | |||||
save_pickle(index2word, self.pickle_path, "id2word.pkl") | |||||
if not pickle_exist(pickle_path, "id2class.pkl"): | |||||
index2label = self.build_reverse_dict(self.label2index) | |||||
save_pickle(index2label, self.pickle_path, "id2class.pkl") | |||||
if not pickle_exist(pickle_path, "data_train.pkl"): | |||||
data_train = self.to_index(data) | |||||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | |||||
split = int(len(data_train) * train_dev_split) | |||||
data_dev = data_train[: split] | |||||
data_train = data_train[split:] | |||||
save_pickle(data_dev, self.pickle_path, "data_dev.pkl") | |||||
print("{} of the training data is split for validation. ".format(train_dev_split)) | |||||
save_pickle(data_train, self.pickle_path, "data_train.pkl") | |||||
def build_dict(self, data): | |||||
""" | |||||
Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||||
:param data: three-level list | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
:return word2index: dict of {str, int} | |||||
label2index: dict of {str, int} | |||||
""" | |||||
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. | |||||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
for example in data: | |||||
for word, label in zip(example[0], example[1]): | |||||
if word not in word2index: | |||||
word2index[word] = len(word2index) | |||||
if label not in label2index: | |||||
label2index[label] = len(label2index) | |||||
return word2index, label2index | |||||
def build_reverse_dict(self, word_dict): | |||||
id2word = {word_dict[w]: w for w in word_dict} | |||||
return id2word | |||||
def to_index(self, data): | |||||
""" | |||||
Convert word strings and label strings into indices. | |||||
:param data: three-level list | |||||
[ | |||||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
... | |||||
] | |||||
:return data_index: the shape of data, but each string is replaced by its corresponding index | |||||
""" | |||||
data_index = [] | |||||
for example in data: | |||||
word_list = [] | |||||
label_list = [] | |||||
for word, label in zip(example[0], example[1]): | |||||
word_list.append(self.word2index[word]) | |||||
label_list.append(self.label2index[label]) | |||||
data_index.append([word_list, label_list]) | |||||
return data_index | |||||
@property | |||||
def vocab_size(self): | |||||
return len(self.word2index) | |||||
@property | |||||
def num_classes(self): | |||||
return len(self.label2index) | |||||
class ClassPreprocess(BasePreprocess): | |||||
""" | |||||
Pre-process the classification datasets. | |||||
Params: | |||||
pickle_path - directory to save result of pre-processing | |||||
Saves: | |||||
word2id.pkl | |||||
id2word.pkl | |||||
class2id.pkl | |||||
id2class.pkl | |||||
embedding.pkl | |||||
data_train.pkl | |||||
data_dev.pkl | |||||
data_test.pkl | |||||
""" | |||||
def __init__(self, pickle_path): | |||||
# super(ClassPreprocess, self).__init__(data, pickle_path) | |||||
self.word_dict = None | |||||
self.label_dict = None | |||||
self.pickle_path = pickle_path # save directory | |||||
def process(self, data, save_name): | |||||
""" | |||||
Process data. | |||||
Params: | |||||
data - nested list, data = [sample1, sample2, ...], | |||||
sample = [sentence, label], sentence = [word1, word2, ...] | |||||
save_name - name of processed data, such as data_train.pkl | |||||
Returns: | |||||
vocab_size - vocabulary size | |||||
n_classes - number of classes | |||||
""" | |||||
self.build_dict(data) | |||||
self.word2id() | |||||
vocab_size = self.id2word() | |||||
self.class2id() | |||||
num_classes = self.id2class() | |||||
self.embedding() | |||||
self.data_generate(data, save_name) | |||||
return vocab_size, num_classes | |||||
def build_dict(self, data): | |||||
"""Build vocabulary.""" | |||||
# just read if word2id.pkl and class2id.pkl exists | |||||
if self.pickle_exist("word2id.pkl") and \ | |||||
self.pickle_exist("class2id.pkl"): | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, 'rb') as f: | |||||
self.word_dict = _pickle.load(f) | |||||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
with open(file_name, 'rb') as f: | |||||
self.label_dict = _pickle.load(f) | |||||
return | |||||
# build vocabulary from scratch if nothing exists | |||||
self.word_dict = { | |||||
DEFAULT_PADDING_LABEL: 0, | |||||
DEFAULT_UNKNOWN_LABEL: 1, | |||||
DEFAULT_RESERVED_LABEL[0]: 2, | |||||
DEFAULT_RESERVED_LABEL[1]: 3, | |||||
DEFAULT_RESERVED_LABEL[2]: 4} | |||||
self.label_dict = {} | |||||
# collect every word and label | |||||
for sent, label in data: | |||||
if len(sent) <= 1: | |||||
continue | |||||
if label not in self.label_dict: | |||||
index = len(self.label_dict) | |||||
self.label_dict[label] = index | |||||
for word in sent: | |||||
if word not in self.word_dict: | |||||
index = len(self.word_dict) | |||||
self.word_dict[word[0]] = index | |||||
def pickle_exist(self, pickle_name): | |||||
""" | |||||
Check whether a pickle file exists. | |||||
Params | |||||
pickle_name: the filename of target pickle file | |||||
Return | |||||
True if file exists else False | |||||
""" | |||||
if not os.path.exists(self.pickle_path): | |||||
os.makedirs(self.pickle_path) | |||||
file_name = os.path.join(self.pickle_path, pickle_name) | |||||
if os.path.exists(file_name): | |||||
return True | |||||
else: | |||||
return False | |||||
def word2id(self): | |||||
"""Save vocabulary of {word:id} mapping format.""" | |||||
# nothing will be done if word2id.pkl exists | |||||
if self.pickle_exist("word2id.pkl"): | |||||
return | |||||
file_name = os.path.join(self.pickle_path, "word2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.word_dict, f) | |||||
def id2word(self): | |||||
"""Save vocabulary of {id:word} mapping format.""" | |||||
# nothing will be done if id2word.pkl exists | |||||
if self.pickle_exist("id2word.pkl"): | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
with open(file_name, 'rb') as f: | |||||
id2word_dict = _pickle.load(f) | |||||
return len(id2word_dict) | |||||
id2word_dict = {self.word_dict[w]: w for w in self.word_dict} | |||||
file_name = os.path.join(self.pickle_path, "id2word.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2word_dict, f) | |||||
return len(id2word_dict) | |||||
def class2id(self): | |||||
"""Save mapping of {class:id}.""" | |||||
# nothing will be done if class2id.pkl exists | |||||
if self.pickle_exist("class2id.pkl"): | |||||
return | |||||
file_name = os.path.join(self.pickle_path, "class2id.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(self.label_dict, f) | |||||
def id2class(self): | |||||
"""Save mapping of {id:class}.""" | |||||
# nothing will be done if id2class.pkl exists | |||||
if self.pickle_exist("id2class.pkl"): | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
with open(file_name, "rb") as f: | |||||
id2class_dict = _pickle.load(f) | |||||
return len(id2class_dict) | |||||
id2class_dict = {self.label_dict[c]: c for c in self.label_dict} | |||||
file_name = os.path.join(self.pickle_path, "id2class.pkl") | |||||
with open(file_name, "wb") as f: | |||||
_pickle.dump(id2class_dict, f) | |||||
return len(id2class_dict) | |||||
def embedding(self): | |||||
"""Save embedding lookup table corresponding to vocabulary.""" | |||||
# nothing will be done if embedding.pkl exists | |||||
if self.pickle_exist("embedding.pkl"): | |||||
return | |||||
# retrieve vocabulary from pre-trained embedding (not implemented) | |||||
def data_generate(self, data_src, save_name): | |||||
"""Convert dataset from text to digit.""" | |||||
# nothing will be done if file exists | |||||
save_path = os.path.join(self.pickle_path, save_name) | |||||
if os.path.exists(save_path): | |||||
return | |||||
data = [] | |||||
# for every sample | |||||
for sent, label in data_src: | |||||
if len(sent) <= 1: | |||||
continue | |||||
label_id = self.label_dict[label] # label id | |||||
sent_id = [] # sentence ids | |||||
for word in sent: | |||||
if word in self.word_dict: | |||||
sent_id.append(self.word_dict[word]) | |||||
else: | |||||
sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) | |||||
data.append([sent_id, label_id]) | |||||
# save data | |||||
with open(save_path, "wb") as f: | |||||
_pickle.dump(data, f) | |||||
class LMPreprocess(BasePreprocess): | |||||
def __init__(self, data, pickle_path): | |||||
super(LMPreprocess, self).__init__(data, pickle_path) | |||||
def infer_preprocess(pickle_path, data): | |||||
""" | |||||
Preprocess over inference data. | |||||
Transform three-level list of strings into that of index. | |||||
[ | |||||
[word_11, word_12, ...], | |||||
[word_21, word_22, ...], | |||||
... | |||||
] | |||||
""" | |||||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
data_index = [] | |||||
for example in data: | |||||
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) | |||||
return data_index |
@@ -1,13 +1,12 @@ | |||||
import os | import os | ||||
import | |||||
import | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
.dataset as dst | |||||
from .model import CNN_text | |||||
from torch.autograd import Variable | from torch.autograd import Variable | ||||
from . import dataset as dst | |||||
from .model import CNN_text | |||||
# Hyper Parameters | # Hyper Parameters | ||||
batch_size = 50 | batch_size = 50 | ||||
learning_rate = 0.0001 | learning_rate = 0.0001 | ||||
@@ -5,7 +5,7 @@ sys.path.append("..") | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
@@ -48,7 +48,7 @@ def infer(): | |||||
print("Inference finished!") | print("Inference finished!") | ||||
def train(): | |||||
def train_test(): | |||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
@@ -59,9 +59,10 @@ def train(): | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = p.vocab_size | |||||
train_args["num_classes"] = p.num_classes | |||||
preprocess = SeqLabelPreprocess() | |||||
data_train, data_dev = preprocess.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||||
train_args["vocab_size"] = preprocess.vocab_size | |||||
train_args["num_classes"] = preprocess.num_classes | |||||
# Trainer | # Trainer | ||||
trainer = SeqLabelTrainer(train_args) | trainer = SeqLabelTrainer(train_args) | ||||
@@ -70,7 +71,7 @@ def train(): | |||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
# Start training | # Start training | ||||
trainer.train(model) | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
@@ -78,8 +79,11 @@ def train(): | |||||
saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
print("Model saved!") | print("Model saved!") | ||||
# testing with validation set | |||||
test(data_dev) | |||||
def test(): | |||||
def test(test_data): | |||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ||||
@@ -99,7 +103,7 @@ def test(): | |||||
tester = SeqLabelTester(test_args) | tester = SeqLabelTester(test_args) | ||||
# Start testing | # Start testing | ||||
tester.test(model) | |||||
tester.test(model, test_data) | |||||
# print test results | # print test results | ||||
print(tester.show_matrices()) | print(tester.show_matrices()) | ||||
@@ -107,4 +111,4 @@ def test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train() | |||||
train_test() |
@@ -4,9 +4,9 @@ import os | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.preprocess import POSPreprocess | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
@@ -114,7 +114,8 @@ emb_path = "data_for_tests/emb50.txt" | |||||
save_path = "data_for_tests/" | save_path = "data_for_tests/" | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
data = data_load(data_path) | data = data_load(data_path) | ||||
p = POSPreprocess(data, pickle_path=pick_path, train_dev_split=0.3) | |||||
preprocess = SeqLabelPreprocess() | |||||
data_train, data_dev = preprocess.run(data, pickle_path=pick_path, train_dev_split=0.3) | |||||
# emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl")) | # emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl")) | ||||
emb = None | emb = None | ||||
args = {"epochs": 20, | args = {"epochs": 20, | ||||
@@ -125,13 +126,13 @@ if __name__ == "__main__": | |||||
"model_saved_path": save_path, | "model_saved_path": save_path, | ||||
"use_cuda": True, | "use_cuda": True, | ||||
"vocab_size": p.vocab_size, | |||||
"num_classes": p.num_classes, | |||||
"vocab_size": preprocess.vocab_size, | |||||
"num_classes": preprocess.num_classes, | |||||
"word_emb_dim": 50, | "word_emb_dim": 50, | ||||
"rnn_hidden_units": 100 | "rnn_hidden_units": 100 | ||||
} | } | ||||
# emb = torch.Tensor(emb).float().cuda() | # emb = torch.Tensor(emb).float().cuda() | ||||
networks = AdvSeqLabel(args, emb) | networks = AdvSeqLabel(args, emb) | ||||
trainer = MyNERTrainer(args) | trainer = MyNERTrainer(args) | ||||
trainer.train(network=networks) | |||||
trainer.train(networks, data_train, data_dev) | |||||
print("Training finished!") | print("Training finished!") |
@@ -0,0 +1,78 @@ | |||||
# python: 3.5 | |||||
# pytorch: 0.4 | |||||
################ | |||||
# Test cross validation. | |||||
################ | |||||
from fastNLP.loader.preprocess import ClassPreprocess | |||||
from fastNLP.core.predictor import ClassificationInfer | |||||
from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
from fastNLP.models.base_model import BaseModel | |||||
from fastNLP.modules import aggregation | |||||
from fastNLP.modules import encoder | |||||
class ClassificationModel(BaseModel): | |||||
""" | |||||
Simple text classification model based on CNN. | |||||
""" | |||||
def __init__(self, class_num, vocab_size): | |||||
super(ClassificationModel, self).__init__() | |||||
self.embed = encoder.Embedding(nums=vocab_size, dims=300) | |||||
self.conv = encoder.Conv( | |||||
in_channels=300, out_channels=100, kernel_size=3) | |||||
self.pool = aggregation.MaxPool() | |||||
self.output = encoder.Linear(input_size=100, output_size=class_num) | |||||
def forward(self, x): | |||||
x = self.embed(x) # [N,L] -> [N,L,C] | |||||
x = self.conv(x) # [N,L,C_in] -> [N,L,C_out] | |||||
x = self.pool(x) # [N,L,C] -> [N,C] | |||||
x = self.output(x) # [N,C] -> [N, N_class] | |||||
return x | |||||
data_dir = 'data' # directory to save data and model | |||||
train_path = 'test/data_for_tests/text_classify.txt' # training set file | |||||
# load dataset | |||||
ds_loader = ClassDatasetLoader("train", train_path) | |||||
data = ds_loader.load() | |||||
# pre-process dataset | |||||
pre = ClassPreprocess(data, data_dir, cross_val=True, n_fold=5) | |||||
# pre = ClassPreprocess(data, data_dir) | |||||
n_classes = pre.num_classes | |||||
vocab_size = pre.vocab_size | |||||
# construct model | |||||
model_args = { | |||||
'num_classes': n_classes, | |||||
'vocab_size': vocab_size | |||||
} | |||||
model = ClassificationModel(class_num=n_classes, vocab_size=vocab_size) | |||||
# train model | |||||
train_args = { | |||||
"epochs": 10, | |||||
"batch_size": 50, | |||||
"pickle_path": data_dir, | |||||
"validate": False, | |||||
"save_best_dev": False, | |||||
"model_saved_path": None, | |||||
"use_cuda": True, | |||||
"learn_rate": 1e-3, | |||||
"momentum": 0.9} | |||||
trainer = ClassificationTrainer(train_args) | |||||
# trainer.train(model, ['data_train.pkl', 'data_dev.pkl']) | |||||
trainer.cross_validate(model) | |||||
# predict using model | |||||
data_infer = [x[0] for x in data] | |||||
infer = ClassificationInfer(data_dir) | |||||
labels_pred = infer.predict(model, data_infer) |
@@ -5,7 +5,7 @@ sys.path.append("..") | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
@@ -68,7 +68,8 @@ def train_and_test(): | |||||
train_data = pos_loader.load_lines() | train_data = pos_loader.load_lines() | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.5) | |||||
p = SeqLabelPreprocess() | |||||
data_train, data_dev = p.run(train_data, pickle_path, train_dev_split=0.5) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
@@ -79,7 +80,7 @@ def train_and_test(): | |||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
# Start training | # Start training | ||||
trainer.train(model) | |||||
trainer.train(model, data_train, data_dev) | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
@@ -103,8 +104,8 @@ def train_and_test(): | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(test_args) | tester = SeqLabelTester(test_args) | ||||
# Start testing | |||||
tester.test(model) | |||||
# Start testing with validation data | |||||
tester.test(model, data_dev) | |||||
# print test results | # print test results | ||||
print(tester.show_matrices()) | print(tester.show_matrices()) | ||||
@@ -5,7 +5,7 @@ sys.path.append("..") | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
@@ -68,7 +68,8 @@ def train_test(): | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path) | |||||
p = SeqLabelPreprocess() | |||||
data_train = p.run(train_data, pickle_path=pickle_path) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
@@ -79,7 +80,7 @@ def train_test(): | |||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
# Start training | # Start training | ||||
trainer.train(model) | |||||
trainer.train(model, data_train) | |||||
print("Training finished!") | print("Training finished!") | ||||
# Saver | # Saver | ||||
@@ -104,7 +105,7 @@ def train_test(): | |||||
tester = SeqLabelTester(test_args) | tester = SeqLabelTester(test_args) | ||||
# Start testing | # Start testing | ||||
tester.test(model) | |||||
tester.test(model, data_train) | |||||
# print test results | # print test results | ||||
print(tester.show_matrices()) | print(tester.show_matrices()) | ||||
@@ -1,7 +1,7 @@ | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | ||||
from fastNLP.loader.preprocess import POSPreprocess | |||||
from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
@@ -17,7 +17,7 @@ def foo(): | |||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | ||||
# Preprocessor | # Preprocessor | ||||
p = POSPreprocess(train_data, pickle_path) | |||||
p = SeqLabelPreprocess(train_data, pickle_path) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
@@ -10,7 +10,7 @@ from fastNLP.core.trainer import ClassificationTrainer | |||||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
from fastNLP.loader.dataset_loader import ClassDatasetLoader | from fastNLP.loader.dataset_loader import ClassDatasetLoader | ||||
from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
from fastNLP.loader.preprocess import ClassPreprocess | |||||
from fastNLP.core.preprocess import ClassPreprocess | |||||
from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
@@ -59,28 +59,28 @@ def train(): | |||||
print(data[0]) | print(data[0]) | ||||
# pre-process data | # pre-process data | ||||
pre = ClassPreprocess(data_dir) | |||||
vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||||
print("vocabulary size:", vocab_size) | |||||
print("number of classes:", n_classes) | |||||
pre = ClassPreprocess() | |||||
data_train = pre.run(data, pickle_path=data_dir) | |||||
print("vocabulary size:", pre.vocab_size) | |||||
print("number of classes:", pre.num_classes) | |||||
# construct model | # construct model | ||||
print("Building model...") | print("Building model...") | ||||
cnn = CNNText(model_args) | |||||
model = CNNText(model_args) | |||||
# train | # train | ||||
print("Training...") | print("Training...") | ||||
trainer = ClassificationTrainer(train_args) | trainer = ClassificationTrainer(train_args) | ||||
trainer.train(cnn) | |||||
trainer.train(model, data_train) | |||||
print("Training finished!") | print("Training finished!") | ||||
saver = ModelSaver("./data_for_tests/saved_model.pkl") | saver = ModelSaver("./data_for_tests/saved_model.pkl") | ||||
saver.save_pytorch(cnn) | |||||
saver.save_pytorch(model) | |||||
print("Model saved!") | print("Model saved!") | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
# train() | |||||
infer() | |||||
train() | |||||
# infer() |