diff --git a/fastNLP/action/inference.py b/fastNLP/action/inference.py index 1a1c4d2c..c0692f28 100644 --- a/fastNLP/action/inference.py +++ b/fastNLP/action/inference.py @@ -1,26 +1,116 @@ +import torch + +from fastNLP.action.action import Batchifier, SequentialSampler +from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL + + class Inference(object): """ This is an interface focusing on predicting output based on trained models. - It does not care about evaluations of the model. + It does not care about evaluations of the model, which is different from Tester. + This is a high-level model wrapper to be called by FastNLP. """ - def __init__(self): - pass + def __init__(self, pickle_path): + self.batch_size = 1 + self.batch_output = [] + self.iterator = None + self.pickle_path = pickle_path + self.index2label = load_pickle(self.pickle_path, "id2class.pkl") + self.word2index = load_pickle(self.pickle_path, "word2id.pkl") + + def predict(self, network, data): + """ + Perform inference. + :param network: + :param data: multi-level lists of strings + :return result: the model outputs + """ + # transform strings into indices + data = self.prepare_input(data) + + # turn on the testing mode; clean up the history + self.mode(network, test=True) + + self.iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) + + num_iter = len(data) // self.batch_size + + for step in range(num_iter): + batch_x = self.batchify(data) + + prediction = self.data_forward(network, batch_x) + + self.batch_output.append(prediction) + + return self.prepare_output(self.batch_output) + + def mode(self, network, test=True): + if test: + network.eval() + else: + network.train() + self.batch_output.clear() + + def data_forward(self, network, x): + """ + This is only for sequence labeling with CRF decoder. To do: more general ? + :param network: + :param x: + :return: + """ + seq_len = [len(seq) for seq in x] + x = torch.Tensor(x).long() + y = network(x) + prediction = network.prediction(y, seq_len) + # To do: hide framework + results = torch.Tensor(prediction).view(-1, ) + return list(results.data) + + def batchify(self, data): + indices = next(self.iterator) + batch_x = [data[idx] for idx in indices] + batch_x = self.pad(batch_x) + return batch_x + + @staticmethod + def pad(batch, fill=0): + """ + Pad a batch of samples to maximum length. + :param batch: list of list + :param fill: word index to pad, default 0. + :return: a padded batch + """ + max_length = max([len(x) for x in batch]) + for idx, sample in enumerate(batch): + if len(sample) < max_length: + batch[idx] = sample + [fill * (max_length - len(sample))] + return batch - def predict(self, model, data): + def prepare_input(self, data): """ - this is actually a forward pass. shall be shared by Trainer/Tester - :param model: + Transform three-level list of strings into that of index. :param data: - :return result: the output results + [ + [word_11, word_12, ...], + [word_21, word_22, ...], + ... + ] """ - raise NotImplementedError + data_index = [] + default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL] + for example in data: + data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) + return data_index - def prepare_input(self, data_path): + def prepare_output(self, batch_outputs): """ - This can also be shared. - :param data_path: + Transform list of batch outputs into strings. + :param batch_outputs: list of list [num_batch, tag_seq_length] :return: """ - raise NotImplementedError + results = [] + for batch in batch_outputs: + results.append([self.index2label[int(x.data)] for x in batch]) + return results diff --git a/fastNLP/action/trainer.py b/fastNLP/action/trainer.py index a9162b2a..c3a66bac 100644 --- a/fastNLP/action/trainer.py +++ b/fastNLP/action/trainer.py @@ -86,7 +86,7 @@ class BaseTrainer(Action): # training iterations in one epoch for step in range(iterations): - batch_x, batch_y = self.batchify(data_train) # pad ? + batch_x, batch_y = self.make_batch(data_train) prediction = self.data_forward(network, batch_x) @@ -180,7 +180,7 @@ class BaseTrainer(Action): """ raise NotImplementedError - def batchify(self, data, output_length=True): + def make_batch(self, data, output_length=True): """ 1. Perform batching from data and produce a batch of training data. 2. Add padding. @@ -191,9 +191,12 @@ class BaseTrainer(Action): [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 ... ] - :return batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] + :return (batch_x, seq_len): tuple of two elements, if output_length is true. + batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] + seq_len: list. The length of the pre-padded sequence, if output_length is True. batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] - seq_len: list. The length of the pre-padded sequence, if output_length is True. + + return batch_x and batch_y, if output_length is False """ indices = next(self.iterator) batch = [data[idx] for idx in indices] @@ -202,7 +205,7 @@ class BaseTrainer(Action): batch_x_pad = self.pad(batch_x) if output_length: seq_len = [len(x) for x in batch_x] - return batch_x_pad, batch_y, seq_len + return (batch_x_pad, seq_len), batch_y else: return batch_x_pad, batch_y @@ -292,17 +295,23 @@ class POSTrainer(BaseTrainer): data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) return data_train, data_dev, 0, 1 - def data_forward(self, network, x): + def data_forward(self, network, inputs): """ :param network: the PyTorch model - :param x: list of list, [batch_size, max_len] + :param inputs: list of list, [batch_size, max_len], + or tuple of (batch_x, seq_len), batch_x == [batch_size, max_len] :return y: [batch_size, max_len, tag_size] """ - self.seq_len = [len(seq) for seq in x] + # unpack the returned value from make_batch + if isinstance(inputs, tuple): + x = inputs[0] + self.seq_len = inputs[1] + else: + x = inputs x = torch.Tensor(x).long() self.batch_size = x.size(0) self.max_len = x.size(1) - # self.mask = seq_mask(seq_len, self.max_len) + y = network(x) return y @@ -325,11 +334,12 @@ class POSTrainer(BaseTrainer): def get_loss(self, predict, truth): """ Compute loss given prediction and ground truth. - :param predict: prediction label vector, [batch_size, tag_size, tag_size] + :param predict: prediction label vector, [batch_size, max_len, tag_size] :param truth: ground truth label vector, [batch_size, max_len] :return: a scalar """ truth = torch.Tensor(truth) + assert truth.shape == (self.batch_size, self.max_len) if self.loss_func is None: if hasattr(self.model, "loss"): self.loss_func = self.model.loss @@ -347,6 +357,35 @@ class POSTrainer(BaseTrainer): else: return False + def make_batch(self, data, output_length=True): + """ + 1. Perform batching from data and produce a batch of training data. + 2. Add padding. + :param data: list. Each entry is a sample, which is also a list of features and label(s). + E.g. + [ + [[word_11, word_12, word_13], [label_11. label_12]], # sample 1 + [[word_21, word_22, word_23], [label_21. label_22]], # sample 2 + ... + ] + :return (batch_x, seq_len): tuple of two elements, if output_length is true. + batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] + seq_len: list. The length of the pre-padded sequence, if output_length is True. + batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] + + return batch_x and batch_y, if output_length is False + """ + indices = next(self.iterator) + batch = [data[idx] for idx in indices] + batch_x = [sample[0] for sample in batch] + batch_y = [sample[1] for sample in batch] + batch_x_pad = self.pad(batch_x) + if output_length: + seq_len = [len(x) for x in batch_x] + return (batch_x_pad, seq_len), batch_y + else: + return batch_x_pad, batch_y + class LanguageModelTrainer(BaseTrainer): """ @@ -438,7 +477,7 @@ class ClassTrainer(BaseTrainer): # training iterations in one epoch step = 0 - for batch_x, batch_y in self.batchify(data_train): + for batch_x, batch_y in self.make_batch(data_train): prediction = self.data_forward(network, batch_x) loss = self.get_loss(prediction, batch_y) @@ -533,7 +572,7 @@ class ClassTrainer(BaseTrainer): """Apply gradient.""" self.optimizer.step() - def batchify(self, data): + def make_batch(self, data): """Batch and pad data.""" for indices in self.iterator: batch = [data[idx] for idx in indices] @@ -559,4 +598,4 @@ if __name__ == "__name__": train_args = {"epochs": 1, "validate": False, "batch_size": 3, "pickle_path": "./"} trainer = BaseTrainer(train_args) data_train = [[[1, 2, 3, 4], [0]] * 10] + [[[1, 3, 5, 2], [1]] * 10] - trainer.batchify(data=data_train) + trainer.make_batch(data=data_train) diff --git a/fastNLP/fastNLP.py b/fastNLP/fastNLP.py new file mode 100644 index 00000000..cfda830c --- /dev/null +++ b/fastNLP/fastNLP.py @@ -0,0 +1,104 @@ +from fastNLP.action.inference import Inference +from fastNLP.loader.config_loader import ConfigLoader, ConfigSection +from fastNLP.loader.model_loader import ModelLoader + +""" +mapping from model name to [URL, file_name.class_name] +Notice that the class of the model should be in "models" directory. + +Example: + "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling"] +""" +FastNLP_MODEL_COLLECTION = { + "zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling"] +} + + +class FastNLP(object): + """ + High-level interface for direct model inference. + Usage: + fastnlp = FastNLP() + fastnlp.load("zh_pos_tag_model") + text = "这是最好的基于深度学习的中文分词系统。" + result = fastnlp.run(text) + print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"] + """ + + def __init__(self, model_dir="./"): + self.model_dir = model_dir + self.model = None + + def load(self, model_name): + """ + Load a pre-trained FastNLP model together with additional data. + :param model_name: str, the name of a FastNLP model. + """ + assert type(model_name) is str + if model_name not in FastNLP_MODEL_COLLECTION: + raise ValueError("No FastNLP model named {}.".format(model_name)) + + if not self.model_exist(model_dir=self.model_dir): + self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0]) + + model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1]) + + model_args = ConfigSection() + # To do: customized config file for model init parameters + ConfigLoader.load_config(self.model_dir + "default.cfg", model_args) + + model = model_class(model_args) + + # To do: framework independent + ModelLoader.load_pytorch(model, self.model_dir + model_name) + + self.model = model + + print("Model loaded. ") + + def run(self, infer_input): + """ + Perform inference over given input using the loaded model. + :param infer_input: str, raw text + :return results: + """ + infer = Inference() + data = infer.prepare_input(infer_input) + results = infer.predict(self.model, data) + return results + + @staticmethod + def _get_model_class(file_class_name): + """ + Feature the class specified by + :param file_class_name: str, contains the name of the Python module followed by the name of the class. + Example: "sequence_modeling.SeqLabeling" + :return module: the model class + """ + import_prefix = "fastNLP.models." + parts = (import_prefix + file_class_name).split(".") + from_module = ".".join(parts[:-1]) + module = __import__(from_module) + for sub in parts[1:]: + module = getattr(module, sub) + return module + + def _load(self, model_dir, model_name): + # To do + return 0 + + def _download(self, model_name, url): + """ + Download the model weights from and save in . + :param model_name: + :param url: + """ + print("Downloading {} from {}".format(model_name, url)) + # To do + + def model_exist(self, model_dir): + """ + Check whether the desired model is already in the directory. + :param model_dir: + """ + pass diff --git a/fastNLP/loader/base_loader.py b/fastNLP/loader/base_loader.py index 2863f01f..45a379c1 100644 --- a/fastNLP/loader/base_loader.py +++ b/fastNLP/loader/base_loader.py @@ -17,7 +17,7 @@ class BaseLoader(object): def load_lines(self): with open(self.data_path, "r", encoding="utf=8") as f: text = f.readlines() - return text + return [line.strip() for line in text] class ToyLoader0(BaseLoader): diff --git a/fastNLP/loader/model_loader.py b/fastNLP/loader/model_loader.py index 8224b3f2..1e1d4f8f 100644 --- a/fastNLP/loader/model_loader.py +++ b/fastNLP/loader/model_loader.py @@ -11,9 +11,11 @@ class ModelLoader(BaseLoader): def __init__(self, data_name, data_path): super(ModelLoader, self).__init__(data_name, data_path) - def load_pytorch(self, empty_model): + @staticmethod + def load_pytorch(empty_model, model_path): """ Load model parameters from .pkl files into the empty PyTorch model. :param empty_model: a PyTorch model with initialized parameters. + :param model_path: str, the path to the saved model. """ - empty_model.load_state_dict(torch.load(self.data_path)) + empty_model.load_state_dict(torch.load(model_path)) diff --git a/fastNLP/loader/preprocess.py b/fastNLP/loader/preprocess.py index 7cd91f9c..10e6e763 100644 --- a/fastNLP/loader/preprocess.py +++ b/fastNLP/loader/preprocess.py @@ -1,346 +1,361 @@ -import _pickle -import os - -DEFAULT_PADDING_LABEL = '' # dict index = 0 -DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 -DEFAULT_RESERVED_LABEL = ['', - '', - ''] # dict index = 2~4 - -DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} - - -# the first vocab in dict with the index = 5 - - -class BasePreprocess(object): - - def __init__(self, data, pickle_path): - super(BasePreprocess, self).__init__() - self.data = data - self.pickle_path = pickle_path - if not self.pickle_path.endswith('/'): - self.pickle_path = self.pickle_path + '/' - - -class POSPreprocess(BasePreprocess): - """ - This class are used to preprocess the pos datasets. - - """ - - def __init__(self, data, pickle_path="./", train_dev_split=0): - """ - Preprocess pipeline, including building mapping from words to index, from index to words, - from labels/classes to index, from index to labels/classes. - :param data: three-level list - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - :param pickle_path: str, the directory to the pickle files. Default: "./" - :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. - - To do: - 1. simplify __init__ - """ - super(POSPreprocess, self).__init__(data, pickle_path) - - self.pickle_path = pickle_path - - if self.pickle_exist("word2id.pkl"): - # load word2index because the construction of the following objects needs it - with open(os.path.join(self.pickle_path, "word2id.pkl"), "rb") as f: - self.word2index = _pickle.load(f) - else: - self.word2index, self.label2index = self.build_dict(data) - with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: - _pickle.dump(self.word2index, f) - - if self.pickle_exist("class2id.pkl"): - with open(os.path.join(self.pickle_path, "class2id.pkl"), "rb") as f: - self.label2index = _pickle.load(f) - else: - with open(os.path.join(self.pickle_path, "class2id.pkl"), "wb") as f: - _pickle.dump(self.label2index, f) - #something will be wrong if word2id.pkl is found but class2id.pkl is not found - - if not self.pickle_exist("id2word.pkl"): - index2word = self.build_reverse_dict(self.word2index) - with open(os.path.join(self.pickle_path, "id2word.pkl"), "wb") as f: - _pickle.dump(index2word, f) - - if not self.pickle_exist("id2class.pkl"): - index2label = self.build_reverse_dict(self.label2index) - with open(os.path.join(self.pickle_path, "word2id.pkl"), "wb") as f: - _pickle.dump(index2label, f) - - if not self.pickle_exist("data_train.pkl"): - data_train = self.to_index(data) - if train_dev_split > 0 and not self.pickle_exist("data_dev.pkl"): - data_dev = data_train[: int(len(data_train) * train_dev_split)] - with open(os.path.join(self.pickle_path, "data_dev.pkl"), "wb") as f: - _pickle.dump(data_dev, f) - with open(os.path.join(self.pickle_path, "data_train.pkl"), "wb") as f: - _pickle.dump(data_train, f) - - def build_dict(self, data): - """ - Add new words with indices into self.word_dict, new labels with indices into self.label_dict. - :param data: three-level list - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - :return word2index: dict of {str, int} - label2index: dict of {str, int} - """ - label2index = {} - word2index = DEFAULT_WORD_TO_INDEX - for example in data: - for word, label in zip(example[0], example[1]): - if word not in word2index: - word2index[word] = len(word2index) - if label not in label2index: - label2index[label] = len(label2index) - return word2index, label2index - - def pickle_exist(self, pickle_name): - """ - :param pickle_name: the filename of target pickle file - :return: True if file exists else False - """ - if not os.path.exists(self.pickle_path): - os.makedirs(self.pickle_path) - file_name = os.path.join(self.pickle_path, pickle_name) - if os.path.exists(file_name): - return True - else: - return False - - def build_reverse_dict(self, word_dict): - id2word = {word_dict[w]: w for w in word_dict} - return id2word - - def to_index(self, data): - """ - Convert word strings and label strings into indices. - :param data: three-level list - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - :return data_index: the shape of data, but each string is replaced by its corresponding index - """ - data_index = [] - for example in data: - word_list = [] - label_list = [] - for word, label in zip(example[0], example[1]): - word_list.append(self.word2index[word]) - label_list.append(self.label2index[label]) - data_index.append([word_list, label_list]) - return data_index - - @property - def vocab_size(self): - return len(self.word2index) - - @property - def num_classes(self): - return len(self.label2index) - - -class ClassPreprocess(BasePreprocess): - """ - Pre-process the classification datasets. - - Params: - pickle_path - directory to save result of pre-processing - Saves: - word2id.pkl - id2word.pkl - class2id.pkl - id2class.pkl - embedding.pkl - data_train.pkl - data_dev.pkl - data_test.pkl - """ - - def __init__(self, pickle_path): - # super(ClassPreprocess, self).__init__(data, pickle_path) - self.word_dict = None - self.label_dict = None - self.pickle_path = pickle_path # save directory - - def process(self, data, save_name): - """ - Process data. - - Params: - data - nested list, data = [sample1, sample2, ...], - sample = [sentence, label], sentence = [word1, word2, ...] - save_name - name of processed data, such as data_train.pkl - Returns: - vocab_size - vocabulary size - n_classes - number of classes - """ - self.build_dict(data) - self.word2id() - vocab_size = self.id2word() - self.class2id() - num_classes = self.id2class() - self.embedding() - self.data_generate(data, save_name) - - return vocab_size, num_classes - - def build_dict(self, data): - """Build vocabulary.""" - - # just read if word2id.pkl and class2id.pkl exists - if self.pickle_exist("word2id.pkl") and \ - self.pickle_exist("class2id.pkl"): - file_name = os.path.join(self.pickle_path, "word2id.pkl") - with open(file_name, 'rb') as f: - self.word_dict = _pickle.load(f) - file_name = os.path.join(self.pickle_path, "class2id.pkl") - with open(file_name, 'rb') as f: - self.label_dict = _pickle.load(f) - return - - # build vocabulary from scratch if nothing exists - self.word_dict = { - DEFAULT_PADDING_LABEL: 0, - DEFAULT_UNKNOWN_LABEL: 1, - DEFAULT_RESERVED_LABEL[0]: 2, - DEFAULT_RESERVED_LABEL[1]: 3, - DEFAULT_RESERVED_LABEL[2]: 4} - self.label_dict = {} - - # collect every word and label - for sent, label in data: - if len(sent) <= 1: - continue - - if label not in self.label_dict: - index = len(self.label_dict) - self.label_dict[label] = index - - for word in sent: - if word not in self.word_dict: - index = len(self.word_dict) - self.word_dict[word[0]] = index - - def pickle_exist(self, pickle_name): - """ - Check whether a pickle file exists. - - Params - pickle_name: the filename of target pickle file - Return - True if file exists else False - """ - if not os.path.exists(self.pickle_path): - os.makedirs(self.pickle_path) - file_name = os.path.join(self.pickle_path, pickle_name) - if os.path.exists(file_name): - return True - else: - return False - - def word2id(self): - """Save vocabulary of {word:id} mapping format.""" - # nothing will be done if word2id.pkl exists - if self.pickle_exist("word2id.pkl"): - return - - file_name = os.path.join(self.pickle_path, "word2id.pkl") - with open(file_name, "wb") as f: - _pickle.dump(self.word_dict, f) - - def id2word(self): - """Save vocabulary of {id:word} mapping format.""" - # nothing will be done if id2word.pkl exists - if self.pickle_exist("id2word.pkl"): - file_name = os.path.join(self.pickle_path, "id2word.pkl") - with open(file_name, 'rb') as f: - id2word_dict = _pickle.load(f) - return len(id2word_dict) - - id2word_dict = {self.word_dict[w]: w for w in self.word_dict} - file_name = os.path.join(self.pickle_path, "id2word.pkl") - with open(file_name, "wb") as f: - _pickle.dump(id2word_dict, f) - return len(id2word_dict) - - def class2id(self): - """Save mapping of {class:id}.""" - # nothing will be done if class2id.pkl exists - if self.pickle_exist("class2id.pkl"): - return - - file_name = os.path.join(self.pickle_path, "class2id.pkl") - with open(file_name, "wb") as f: - _pickle.dump(self.label_dict, f) - - def id2class(self): - """Save mapping of {id:class}.""" - # nothing will be done if id2class.pkl exists - if self.pickle_exist("id2class.pkl"): - file_name = os.path.join(self.pickle_path, "id2class.pkl") - with open(file_name, "rb") as f: - id2class_dict = _pickle.load(f) - return len(id2class_dict) - - id2class_dict = {self.label_dict[c]: c for c in self.label_dict} - file_name = os.path.join(self.pickle_path, "id2class.pkl") - with open(file_name, "wb") as f: - _pickle.dump(id2class_dict, f) - return len(id2class_dict) - - def embedding(self): - """Save embedding lookup table corresponding to vocabulary.""" - # nothing will be done if embedding.pkl exists - if self.pickle_exist("embedding.pkl"): - return - - # retrieve vocabulary from pre-trained embedding (not implemented) - - def data_generate(self, data_src, save_name): - """Convert dataset from text to digit.""" - - # nothing will be done if file exists - save_path = os.path.join(self.pickle_path, save_name) - if os.path.exists(save_path): - return - - data = [] - # for every sample - for sent, label in data_src: - if len(sent) <= 1: - continue - - label_id = self.label_dict[label] # label id - sent_id = [] # sentence ids - for word in sent: - if word in self.word_dict: - sent_id.append(self.word_dict[word]) - else: - sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) - data.append([sent_id, label_id]) - - # save data - with open(save_path, "wb") as f: - _pickle.dump(data, f) - - -class LMPreprocess(BasePreprocess): - def __init__(self, data, pickle_path): - super(LMPreprocess, self).__init__(data, pickle_path) +import _pickle +import os + +DEFAULT_PADDING_LABEL = '' # dict index = 0 +DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 +DEFAULT_RESERVED_LABEL = ['', + '', + ''] # dict index = 2~4 + +DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, + DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, + DEFAULT_RESERVED_LABEL[2]: 4} + + +# the first vocab in dict with the index = 5 + +def save_pickle(obj, pickle_path, file_name): + with open(os.path.join(pickle_path, file_name), "wb") as f: + _pickle.dump(obj, f) + print("{} saved. ".format(file_name)) + + +def load_pickle(pickle_path, file_name): + with open(os.path.join(pickle_path, file_name), "rb") as f: + obj = _pickle.load(f) + return obj + + +def pickle_exist(pickle_path, pickle_name): + """ + :param pickle_path: the directory of target pickle file + :param pickle_name: the filename of target pickle file + :return: True if file exists else False + """ + if not os.path.exists(pickle_path): + os.makedirs(pickle_path) + file_name = os.path.join(pickle_path, pickle_name) + if os.path.exists(file_name): + return True + else: + return False + + +class BasePreprocess(object): + + def __init__(self, data, pickle_path): + super(BasePreprocess, self).__init__() + # self.data = data + self.pickle_path = pickle_path + if not self.pickle_path.endswith('/'): + self.pickle_path = self.pickle_path + '/' + + +class POSPreprocess(BasePreprocess): + """ + This class are used to preprocess the POS Tag datasets. + + """ + + def __init__(self, data, pickle_path="./", train_dev_split=0): + """ + Preprocess pipeline, including building mapping from words to index, from index to words, + from labels/classes to index, from index to labels/classes. + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :param pickle_path: str, the directory to the pickle files. Default: "./" + :param train_dev_split: float in [0, 1]. The ratio of dev data split from training data. Default: 0. + + """ + super(POSPreprocess, self).__init__(data, pickle_path) + + self.pickle_path = pickle_path + + if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): + self.word2index = load_pickle(self.pickle_path, "word2id.pkl") + self.label2index = load_pickle(self.pickle_path, "class2id.pkl") + else: + self.word2index, self.label2index = self.build_dict(data) + save_pickle(self.word2index, self.pickle_path, "word2id.pkl") + save_pickle(self.label2index, self.pickle_path, "class2id.pkl") + + if not pickle_exist(pickle_path, "id2word.pkl"): + index2word = self.build_reverse_dict(self.word2index) + save_pickle(index2word, self.pickle_path, "id2word.pkl") + + if not pickle_exist(pickle_path, "id2class.pkl"): + index2label = self.build_reverse_dict(self.label2index) + save_pickle(index2label, self.pickle_path, "id2class.pkl") + + if not pickle_exist(pickle_path, "data_train.pkl"): + data_train = self.to_index(data) + if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): + data_dev = data_train[: int(len(data_train) * train_dev_split)] + save_pickle(data_dev, self.pickle_path, "data_dev.pkl") + save_pickle(data_train, self.pickle_path, "data_train.pkl") + + def build_dict(self, data): + """ + Add new words with indices into self.word_dict, new labels with indices into self.label_dict. + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :return word2index: dict of {str, int} + label2index: dict of {str, int} + """ + label2index = {} + word2index = DEFAULT_WORD_TO_INDEX + for example in data: + for word, label in zip(example[0], example[1]): + if word not in word2index: + word2index[word] = len(word2index) + if label not in label2index: + label2index[label] = len(label2index) + return word2index, label2index + + def build_reverse_dict(self, word_dict): + id2word = {word_dict[w]: w for w in word_dict} + return id2word + + def to_index(self, data): + """ + Convert word strings and label strings into indices. + :param data: three-level list + [ + [ [word_11, word_12, ...], [label_1, label_1, ...] ], + [ [word_21, word_22, ...], [label_2, label_1, ...] ], + ... + ] + :return data_index: the shape of data, but each string is replaced by its corresponding index + """ + data_index = [] + for example in data: + word_list = [] + label_list = [] + for word, label in zip(example[0], example[1]): + word_list.append(self.word2index[word]) + label_list.append(self.label2index[label]) + data_index.append([word_list, label_list]) + return data_index + + @property + def vocab_size(self): + return len(self.word2index) + + @property + def num_classes(self): + return len(self.label2index) + + +class ClassPreprocess(BasePreprocess): + """ + Pre-process the classification datasets. + + Params: + pickle_path - directory to save result of pre-processing + Saves: + word2id.pkl + id2word.pkl + class2id.pkl + id2class.pkl + embedding.pkl + data_train.pkl + data_dev.pkl + data_test.pkl + """ + + def __init__(self, pickle_path): + # super(ClassPreprocess, self).__init__(data, pickle_path) + self.word_dict = None + self.label_dict = None + self.pickle_path = pickle_path # save directory + + def process(self, data, save_name): + """ + Process data. + + Params: + data - nested list, data = [sample1, sample2, ...], + sample = [sentence, label], sentence = [word1, word2, ...] + save_name - name of processed data, such as data_train.pkl + Returns: + vocab_size - vocabulary size + n_classes - number of classes + """ + self.build_dict(data) + self.word2id() + vocab_size = self.id2word() + self.class2id() + num_classes = self.id2class() + self.embedding() + self.data_generate(data, save_name) + + return vocab_size, num_classes + + def build_dict(self, data): + """Build vocabulary.""" + + # just read if word2id.pkl and class2id.pkl exists + if self.pickle_exist("word2id.pkl") and \ + self.pickle_exist("class2id.pkl"): + file_name = os.path.join(self.pickle_path, "word2id.pkl") + with open(file_name, 'rb') as f: + self.word_dict = _pickle.load(f) + file_name = os.path.join(self.pickle_path, "class2id.pkl") + with open(file_name, 'rb') as f: + self.label_dict = _pickle.load(f) + return + + # build vocabulary from scratch if nothing exists + self.word_dict = { + DEFAULT_PADDING_LABEL: 0, + DEFAULT_UNKNOWN_LABEL: 1, + DEFAULT_RESERVED_LABEL[0]: 2, + DEFAULT_RESERVED_LABEL[1]: 3, + DEFAULT_RESERVED_LABEL[2]: 4} + self.label_dict = {} + + # collect every word and label + for sent, label in data: + if len(sent) <= 1: + continue + + if label not in self.label_dict: + index = len(self.label_dict) + self.label_dict[label] = index + + for word in sent: + if word not in self.word_dict: + index = len(self.word_dict) + self.word_dict[word[0]] = index + + def pickle_exist(self, pickle_name): + """ + Check whether a pickle file exists. + + Params + pickle_name: the filename of target pickle file + Return + True if file exists else False + """ + if not os.path.exists(self.pickle_path): + os.makedirs(self.pickle_path) + file_name = os.path.join(self.pickle_path, pickle_name) + if os.path.exists(file_name): + return True + else: + return False + + def word2id(self): + """Save vocabulary of {word:id} mapping format.""" + # nothing will be done if word2id.pkl exists + if self.pickle_exist("word2id.pkl"): + return + + file_name = os.path.join(self.pickle_path, "word2id.pkl") + with open(file_name, "wb") as f: + _pickle.dump(self.word_dict, f) + + def id2word(self): + """Save vocabulary of {id:word} mapping format.""" + # nothing will be done if id2word.pkl exists + if self.pickle_exist("id2word.pkl"): + file_name = os.path.join(self.pickle_path, "id2word.pkl") + with open(file_name, 'rb') as f: + id2word_dict = _pickle.load(f) + return len(id2word_dict) + + id2word_dict = {self.word_dict[w]: w for w in self.word_dict} + file_name = os.path.join(self.pickle_path, "id2word.pkl") + with open(file_name, "wb") as f: + _pickle.dump(id2word_dict, f) + return len(id2word_dict) + + def class2id(self): + """Save mapping of {class:id}.""" + # nothing will be done if class2id.pkl exists + if self.pickle_exist("class2id.pkl"): + return + + file_name = os.path.join(self.pickle_path, "class2id.pkl") + with open(file_name, "wb") as f: + _pickle.dump(self.label_dict, f) + + def id2class(self): + """Save mapping of {id:class}.""" + # nothing will be done if id2class.pkl exists + if self.pickle_exist("id2class.pkl"): + file_name = os.path.join(self.pickle_path, "id2class.pkl") + with open(file_name, "rb") as f: + id2class_dict = _pickle.load(f) + return len(id2class_dict) + + id2class_dict = {self.label_dict[c]: c for c in self.label_dict} + file_name = os.path.join(self.pickle_path, "id2class.pkl") + with open(file_name, "wb") as f: + _pickle.dump(id2class_dict, f) + return len(id2class_dict) + + def embedding(self): + """Save embedding lookup table corresponding to vocabulary.""" + # nothing will be done if embedding.pkl exists + if self.pickle_exist("embedding.pkl"): + return + + # retrieve vocabulary from pre-trained embedding (not implemented) + + def data_generate(self, data_src, save_name): + """Convert dataset from text to digit.""" + + # nothing will be done if file exists + save_path = os.path.join(self.pickle_path, save_name) + if os.path.exists(save_path): + return + + data = [] + # for every sample + for sent, label in data_src: + if len(sent) <= 1: + continue + + label_id = self.label_dict[label] # label id + sent_id = [] # sentence ids + for word in sent: + if word in self.word_dict: + sent_id.append(self.word_dict[word]) + else: + sent_id.append(self.word_dict[DEFAULT_UNKNOWN_LABEL]) + data.append([sent_id, label_id]) + + # save data + with open(save_path, "wb") as f: + _pickle.dump(data, f) + + +class LMPreprocess(BasePreprocess): + def __init__(self, data, pickle_path): + super(LMPreprocess, self).__init__(data, pickle_path) + + +def infer_preprocess(pickle_path, data): + """ + Preprocess over inference data. + Transform three-level list of strings into that of index. + [ + [word_11, word_12, ...], + [word_21, word_22, ...], + ... + ] + """ + word2index = load_pickle(pickle_path, "word2id.pkl") + data_index = [] + for example in data: + data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) + return data_index diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index ca877261..e37109bb 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -9,17 +9,12 @@ class SeqLabeling(BaseModel): PyTorch Network for sequence labeling """ - def __init__(self, hidden_dim, - rnn_num_layer, - num_classes, - vocab_size, - word_emb_dim=100, - init_emb=None, - rnn_mode="gru", - bi_direction=False, - dropout=0.5, - use_crf=True): + def __init__(self, args): super(SeqLabeling, self).__init__() + vocab_size = args["vocab_size"] + word_emb_dim = args["word_emb_dim"] + hidden_dim = args["rnn_hidden_units"] + num_classes = args["num_classes"] self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) @@ -29,7 +24,7 @@ class SeqLabeling(BaseModel): def forward(self, x): """ :param x: LongTensor, [batch_size, mex_len] - :return y: [batch_size, tag_size, tag_size] + :return y: [batch_size, mex_len, tag_size] """ x = self.Embedding(x) # [batch_size, max_len, word_emb_dim] @@ -64,7 +59,7 @@ class SeqLabeling(BaseModel): def prediction(self, x, seq_length): """ - :param x: FloatTensor, [batch_size, tag_size, tag_size] + :param x: FloatTensor, [batch_size, max_len, tag_size] :param seq_length: int :return prediction: list of tuple of (decode path(list), best score) """ diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 6034c48d..bed6c276 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -13,7 +13,7 @@ class Lstm(nn.Module): bidirectional : If True, becomes a bidirectional RNN. Default: False. """ - def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.5, bidirectional=False): + def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False): super(Lstm, self).__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, dropout=dropout, bidirectional=bidirectional) diff --git a/test/data_for_tests/config b/test/data_for_tests/config index 181d0ebf..fad9d876 100644 --- a/test/data_for_tests/config +++ b/test/data_for_tests/config @@ -74,3 +74,9 @@ save_dev_input = false save_loss = true batch_size = 1 pickle_path = "./data_for_tests/" +rnn_hidden_units = 100 +rnn_layers = 1 +rnn_bi_direction = true +word_emb_dim = 100 +dropout = 0.5 +use_crf = true diff --git a/test/data_for_tests/people_infer.txt b/test/data_for_tests/people_infer.txt new file mode 100644 index 00000000..639ea413 --- /dev/null +++ b/test/data_for_tests/people_infer.txt @@ -0,0 +1,2 @@ +迈向充满希望的新世纪——一九九八年新年讲话 +(附图片1张) \ No newline at end of file diff --git a/test/test_POS_pipeline.py b/test/test_POS_pipeline.py index 17b1b58c..fdf5de3e 100644 --- a/test/test_POS_pipeline.py +++ b/test/test_POS_pipeline.py @@ -4,8 +4,8 @@ sys.path.append("..") from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.action.trainer import POSTrainer -from fastNLP.loader.dataset_loader import POSDatasetLoader -from fastNLP.loader.preprocess import POSPreprocess +from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader +from fastNLP.loader.preprocess import POSPreprocess, load_pickle from fastNLP.saver.model_saver import ModelSaver from fastNLP.loader.model_loader import ModelLoader from fastNLP.action.tester import POSTester @@ -15,32 +15,49 @@ from fastNLP.action.inference import Inference data_name = "people.txt" data_path = "data_for_tests/people.txt" pickle_path = "data_for_tests" +data_infer_path = "data_for_tests/people_infer.txt" -def test_infer(): +def infer(): + # Load infer configuration, the same as test + test_args = ConfigSection() + ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) + + # fetch dictinary size and number of labels from pickle files + word2index = load_pickle(pickle_path, "word2id.pkl") + test_args["vocab_size"] = len(word2index) + index2label = load_pickle(pickle_path, "id2class.pkl") + test_args["num_classes"] = len(index2label) + # Define the same model - model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], - num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], - word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], - rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) + model = SeqLabeling(test_args) # Dump trained parameters into the model - ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model) + ModelLoader.load_pytorch(model, "./saved_model.pkl") print("model loaded!") # Data Loader - pos_loader = POSDatasetLoader(data_name, data_path) - infer_data = pos_loader.load_lines() - - # Preprocessor - POSPreprocess(infer_data, pickle_path) + raw_data_loader = BaseLoader(data_name, data_infer_path) + infer_data = raw_data_loader.load_lines() + """ + Transform strings into list of list of strings. + [ + [word_11, word_12, ...], + [word_21, word_22, ...], + ... + ] + In this case, each line in "people_infer.txt" is already a sentence. So load_lines() just splits them. + """ # Inference interface - infer = Inference() + infer = Inference(pickle_path) results = infer.predict(model, infer_data) + print(results) + print("Inference finished!") -if __name__ == "__main__": + +def train_test(): # Config Loader train_args = ConfigSection() ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) @@ -58,10 +75,7 @@ if __name__ == "__main__": trainer = POSTrainer(train_args) # Model - model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], - num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], - word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], - rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) + model = SeqLabeling(train_args) # Start training trainer.train(model) @@ -75,13 +89,10 @@ if __name__ == "__main__": del model, trainer, pos_loader # Define the same model - model = SeqLabeling(hidden_dim=train_args["rnn_hidden_units"], rnn_num_layer=train_args["rnn_layers"], - num_classes=train_args["num_classes"], vocab_size=train_args["vocab_size"], - word_emb_dim=train_args["word_emb_dim"], bi_direction=train_args["rnn_bi_direction"], - rnn_mode="gru", dropout=train_args["dropout"], use_crf=train_args["use_crf"]) + model = SeqLabeling(train_args) # Dump trained parameters into the model - ModelLoader("arbitrary_name", "./saved_model.pkl").load_pytorch(model) + ModelLoader.load_pytorch(model, "./saved_model.pkl") print("model loaded!") # Load test configuration @@ -97,3 +108,7 @@ if __name__ == "__main__": # print test results print(tester.show_matrices()) print("model tested!") + + +if __name__ == "__main__": + infer()