| @@ -1 +0,0 @@ | |||||
| @@ -4,88 +4,6 @@ import numpy as np | |||||
| import torch | import torch | ||||
| class Action(object): | |||||
| """Operations shared by Trainer, Tester, or Inference. | |||||
| This is designed for reducing replicate codes. | |||||
| - make_batch: produce a min-batch of data. @staticmethod | |||||
| - pad: padding method used in sequence modeling. @staticmethod | |||||
| - mode: change network mode for either train or test. (for PyTorch) @staticmethod | |||||
| """ | |||||
| def __init__(self): | |||||
| super(Action, self).__init__() | |||||
| @staticmethod | |||||
| def make_batch(iterator, use_cuda, output_length=True, max_len=None): | |||||
| """Batch and Pad data. | |||||
| :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. | |||||
| :param use_cuda: bool, whether to use GPU | |||||
| :param output_length: bool, whether to output the original length of the sequence before padding. (default: True) | |||||
| :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||||
| :return : | |||||
| if output_length is True, | |||||
| (batch_x, seq_len): tuple of two elements | |||||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
| seq_len: list. The length of the pre-padded sequence, if output_length is True. | |||||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
| if output_length is False, | |||||
| batch_x: list. Each entry is a list of features of a sample. [batch_size, max_len] | |||||
| batch_y: list. Each entry is a list of labels of a sample. [batch_size, num_labels] | |||||
| """ | |||||
| for batch in iterator: | |||||
| batch_x = [sample[0] for sample in batch] | |||||
| batch_y = [sample[1] for sample in batch] | |||||
| batch_x = Action.pad(batch_x) | |||||
| # pad batch_y only if it is a 2-level list | |||||
| if len(batch_y) > 0 and isinstance(batch_y[0], list): | |||||
| batch_y = Action.pad(batch_y) | |||||
| # convert list to tensor | |||||
| batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||||
| batch_y = convert_to_torch_tensor(batch_y, use_cuda) | |||||
| # trim data to max_len | |||||
| if max_len is not None and batch_x.size(1) > max_len: | |||||
| batch_x = batch_x[:, :max_len] | |||||
| if output_length: | |||||
| seq_len = [len(x) for x in batch_x] | |||||
| yield (batch_x, seq_len), batch_y | |||||
| else: | |||||
| yield batch_x, batch_y | |||||
| @staticmethod | |||||
| def pad(batch, fill=0): | |||||
| """ Pad a mini-batch of sequence samples to maximum length of this batch. | |||||
| :param batch: list of list | |||||
| :param fill: word index to pad, default 0. | |||||
| :return batch: a padded mini-batch | |||||
| """ | |||||
| max_length = max([len(x) for x in batch]) | |||||
| for idx, sample in enumerate(batch): | |||||
| if len(sample) < max_length: | |||||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
| return batch | |||||
| @staticmethod | |||||
| def mode(model, is_test=False): | |||||
| """Train mode or Test mode. This is for PyTorch currently. | |||||
| :param model: a PyTorch model | |||||
| :param is_test: bool, whether in test mode or not. | |||||
| """ | |||||
| if is_test: | |||||
| model.eval() | |||||
| else: | |||||
| model.train() | |||||
| def convert_to_torch_tensor(data_list, use_cuda): | def convert_to_torch_tensor(data_list, use_cuda): | ||||
| """Convert lists into (cuda) Tensors. | """Convert lists into (cuda) Tensors. | ||||
| @@ -168,19 +86,7 @@ class BaseSampler(object): | |||||
| """ | """ | ||||
| def __init__(self, data_set): | |||||
| """ | |||||
| :param data_set: multi-level list, of shape [num_example, *] | |||||
| """ | |||||
| self.data_set_length = len(data_set) | |||||
| self.data = data_set | |||||
| def __len__(self): | |||||
| return self.data_set_length | |||||
| def __iter__(self): | |||||
| def __call__(self, *args, **kwargs): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -189,16 +95,8 @@ class SequentialSampler(BaseSampler): | |||||
| """ | """ | ||||
| def __init__(self, data_set): | |||||
| """ | |||||
| :param data_set: multi-level list | |||||
| """ | |||||
| super(SequentialSampler, self).__init__(data_set) | |||||
| def __iter__(self): | |||||
| return iter(self.data) | |||||
| def __call__(self, data_set): | |||||
| return list(range(len(data_set))) | |||||
| class RandomSampler(BaseSampler): | class RandomSampler(BaseSampler): | ||||
| @@ -206,17 +104,9 @@ class RandomSampler(BaseSampler): | |||||
| """ | """ | ||||
| def __init__(self, data_set): | |||||
| """ | |||||
| :param data_set: multi-level list | |||||
| def __call__(self, data_set): | |||||
| return list(np.random.permutation(len(data_set))) | |||||
| """ | |||||
| super(RandomSampler, self).__init__(data_set) | |||||
| self.order = np.random.permutation(self.data_set_length) | |||||
| def __iter__(self): | |||||
| return iter((self.data[idx] for idx in self.order)) | |||||
| class Batchifier(object): | class Batchifier(object): | ||||
| @@ -252,6 +142,7 @@ class BucketBatchifier(Batchifier): | |||||
| """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | ||||
| In sampling, first random choose a bucket. Then sample data from it. | In sampling, first random choose a bucket. Then sample data from it. | ||||
| The number of buckets is decided dynamically by the variance of sentence lengths. | The number of buckets is decided dynamically by the variance of sentence lengths. | ||||
| TODO: merge it into Batch | |||||
| """ | """ | ||||
| def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): | def __init__(self, data_set, batch_size, num_buckets, drop_last=True, sampler=None): | ||||
| @@ -0,0 +1,126 @@ | |||||
| from collections import defaultdict | |||||
| import torch | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | |||||
| class Batch(object): | |||||
| """Batch is an iterable object which iterates over mini-batches. | |||||
| :: | |||||
| for batch_x, batch_y in Batch(data_set): | |||||
| """ | |||||
| def __init__(self, dataset, batch_size, sampler, use_cuda): | |||||
| self.dataset = dataset | |||||
| self.batch_size = batch_size | |||||
| self.sampler = sampler | |||||
| self.use_cuda = use_cuda | |||||
| self.idx_list = None | |||||
| self.curidx = 0 | |||||
| def __iter__(self): | |||||
| self.idx_list = self.sampler(self.dataset) | |||||
| self.curidx = 0 | |||||
| self.lengths = self.dataset.get_length() | |||||
| return self | |||||
| def __next__(self): | |||||
| """ | |||||
| :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
| batch_x also contains an item (str: list of int) about origin lengths, | |||||
| which means ("field_name_origin_len": origin lengths). | |||||
| E.g. | |||||
| :: | |||||
| {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||||
| batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||||
| All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||||
| The names of fields are defined in preprocessor's convert_to_dataset method. | |||||
| """ | |||||
| if self.curidx >= len(self.idx_list): | |||||
| raise StopIteration | |||||
| else: | |||||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||||
| padding_length = {field_name: max(field_length[self.curidx: endidx]) | |||||
| for field_name, field_length in self.lengths.items()} | |||||
| origin_lengths = {field_name: field_length[self.curidx: endidx] | |||||
| for field_name, field_length in self.lengths.items()} | |||||
| batch_x, batch_y = defaultdict(list), defaultdict(list) | |||||
| for idx in range(self.curidx, endidx): | |||||
| x, y = self.dataset.to_tensor(idx, padding_length) | |||||
| for name, tensor in x.items(): | |||||
| batch_x[name].append(tensor) | |||||
| for name, tensor in y.items(): | |||||
| batch_y[name].append(tensor) | |||||
| batch_origin_length = {} | |||||
| # combine instances into a batch | |||||
| for batch in (batch_x, batch_y): | |||||
| for name, tensor_list in batch.items(): | |||||
| if self.use_cuda: | |||||
| batch[name] = torch.stack(tensor_list, dim=0).cuda() | |||||
| else: | |||||
| batch[name] = torch.stack(tensor_list, dim=0) | |||||
| # add origin lengths in batch_x | |||||
| for name, tensor in batch_x.items(): | |||||
| if self.use_cuda: | |||||
| batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda() | |||||
| else: | |||||
| batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]) | |||||
| batch_x.update(batch_origin_length) | |||||
| self.curidx += endidx | |||||
| return batch_x, batch_y | |||||
| if __name__ == "__main__": | |||||
| """simple running example | |||||
| """ | |||||
| texts = ["i am a cat", | |||||
| "this is a test of new batch", | |||||
| "haha" | |||||
| ] | |||||
| labels = [0, 1, 0] | |||||
| # prepare vocabulary | |||||
| vocab = {} | |||||
| for text in texts: | |||||
| for tokens in text.split(): | |||||
| if tokens not in vocab: | |||||
| vocab[tokens] = len(vocab) | |||||
| print("vocabulary: ", vocab) | |||||
| # prepare input dataset | |||||
| data = DataSet() | |||||
| for text, label in zip(texts, labels): | |||||
| x = TextField(text.split(), False) | |||||
| y = LabelField(label, is_target=True) | |||||
| ins = Instance(text=x, label=y) | |||||
| data.append(ins) | |||||
| # use vocabulary to index data | |||||
| data.index_field("text", vocab) | |||||
| # define naive sampler for batch class | |||||
| class SeqSampler: | |||||
| def __call__(self, dataset): | |||||
| return list(range(len(dataset))) | |||||
| # use batch to iterate dataset | |||||
| data_iterator = Batch(data, 2, SeqSampler(), False) | |||||
| for epoch in range(1): | |||||
| for batch_x, batch_y in data_iterator: | |||||
| print(batch_x) | |||||
| print(batch_y) | |||||
| # do stuff | |||||
| @@ -0,0 +1,111 @@ | |||||
| from collections import defaultdict | |||||
| from fastNLP.core.field import TextField | |||||
| from fastNLP.core.instance import Instance | |||||
| def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | |||||
| if has_target is True: | |||||
| if label_vocab is None: | |||||
| raise RuntimeError("Must provide label vocabulary to transform labels.") | |||||
| return create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab) | |||||
| else: | |||||
| return create_unlabeled_dataset_from_lists(str_lists, word_vocab) | |||||
| def create_labeled_dataset_from_lists(str_lists, word_vocab, label_vocab): | |||||
| """Create an DataSet instance that contains labels. | |||||
| :param str_lists: list of list of strings, [num_examples, 2, *]. | |||||
| :: | |||||
| [ | |||||
| [[word_11, word_12, ...], [label_11, label_12, ...]], | |||||
| ... | |||||
| ] | |||||
| :param word_vocab: dict of (str: int), which means (word: index). | |||||
| :param label_vocab: dict of (str: int), which means (word: index). | |||||
| :return data_set: a DataSet instance. | |||||
| """ | |||||
| data_set = DataSet() | |||||
| for example in str_lists: | |||||
| word_seq, label_seq = example[0], example[1] | |||||
| x = TextField(word_seq, is_target=False) | |||||
| y = TextField(label_seq, is_target=True) | |||||
| data_set.append(Instance(word_seq=x, label_seq=y)) | |||||
| data_set.index_field("word_seq", word_vocab) | |||||
| data_set.index_field("label_seq", label_vocab) | |||||
| return data_set | |||||
| def create_unlabeled_dataset_from_lists(str_lists, word_vocab): | |||||
| """Create an DataSet instance that contains no labels. | |||||
| :param str_lists: list of list of strings, [num_examples, *]. | |||||
| :: | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| ... | |||||
| ] | |||||
| :param word_vocab: dict of (str: int), which means (word: index). | |||||
| :return data_set: a DataSet instance. | |||||
| """ | |||||
| data_set = DataSet() | |||||
| for word_seq in str_lists: | |||||
| x = TextField(word_seq, is_target=False) | |||||
| data_set.append(Instance(word_seq=x)) | |||||
| data_set.index_field("word_seq", word_vocab) | |||||
| return data_set | |||||
| class DataSet(list): | |||||
| """A DataSet object is a list of Instance objects. | |||||
| """ | |||||
| def __init__(self, name="", instances=None): | |||||
| """ | |||||
| :param name: str, the name of the dataset. (default: "") | |||||
| :param instances: list of Instance objects. (default: None) | |||||
| """ | |||||
| list.__init__([]) | |||||
| self.name = name | |||||
| if instances is not None: | |||||
| self.extend(instances) | |||||
| def index_all(self, vocab): | |||||
| for ins in self: | |||||
| ins.index_all(vocab) | |||||
| def index_field(self, field_name, vocab): | |||||
| for ins in self: | |||||
| ins.index_field(field_name, vocab) | |||||
| def to_tensor(self, idx: int, padding_length: dict): | |||||
| """Convert an instance in a dataset to tensor. | |||||
| :param idx: int, the index of the instance in the dataset. | |||||
| :param padding_length: int | |||||
| :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
| tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
| """ | |||||
| ins = self[idx] | |||||
| return ins.to_tensor(padding_length) | |||||
| def get_length(self): | |||||
| """Fetch lengths of all fields in all instances in a dataset. | |||||
| :return lengths: dict of (str: list). The str is the field name. | |||||
| The list contains lengths of this field in all instances. | |||||
| """ | |||||
| lengths = defaultdict(list) | |||||
| for ins in self: | |||||
| for field_name, field_length in ins.get_length().items(): | |||||
| lengths[field_name].append(field_length) | |||||
| return lengths | |||||
| @@ -0,0 +1,93 @@ | |||||
| import torch | |||||
| class Field(object): | |||||
| """A field defines a data type. | |||||
| """ | |||||
| def __init__(self, is_target: bool): | |||||
| self.is_target = is_target | |||||
| def index(self, vocab): | |||||
| raise NotImplementedError | |||||
| def get_length(self): | |||||
| raise NotImplementedError | |||||
| def to_tensor(self, padding_length): | |||||
| raise NotImplementedError | |||||
| class TextField(Field): | |||||
| def __init__(self, text, is_target): | |||||
| """ | |||||
| :param text: list of strings | |||||
| :param is_target: bool | |||||
| """ | |||||
| super(TextField, self).__init__(is_target) | |||||
| self.text = text | |||||
| self._index = None | |||||
| def index(self, vocab): | |||||
| if self._index is None: | |||||
| self._index = [vocab[c] for c in self.text] | |||||
| else: | |||||
| raise RuntimeError("Replicate indexing of this field.") | |||||
| return self._index | |||||
| def get_length(self): | |||||
| """Fetch the length of the text field. | |||||
| :return length: int, the length of the text. | |||||
| """ | |||||
| return len(self.text) | |||||
| def to_tensor(self, padding_length: int): | |||||
| """Convert text field to tensor. | |||||
| :param padding_length: int | |||||
| :return tensor: torch.LongTensor, of shape [padding_length, ] | |||||
| """ | |||||
| pads = [] | |||||
| if self._index is None: | |||||
| raise RuntimeError("Indexing not done before to_tensor in TextField.") | |||||
| if padding_length > self.get_length(): | |||||
| pads = [0] * (padding_length - self.get_length()) | |||||
| return torch.LongTensor(self._index + pads) | |||||
| class LabelField(Field): | |||||
| def __init__(self, label, is_target=True): | |||||
| super(LabelField, self).__init__(is_target) | |||||
| self.label = label | |||||
| self._index = None | |||||
| def get_length(self): | |||||
| """Fetch the length of the label field. | |||||
| :return length: int, the length of the label, always 1. | |||||
| """ | |||||
| return 1 | |||||
| def index(self, vocab): | |||||
| if self._index is None: | |||||
| self._index = vocab[self.label] | |||||
| return self._index | |||||
| def to_tensor(self, padding_length): | |||||
| if self._index is None: | |||||
| if isinstance(self.label, int): | |||||
| return torch.LongTensor([self.label]) | |||||
| elif isinstance(self.label, str): | |||||
| raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | |||||
| else: | |||||
| raise RuntimeError( | |||||
| "Not support type for LabelField. Expect str or int, got {}.".format(type(self.label))) | |||||
| else: | |||||
| return torch.LongTensor([self._index]) | |||||
| if __name__ == "__main__": | |||||
| tf = TextField("test the code".split(), is_target=False) | |||||
| @@ -0,0 +1,53 @@ | |||||
| class Instance(object): | |||||
| """An instance which consists of Fields is an example in the DataSet. | |||||
| """ | |||||
| def __init__(self, **fields): | |||||
| self.fields = fields | |||||
| self.has_index = False | |||||
| self.indexes = {} | |||||
| def add_field(self, field_name, field): | |||||
| self.fields[field_name] = field | |||||
| def get_length(self): | |||||
| """Fetch the length of all fields in the instance. | |||||
| :return length: dict of (str: int), which means (field name: field length). | |||||
| """ | |||||
| length = {name: field.get_length() for name, field in self.fields.items()} | |||||
| return length | |||||
| def index_field(self, field_name, vocab): | |||||
| """use `vocab` to index certain field | |||||
| """ | |||||
| self.indexes[field_name] = self.fields[field_name].index(vocab) | |||||
| def index_all(self, vocab): | |||||
| """use `vocab` to index all fields | |||||
| """ | |||||
| if self.has_index: | |||||
| print("error") | |||||
| return self.indexes | |||||
| indexes = {name: field.index(vocab) for name, field in self.fields.items()} | |||||
| self.indexes = indexes | |||||
| return indexes | |||||
| def to_tensor(self, padding_length: dict): | |||||
| """Convert instance to tensor. | |||||
| :param padding_length: dict of (str: int), which means (field name: padding_length of this field) | |||||
| :return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
| tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||||
| If is_target is False for all fields, tensor_y would be an empty dict. | |||||
| """ | |||||
| tensor_x = {} | |||||
| tensor_y = {} | |||||
| for name, field in self.fields.items(): | |||||
| if field.is_target: | |||||
| tensor_y[name] = field.to_tensor(padding_length[name]) | |||||
| else: | |||||
| tensor_x[name] = field.to_tensor(padding_length[name]) | |||||
| return tensor_x, tensor_y | |||||
| @@ -37,5 +37,7 @@ class Loss(object): | |||||
| """ | """ | ||||
| if loss_name == "cross_entropy": | if loss_name == "cross_entropy": | ||||
| return torch.nn.CrossEntropyLoss() | return torch.nn.CrossEntropyLoss() | ||||
| elif loss_name == 'nll': | |||||
| return torch.nn.NLLLoss() | |||||
| else: | else: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -1,53 +1,10 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from fastNLP.core.action import Batchifier, SequentialSampler | |||||
| from fastNLP.core.action import convert_to_torch_tensor | |||||
| from fastNLP.core.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||||
| from fastNLP.modules import utils | |||||
| def make_batch(iterator, use_cuda, output_length=False, max_len=None, min_len=None): | |||||
| """Batch and Pad data, only for Inference. | |||||
| :param iterator: An iterable object that returns a list of indices representing a mini-batch of samples. | |||||
| :param use_cuda: bool, whether to use GPU | |||||
| :param output_length: bool, whether to output the original length of the sequence before padding. (default: False) | |||||
| :param max_len: int, maximum sequence length. Longer sequences will be clipped. (default: None) | |||||
| :param min_len: int, minimum sequence length. Shorter sequences will be padded. (default: None) | |||||
| :return: | |||||
| """ | |||||
| for batch_x in iterator: | |||||
| batch_x = pad(batch_x) | |||||
| # convert list to tensor | |||||
| batch_x = convert_to_torch_tensor(batch_x, use_cuda) | |||||
| # trim data to max_len | |||||
| if max_len is not None and batch_x.size(1) > max_len: | |||||
| batch_x = batch_x[:, :max_len] | |||||
| if min_len is not None and batch_x.size(1) < min_len: | |||||
| pad_tensor = torch.zeros(batch_x.size(0), min_len - batch_x.size(1)).to(batch_x) | |||||
| batch_x = torch.cat((batch_x, pad_tensor), 1) | |||||
| if output_length: | |||||
| seq_len = [len(x) for x in batch_x] | |||||
| yield tuple([batch_x, seq_len]) | |||||
| else: | |||||
| yield batch_x | |||||
| def pad(batch, fill=0): | |||||
| """ Pad a mini-batch of sequence samples to maximum length of this batch. | |||||
| :param batch: list of list | |||||
| :param fill: word index to pad, default 0. | |||||
| :return batch: a padded mini-batch | |||||
| """ | |||||
| max_length = max([len(x) for x in batch]) | |||||
| for idx, sample in enumerate(batch): | |||||
| if len(sample) < max_length: | |||||
| batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||||
| return batch | |||||
| from fastNLP.core.action import SequentialSampler | |||||
| from fastNLP.core.batch import Batch | |||||
| from fastNLP.core.dataset import create_dataset_from_lists | |||||
| from fastNLP.core.preprocess import load_pickle | |||||
| class Predictor(object): | class Predictor(object): | ||||
| @@ -59,11 +16,17 @@ class Predictor(object): | |||||
| Currently, Predictor does not support GPU. | Currently, Predictor does not support GPU. | ||||
| """ | """ | ||||
| def __init__(self, pickle_path): | |||||
| def __init__(self, pickle_path, task): | |||||
| """ | |||||
| :param pickle_path: str, the path to the pickle files. | |||||
| :param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify"). | |||||
| """ | |||||
| self.batch_size = 1 | self.batch_size = 1 | ||||
| self.batch_output = [] | self.batch_output = [] | ||||
| self.iterator = None | |||||
| self.pickle_path = pickle_path | self.pickle_path = pickle_path | ||||
| self._task = task # one of ("seq_label", "text_classify") | |||||
| self.index2label = load_pickle(self.pickle_path, "id2class.pkl") | self.index2label = load_pickle(self.pickle_path, "id2class.pkl") | ||||
| self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | self.word2index = load_pickle(self.pickle_path, "word2id.pkl") | ||||
| @@ -71,19 +34,19 @@ class Predictor(object): | |||||
| """Perform inference using the trained model. | """Perform inference using the trained model. | ||||
| :param network: a PyTorch model (cpu) | :param network: a PyTorch model (cpu) | ||||
| :param data: list of list of strings | |||||
| :param data: list of list of strings, [num_examples, seq_len] | |||||
| :return: list of list of strings, [num_examples, tag_seq_length] | :return: list of list of strings, [num_examples, tag_seq_length] | ||||
| """ | """ | ||||
| # transform strings into indices | |||||
| # transform strings into DataSet object | |||||
| data = self.prepare_input(data) | data = self.prepare_input(data) | ||||
| # turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
| self.mode(network, test=True) | self.mode(network, test=True) | ||||
| self.batch_output.clear() | self.batch_output.clear() | ||||
| data_iterator = iter(Batchifier(SequentialSampler(data), self.batch_size, drop_last=False)) | |||||
| data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | |||||
| for batch_x in self.make_batch(data_iterator, use_cuda=False): | |||||
| for batch_x, _ in data_iterator: | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| @@ -99,103 +62,61 @@ class Predictor(object): | |||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| """Forward through network.""" | """Forward through network.""" | ||||
| raise NotImplementedError | |||||
| def make_batch(self, iterator, use_cuda): | |||||
| raise NotImplementedError | |||||
| y = network(**x) | |||||
| if self._task == "seq_label": | |||||
| y = network.prediction(y) | |||||
| return y | |||||
| def prepare_input(self, data): | def prepare_input(self, data): | ||||
| """Transform two-level list of strings into that of index. | |||||
| """Transform two-level list of strings into an DataSet object. | |||||
| In the training pipeline, this is done by Preprocessor. But in inference time, we do not call Preprocessor. | |||||
| :param data: | |||||
| :param data: list of list of strings. | |||||
| :: | |||||
| [ | [ | ||||
| [word_11, word_12, ...], | [word_11, word_12, ...], | ||||
| [word_21, word_22, ...], | [word_21, word_22, ...], | ||||
| ... | ... | ||||
| ] | ] | ||||
| :return data_index: list of list of int. | |||||
| :return data_set: a DataSet instance. | |||||
| """ | """ | ||||
| assert isinstance(data, list) | assert isinstance(data, list) | ||||
| data_index = [] | |||||
| default_unknown_index = self.word2index[DEFAULT_UNKNOWN_LABEL] | |||||
| for example in data: | |||||
| data_index.append([self.word2index.get(w, default_unknown_index) for w in example]) | |||||
| return data_index | |||||
| return create_dataset_from_lists(data, self.word2index, has_target=False) | |||||
| def prepare_output(self, data): | def prepare_output(self, data): | ||||
| """Transform list of batch outputs into strings.""" | """Transform list of batch outputs into strings.""" | ||||
| raise NotImplementedError | |||||
| class SeqLabelInfer(Predictor): | |||||
| """ | |||||
| Inference on sequence labeling models. | |||||
| """ | |||||
| def __init__(self, pickle_path): | |||||
| super(SeqLabelInfer, self).__init__(pickle_path) | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_prepare_output(data) | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_prepare_output(data) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}".format(self._task)) | |||||
| def data_forward(self, network, inputs): | |||||
| """ | |||||
| This is only for sequence labeling with CRF decoder. | |||||
| :param network: a PyTorch model | |||||
| :param inputs: tuple of (x, seq_len) | |||||
| x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||||
| after padding. | |||||
| seq_len: list of int, the lengths of sequences before padding. | |||||
| :return prediction: Tensor of shape [batch_size, max_len] | |||||
| """ | |||||
| if not isinstance(inputs[1], list) and isinstance(inputs[0], list): | |||||
| raise RuntimeError("output_length must be true for sequence modeling.") | |||||
| # unpack the returned value from make_batch | |||||
| x, seq_len = inputs[0], inputs[1] | |||||
| batch_size, max_len = x.size(0), x.size(1) | |||||
| mask = utils.seq_mask(seq_len, max_len) | |||||
| mask = mask.byte().view(batch_size, max_len) | |||||
| y = network(x) | |||||
| prediction = network.prediction(y, mask) | |||||
| return torch.Tensor(prediction) | |||||
| def make_batch(self, iterator, use_cuda): | |||||
| return make_batch(iterator, use_cuda, output_length=True) | |||||
| def prepare_output(self, batch_outputs): | |||||
| """Transform list of batch outputs into strings. | |||||
| :param batch_outputs: list of 2-D Tensor, shape [num_batch, batch-size, tag_seq_length]. | |||||
| :return results: 2-D list of strings, shape [num_examples, tag_seq_length] | |||||
| """ | |||||
| def _seq_label_prepare_output(self, batch_outputs): | |||||
| results = [] | results = [] | ||||
| for batch in batch_outputs: | for batch in batch_outputs: | ||||
| for example in np.array(batch): | for example in np.array(batch): | ||||
| results.append([self.index2label[int(x)] for x in example]) | results.append([self.index2label[int(x)] for x in example]) | ||||
| return results | return results | ||||
| class ClassificationInfer(Predictor): | |||||
| """ | |||||
| Inference on Classification models. | |||||
| """ | |||||
| def __init__(self, pickle_path): | |||||
| super(ClassificationInfer, self).__init__(pickle_path) | |||||
| def data_forward(self, network, x): | |||||
| """Forward through network.""" | |||||
| logits = network(x) | |||||
| return logits | |||||
| def make_batch(self, iterator, use_cuda): | |||||
| return make_batch(iterator, use_cuda, output_length=False, min_len=5) | |||||
| def prepare_output(self, batch_outputs): | |||||
| """ | |||||
| Transform list of batch outputs into strings. | |||||
| :param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes]. | |||||
| :return results: list of strings | |||||
| """ | |||||
| def _text_classify_prepare_output(self, batch_outputs): | |||||
| results = [] | results = [] | ||||
| for batch_out in batch_outputs: | for batch_out in batch_outputs: | ||||
| idx = np.argmax(batch_out.detach().numpy(), axis=-1) | idx = np.argmax(batch_out.detach().numpy(), axis=-1) | ||||
| results.extend([self.index2label[i] for i in idx]) | results.extend([self.index2label[i] for i in idx]) | ||||
| return results | return results | ||||
| class SeqLabelInfer(Predictor): | |||||
| def __init__(self, pickle_path): | |||||
| print( | |||||
| "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.") | |||||
| super(SeqLabelInfer, self).__init__(pickle_path, "seq_label") | |||||
| class ClassificationInfer(Predictor): | |||||
| def __init__(self, pickle_path): | |||||
| print( | |||||
| "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.") | |||||
| super(ClassificationInfer, self).__init__(pickle_path, "text_classify") | |||||
| @@ -3,6 +3,10 @@ import os | |||||
| import numpy as np | import numpy as np | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | |||||
| DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | ||||
| DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | ||||
| DEFAULT_RESERVED_LABEL = ['<reserved-2>', | DEFAULT_RESERVED_LABEL = ['<reserved-2>', | ||||
| @@ -84,7 +88,7 @@ class BasePreprocess(object): | |||||
| return len(self.label2index) | return len(self.label2index) | ||||
| def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | ||||
| """Main preprocessing pipeline. | |||||
| """Main pre-processing pipeline. | |||||
| :param train_dev_data: three-level list, with either single label or multiple labels in a sample. | :param train_dev_data: three-level list, with either single label or multiple labels in a sample. | ||||
| :param test_data: three-level list, with either single label or multiple labels in a sample. (optional) | :param test_data: three-level list, with either single label or multiple labels in a sample. (optional) | ||||
| @@ -92,7 +96,9 @@ class BasePreprocess(object): | |||||
| :param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. | :param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. | ||||
| :param cross_val: bool, whether to do cross validation. | :param cross_val: bool, whether to do cross validation. | ||||
| :param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. | :param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. | ||||
| :return results: a tuple of datasets after preprocessing. | |||||
| :return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset. | |||||
| If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | |||||
| is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | |||||
| """ | """ | ||||
| if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | ||||
| @@ -111,68 +117,87 @@ class BasePreprocess(object): | |||||
| index2label = self.build_reverse_dict(self.label2index) | index2label = self.build_reverse_dict(self.label2index) | ||||
| save_pickle(index2label, pickle_path, "id2class.pkl") | save_pickle(index2label, pickle_path, "id2class.pkl") | ||||
| data_train = [] | |||||
| data_dev = [] | |||||
| train_set = [] | |||||
| dev_set = [] | |||||
| if not cross_val: | if not cross_val: | ||||
| if not pickle_exist(pickle_path, "data_train.pkl"): | if not pickle_exist(pickle_path, "data_train.pkl"): | ||||
| data_train.extend(self.to_index(train_dev_data)) | |||||
| if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | ||||
| split = int(len(data_train) * train_dev_split) | |||||
| data_dev = data_train[: split] | |||||
| data_train = data_train[split:] | |||||
| save_pickle(data_dev, pickle_path, "data_dev.pkl") | |||||
| split = int(len(train_dev_data) * train_dev_split) | |||||
| data_dev = train_dev_data[: split] | |||||
| data_train = train_dev_data[split:] | |||||
| train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) | |||||
| dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) | |||||
| save_pickle(dev_set, pickle_path, "data_dev.pkl") | |||||
| print("{} of the training data is split for validation. ".format(train_dev_split)) | print("{} of the training data is split for validation. ".format(train_dev_split)) | ||||
| save_pickle(data_train, pickle_path, "data_train.pkl") | |||||
| else: | |||||
| train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) | |||||
| save_pickle(train_set, pickle_path, "data_train.pkl") | |||||
| else: | else: | ||||
| data_train = load_pickle(pickle_path, "data_train.pkl") | |||||
| train_set = load_pickle(pickle_path, "data_train.pkl") | |||||
| if pickle_exist(pickle_path, "data_dev.pkl"): | if pickle_exist(pickle_path, "data_dev.pkl"): | ||||
| data_dev = load_pickle(pickle_path, "data_dev.pkl") | |||||
| dev_set = load_pickle(pickle_path, "data_dev.pkl") | |||||
| else: | else: | ||||
| # cross_val is True | # cross_val is True | ||||
| if not pickle_exist(pickle_path, "data_train_0.pkl"): | if not pickle_exist(pickle_path, "data_train_0.pkl"): | ||||
| # cross validation | # cross validation | ||||
| data_idx = self.to_index(train_dev_data) | |||||
| data_cv = self.cv_split(data_idx, n_fold) | |||||
| data_cv = self.cv_split(train_dev_data, n_fold) | |||||
| for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | ||||
| data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) | |||||
| data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) | |||||
| save_pickle( | save_pickle( | ||||
| data_train_cv, pickle_path, | data_train_cv, pickle_path, | ||||
| "data_train_{}.pkl".format(i)) | "data_train_{}.pkl".format(i)) | ||||
| save_pickle( | save_pickle( | ||||
| data_dev_cv, pickle_path, | data_dev_cv, pickle_path, | ||||
| "data_dev_{}.pkl".format(i)) | "data_dev_{}.pkl".format(i)) | ||||
| data_train.append(data_train_cv) | |||||
| data_dev.append(data_dev_cv) | |||||
| train_set.append(data_train_cv) | |||||
| dev_set.append(data_dev_cv) | |||||
| print("{}-fold cross validation.".format(n_fold)) | print("{}-fold cross validation.".format(n_fold)) | ||||
| else: | else: | ||||
| for i in range(n_fold): | for i in range(n_fold): | ||||
| data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) | data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) | ||||
| data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) | data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) | ||||
| data_train.append(data_train_cv) | |||||
| data_dev.append(data_dev_cv) | |||||
| train_set.append(data_train_cv) | |||||
| dev_set.append(data_dev_cv) | |||||
| # prepare test data if provided | # prepare test data if provided | ||||
| data_test = [] | |||||
| test_set = [] | |||||
| if test_data is not None: | if test_data is not None: | ||||
| if not pickle_exist(pickle_path, "data_test.pkl"): | if not pickle_exist(pickle_path, "data_test.pkl"): | ||||
| data_test = self.to_index(test_data) | |||||
| save_pickle(data_test, pickle_path, "data_test.pkl") | |||||
| test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) | |||||
| save_pickle(test_set, pickle_path, "data_test.pkl") | |||||
| # return preprocessed results | # return preprocessed results | ||||
| results = [data_train] | |||||
| results = [train_set] | |||||
| if cross_val or train_dev_split > 0: | if cross_val or train_dev_split > 0: | ||||
| results.append(data_dev) | |||||
| results.append(dev_set) | |||||
| if test_data: | if test_data: | ||||
| results.append(data_test) | |||||
| results.append(test_set) | |||||
| if len(results) == 1: | if len(results) == 1: | ||||
| return results[0] | return results[0] | ||||
| else: | else: | ||||
| return tuple(results) | return tuple(results) | ||||
| def build_dict(self, data): | def build_dict(self, data): | ||||
| raise NotImplementedError | |||||
| label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| for example in data: | |||||
| for word in example[0]: | |||||
| if word not in word2index: | |||||
| word2index[word] = len(word2index) | |||||
| label = example[1] | |||||
| if isinstance(label, str): | |||||
| # label is a string | |||||
| if label not in label2index: | |||||
| label2index[label] = len(label2index) | |||||
| elif isinstance(label, list): | |||||
| # label is a list of strings | |||||
| for single_label in label: | |||||
| if single_label not in label2index: | |||||
| label2index[single_label] = len(label2index) | |||||
| return word2index, label2index | |||||
| def to_index(self, data): | |||||
| raise NotImplementedError | |||||
| def build_reverse_dict(self, word_dict): | def build_reverse_dict(self, word_dict): | ||||
| id2word = {word_dict[w]: w for w in word_dict} | id2word = {word_dict[w]: w for w in word_dict} | ||||
| @@ -186,11 +211,23 @@ class BasePreprocess(object): | |||||
| return data_train, data_dev | return data_train, data_dev | ||||
| def cv_split(self, data, n_fold): | def cv_split(self, data, n_fold): | ||||
| """Split data for cross validation.""" | |||||
| """Split data for cross validation. | |||||
| :param data: list of string | |||||
| :param n_fold: int | |||||
| :return data_cv: | |||||
| :: | |||||
| [ | |||||
| (data_train, data_dev), # 1st fold | |||||
| (data_train, data_dev), # 2nd fold | |||||
| ... | |||||
| ] | |||||
| """ | |||||
| data_copy = data.copy() | data_copy = data.copy() | ||||
| np.random.shuffle(data_copy) | np.random.shuffle(data_copy) | ||||
| fold_size = round(len(data_copy) / n_fold) | fold_size = round(len(data_copy) / n_fold) | ||||
| data_cv = [] | data_cv = [] | ||||
| for i in range(n_fold - 1): | for i in range(n_fold - 1): | ||||
| start = i * fold_size | start = i * fold_size | ||||
| @@ -202,154 +239,72 @@ class BasePreprocess(object): | |||||
| data_dev = data_copy[start:] | data_dev = data_copy[start:] | ||||
| data_train = data_copy[:start] | data_train = data_copy[:start] | ||||
| data_cv.append((data_train, data_dev)) | data_cv.append((data_train, data_dev)) | ||||
| return data_cv | return data_cv | ||||
| def convert_to_dataset(self, data, vocab, label_vocab): | |||||
| """Convert list of indices into a DataSet object. | |||||
| class SeqLabelPreprocess(BasePreprocess): | |||||
| """Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
| from labels/classes to index, from index to labels/classes. | |||||
| data of three-level list which have multiple labels in each sample. | |||||
| :: | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| """ | |||||
| def __init__(self): | |||||
| super(SeqLabelPreprocess, self).__init__() | |||||
| def build_dict(self, data): | |||||
| """Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||||
| :param data: three-level list | |||||
| :: | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return word2index: dict of {str, int} | |||||
| label2index: dict of {str, int} | |||||
| :param data: list. Entries are strings. | |||||
| :param vocab: a dict, mapping string (token) to index (int). | |||||
| :param label_vocab: a dict, mapping string (label) to index (int). | |||||
| :return data_set: a DataSet object | |||||
| """ | """ | ||||
| # In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. | |||||
| label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| for example in data: | |||||
| for word, label in zip(example[0], example[1]): | |||||
| if word not in word2index: | |||||
| word2index[word] = len(word2index) | |||||
| if label not in label2index: | |||||
| label2index[label] = len(label2index) | |||||
| return word2index, label2index | |||||
| def to_index(self, data): | |||||
| """Convert word strings and label strings into indices. | |||||
| :param data: three-level list | |||||
| :: | |||||
| use_word_seq = False | |||||
| use_label_seq = False | |||||
| use_label_str = False | |||||
| [ | |||||
| [ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||||
| [ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||||
| ... | |||||
| ] | |||||
| :return data_index: the same shape as data, but each string is replaced by its corresponding index | |||||
| """ | |||||
| data_index = [] | |||||
| # construct a DataSet object and fill it with Instances | |||||
| data_set = DataSet() | |||||
| for example in data: | for example in data: | ||||
| word_list = [] | |||||
| label_list = [] | |||||
| for word, label in zip(example[0], example[1]): | |||||
| word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||||
| label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||||
| data_index.append([word_list, label_list]) | |||||
| return data_index | |||||
| words, label = example[0], example[1] | |||||
| instance = Instance() | |||||
| if isinstance(words, list): | |||||
| x = TextField(words, is_target=False) | |||||
| instance.add_field("word_seq", x) | |||||
| use_word_seq = True | |||||
| else: | |||||
| raise NotImplementedError("words is a {}".format(type(words))) | |||||
| if isinstance(label, list): | |||||
| y = TextField(label, is_target=True) | |||||
| instance.add_field("label_seq", y) | |||||
| use_label_seq = True | |||||
| elif isinstance(label, str): | |||||
| y = LabelField(label, is_target=True) | |||||
| instance.add_field("label", y) | |||||
| use_label_str = True | |||||
| else: | |||||
| raise NotImplementedError("label is a {}".format(type(label))) | |||||
| data_set.append(instance) | |||||
| class ClassPreprocess(BasePreprocess): | |||||
| """ Preprocess pipeline for classification datasets. | |||||
| Preprocess pipeline, including building mapping from words to index, from index to words, | |||||
| from labels/classes to index, from index to labels/classes. | |||||
| design for data of three-level list which has a single label in each sample. | |||||
| :: | |||||
| # convert strings to indices | |||||
| if use_word_seq: | |||||
| data_set.index_field("word_seq", vocab) | |||||
| if use_label_seq: | |||||
| data_set.index_field("label_seq", label_vocab) | |||||
| if use_label_str: | |||||
| data_set.index_field("label", label_vocab) | |||||
| [ | |||||
| [ [word_11, word_12, ...], label_1 ], | |||||
| [ [word_21, word_22, ...], label_2 ], | |||||
| ... | |||||
| ] | |||||
| return data_set | |||||
| """ | |||||
| class SeqLabelPreprocess(BasePreprocess): | |||||
| def __init__(self): | def __init__(self): | ||||
| super(ClassPreprocess, self).__init__() | |||||
| def build_dict(self, data): | |||||
| """Build vocabulary.""" | |||||
| # build vocabulary from scratch if nothing exists | |||||
| word2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| label2index = DEFAULT_WORD_TO_INDEX.copy() | |||||
| # collect every word and label | |||||
| for sent, label in data: | |||||
| if len(sent) <= 1: | |||||
| continue | |||||
| if label not in label2index: | |||||
| label2index[label] = len(label2index) | |||||
| for word in sent: | |||||
| if word not in word2index: | |||||
| word2index[word] = len(word2index) | |||||
| return word2index, label2index | |||||
| def to_index(self, data): | |||||
| """Convert word strings and label strings into indices. | |||||
| :param data: three-level list | |||||
| :: | |||||
| [ | |||||
| [ [word_11, word_12, ...], label_1 ], | |||||
| [ [word_21, word_22, ...], label_2 ], | |||||
| ... | |||||
| ] | |||||
| super(SeqLabelPreprocess, self).__init__() | |||||
| :return data_index: the same shape as data, but each string is replaced by its corresponding index | |||||
| """ | |||||
| data_index = [] | |||||
| for example in data: | |||||
| word_list = [] | |||||
| # example[0] is the word list, example[1] is the single label | |||||
| for word in example[0]: | |||||
| word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||||
| label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | |||||
| data_index.append([word_list, label_index]) | |||||
| return data_index | |||||
| def infer_preprocess(pickle_path, data): | |||||
| """Preprocess over inference data. Transform three-level list of strings into that of index. | |||||
| :: | |||||
| class ClassPreprocess(BasePreprocess): | |||||
| def __init__(self): | |||||
| super(ClassPreprocess, self).__init__() | |||||
| [ | |||||
| [word_11, word_12, ...], | |||||
| [word_21, word_22, ...], | |||||
| ... | |||||
| ] | |||||
| """ | |||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| data_index = [] | |||||
| for example in data: | |||||
| data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) | |||||
| return data_index | |||||
| if __name__ == "__main__": | |||||
| p = BasePreprocess() | |||||
| train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||||
| [["You", "are", "pretty", "."], "1"] | |||||
| ] | |||||
| training_set = p.run(train_dev_data) | |||||
| print(training_set) | |||||
| @@ -1,9 +1,8 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| from fastNLP.core.action import Action | |||||
| from fastNLP.core.action import RandomSampler, Batchifier | |||||
| from fastNLP.modules import utils | |||||
| from fastNLP.core.action import RandomSampler | |||||
| from fastNLP.core.batch import Batch | |||||
| from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
| logger = create_logger(__name__, "./train_test.log") | logger = create_logger(__name__, "./train_test.log") | ||||
| @@ -35,16 +34,16 @@ class BaseTester(object): | |||||
| """ | """ | ||||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
| This is used to warn users of essential settings in the training. | This is used to warn users of essential settings in the training. | ||||
| Obviously, "required_args" is the subset of "default_args". | |||||
| The value in "default_args" to the keys in "required_args" is simply for type check. | |||||
| Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
| """ | """ | ||||
| # add required arguments here | |||||
| required_args = {} | |||||
| required_args = {"task" # one of ("seq_label", "text_classify") | |||||
| } | |||||
| for req_key in required_args: | for req_key in required_args: | ||||
| if req_key not in kwargs: | if req_key not in kwargs: | ||||
| logger.error("Tester lacks argument {}".format(req_key)) | logger.error("Tester lacks argument {}".format(req_key)) | ||||
| raise ValueError("Tester lacks argument {}".format(req_key)) | raise ValueError("Tester lacks argument {}".format(req_key)) | ||||
| self._task = kwargs["task"] | |||||
| for key in default_args: | for key in default_args: | ||||
| if key in kwargs: | if key in kwargs: | ||||
| @@ -79,14 +78,14 @@ class BaseTester(object): | |||||
| self._model = network | self._model = network | ||||
| # turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
| self.mode(network, test=True) | |||||
| self.mode(network, is_test=True) | |||||
| self.eval_history.clear() | self.eval_history.clear() | ||||
| self.batch_output.clear() | self.batch_output.clear() | ||||
| iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False)) | |||||
| data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||||
| step = 0 | step = 0 | ||||
| for batch_x, batch_y in self.make_batch(iterator): | |||||
| for batch_x, batch_y in data_iterator: | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| eval_results = self.evaluate(prediction, batch_y) | eval_results = self.evaluate(prediction, batch_y) | ||||
| @@ -102,17 +101,22 @@ class BaseTester(object): | |||||
| print(self.make_eval_output(prediction, eval_results)) | print(self.make_eval_output(prediction, eval_results)) | ||||
| step += 1 | step += 1 | ||||
| def mode(self, model, test): | |||||
| def mode(self, model, is_test=False): | |||||
| """Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
| :param model: a PyTorch model | :param model: a PyTorch model | ||||
| :param test: bool, whether in test mode. | |||||
| :param is_test: bool, whether in test mode or not. | |||||
| """ | """ | ||||
| Action.mode(model, test) | |||||
| if is_test: | |||||
| model.eval() | |||||
| else: | |||||
| model.train() | |||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| """A forward pass of the model. """ | """A forward pass of the model. """ | ||||
| raise NotImplementedError | |||||
| y = network(**x) | |||||
| return y | |||||
| def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
| """Compute evaluation metrics. | """Compute evaluation metrics. | ||||
| @@ -121,7 +125,38 @@ class BaseTester(object): | |||||
| :param truth: Tensor | :param truth: Tensor | ||||
| :return eval_results: can be anything. It will be stored in self.eval_history | :return eval_results: can be anything. It will be stored in self.eval_history | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| if "label_seq" in truth: | |||||
| truth = truth["label_seq"] | |||||
| elif "label" in truth: | |||||
| truth = truth["label"] | |||||
| else: | |||||
| raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_evaluate(predict, truth) | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_evaluate(predict, truth) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| def _seq_label_evaluate(self, predict, truth): | |||||
| batch_size, max_len = predict.size(0), predict.size(1) | |||||
| loss = self._model.loss(predict, truth) / batch_size | |||||
| prediction = self._model.prediction(predict) | |||||
| # pad prediction to equal length | |||||
| for pred in prediction: | |||||
| if len(pred) < max_len: | |||||
| pred += [0] * (max_len - len(pred)) | |||||
| results = torch.Tensor(prediction).view(-1, ) | |||||
| # make sure "results" is in the same device as "truth" | |||||
| results = results.to(truth) | |||||
| accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||||
| return [float(loss), float(accuracy)] | |||||
| def _text_classify_evaluate(self, y_logit, y_true): | |||||
| y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
| return [y_prob, y_true] | |||||
| @property | @property | ||||
| def metrics(self): | def metrics(self): | ||||
| @@ -131,7 +166,27 @@ class BaseTester(object): | |||||
| :return : variable number of outputs | :return : variable number of outputs | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_metrics | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_metrics | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| @property | |||||
| def _seq_label_metrics(self): | |||||
| batch_loss = np.mean([x[0] for x in self.eval_history]) | |||||
| batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||||
| return batch_loss, batch_accuracy | |||||
| @property | |||||
| def _text_classify_metrics(self): | |||||
| y_prob, y_true = zip(*self.eval_history) | |||||
| y_prob = torch.cat(y_prob, dim=0) | |||||
| y_pred = torch.argmax(y_prob, dim=-1) | |||||
| y_true = torch.cat(y_true, dim=0) | |||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
| def show_metrics(self): | def show_metrics(self): | ||||
| """Customize evaluation outputs in Trainer. | """Customize evaluation outputs in Trainer. | ||||
| @@ -140,10 +195,8 @@ class BaseTester(object): | |||||
| :return print_str: str | :return print_str: str | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| def make_batch(self, iterator): | |||||
| raise NotImplementedError | |||||
| loss, accuracy = self.metrics | |||||
| return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||||
| def make_eval_output(self, predictions, eval_results): | def make_eval_output(self, predictions, eval_results): | ||||
| """Customize Tester outputs. | """Customize Tester outputs. | ||||
| @@ -152,108 +205,20 @@ class BaseTester(object): | |||||
| :param eval_results: Tensor | :param eval_results: Tensor | ||||
| :return: str, to be printed. | :return: str, to be printed. | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| return self.show_metrics() | |||||
| class SeqLabelTester(BaseTester): | |||||
| """Tester for sequence labeling. | |||||
| """ | |||||
| class SeqLabelTester(BaseTester): | |||||
| def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
| """ | |||||
| :param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||||
| """ | |||||
| test_args.update({"task": "seq_label"}) | |||||
| print( | |||||
| "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") | |||||
| super(SeqLabelTester, self).__init__(**test_args) | super(SeqLabelTester, self).__init__(**test_args) | ||||
| self.max_len = None | |||||
| self.mask = None | |||||
| self.seq_len = None | |||||
| def data_forward(self, network, inputs): | |||||
| """This is only for sequence labeling with CRF decoder. | |||||
| :param network: a PyTorch model | |||||
| :param inputs: tuple of (x, seq_len) | |||||
| x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||||
| after padding. | |||||
| seq_len: list of int, the lengths of sequences before padding. | |||||
| :return y: Tensor of shape [batch_size, max_len] | |||||
| """ | |||||
| if not isinstance(inputs, tuple): | |||||
| raise RuntimeError("output_length must be true for sequence modeling.") | |||||
| # unpack the returned value from make_batch | |||||
| x, seq_len = inputs[0], inputs[1] | |||||
| batch_size, max_len = x.size(0), x.size(1) | |||||
| mask = utils.seq_mask(seq_len, max_len) | |||||
| mask = mask.byte().view(batch_size, max_len) | |||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| mask = mask.cuda() | |||||
| self.mask = mask | |||||
| self.seq_len = seq_len | |||||
| y = network(x) | |||||
| return y | |||||
| def evaluate(self, predict, truth): | |||||
| """Compute metrics (or loss). | |||||
| :param predict: Tensor, [batch_size, max_len, tag_size] | |||||
| :param truth: Tensor, [batch_size, max_len] | |||||
| :return: | |||||
| """ | |||||
| batch_size, max_len = predict.size(0), predict.size(1) | |||||
| loss = self._model.loss(predict, truth, self.mask) / batch_size | |||||
| prediction = self._model.prediction(predict, self.mask) | |||||
| results = torch.Tensor(prediction).view(-1, ) | |||||
| # make sure "results" is in the same device as "truth" | |||||
| results = results.to(truth) | |||||
| accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||||
| return [float(loss), float(accuracy)] | |||||
| def metrics(self): | |||||
| batch_loss = np.mean([x[0] for x in self.eval_history]) | |||||
| batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||||
| return batch_loss, batch_accuracy | |||||
| def show_metrics(self): | |||||
| """This is called by Trainer to print evaluation on dev set. | |||||
| :return print_str: str | |||||
| """ | |||||
| loss, accuracy = self.metrics() | |||||
| return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||||
| def make_batch(self, iterator): | |||||
| return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | |||||
| class ClassificationTester(BaseTester): | class ClassificationTester(BaseTester): | ||||
| """Tester for classification.""" | |||||
| def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
| """ | |||||
| :param test_args: a dict-like object that has __getitem__ method. | |||||
| can be accessed by "test_args["key_str"]" | |||||
| """ | |||||
| test_args.update({"task": "seq_label"}) | |||||
| print( | |||||
| "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.") | |||||
| super(ClassificationTester, self).__init__(**test_args) | super(ClassificationTester, self).__init__(**test_args) | ||||
| def make_batch(self, iterator, max_len=None): | |||||
| return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | |||||
| def data_forward(self, network, x): | |||||
| """Forward through network.""" | |||||
| logits = network(x) | |||||
| return logits | |||||
| def evaluate(self, y_logit, y_true): | |||||
| """Return y_pred and y_true.""" | |||||
| y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
| return [y_prob, y_true] | |||||
| def metrics(self): | |||||
| """Compute accuracy.""" | |||||
| y_prob, y_true = zip(*self.eval_history) | |||||
| y_prob = torch.cat(y_prob, dim=0) | |||||
| y_pred = torch.argmax(y_prob, dim=-1) | |||||
| y_true = torch.cat(y_true, dim=0) | |||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
| @@ -4,15 +4,13 @@ import time | |||||
| from datetime import timedelta | from datetime import timedelta | ||||
| import torch | import torch | ||||
| import tensorboardX | |||||
| from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
| from fastNLP.core.action import Action | |||||
| from fastNLP.core.action import RandomSampler, Batchifier | |||||
| from fastNLP.core.action import RandomSampler | |||||
| from fastNLP.core.batch import Batch | |||||
| from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.tester import SeqLabelTester, ClassificationTester | from fastNLP.core.tester import SeqLabelTester, ClassificationTester | ||||
| from fastNLP.modules import utils | |||||
| from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| @@ -50,16 +48,16 @@ class BaseTrainer(object): | |||||
| """ | """ | ||||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
| This is used to warn users of essential settings in the training. | This is used to warn users of essential settings in the training. | ||||
| Obviously, "required_args" is the subset of "default_args". | |||||
| The value in "default_args" to the keys in "required_args" is simply for type check. | |||||
| Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||||
| """ | """ | ||||
| # add required arguments here | |||||
| required_args = {} | |||||
| required_args = {"task" # one of ("seq_label", "text_classify") | |||||
| } | |||||
| for req_key in required_args: | for req_key in required_args: | ||||
| if req_key not in kwargs: | if req_key not in kwargs: | ||||
| logger.error("Trainer lacks argument {}".format(req_key)) | logger.error("Trainer lacks argument {}".format(req_key)) | ||||
| raise ValueError("Trainer lacks argument {}".format(req_key)) | raise ValueError("Trainer lacks argument {}".format(req_key)) | ||||
| self._task = kwargs["task"] | |||||
| for key in default_args: | for key in default_args: | ||||
| if key in kwargs: | if key in kwargs: | ||||
| @@ -90,13 +88,14 @@ class BaseTrainer(object): | |||||
| self._optimizer_proto = default_args["optimizer"] | self._optimizer_proto = default_args["optimizer"] | ||||
| self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | ||||
| self._graph_summaried = False | self._graph_summaried = False | ||||
| self._best_accuracy = 0.0 | |||||
| def train(self, network, train_data, dev_data=None): | def train(self, network, train_data, dev_data=None): | ||||
| """General Training Procedure | """General Training Procedure | ||||
| :param network: a model | :param network: a model | ||||
| :param train_data: three-level list, the training set. | |||||
| :param dev_data: three-level list, the validation data (optional) | |||||
| :param train_data: a DataSet instance, the training data | |||||
| :param dev_data: a DataSet instance, the validation data (optional) | |||||
| """ | """ | ||||
| # transfer model to gpu if available | # transfer model to gpu if available | ||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| @@ -126,9 +125,10 @@ class BaseTrainer(object): | |||||
| logger.info("training epoch {}".format(epoch)) | logger.info("training epoch {}".format(epoch)) | ||||
| # turn on network training mode | # turn on network training mode | ||||
| self.mode(network, test=False) | |||||
| self.mode(network, is_test=False) | |||||
| # prepare mini-batch iterator | # prepare mini-batch iterator | ||||
| data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) | |||||
| data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||||
| use_cuda=self.use_cuda) | |||||
| logger.info("prepared data iterator") | logger.info("prepared data iterator") | ||||
| # one forward and backward pass | # one forward and backward pass | ||||
| @@ -157,7 +157,7 @@ class BaseTrainer(object): | |||||
| - epoch: int, | - epoch: int, | ||||
| """ | """ | ||||
| step = 0 | step = 0 | ||||
| for batch_x, batch_y in self.make_batch(data_iterator): | |||||
| for batch_x, batch_y in data_iterator: | |||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| @@ -166,10 +166,6 @@ class BaseTrainer(object): | |||||
| self.update() | self.update() | ||||
| self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | ||||
| if not self._graph_summaried: | |||||
| self._summary_writer.add_graph(network, batch_x) | |||||
| self._graph_summaried = True | |||||
| if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | ||||
| end = time.time() | end = time.time() | ||||
| diff = timedelta(seconds=round(end - kwargs["start"])) | diff = timedelta(seconds=round(end - kwargs["start"])) | ||||
| @@ -204,11 +200,17 @@ class BaseTrainer(object): | |||||
| network_copy = copy.deepcopy(network) | network_copy = copy.deepcopy(network) | ||||
| self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | ||||
| def make_batch(self, iterator): | |||||
| raise NotImplementedError | |||||
| def mode(self, model, is_test=False): | |||||
| """Train mode or Test mode. This is for PyTorch currently. | |||||
| :param model: a PyTorch model | |||||
| :param is_test: bool, whether in test mode or not. | |||||
| def mode(self, network, test): | |||||
| Action.mode(network, test) | |||||
| """ | |||||
| if is_test: | |||||
| model.eval() | |||||
| else: | |||||
| model.train() | |||||
| def define_optimizer(self): | def define_optimizer(self): | ||||
| """Define framework-specific optimizer specified by the models. | """Define framework-specific optimizer specified by the models. | ||||
| @@ -224,7 +226,20 @@ class BaseTrainer(object): | |||||
| self._optimizer.step() | self._optimizer.step() | ||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| raise NotImplementedError | |||||
| if self._task == "seq_label": | |||||
| y = network(x["word_seq"], x["word_seq_origin_len"]) | |||||
| elif self._task == "text_classify": | |||||
| y = network(x["word_seq"]) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| if not self._graph_summaried: | |||||
| if self._task == "seq_label": | |||||
| self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | |||||
| elif self._task == "text_classify": | |||||
| self._summary_writer.add_graph(network, x["word_seq"], verbose=False) | |||||
| self._graph_summaried = True | |||||
| return y | |||||
| def grad_backward(self, loss): | def grad_backward(self, loss): | ||||
| """Compute gradient with link rules. | """Compute gradient with link rules. | ||||
| @@ -243,6 +258,13 @@ class BaseTrainer(object): | |||||
| :param truth: ground truth label vector | :param truth: ground truth label vector | ||||
| :return: a scalar | :return: a scalar | ||||
| """ | """ | ||||
| if "label_seq" in truth: | |||||
| truth = truth["label_seq"] | |||||
| elif "label" in truth: | |||||
| truth = truth["label"] | |||||
| truth = truth.view((-1,)) | |||||
| else: | |||||
| raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||||
| return self._loss_func(predict, truth) | return self._loss_func(predict, truth) | ||||
| def define_loss(self): | def define_loss(self): | ||||
| @@ -270,7 +292,12 @@ class BaseTrainer(object): | |||||
| :param validator: a Tester instance | :param validator: a Tester instance | ||||
| :return: bool, True means current results on dev set is the best. | :return: bool, True means current results on dev set is the best. | ||||
| """ | """ | ||||
| raise NotImplementedError | |||||
| loss, accuracy = validator.metrics | |||||
| if accuracy > self._best_accuracy: | |||||
| self._best_accuracy = accuracy | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def save_model(self, network, model_name): | def save_model(self, network, model_name): | ||||
| """Save this model with such a name. | """Save this model with such a name. | ||||
| @@ -291,55 +318,11 @@ class SeqLabelTrainer(BaseTrainer): | |||||
| """Trainer for Sequence Labeling | """Trainer for Sequence Labeling | ||||
| """ | """ | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| kwargs.update({"task": "seq_label"}) | |||||
| print( | |||||
| "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.") | |||||
| super(SeqLabelTrainer, self).__init__(**kwargs) | super(SeqLabelTrainer, self).__init__(**kwargs) | ||||
| # self.vocab_size = kwargs["vocab_size"] | |||||
| # self.num_classes = kwargs["num_classes"] | |||||
| self.max_len = None | |||||
| self.mask = None | |||||
| self.best_accuracy = 0.0 | |||||
| def data_forward(self, network, inputs): | |||||
| if not isinstance(inputs, tuple): | |||||
| raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | |||||
| # unpack the returned value from make_batch | |||||
| x, seq_len = inputs[0], inputs[1] | |||||
| batch_size, max_len = x.size(0), x.size(1) | |||||
| mask = utils.seq_mask(seq_len, max_len) | |||||
| mask = mask.byte().view(batch_size, max_len) | |||||
| if torch.cuda.is_available() and self.use_cuda: | |||||
| mask = mask.cuda() | |||||
| self.mask = mask | |||||
| y = network(x) | |||||
| return y | |||||
| def get_loss(self, predict, truth): | |||||
| """Compute loss given prediction and ground truth. | |||||
| :param predict: prediction label vector, [batch_size, max_len, tag_size] | |||||
| :param truth: ground truth label vector, [batch_size, max_len] | |||||
| :return loss: a scalar | |||||
| """ | |||||
| batch_size, max_len = predict.size(0), predict.size(1) | |||||
| assert truth.shape == (batch_size, max_len) | |||||
| loss = self._model.loss(predict, truth, self.mask) | |||||
| return loss | |||||
| def best_eval_result(self, validator): | |||||
| loss, accuracy = validator.metrics() | |||||
| if accuracy > self.best_accuracy: | |||||
| self.best_accuracy = accuracy | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def make_batch(self, iterator): | |||||
| return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | |||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| return SeqLabelTester(**valid_args) | return SeqLabelTester(**valid_args) | ||||
| @@ -349,33 +332,10 @@ class ClassificationTrainer(BaseTrainer): | |||||
| """Trainer for text classification.""" | """Trainer for text classification.""" | ||||
| def __init__(self, **train_args): | def __init__(self, **train_args): | ||||
| train_args.update({"task": "text_classify"}) | |||||
| print( | |||||
| "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.") | |||||
| super(ClassificationTrainer, self).__init__(**train_args) | super(ClassificationTrainer, self).__init__(**train_args) | ||||
| self.iterator = None | |||||
| self.loss_func = None | |||||
| self.optimizer = None | |||||
| self.best_accuracy = 0 | |||||
| def data_forward(self, network, x): | |||||
| """Forward through network.""" | |||||
| logits = network(x) | |||||
| return logits | |||||
| def make_batch(self, iterator): | |||||
| return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | |||||
| def get_acc(self, y_logit, y_true): | |||||
| """Compute accuracy.""" | |||||
| y_pred = torch.argmax(y_logit, dim=-1) | |||||
| return int(torch.sum(y_true == y_pred)) / len(y_true) | |||||
| def best_eval_result(self, validator): | |||||
| _, _, accuracy = validator.metrics() | |||||
| if accuracy > self.best_accuracy: | |||||
| self.best_accuracy = accuracy | |||||
| return True | |||||
| else: | |||||
| return False | |||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| return ClassificationTester(**valid_args) | return ClassificationTester(**valid_args) | ||||
| @@ -35,8 +35,12 @@ class CNNText(torch.nn.Module): | |||||
| self.dropout = nn.Dropout(drop_prob) | self.dropout = nn.Dropout(drop_prob) | ||||
| self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes) | ||||
| def forward(self, x): | |||||
| x = self.embed(x) # [N,L] -> [N,L,C] | |||||
| def forward(self, word_seq): | |||||
| """ | |||||
| :param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
| :return x: torch.LongTensor, [batch_size, num_classes] | |||||
| """ | |||||
| x = self.embed(word_seq) # [N,L] -> [N,L,C] | |||||
| x = self.conv_pool(x) # [N,L,C] -> [N,C] | x = self.conv_pool(x) # [N,L,C] -> [N,C] | ||||
| x = self.dropout(x) | x = self.dropout(x) | ||||
| x = self.fc(x) # [N,C] -> [N, N_class] | x = self.fc(x) # [N,C] -> [N, N_class] | ||||
| @@ -4,6 +4,20 @@ from fastNLP.models.base_model import BaseModel | |||||
| from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
| def seq_mask(seq_len, max_len): | |||||
| """Create a mask for the sequences. | |||||
| :param seq_len: list or torch.LongTensor | |||||
| :param max_len: int | |||||
| :return mask: torch.LongTensor | |||||
| """ | |||||
| if isinstance(seq_len, list): | |||||
| seq_len = torch.LongTensor(seq_len) | |||||
| mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] | |||||
| mask = torch.stack(mask, 1) | |||||
| return mask | |||||
| class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
| """ | """ | ||||
| PyTorch Network for sequence labeling | PyTorch Network for sequence labeling | ||||
| @@ -20,13 +34,17 @@ class SeqLabeling(BaseModel): | |||||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | ||||
| self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | ||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
| self.mask = None | |||||
| def forward(self, x): | |||||
| def forward(self, word_seq, word_seq_origin_len): | |||||
| """ | """ | ||||
| :param x: LongTensor, [batch_size, mex_len] | |||||
| :param word_seq: LongTensor, [batch_size, mex_len] | |||||
| :param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | |||||
| :return y: [batch_size, mex_len, tag_size] | :return y: [batch_size, mex_len, tag_size] | ||||
| """ | """ | ||||
| x = self.Embedding(x) | |||||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||||
| x = self.Embedding(word_seq) | |||||
| # [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
| x = self.Rnn(x) | x = self.Rnn(x) | ||||
| # [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
| @@ -34,27 +52,34 @@ class SeqLabeling(BaseModel): | |||||
| # [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
| return x | return x | ||||
| def loss(self, x, y, mask): | |||||
| def loss(self, x, y): | |||||
| """ | """ | ||||
| Negative log likelihood loss. | Negative log likelihood loss. | ||||
| :param x: Tensor, [batch_size, max_len, tag_size] | :param x: Tensor, [batch_size, max_len, tag_size] | ||||
| :param y: Tensor, [batch_size, max_len] | :param y: Tensor, [batch_size, max_len] | ||||
| :param mask: ByteTensor, [batch_size, ,max_len] | |||||
| :return loss: a scalar Tensor | :return loss: a scalar Tensor | ||||
| """ | """ | ||||
| x = x.float() | x = x.float() | ||||
| y = y.long() | y = y.long() | ||||
| total_loss = self.Crf(x, y, mask) | |||||
| assert x.shape[:2] == y.shape | |||||
| assert y.shape == self.mask.shape | |||||
| total_loss = self.Crf(x, y, self.mask) | |||||
| return torch.mean(total_loss) | return torch.mean(total_loss) | ||||
| def prediction(self, x, mask): | |||||
| def make_mask(self, x, seq_len): | |||||
| batch_size, max_len = x.size(0), x.size(1) | |||||
| mask = seq_mask(seq_len, max_len) | |||||
| mask = mask.byte().view(batch_size, max_len) | |||||
| mask = mask.to(x) | |||||
| return mask | |||||
| def prediction(self, x): | |||||
| """ | """ | ||||
| :param x: FloatTensor, [batch_size, max_len, tag_size] | :param x: FloatTensor, [batch_size, max_len, tag_size] | ||||
| :param mask: ByteTensor, [batch_size, max_len] | |||||
| :return prediction: list of [decode path(list)] | :return prediction: list of [decode path(list)] | ||||
| """ | """ | ||||
| tag_seq = self.Crf.viterbi_decode(x, mask) | |||||
| tag_seq = self.Crf.viterbi_decode(x, self.mask) | |||||
| return tag_seq | return tag_seq | ||||
| @@ -81,14 +106,17 @@ class AdvSeqLabel(SeqLabeling): | |||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
| def forward(self, x): | |||||
| def forward(self, word_seq, word_seq_origin_len): | |||||
| """ | """ | ||||
| :param x: LongTensor, [batch_size, mex_len] | |||||
| :param word_seq: LongTensor, [batch_size, mex_len] | |||||
| :param word_seq_origin_len: list of int. | |||||
| :return y: [batch_size, mex_len, tag_size] | :return y: [batch_size, mex_len, tag_size] | ||||
| """ | """ | ||||
| batch_size = x.size(0) | |||||
| max_len = x.size(1) | |||||
| x = self.Embedding(x) | |||||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||||
| batch_size = word_seq.size(0) | |||||
| max_len = word_seq.size(1) | |||||
| x = self.Embedding(word_seq) | |||||
| # [batch_size, max_len, word_emb_dim] | # [batch_size, max_len, word_emb_dim] | ||||
| x = self.Rnn(x) | x = self.Rnn(x) | ||||
| # [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
| @@ -1,8 +1,10 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from torch.autograd import Variable | from torch.autograd import Variable | ||||
| import torch.nn.functional as F | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class SelfAttention(nn.Module): | class SelfAttention(nn.Module): | ||||
| """ | """ | ||||
| Self Attention Module. | Self Attention Module. | ||||
| @@ -13,13 +15,18 @@ class SelfAttention(nn.Module): | |||||
| num_vec: int, the number of encoded vectors | num_vec: int, the number of encoded vectors | ||||
| """ | """ | ||||
| def __init__(self, input_size, dim=10, num_vec=10): | |||||
| def __init__(self, input_size, dim=10, num_vec=10 ,drop = 0.5 ,initial_method =None): | |||||
| super(SelfAttention, self).__init__() | super(SelfAttention, self).__init__() | ||||
| self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||||
| self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||||
| # self.W_s1 = nn.Parameter(torch.randn(dim, input_size), requires_grad=True) | |||||
| # self.W_s2 = nn.Parameter(torch.randn(num_vec, dim), requires_grad=True) | |||||
| self.attention_hops = num_vec | |||||
| self.ws1 = nn.Linear(input_size, dim, bias=False) | |||||
| self.ws2 = nn.Linear(dim, num_vec, bias=False) | |||||
| self.drop = nn.Dropout(drop) | |||||
| self.softmax = nn.Softmax(dim=2) | self.softmax = nn.Softmax(dim=2) | ||||
| self.tanh = nn.Tanh() | self.tanh = nn.Tanh() | ||||
| initial_parameter(self, initial_method) | |||||
| def penalization(self, A): | def penalization(self, A): | ||||
| """ | """ | ||||
| compute the penalization term for attention module | compute the penalization term for attention module | ||||
| @@ -32,11 +39,33 @@ class SelfAttention(nn.Module): | |||||
| M = M.view(M.size(0), -1) | M = M.view(M.size(0), -1) | ||||
| return torch.sum(M ** 2, dim=1) | return torch.sum(M ** 2, dim=1) | ||||
| def forward(self, x): | |||||
| inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) | |||||
| A = self.softmax(torch.matmul(self.W_s2, inter)) | |||||
| out = torch.matmul(A, x) | |||||
| out = out.view(out.size(0), -1) | |||||
| penalty = self.penalization(A) | |||||
| return out, penalty | |||||
| def forward(self, outp ,inp): | |||||
| # the following code can not be use because some word are padding ,these is not such module! | |||||
| # inter = self.tanh(torch.matmul(self.W_s1, torch.transpose(x, 1, 2))) # [] | |||||
| # A = self.softmax(torch.matmul(self.W_s2, inter)) | |||||
| # out = torch.matmul(A, x) | |||||
| # out = out.view(out.size(0), -1) | |||||
| # penalty = self.penalization(A) | |||||
| # return out, penalty | |||||
| outp = outp.contiguous() | |||||
| size = outp.size() # [bsz, len, nhid] | |||||
| compressed_embeddings = outp.view(-1, size[2]) # [bsz*len, nhid*2] | |||||
| transformed_inp = torch.transpose(inp, 0, 1).contiguous() # [bsz, len] | |||||
| transformed_inp = transformed_inp.view(size[0], 1, size[1]) # [bsz, 1, len] | |||||
| concatenated_inp = [transformed_inp for i in range(self.attention_hops)] | |||||
| concatenated_inp = torch.cat(concatenated_inp, 1) # [bsz, hop, len] | |||||
| hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) # [bsz*len, attention-unit] | |||||
| attention = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] | |||||
| attention = torch.transpose(attention, 1, 2).contiguous() # [bsz, hop, len] | |||||
| penalized_alphas = attention + ( | |||||
| -10000 * (concatenated_inp == 0).float()) | |||||
| # [bsz, hop, len] + [bsz, hop, len] | |||||
| attention = self.softmax(penalized_alphas.view(-1, size[1])) # [bsz*hop, len] | |||||
| attention = attention.view(size[0], self.attention_hops, size[1]) # [bsz, hop, len] | |||||
| return torch.bmm(attention, outp), attention # output --> [baz ,hop ,nhid] | |||||
| @@ -1,6 +1,7 @@ | |||||
| import torch | import torch | ||||
| from torch import nn | from torch import nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| def log_sum_exp(x, dim=-1): | def log_sum_exp(x, dim=-1): | ||||
| max_value, _ = x.max(dim=dim, keepdim=True) | max_value, _ = x.max(dim=dim, keepdim=True) | ||||
| @@ -19,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens): | |||||
| class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
| def __init__(self, tag_size, include_start_end_trans=True): | |||||
| def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): | |||||
| """ | """ | ||||
| :param tag_size: int, num of tags | :param tag_size: int, num of tags | ||||
| :param include_start_end_trans: bool, whether to include start/end tag | :param include_start_end_trans: bool, whether to include start/end tag | ||||
| @@ -35,8 +36,8 @@ class ConditionalRandomField(nn.Module): | |||||
| self.start_scores = nn.Parameter(torch.randn(tag_size)) | self.start_scores = nn.Parameter(torch.randn(tag_size)) | ||||
| self.end_scores = nn.Parameter(torch.randn(tag_size)) | self.end_scores = nn.Parameter(torch.randn(tag_size)) | ||||
| self.reset_parameter() | |||||
| # self.reset_parameter() | |||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameter(self): | def reset_parameter(self): | ||||
| nn.init.xavier_normal_(self.transition_m) | nn.init.xavier_normal_(self.transition_m) | ||||
| if self.include_start_end_trans: | if self.include_start_end_trans: | ||||
| @@ -1,8 +1,8 @@ | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class MLP(nn.Module): | class MLP(nn.Module): | ||||
| def __init__(self, size_layer, num_class=2, activation='relu'): | |||||
| def __init__(self, size_layer, num_class=2, activation='relu' , initial_method = None): | |||||
| """Multilayer Perceptrons as a decoder | """Multilayer Perceptrons as a decoder | ||||
| Args: | Args: | ||||
| @@ -36,7 +36,7 @@ class MLP(nn.Module): | |||||
| self.hidden_active = activation | self.hidden_active = activation | ||||
| else: | else: | ||||
| raise ValueError("should set activation correctly: {}".format(activation)) | raise ValueError("should set activation correctly: {}".format(activation)) | ||||
| initial_parameter(self, initial_method ) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| for layer in self.hiddens: | for layer in self.hiddens: | ||||
| x = self.hidden_active(layer(x)) | x = self.hidden_active(layer(x)) | ||||
| @@ -1,11 +1,12 @@ | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch import nn | from torch import nn | ||||
| # from torch.nn.init import xavier_uniform | |||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class ConvCharEmbedding(nn.Module): | class ConvCharEmbedding(nn.Module): | ||||
| def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5)): | |||||
| def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5),initial_method = None): | |||||
| """ | """ | ||||
| Character Level Word Embedding | Character Level Word Embedding | ||||
| :param char_emb_size: the size of character level embedding. Default: 50 | :param char_emb_size: the size of character level embedding. Default: 50 | ||||
| @@ -20,6 +21,8 @@ class ConvCharEmbedding(nn.Module): | |||||
| nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | ||||
| for i in range(len(kernels))]) | for i in range(len(kernels))]) | ||||
| initial_parameter(self,initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| """ | """ | ||||
| :param x: [batch_size * sent_length, word_length, char_emb_size] | :param x: [batch_size * sent_length, word_length, char_emb_size] | ||||
| @@ -53,7 +56,7 @@ class LSTMCharEmbedding(nn.Module): | |||||
| :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | :param hidden_size: int, the number of hidden units. Default: equal to char_emb_size. | ||||
| """ | """ | ||||
| def __init__(self, char_emb_size=50, hidden_size=None): | |||||
| def __init__(self, char_emb_size=50, hidden_size=None , initial_method= None): | |||||
| super(LSTMCharEmbedding, self).__init__() | super(LSTMCharEmbedding, self).__init__() | ||||
| self.hidden_size = char_emb_size if hidden_size is None else hidden_size | self.hidden_size = char_emb_size if hidden_size is None else hidden_size | ||||
| @@ -62,7 +65,7 @@ class LSTMCharEmbedding(nn.Module): | |||||
| num_layers=1, | num_layers=1, | ||||
| bias=True, | bias=True, | ||||
| batch_first=True) | batch_first=True) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| """ | """ | ||||
| :param x:[ n_batch*n_word, word_length, char_emb_size] | :param x:[ n_batch*n_word, word_length, char_emb_size] | ||||
| @@ -6,6 +6,7 @@ import torch.nn as nn | |||||
| from torch.nn.init import xavier_uniform_ | from torch.nn.init import xavier_uniform_ | ||||
| # import torch.nn.functional as F | # import torch.nn.functional as F | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class Conv(nn.Module): | class Conv(nn.Module): | ||||
| """ | """ | ||||
| @@ -15,7 +16,7 @@ class Conv(nn.Module): | |||||
| def __init__(self, in_channels, out_channels, kernel_size, | def __init__(self, in_channels, out_channels, kernel_size, | ||||
| stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
| groups=1, bias=True, activation='relu'): | |||||
| groups=1, bias=True, activation='relu',initial_method = None ): | |||||
| super(Conv, self).__init__() | super(Conv, self).__init__() | ||||
| self.conv = nn.Conv1d( | self.conv = nn.Conv1d( | ||||
| in_channels=in_channels, | in_channels=in_channels, | ||||
| @@ -26,7 +27,7 @@ class Conv(nn.Module): | |||||
| dilation=dilation, | dilation=dilation, | ||||
| groups=groups, | groups=groups, | ||||
| bias=bias) | bias=bias) | ||||
| xavier_uniform_(self.conv.weight) | |||||
| # xavier_uniform_(self.conv.weight) | |||||
| activations = { | activations = { | ||||
| 'relu': nn.ReLU(), | 'relu': nn.ReLU(), | ||||
| @@ -37,6 +38,7 @@ class Conv(nn.Module): | |||||
| raise Exception( | raise Exception( | ||||
| 'Should choose activation function from: ' + | 'Should choose activation function from: ' + | ||||
| ', '.join([x for x in activations])) | ', '.join([x for x in activations])) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] | ||||
| @@ -5,7 +5,7 @@ import torch | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from torch.nn.init import xavier_uniform_ | from torch.nn.init import xavier_uniform_ | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
| """ | """ | ||||
| @@ -14,7 +14,7 @@ class ConvMaxpool(nn.Module): | |||||
| def __init__(self, in_channels, out_channels, kernel_sizes, | def __init__(self, in_channels, out_channels, kernel_sizes, | ||||
| stride=1, padding=0, dilation=1, | stride=1, padding=0, dilation=1, | ||||
| groups=1, bias=True, activation='relu'): | |||||
| groups=1, bias=True, activation='relu',initial_method = None ): | |||||
| super(ConvMaxpool, self).__init__() | super(ConvMaxpool, self).__init__() | ||||
| # convolution | # convolution | ||||
| @@ -47,6 +47,8 @@ class ConvMaxpool(nn.Module): | |||||
| raise Exception( | raise Exception( | ||||
| "Undefined activation function: choose from: relu") | "Undefined activation function: choose from: relu") | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| # [N,L,C] -> [N,C,L] | # [N,L,C] -> [N,C,L] | ||||
| x = torch.transpose(x, 1, 2) | x = torch.transpose(x, 1, 2) | ||||
| @@ -1,6 +1,6 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class Linear(nn.Module): | class Linear(nn.Module): | ||||
| """ | """ | ||||
| Linear module | Linear module | ||||
| @@ -12,10 +12,10 @@ class Linear(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN | bidirectional : If True, becomes a bidirectional RNN | ||||
| """ | """ | ||||
| def __init__(self, input_size, output_size, bias=True): | |||||
| def __init__(self, input_size, output_size, bias=True,initial_method = None ): | |||||
| super(Linear, self).__init__() | super(Linear, self).__init__() | ||||
| self.linear = nn.Linear(input_size, output_size, bias) | self.linear = nn.Linear(input_size, output_size, bias) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x = self.linear(x) | x = self.linear(x) | ||||
| return x | return x | ||||
| @@ -1,6 +1,6 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| class Lstm(nn.Module): | class Lstm(nn.Module): | ||||
| """ | """ | ||||
| LSTM module | LSTM module | ||||
| @@ -13,11 +13,13 @@ class Lstm(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN. Default: False. | bidirectional : If True, becomes a bidirectional RNN. Default: False. | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False): | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None): | |||||
| super(Lstm, self).__init__() | super(Lstm, self).__init__() | ||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | ||||
| dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
| initial_parameter(self, initial_method) | |||||
| def forward(self, x): | def forward(self, x): | ||||
| x, _ = self.lstm(x) | x, _ = self.lstm(x) | ||||
| return x | return x | ||||
| if __name__ == "__main__": | |||||
| lstm = Lstm(10) | |||||
| @@ -4,7 +4,7 @@ import torch | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| def MaskedRecurrent(reverse=False): | def MaskedRecurrent(reverse=False): | ||||
| def forward(input, hidden, cell, mask, train=True, dropout=0): | def forward(input, hidden, cell, mask, train=True, dropout=0): | ||||
| """ | """ | ||||
| @@ -192,7 +192,7 @@ def AutogradMaskedStep(num_layers=1, dropout=0, train=True, lstm=False): | |||||
| class MaskedRNNBase(nn.Module): | class MaskedRNNBase(nn.Module): | ||||
| def __init__(self, Cell, input_size, hidden_size, | def __init__(self, Cell, input_size, hidden_size, | ||||
| num_layers=1, bias=True, batch_first=False, | num_layers=1, bias=True, batch_first=False, | ||||
| layer_dropout=0, step_dropout=0, bidirectional=False, **kwargs): | |||||
| layer_dropout=0, step_dropout=0, bidirectional=False, initial_method = None , **kwargs): | |||||
| """ | """ | ||||
| :param Cell: | :param Cell: | ||||
| :param input_size: | :param input_size: | ||||
| @@ -226,7 +226,7 @@ class MaskedRNNBase(nn.Module): | |||||
| cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) | cell = self.Cell(layer_input_size, hidden_size, self.bias, **kwargs) | ||||
| self.all_cells.append(cell) | self.all_cells.append(cell) | ||||
| self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 | self.add_module('cell%d' % (layer * num_directions + direction), cell) # Max的代码写得真好看 | ||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| for cell in self.all_cells: | for cell in self.all_cells: | ||||
| cell.reset_parameters() | cell.reset_parameters() | ||||
| @@ -6,6 +6,7 @@ import torch.nn.functional as F | |||||
| from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend | ||||
| from torch.nn.parameter import Parameter | from torch.nn.parameter import Parameter | ||||
| from fastNLP.modules.utils import initial_parameter | |||||
| def default_initializer(hidden_size): | def default_initializer(hidden_size): | ||||
| stdv = 1.0 / math.sqrt(hidden_size) | stdv = 1.0 / math.sqrt(hidden_size) | ||||
| @@ -172,7 +173,7 @@ def AutogradVarMaskedStep(num_layers=1, lstm=False): | |||||
| class VarMaskedRNNBase(nn.Module): | class VarMaskedRNNBase(nn.Module): | ||||
| def __init__(self, Cell, input_size, hidden_size, | def __init__(self, Cell, input_size, hidden_size, | ||||
| num_layers=1, bias=True, batch_first=False, | num_layers=1, bias=True, batch_first=False, | ||||
| dropout=(0, 0), bidirectional=False, initializer=None, **kwargs): | |||||
| dropout=(0, 0), bidirectional=False, initializer=None,initial_method = None, **kwargs): | |||||
| super(VarMaskedRNNBase, self).__init__() | super(VarMaskedRNNBase, self).__init__() | ||||
| self.Cell = Cell | self.Cell = Cell | ||||
| @@ -193,7 +194,7 @@ class VarMaskedRNNBase(nn.Module): | |||||
| cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) | cell = self.Cell(layer_input_size, hidden_size, self.bias, p=dropout, initializer=initializer, **kwargs) | ||||
| self.all_cells.append(cell) | self.all_cells.append(cell) | ||||
| self.add_module('cell%d' % (layer * num_directions + direction), cell) | self.add_module('cell%d' % (layer * num_directions + direction), cell) | ||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| for cell in self.all_cells: | for cell in self.all_cells: | ||||
| cell.reset_parameters() | cell.reset_parameters() | ||||
| @@ -284,7 +285,7 @@ class VarFastLSTMCell(VarRNNCellBase): | |||||
| \end{array} | \end{array} | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None): | |||||
| def __init__(self, input_size, hidden_size, bias=True, p=(0.5, 0.5), initializer=None,initial_method =None): | |||||
| super(VarFastLSTMCell, self).__init__() | super(VarFastLSTMCell, self).__init__() | ||||
| self.input_size = input_size | self.input_size = input_size | ||||
| self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
| @@ -311,7 +312,7 @@ class VarFastLSTMCell(VarRNNCellBase): | |||||
| self.p_hidden = p_hidden | self.p_hidden = p_hidden | ||||
| self.noise_in = None | self.noise_in = None | ||||
| self.noise_hidden = None | self.noise_hidden = None | ||||
| initial_parameter(self, initial_method) | |||||
| def reset_parameters(self): | def reset_parameters(self): | ||||
| for weight in self.parameters(): | for weight in self.parameters(): | ||||
| if weight.dim() == 1: | if weight.dim() == 1: | ||||
| @@ -2,8 +2,8 @@ from collections import defaultdict | |||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| import torch.nn.init as init | |||||
| import torch.nn as nn | |||||
| def mask_softmax(matrix, mask): | def mask_softmax(matrix, mask): | ||||
| if mask is None: | if mask is None: | ||||
| result = torch.nn.functional.softmax(matrix, dim=-1) | result = torch.nn.functional.softmax(matrix, dim=-1) | ||||
| @@ -11,6 +11,51 @@ def mask_softmax(matrix, mask): | |||||
| raise NotImplementedError | raise NotImplementedError | ||||
| return result | return result | ||||
| def initial_parameter(net ,initial_method =None): | |||||
| if initial_method == 'xavier_uniform': | |||||
| init_method = init.xavier_uniform_ | |||||
| elif initial_method=='xavier_normal': | |||||
| init_method = init.xavier_normal_ | |||||
| elif initial_method == 'kaiming_normal' or initial_method =='msra': | |||||
| init_method = init.kaiming_normal | |||||
| elif initial_method == 'kaiming_uniform': | |||||
| init_method = init.kaiming_normal | |||||
| elif initial_method == 'orthogonal': | |||||
| init_method = init.orthogonal_ | |||||
| elif initial_method == 'sparse': | |||||
| init_method = init.sparse_ | |||||
| elif initial_method =='normal': | |||||
| init_method = init.normal_ | |||||
| elif initial_method =='uniform': | |||||
| initial_method = init.uniform_ | |||||
| else: | |||||
| init_method = init.xavier_normal_ | |||||
| def weights_init(m): | |||||
| # classname = m.__class__.__name__ | |||||
| if isinstance(m, nn.Conv2d) or isinstance(m,nn.Conv1d) or isinstance(m,nn.Conv3d): # for all the cnn | |||||
| if initial_method != None: | |||||
| init_method(m.weight.data) | |||||
| else: | |||||
| init.xavier_normal_(m.weight.data) | |||||
| init.normal_(m.bias.data) | |||||
| elif isinstance(m, nn.LSTM): | |||||
| for w in m.parameters(): | |||||
| if len(w.data.size())>1: | |||||
| init_method(w.data) # weight | |||||
| else: | |||||
| init.normal_(w.data) # bias | |||||
| elif hasattr(m, 'weight') and m.weight.requires_grad: | |||||
| init_method(m.weight.data) | |||||
| else: | |||||
| for w in m.parameters() : | |||||
| if w.requires_grad: | |||||
| if len(w.data.size())>1: | |||||
| init_method(w.data) # weight | |||||
| else: | |||||
| init.normal_(w.data) # bias | |||||
| # print("init else") | |||||
| net.apply(weights_init) | |||||
| def seq_mask(seq_len, max_len): | def seq_mask(seq_len, max_len): | ||||
| mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | mask = [torch.ge(torch.LongTensor(seq_len), i + 1) for i in range(max_len)] | ||||
| @@ -0,0 +1,13 @@ | |||||
| [train] | |||||
| epochs = 30 | |||||
| batch_size = 32 | |||||
| pickle_path = "./save/" | |||||
| validate = true | |||||
| save_best_dev = true | |||||
| model_saved_path = "./save/" | |||||
| rnn_hidden_units = 300 | |||||
| word_emb_dim = 300 | |||||
| use_crf = true | |||||
| use_cuda = false | |||||
| loss_func = "cross_entropy" | |||||
| num_classes = 5 | |||||
| @@ -0,0 +1,80 @@ | |||||
| import os | |||||
| import torch.nn.functional as F | |||||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader | |||||
| from fastNLP.loader.embed_loader import EmbedLoader as EmbedLoader | |||||
| from fastNLP.loader.config_loader import ConfigSection | |||||
| from fastNLP.loader.config_loader import ConfigLoader | |||||
| from fastNLP.models.base_model import BaseModel | |||||
| from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||||
| from fastNLP.core.trainer import ClassificationTrainer | |||||
| from fastNLP.modules.encoder.embedding import Embedding as Embedding | |||||
| from fastNLP.modules.encoder.lstm import Lstm | |||||
| from fastNLP.modules.aggregation.self_attention import SelfAttention | |||||
| from fastNLP.modules.decoder.MLP import MLP | |||||
| train_data_path = 'small_train_data.txt' | |||||
| dev_data_path = 'small_dev_data.txt' | |||||
| # emb_path = 'glove.txt' | |||||
| lstm_hidden_size = 300 | |||||
| embeding_size = 300 | |||||
| attention_unit = 350 | |||||
| attention_hops = 10 | |||||
| class_num = 5 | |||||
| nfc = 3000 | |||||
| ### data load ### | |||||
| train_dataset = Dataset_loader(train_data_path) | |||||
| train_data = train_dataset.load() | |||||
| dev_args = Dataset_loader(dev_data_path) | |||||
| dev_data = dev_args.load() | |||||
| ###### preprocess #### | |||||
| preprocess = Preprocess() | |||||
| word2index, label2index = preprocess.build_dict(train_data) | |||||
| train_data, dev_data = preprocess.run(train_data, dev_data) | |||||
| # emb = EmbedLoader(emb_path) | |||||
| # embedding = emb.load_embedding(emb_dim= embeding_size , emb_file= emb_path ,word_dict= word2index) | |||||
| ### construct vocab ### | |||||
| class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel): | |||||
| def __init__(self, args=None): | |||||
| super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() | |||||
| self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) | |||||
| self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True) | |||||
| self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) | |||||
| self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ] ,num_class=class_num ,) | |||||
| def forward(self,x): | |||||
| x_emb = self.embedding(x) | |||||
| output = self.lstm(x_emb) | |||||
| after_attention, penalty = self.attention(output,x) | |||||
| after_attention =after_attention.view(after_attention.size(0),-1) | |||||
| output = self.mlp(after_attention) | |||||
| return output | |||||
| def loss(self, predict, ground_truth): | |||||
| print("predict:%s; g:%s" % (str(predict.size()), str(ground_truth.size()))) | |||||
| print(ground_truth) | |||||
| return F.cross_entropy(predict, ground_truth) | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("good path").load_config('config.cfg',{"train": train_args}) | |||||
| train_args['vocab'] = len(word2index) | |||||
| trainer = ClassificationTrainer(**train_args.data) | |||||
| # for k in train_args.__dict__.keys(): | |||||
| # print(k, train_args[k]) | |||||
| model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args) | |||||
| trainer.train(model,train_data , dev_data) | |||||
| @@ -2,18 +2,18 @@ | |||||
| # coding=utf-8 | # coding=utf-8 | ||||
| from setuptools import setup, find_packages | from setuptools import setup, find_packages | ||||
| with open('README.md') as f: | |||||
| with open('README.md', encoding='utf-8') as f: | |||||
| readme = f.read() | readme = f.read() | ||||
| with open('LICENSE') as f: | |||||
| with open('LICENSE', encoding='utf-8') as f: | |||||
| license = f.read() | license = f.read() | ||||
| with open('requirements.txt') as f: | |||||
| with open('requirements.txt', encoding='utf-8') as f: | |||||
| reqs = f.read() | reqs = f.read() | ||||
| setup( | setup( | ||||
| name='fastNLP', | name='fastNLP', | ||||
| version='0.0.1', | |||||
| version='0.0.3', | |||||
| description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
| long_description=readme, | long_description=readme, | ||||
| license=license, | license=license, | ||||
| @@ -1,17 +0,0 @@ | |||||
| import unittest | |||||
| from fastNLP.core.action import Action, Batchifier, SequentialSampler | |||||
| class TestAction(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| x = [1, 2, 3, 4, 5, 6, 7, 8] | |||||
| y = [1, 1, 1, 1, 2, 2, 2, 2] | |||||
| data = [] | |||||
| for i in range(len(x)): | |||||
| data.append([[x[i]], [y[i]]]) | |||||
| data = Batchifier(SequentialSampler(data), batch_size=2, drop_last=False) | |||||
| action = Action() | |||||
| for batch_x in action.make_batch(data, use_cuda=False, output_length=True, max_len=None): | |||||
| print(batch_x) | |||||
| @@ -0,0 +1,62 @@ | |||||
| import unittest | |||||
| import torch | |||||
| from fastNLP.core.batch import Batch | |||||
| from fastNLP.core.dataset import DataSet, create_dataset_from_lists | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | |||||
| raw_texts = ["i am a cat", | |||||
| "this is a test of new batch", | |||||
| "ha ha", | |||||
| "I am a good boy .", | |||||
| "This is the most beautiful girl ." | |||||
| ] | |||||
| texts = [text.strip().split() for text in raw_texts] | |||||
| labels = [0, 1, 0, 0, 1] | |||||
| # prepare vocabulary | |||||
| vocab = {} | |||||
| for text in texts: | |||||
| for tokens in text: | |||||
| if tokens not in vocab: | |||||
| vocab[tokens] = len(vocab) | |||||
| class TestCase1(unittest.TestCase): | |||||
| def test(self): | |||||
| data = DataSet() | |||||
| for text, label in zip(texts, labels): | |||||
| x = TextField(text, is_target=False) | |||||
| y = LabelField(label, is_target=True) | |||||
| ins = Instance(text=x, label=y) | |||||
| data.append(ins) | |||||
| # use vocabulary to index data | |||||
| data.index_field("text", vocab) | |||||
| # define naive sampler for batch class | |||||
| class SeqSampler: | |||||
| def __call__(self, dataset): | |||||
| return list(range(len(dataset))) | |||||
| # use batch to iterate dataset | |||||
| data_iterator = Batch(data, 2, SeqSampler(), False) | |||||
| for batch_x, batch_y in data_iterator: | |||||
| self.assertEqual(len(batch_x), 2) | |||||
| self.assertTrue(isinstance(batch_x, dict)) | |||||
| self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | |||||
| self.assertTrue(isinstance(batch_y, dict)) | |||||
| self.assertTrue(isinstance(batch_y["label"], torch.LongTensor)) | |||||
| class TestCase2(unittest.TestCase): | |||||
| def test(self): | |||||
| data = DataSet() | |||||
| for text in texts: | |||||
| x = TextField(text, is_target=False) | |||||
| ins = Instance(text=x) | |||||
| data.append(ins) | |||||
| data_set = create_dataset_from_lists(texts, vocab, has_target=False) | |||||
| self.assertTrue(type(data) == type(data_set)) | |||||
| @@ -0,0 +1,51 @@ | |||||
| import os | |||||
| import unittest | |||||
| from fastNLP.core.predictor import Predictor | |||||
| from fastNLP.core.preprocess import save_pickle | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| class TestPredictor(unittest.TestCase): | |||||
| def test_seq_label(self): | |||||
| model_args = { | |||||
| "vocab_size": 10, | |||||
| "word_emb_dim": 100, | |||||
| "rnn_hidden_units": 100, | |||||
| "num_classes": 5 | |||||
| } | |||||
| infer_data = [ | |||||
| ['a', 'b', 'c', 'd', 'e'], | |||||
| ['a', '@', 'c', 'd', 'e'], | |||||
| ['a', 'b', '#', 'd', 'e'], | |||||
| ['a', 'b', 'c', '?', 'e'], | |||||
| ['a', 'b', 'c', 'd', '$'], | |||||
| ['!', 'b', 'c', 'd', 'e'] | |||||
| ] | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| os.system("mkdir save") | |||||
| save_pickle({0: "0", 1: "1", 2: "2", 3: "3", 4: "4"}, "./save/", "id2class.pkl") | |||||
| save_pickle(vocab, "./save/", "word2id.pkl") | |||||
| model = SeqLabeling(model_args) | |||||
| predictor = Predictor("./save/", task="seq_label") | |||||
| results = predictor.predict(network=model, data=infer_data) | |||||
| self.assertTrue(isinstance(results, list)) | |||||
| self.assertGreater(len(results), 0) | |||||
| for res in results: | |||||
| self.assertTrue(isinstance(res, list)) | |||||
| self.assertEqual(len(res), 5) | |||||
| self.assertTrue(isinstance(res[0], str)) | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| class TestPredictor2(unittest.TestCase): | |||||
| def test_text_classify(self): | |||||
| # TODO | |||||
| pass | |||||
| @@ -1,24 +1,25 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.preprocess import SeqLabelPreprocess | from fastNLP.core.preprocess import SeqLabelPreprocess | ||||
| data = [ | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| ] | |||||
| class TestSeqLabelPreprocess(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| data = [ | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| [['Tom', 'and', 'Jerry', '.'], ['n', '&', 'n', '.']], | |||||
| [['Hello', 'world', '!'], ['a', 'n', '.']], | |||||
| ] | |||||
| class TestCase1(unittest.TestCase): | |||||
| def test(self): | |||||
| if os.path.exists("./save"): | if os.path.exists("./save"): | ||||
| for root, dirs, files in os.walk("./save", topdown=False): | for root, dirs, files in os.walk("./save", topdown=False): | ||||
| for name in files: | for name in files: | ||||
| @@ -27,17 +28,45 @@ class TestSeqLabelPreprocess(unittest.TestCase): | |||||
| os.rmdir(os.path.join(root, name)) | os.rmdir(os.path.join(root, name)) | ||||
| result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | ||||
| pickle_path="./save") | pickle_path="./save") | ||||
| result = SeqLabelPreprocess().run(train_dev_data=data, train_dev_split=0.4, | |||||
| pickle_path="./save") | |||||
| self.assertEqual(len(result), 2) | |||||
| self.assertEqual(type(result[0]), DataSet) | |||||
| self.assertEqual(type(result[1]), DataSet) | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| class TestCase2(unittest.TestCase): | |||||
| def test(self): | |||||
| if os.path.exists("./save"): | if os.path.exists("./save"): | ||||
| for root, dirs, files in os.walk("./save", topdown=False): | for root, dirs, files in os.walk("./save", topdown=False): | ||||
| for name in files: | for name in files: | ||||
| os.remove(os.path.join(root, name)) | os.remove(os.path.join(root, name)) | ||||
| for name in dirs: | for name in dirs: | ||||
| os.rmdir(os.path.join(root, name)) | os.rmdir(os.path.join(root, name)) | ||||
| result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | |||||
| pickle_path="./save", train_dev_split=0.4, | |||||
| cross_val=True) | |||||
| result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | result = SeqLabelPreprocess().run(test_data=data, train_dev_data=data, | ||||
| pickle_path="./save", train_dev_split=0.4, | pickle_path="./save", train_dev_split=0.4, | ||||
| cross_val=True) | |||||
| cross_val=False) | |||||
| self.assertEqual(len(result), 3) | |||||
| self.assertEqual(type(result[0]), DataSet) | |||||
| self.assertEqual(type(result[1]), DataSet) | |||||
| self.assertEqual(type(result[2]), DataSet) | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| class TestCase3(unittest.TestCase): | |||||
| def test(self): | |||||
| num_folds = 2 | |||||
| result = SeqLabelPreprocess().run(test_data=None, train_dev_data=data, | |||||
| pickle_path="./save", train_dev_split=0.4, | |||||
| cross_val=True, n_fold=num_folds) | |||||
| self.assertEqual(len(result), 2) | |||||
| self.assertEqual(len(result[0]), num_folds) | |||||
| self.assertEqual(len(result[1]), num_folds) | |||||
| for data_set in result[0] + result[1]: | |||||
| self.assertEqual(type(data_set), DataSet) | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| @@ -1,37 +1,55 @@ | |||||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
| import os | |||||
| import unittest | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
| pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
| def foo(): | |||||
| loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8") | |||||
| train_data = loader.load_pku() | |||||
| train_args = ConfigSection() | |||||
| ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
| # Preprocessor | |||||
| p = SeqLabelPreprocess() | |||||
| train_data = p.run(train_data) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| model = SeqLabeling(train_args) | |||||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
| "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | |||||
| "use_cuda": True} | |||||
| validator = SeqLabelTester(**valid_args) | |||||
| print("start validation.") | |||||
| validator.test(model, train_data) | |||||
| print(validator.show_metrics()) | |||||
| if __name__ == "__main__": | |||||
| foo() | |||||
| class TestTester(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| model_args = { | |||||
| "vocab_size": 10, | |||||
| "word_emb_dim": 100, | |||||
| "rnn_hidden_units": 100, | |||||
| "num_classes": 5 | |||||
| } | |||||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
| "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | |||||
| "use_cuda": False, "print_every_step": 1} | |||||
| train_data = [ | |||||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| ] | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
| data_set = DataSet() | |||||
| for example in train_data: | |||||
| text, label = example[0], example[1] | |||||
| x = TextField(text, False) | |||||
| y = TextField(label, is_target=True) | |||||
| ins = Instance(word_seq=x, label_seq=y) | |||||
| data_set.append(ins) | |||||
| data_set.index_field("word_seq", vocab) | |||||
| data_set.index_field("label_seq", label_vocab) | |||||
| model = SeqLabeling(model_args) | |||||
| tester = SeqLabelTester(**valid_args) | |||||
| tester.test(network=model, dev_data=data_set) | |||||
| # If this can run, everything is OK. | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| @@ -1,33 +1,54 @@ | |||||
| import os | import os | ||||
| import torch.nn as nn | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | |||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField | |||||
| from fastNLP.core.instance import Instance | |||||
| from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | |||||
| args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/", | |||||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | "save_best_dev": True, "model_name": "default_model_name.pkl", | ||||
| "loss": Loss(None), | "loss": Loss(None), | ||||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
| "vocab_size": 20, | |||||
| "vocab_size": 10, | |||||
| "word_emb_dim": 100, | "word_emb_dim": 100, | ||||
| "rnn_hidden_units": 100, | "rnn_hidden_units": 100, | ||||
| "num_classes": 3 | |||||
| "num_classes": 5 | |||||
| } | } | ||||
| trainer = SeqLabelTrainer() | |||||
| trainer = SeqLabelTrainer(**args) | |||||
| train_data = [ | train_data = [ | ||||
| [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]], | |||||
| [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]], | |||||
| [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]], | |||||
| [[1, 2, 3, 4, 5, 6], [1, 0, 1, 0, 1, 2]], | |||||
| [[2, 3, 4, 5, 1, 6], [0, 1, 0, 1, 0, 2]], | |||||
| [[1, 4, 1, 4, 1, 6], [1, 0, 1, 0, 1, 2]], | |||||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', '@', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', '#', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', '?', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['a', 'b', 'c', 'd', '$'], ['a', '@', 'c', 'd', 'e']], | |||||
| [['!', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | |||||
| ] | ] | ||||
| dev_data = train_data | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | |||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | |||||
| data_set = DataSet() | |||||
| for example in train_data: | |||||
| text, label = example[0], example[1] | |||||
| x = TextField(text, False) | |||||
| y = TextField(label, is_target=True) | |||||
| ins = Instance(word_seq=x, label_seq=y) | |||||
| data_set.append(ins) | |||||
| data_set.index_field("word_seq", vocab) | |||||
| data_set.index_field("label_seq", label_vocab) | |||||
| model = SeqLabeling(args) | model = SeqLabeling(args) | ||||
| trainer.train(network=model, train_data=train_data, dev_data=dev_data) | |||||
| trainer.train(network=model, train_data=data_set, dev_data=data_set) | |||||
| # If this can run, everything is OK. | |||||
| os.system("rm -rf save") | |||||
| print("pickle path deleted") | |||||
| @@ -15,11 +15,11 @@ from fastNLP.core.optimizer import Optimizer | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | ||||
| parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt", | |||||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||||
| help="path to the training data") | help="path to the training data") | ||||
| parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | ||||
| parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt", | |||||
| parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||||
| help="data used for inference") | help="data used for inference") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| @@ -86,7 +86,7 @@ def train_and_test(): | |||||
| trainer = SeqLabelTrainer( | trainer = SeqLabelTrainer( | ||||
| epochs=trainer_args["epochs"], | epochs=trainer_args["epochs"], | ||||
| batch_size=trainer_args["batch_size"], | batch_size=trainer_args["batch_size"], | ||||
| validate=trainer_args["validate"], | |||||
| validate=False, | |||||
| use_cuda=trainer_args["use_cuda"], | use_cuda=trainer_args["use_cuda"], | ||||
| pickle_path=pickle_path, | pickle_path=pickle_path, | ||||
| save_best_dev=trainer_args["save_best_dev"], | save_best_dev=trainer_args["save_best_dev"], | ||||
| @@ -121,7 +121,7 @@ def train_and_test(): | |||||
| # Tester | # Tester | ||||
| tester = SeqLabelTester(save_output=False, | tester = SeqLabelTester(save_output=False, | ||||
| save_loss=False, | |||||
| save_loss=True, | |||||
| save_best_dev=False, | save_best_dev=False, | ||||
| batch_size=4, | batch_size=4, | ||||
| use_cuda=False, | use_cuda=False, | ||||
| @@ -139,5 +139,5 @@ def train_and_test(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| # train_and_test() | |||||
| infer() | |||||
| train_and_test() | |||||
| # infer() | |||||
| @@ -1,8 +0,0 @@ | |||||
| def test_charlm(): | |||||
| pass | |||||
| if __name__ == "__main__": | |||||
| test_charlm() | |||||
| @@ -0,0 +1,85 @@ | |||||
| import os | |||||
| from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
| from fastNLP.core.tester import SeqLabelTester | |||||
| from fastNLP.core.trainer import SeqLabelTrainer | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader | |||||
| from fastNLP.loader.model_loader import ModelLoader | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.saver.model_saver import ModelSaver | |||||
| pickle_path = "./seq_label/" | |||||
| model_name = "seq_label_model.pkl" | |||||
| config_dir = "test/data_for_tests/config" | |||||
| data_path = "test/data_for_tests/people.txt" | |||||
| data_infer_path = "test/data_for_tests/people_infer.txt" | |||||
| def test_training(): | |||||
| # Config Loader | |||||
| trainer_args = ConfigSection() | |||||
| model_args = ConfigSection() | |||||
| ConfigLoader("_").load_config(config_dir, { | |||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | |||||
| # Data Loader | |||||
| pos_loader = POSDatasetLoader(data_path) | |||||
| train_data = pos_loader.load_lines() | |||||
| # Preprocessor | |||||
| p = SeqLabelPreprocess() | |||||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
| model_args["vocab_size"] = p.vocab_size | |||||
| model_args["num_classes"] = p.num_classes | |||||
| trainer = SeqLabelTrainer( | |||||
| epochs=trainer_args["epochs"], | |||||
| batch_size=trainer_args["batch_size"], | |||||
| validate=False, | |||||
| use_cuda=False, | |||||
| pickle_path=pickle_path, | |||||
| save_best_dev=trainer_args["save_best_dev"], | |||||
| model_name=model_name, | |||||
| optimizer=Optimizer("SGD", lr=0.01, momentum=0.9), | |||||
| ) | |||||
| # Model | |||||
| model = SeqLabeling(model_args) | |||||
| # Start training | |||||
| trainer.train(model, data_train, data_dev) | |||||
| # Saver | |||||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | |||||
| saver.save_pytorch(model) | |||||
| del model, trainer, pos_loader | |||||
| # Define the same model | |||||
| model = SeqLabeling(model_args) | |||||
| # Dump trained parameters into the model | |||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | |||||
| # Load test configuration | |||||
| tester_args = ConfigSection() | |||||
| ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| # Tester | |||||
| tester = SeqLabelTester(save_output=False, | |||||
| save_loss=True, | |||||
| save_best_dev=False, | |||||
| batch_size=4, | |||||
| use_cuda=False, | |||||
| pickle_path=pickle_path, | |||||
| model_name="seq_label_in_test.pkl", | |||||
| print_every_step=1 | |||||
| ) | |||||
| # Start testing with validation data | |||||
| tester.test(model, data_dev) | |||||
| loss, accuracy = tester.metrics | |||||
| assert 0 < accuracy < 1 | |||||
| @@ -19,9 +19,9 @@ from fastNLP.core.loss import Loss | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | ||||
| parser.add_argument("-t", "--train", type=str, default="./data_for_tests/text_classify.txt", | |||||
| parser.add_argument("-t", "--train", type=str, default="../data_for_tests/text_classify.txt", | |||||
| help="path to the training data") | help="path to the training data") | ||||
| parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||||
| parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | parser.add_argument("-m", "--model_name", type=str, default="classify_model.pkl", help="the name of the model") | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| @@ -115,4 +115,4 @@ def train(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| train() | train() | ||||
| infer() | |||||
| # infer() | |||||