Merge Preprocessor into DataSet.tags/v0.1.0
| @@ -5,7 +5,6 @@ python: | |||||
| install: | install: | ||||
| - pip install --quiet -r requirements.txt | - pip install --quiet -r requirements.txt | ||||
| - pip install pytest pytest-cov | - pip install pytest pytest-cov | ||||
| - pip install -U scikit-learn | |||||
| # command to run tests | # command to run tests | ||||
| script: | script: | ||||
| - pytest --cov=./ | - pytest --cov=./ | ||||
| @@ -30,77 +30,36 @@ Run the following commands to install fastNLP package. | |||||
| pip install fastNLP | pip install fastNLP | ||||
| ``` | ``` | ||||
| ### Cloning From GitHub | |||||
| If you just want to use fastNLP, use: | |||||
| ```shell | |||||
| git clone https://github.com/fastnlp/fastNLP | |||||
| cd fastNLP | |||||
| ``` | |||||
| ### PyTorch Installation | |||||
| Visit the [PyTorch official website] for installation instructions based on your system. In general, you could use: | |||||
| ```shell | |||||
| # using conda | |||||
| conda install pytorch torchvision -c pytorch | |||||
| # or using pip | |||||
| pip3 install torch torchvision | |||||
| ``` | |||||
| ### TensorboardX Installation | |||||
| ```shell | |||||
| pip3 install tensorboardX | |||||
| ``` | |||||
| ## Project Structure | ## Project Structure | ||||
| ``` | |||||
| FastNLP | |||||
| ├── docs | |||||
| ├── fastNLP | |||||
| │ ├── core | |||||
| │ │ ├── action.py | |||||
| │ │ ├── __init__.py | |||||
| │ │ ├── loss.py | |||||
| │ │ ├── metrics.py | |||||
| │ │ ├── optimizer.py | |||||
| │ │ ├── predictor.py | |||||
| │ │ ├── preprocess.py | |||||
| │ │ ├── README.md | |||||
| │ │ ├── tester.py | |||||
| │ │ └── trainer.py | |||||
| │ ├── fastnlp.py | |||||
| │ ├── __init__.py | |||||
| │ ├── loader | |||||
| │ │ ├── base_loader.py | |||||
| │ │ ├── config_loader.py | |||||
| │ │ ├── dataset_loader.py | |||||
| │ │ ├── embed_loader.py | |||||
| │ │ ├── __init__.py | |||||
| │ │ └── model_loader.py | |||||
| │ ├── models | |||||
| │ ├── modules | |||||
| │ │ ├── aggregation | |||||
| │ │ ├── decoder | |||||
| │ │ ├── encoder | |||||
| │ │ ├── __init__.py | |||||
| │ │ ├── interaction | |||||
| │ │ ├── other_modules.py | |||||
| │ │ └── utils.py | |||||
| │ └── saver | |||||
| ├── LICENSE | |||||
| ├── README.md | |||||
| ├── reproduction | |||||
| ├── requirements.txt | |||||
| ├── setup.py | |||||
| └── test | |||||
| ├── core | |||||
| ├── data_for_tests | |||||
| ├── __init__.py | |||||
| ├── loader | |||||
| ├── modules | |||||
| └── readme_example.py | |||||
| ``` | |||||
| <table> | |||||
| <tr> | |||||
| <td><b> fastNLP </b></td> | |||||
| <td> an open-source NLP library </td> | |||||
| </tr> | |||||
| <tr> | |||||
| <td><b> fastNLP.core </b></td> | |||||
| <td> trainer, tester, predictor </td> | |||||
| </tr> | |||||
| <tr> | |||||
| <td><b> fastNLP.loader </b></td> | |||||
| <td> all kinds of loaders/readers </td> | |||||
| </tr> | |||||
| <tr> | |||||
| <td><b> fastNLP.models </b></td> | |||||
| <td> a collection of NLP models </td> | |||||
| </tr> | |||||
| <tr> | |||||
| <td><b> fastNLP.modules </b></td> | |||||
| <td> a collection of PyTorch sub-models/components/wheels </td> | |||||
| </tr> | |||||
| <tr> | |||||
| <td><b> fastNLP.saver </b></td> | |||||
| <td> all kinds of savers/writers </td> | |||||
| </tr> | |||||
| <tr> | |||||
| <td><b> fastNLP.fastnlp </b></td> | |||||
| <td> a high-level interface for prediction </td> | |||||
| </tr> | |||||
| </table> | |||||
| @@ -18,7 +18,7 @@ pre-processing data, constructing model and training model. | |||||
| from fastNLP.modules import aggregation | from fastNLP.modules import aggregation | ||||
| from fastNLP.modules import decoder | from fastNLP.modules import decoder | ||||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
| from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
| from fastNLP.loader.preprocess import ClassPreprocess | from fastNLP.loader.preprocess import ClassPreprocess | ||||
| from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
| from fastNLP.core.inference import ClassificationInfer | from fastNLP.core.inference import ClassificationInfer | ||||
| @@ -50,7 +50,7 @@ pre-processing data, constructing model and training model. | |||||
| train_path = 'test/data_for_tests/text_classify.txt' # training set file | train_path = 'test/data_for_tests/text_classify.txt' # training set file | ||||
| # load dataset | # load dataset | ||||
| ds_loader = ClassDatasetLoader("train", train_path) | |||||
| ds_loader = ClassDataSetLoader("train", train_path) | |||||
| data = ds_loader.load() | data = ds_loader.load() | ||||
| # pre-process dataset | # pre-process dataset | ||||
| @@ -3,7 +3,7 @@ from fastNLP.core.optimizer import Optimizer | |||||
| from fastNLP.core.predictor import ClassificationInfer | from fastNLP.core.predictor import ClassificationInfer | ||||
| from fastNLP.core.preprocess import ClassPreprocess | from fastNLP.core.preprocess import ClassPreprocess | ||||
| from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
| from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.modules import aggregator | from fastNLP.modules import aggregator | ||||
| from fastNLP.modules import decoder | from fastNLP.modules import decoder | ||||
| @@ -36,7 +36,7 @@ data_dir = 'save/' # directory to save data and model | |||||
| train_path = './data_for_tests/text_classify.txt' # training set file | train_path = './data_for_tests/text_classify.txt' # training set file | ||||
| # load dataset | # load dataset | ||||
| ds_loader = ClassDatasetLoader(train_path) | |||||
| ds_loader = ClassDataSetLoader() | |||||
| data = ds_loader.load() | data = ds_loader.load() | ||||
| # pre-process dataset | # pre-process dataset | ||||
| @@ -17,7 +17,7 @@ class Batch(object): | |||||
| :param dataset: a DataSet object | :param dataset: a DataSet object | ||||
| :param batch_size: int, the size of the batch | :param batch_size: int, the size of the batch | ||||
| :param sampler: a Sampler object | :param sampler: a Sampler object | ||||
| :param use_cuda: bool, whetjher to use GPU | |||||
| :param use_cuda: bool, whether to use GPU | |||||
| """ | """ | ||||
| self.dataset = dataset | self.dataset = dataset | ||||
| @@ -37,15 +37,12 @@ class Batch(object): | |||||
| """ | """ | ||||
| :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | ||||
| batch_x also contains an item (str: list of int) about origin lengths, | |||||
| which means ("field_name_origin_len": origin lengths). | |||||
| E.g. | E.g. | ||||
| :: | :: | ||||
| {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | ||||
| batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | ||||
| All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | ||||
| The names of fields are defined in preprocessor's convert_to_dataset method. | |||||
| """ | """ | ||||
| if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
| @@ -54,10 +51,9 @@ class Batch(object): | |||||
| endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | ||||
| padding_length = {field_name: max(field_length[self.curidx: endidx]) | padding_length = {field_name: max(field_length[self.curidx: endidx]) | ||||
| for field_name, field_length in self.lengths.items()} | for field_name, field_length in self.lengths.items()} | ||||
| origin_lengths = {field_name: field_length[self.curidx: endidx] | |||||
| for field_name, field_length in self.lengths.items()} | |||||
| batch_x, batch_y = defaultdict(list), defaultdict(list) | batch_x, batch_y = defaultdict(list), defaultdict(list) | ||||
| # transform index to tensor and do padding for sequences | |||||
| for idx in range(self.curidx, endidx): | for idx in range(self.curidx, endidx): | ||||
| x, y = self.dataset.to_tensor(idx, padding_length) | x, y = self.dataset.to_tensor(idx, padding_length) | ||||
| for name, tensor in x.items(): | for name, tensor in x.items(): | ||||
| @@ -65,8 +61,7 @@ class Batch(object): | |||||
| for name, tensor in y.items(): | for name, tensor in y.items(): | ||||
| batch_y[name].append(tensor) | batch_y[name].append(tensor) | ||||
| batch_origin_length = {} | |||||
| # combine instances into a batch | |||||
| # combine instances to form a batch | |||||
| for batch in (batch_x, batch_y): | for batch in (batch_x, batch_y): | ||||
| for name, tensor_list in batch.items(): | for name, tensor_list in batch.items(): | ||||
| if self.use_cuda: | if self.use_cuda: | ||||
| @@ -74,14 +69,6 @@ class Batch(object): | |||||
| else: | else: | ||||
| batch[name] = torch.stack(tensor_list, dim=0) | batch[name] = torch.stack(tensor_list, dim=0) | ||||
| # add origin lengths in batch_x | |||||
| for name, tensor in batch_x.items(): | |||||
| if self.use_cuda: | |||||
| batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda() | |||||
| else: | |||||
| batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]) | |||||
| batch_x.update(batch_origin_length) | |||||
| self.curidx = endidx | self.curidx = endidx | ||||
| return batch_x, batch_y | return batch_x, batch_y | ||||
| @@ -1,7 +1,12 @@ | |||||
| import random | |||||
| import sys | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from copy import deepcopy | |||||
| from fastNLP.core.field import TextField | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.vocabulary import Vocabulary | |||||
| from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader | |||||
| def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None): | ||||
| @@ -65,17 +70,19 @@ class DataSet(list): | |||||
| """A DataSet object is a list of Instance objects. | """A DataSet object is a list of Instance objects. | ||||
| """ | """ | ||||
| def __init__(self, name="", instances=None): | |||||
| def __init__(self, name="", instances=None, load_func=None): | |||||
| """ | """ | ||||
| :param name: str, the name of the dataset. (default: "") | :param name: str, the name of the dataset. (default: "") | ||||
| :param instances: list of Instance objects. (default: None) | :param instances: list of Instance objects. (default: None) | ||||
| :param load_func: a function that takes the dataset path (string) as input and returns multi-level lists. | |||||
| """ | """ | ||||
| list.__init__([]) | list.__init__([]) | ||||
| self.name = name | self.name = name | ||||
| if instances is not None: | if instances is not None: | ||||
| self.extend(instances) | self.extend(instances) | ||||
| self.data_set_load_func = load_func | |||||
| def index_all(self, vocab): | def index_all(self, vocab): | ||||
| for ins in self: | for ins in self: | ||||
| @@ -109,3 +116,191 @@ class DataSet(list): | |||||
| for field_name, field_length in ins.get_length().items(): | for field_name, field_length in ins.get_length().items(): | ||||
| lengths[field_name].append(field_length) | lengths[field_name].append(field_length) | ||||
| return lengths | return lengths | ||||
| def convert(self, data): | |||||
| """Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training.""" | |||||
| raise NotImplementedError | |||||
| def convert_with_vocabs(self, data, vocabs): | |||||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing.""" | |||||
| raise NotImplementedError | |||||
| def convert_for_infer(self, data, vocabs): | |||||
| """Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting.""" | |||||
| def load(self, data_path, vocabs=None, infer=False): | |||||
| """Load data from the given files. | |||||
| :param data_path: str, the path to the data | |||||
| :param infer: bool. If True, there is no label information in the data. Default: False. | |||||
| :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed. | |||||
| """ | |||||
| raw_data = self.data_set_load_func(data_path) | |||||
| if infer is True: | |||||
| self.convert_for_infer(raw_data, vocabs) | |||||
| else: | |||||
| if vocabs is not None: | |||||
| self.convert_with_vocabs(raw_data, vocabs) | |||||
| else: | |||||
| self.convert(raw_data) | |||||
| def load_raw(self, raw_data, vocabs): | |||||
| """Load raw data without loader. Used in FastNLP class. | |||||
| :param raw_data: | |||||
| :param vocabs: | |||||
| :return: | |||||
| """ | |||||
| self.convert_for_infer(raw_data, vocabs) | |||||
| def split(self, ratio, shuffle=True): | |||||
| """Train/dev splitting | |||||
| :param ratio: float, between 0 and 1. The ratio of development set in origin data set. | |||||
| :param shuffle: bool, whether shuffle the data set before splitting. Default: True. | |||||
| :return train_set: a DataSet object, representing the training set | |||||
| dev_set: a DataSet object, representing the validation set | |||||
| """ | |||||
| assert 0 < ratio < 1 | |||||
| if shuffle: | |||||
| random.shuffle(self) | |||||
| split_idx = int(len(self) * ratio) | |||||
| dev_set = deepcopy(self) | |||||
| train_set = deepcopy(self) | |||||
| del train_set[:split_idx] | |||||
| del dev_set[split_idx:] | |||||
| return train_set, dev_set | |||||
| class SeqLabelDataSet(DataSet): | |||||
| def __init__(self, instances=None, load_func=POSDataSetLoader().load): | |||||
| super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||||
| self.word_vocab = Vocabulary() | |||||
| self.label_vocab = Vocabulary() | |||||
| def convert(self, data): | |||||
| """Convert lists of strings into Instances with Fields. | |||||
| :param data: 3-level lists. Entries are strings. | |||||
| """ | |||||
| bar = ProgressBar(total=len(data)) | |||||
| for example in data: | |||||
| word_seq, label_seq = example[0], example[1] | |||||
| # list, list | |||||
| self.word_vocab.update(word_seq) | |||||
| self.label_vocab.update(label_seq) | |||||
| x = TextField(word_seq, is_target=False) | |||||
| x_len = LabelField(len(word_seq), is_target=False) | |||||
| y = TextField(label_seq, is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("truth", y) | |||||
| instance.add_field("word_seq_origin_len", x_len) | |||||
| self.append(instance) | |||||
| bar.move() | |||||
| self.index_field("word_seq", self.word_vocab) | |||||
| self.index_field("truth", self.label_vocab) | |||||
| # no need to index "word_seq_origin_len" | |||||
| def convert_with_vocabs(self, data, vocabs): | |||||
| for example in data: | |||||
| word_seq, label_seq = example[0], example[1] | |||||
| # list, list | |||||
| x = TextField(word_seq, is_target=False) | |||||
| x_len = LabelField(len(word_seq), is_target=False) | |||||
| y = TextField(label_seq, is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("truth", y) | |||||
| instance.add_field("word_seq_origin_len", x_len) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| self.index_field("truth", vocabs["label_vocab"]) | |||||
| # no need to index "word_seq_origin_len" | |||||
| def convert_for_infer(self, data, vocabs): | |||||
| for word_seq in data: | |||||
| # list | |||||
| x = TextField(word_seq, is_target=False) | |||||
| x_len = LabelField(len(word_seq), is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("word_seq_origin_len", x_len) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| # no need to index "word_seq_origin_len" | |||||
| class TextClassifyDataSet(DataSet): | |||||
| def __init__(self, instances=None, load_func=ClassDataSetLoader().load): | |||||
| super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func) | |||||
| self.word_vocab = Vocabulary() | |||||
| self.label_vocab = Vocabulary(need_default=False) | |||||
| def convert(self, data): | |||||
| for example in data: | |||||
| word_seq, label = example[0], example[1] | |||||
| # list, str | |||||
| self.word_vocab.update(word_seq) | |||||
| self.label_vocab.update(label) | |||||
| x = TextField(word_seq, is_target=False) | |||||
| y = LabelField(label, is_target=True) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("label", y) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", self.word_vocab) | |||||
| self.index_field("label", self.label_vocab) | |||||
| def convert_with_vocabs(self, data, vocabs): | |||||
| for example in data: | |||||
| word_seq, label = example[0], example[1] | |||||
| # list, str | |||||
| x = TextField(word_seq, is_target=False) | |||||
| y = LabelField(label, is_target=True) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| instance.add_field("label", y) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| self.index_field("label", vocabs["label_vocab"]) | |||||
| def convert_for_infer(self, data, vocabs): | |||||
| for word_seq in data: | |||||
| # list | |||||
| x = TextField(word_seq, is_target=False) | |||||
| instance = Instance() | |||||
| instance.add_field("word_seq", x) | |||||
| self.append(instance) | |||||
| self.index_field("word_seq", vocabs["word_vocab"]) | |||||
| def change_field_is_target(data_set, field_name, new_target): | |||||
| """Change the flag of is_target in a field. | |||||
| :param data_set: a DataSet object | |||||
| :param field_name: str, the name of the field | |||||
| :param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither. | |||||
| """ | |||||
| for inst in data_set: | |||||
| inst.fields[field_name].is_target = new_target | |||||
| class ProgressBar: | |||||
| def __init__(self, count=0, total=0, width=100): | |||||
| self.count = count | |||||
| self.total = total | |||||
| self.width = width | |||||
| def move(self): | |||||
| self.count += 1 | |||||
| progress = self.width * self.count // self.total | |||||
| sys.stdout.write('{0:3}/{1:3}: '.format(self.count, self.total)) | |||||
| sys.stdout.write('#' * progress + '-' * (self.width - progress) + '\r') | |||||
| if progress == self.width: | |||||
| sys.stdout.write('\n') | |||||
| sys.stdout.flush() | |||||
| @@ -59,6 +59,9 @@ class TextField(Field): | |||||
| class LabelField(Field): | class LabelField(Field): | ||||
| """The Field representing a single label. Can be a string or integer. | |||||
| """ | |||||
| def __init__(self, label, is_target=True): | def __init__(self, label, is_target=True): | ||||
| super(LabelField, self).__init__(is_target) | super(LabelField, self).__init__(is_target) | ||||
| self.label = label | self.label = label | ||||
| @@ -73,13 +76,14 @@ class LabelField(Field): | |||||
| def index(self, vocab): | def index(self, vocab): | ||||
| if self._index is None: | if self._index is None: | ||||
| self._index = vocab[self.label] | |||||
| if isinstance(self.label, str): | |||||
| self._index = vocab[self.label] | |||||
| return self._index | return self._index | ||||
| def to_tensor(self, padding_length): | def to_tensor(self, padding_length): | ||||
| if self._index is None: | if self._index is None: | ||||
| if isinstance(self.label, int): | if isinstance(self.label, int): | ||||
| return torch.LongTensor([self.label]) | |||||
| return torch.tensor(self.label) | |||||
| elif isinstance(self.label, str): | elif isinstance(self.label, str): | ||||
| raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) | ||||
| else: | else: | ||||
| @@ -46,8 +46,11 @@ class Instance(object): | |||||
| tensor_x = {} | tensor_x = {} | ||||
| tensor_y = {} | tensor_y = {} | ||||
| for name, field in self.fields.items(): | for name, field in self.fields.items(): | ||||
| if field.is_target: | |||||
| if field.is_target is True: | |||||
| tensor_y[name] = field.to_tensor(padding_length[name]) | tensor_y[name] = field.to_tensor(padding_length[name]) | ||||
| else: | |||||
| elif field.is_target is False: | |||||
| tensor_x[name] = field.to_tensor(padding_length[name]) | tensor_x[name] = field.to_tensor(padding_length[name]) | ||||
| else: | |||||
| # is_target is None | |||||
| continue | |||||
| return tensor_x, tensor_y | return tensor_x, tensor_y | ||||
| @@ -33,10 +33,25 @@ class Loss(object): | |||||
| """Given a name of a loss function, return it from PyTorch. | """Given a name of a loss function, return it from PyTorch. | ||||
| :param loss_name: str, the name of a loss function | :param loss_name: str, the name of a loss function | ||||
| - cross_entropy: combines log softmax and nll loss in a single function. | |||||
| - nll: negative log likelihood | |||||
| :return loss: a PyTorch loss | :return loss: a PyTorch loss | ||||
| """ | """ | ||||
| class InnerCrossEntropy: | |||||
| """A simple wrapper to guarantee input shapes.""" | |||||
| def __init__(self): | |||||
| self.f = torch.nn.CrossEntropyLoss() | |||||
| def __call__(self, predict, truth): | |||||
| truth = truth.view(-1, ) | |||||
| return self.f(predict, truth) | |||||
| if loss_name == "cross_entropy": | if loss_name == "cross_entropy": | ||||
| return torch.nn.CrossEntropyLoss() | |||||
| return InnerCrossEntropy() | |||||
| elif loss_name == 'nll': | elif loss_name == 'nll': | ||||
| return torch.nn.NLLLoss() | return torch.nn.NLLLoss() | ||||
| else: | else: | ||||
| @@ -4,6 +4,59 @@ import numpy as np | |||||
| import torch | import torch | ||||
| class Evaluator(object): | |||||
| def __init__(self): | |||||
| pass | |||||
| def __call__(self, predict, truth): | |||||
| """ | |||||
| :param predict: list of tensors, the network outputs from all batches. | |||||
| :param truth: list of dict, the ground truths from all batch_y. | |||||
| :return: | |||||
| """ | |||||
| raise NotImplementedError | |||||
| class ClassifyEvaluator(Evaluator): | |||||
| def __init__(self): | |||||
| super(ClassifyEvaluator, self).__init__() | |||||
| def __call__(self, predict, truth): | |||||
| y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] | |||||
| y_prob = torch.cat(y_prob, dim=0) | |||||
| y_pred = torch.argmax(y_prob, dim=-1) | |||||
| y_true = torch.cat(truth, dim=0) | |||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
| return {"accuracy": acc} | |||||
| class SeqLabelEvaluator(Evaluator): | |||||
| def __init__(self): | |||||
| super(SeqLabelEvaluator, self).__init__() | |||||
| def __call__(self, predict, truth): | |||||
| """ | |||||
| :param predict: list of List, the network outputs from all batches. | |||||
| :param truth: list of dict, the ground truths from all batch_y. | |||||
| :return accuracy: | |||||
| """ | |||||
| truth = [item["truth"] for item in truth] | |||||
| total_correct, total_count= 0., 0. | |||||
| for x, y in zip(predict, truth): | |||||
| x = torch.Tensor(x) | |||||
| y = y.to(x) # make sure they are in the same device | |||||
| mask = x.ge(1).float() | |||||
| # correct = torch.sum(x * mask.float() == (y * mask.long()).float()) | |||||
| correct = torch.sum(x * mask == y * mask) | |||||
| correct -= torch.sum(x.le(0)) | |||||
| total_correct += float(correct) | |||||
| total_count += float(torch.sum(mask)) | |||||
| accuracy = total_correct / total_count | |||||
| return {"accuracy": float(accuracy)} | |||||
| def _conver_numpy(x): | def _conver_numpy(x): | ||||
| """convert input data to numpy array | """convert input data to numpy array | ||||
| @@ -16,43 +16,42 @@ class Predictor(object): | |||||
| Currently, Predictor does not support GPU. | Currently, Predictor does not support GPU. | ||||
| """ | """ | ||||
| def __init__(self, pickle_path, task): | |||||
| def __init__(self, pickle_path, post_processor): | |||||
| """ | """ | ||||
| :param pickle_path: str, the path to the pickle files. | :param pickle_path: str, the path to the pickle files. | ||||
| :param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify"). | |||||
| :param post_processor: a function or callable object, that takes list of batch outputs as input | |||||
| """ | """ | ||||
| self.batch_size = 1 | self.batch_size = 1 | ||||
| self.batch_output = [] | self.batch_output = [] | ||||
| self.pickle_path = pickle_path | self.pickle_path = pickle_path | ||||
| self._task = task # one of ("seq_label", "text_classify") | |||||
| self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl") | |||||
| self._post_processor = post_processor | |||||
| self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl") | |||||
| self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl") | ||||
| def predict(self, network, data): | def predict(self, network, data): | ||||
| """Perform inference using the trained model. | """Perform inference using the trained model. | ||||
| :param network: a PyTorch model (cpu) | :param network: a PyTorch model (cpu) | ||||
| :param data: list of list of strings, [num_examples, seq_len] | |||||
| :param data: a DataSet object. | |||||
| :return: list of list of strings, [num_examples, tag_seq_length] | :return: list of list of strings, [num_examples, tag_seq_length] | ||||
| """ | """ | ||||
| # transform strings into DataSet object | # transform strings into DataSet object | ||||
| data = self.prepare_input(data) | |||||
| # data = self.prepare_input(data) | |||||
| # turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
| self.mode(network, test=True) | self.mode(network, test=True) | ||||
| self.batch_output.clear() | |||||
| batch_output = [] | |||||
| data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) | ||||
| for batch_x, _ in data_iterator: | for batch_x, _ in data_iterator: | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| batch_output.append(prediction) | |||||
| self.batch_output.append(prediction) | |||||
| return self.prepare_output(self.batch_output) | |||||
| return self._post_processor(batch_output, self.label_vocab) | |||||
| def mode(self, network, test=True): | def mode(self, network, test=True): | ||||
| if test: | if test: | ||||
| @@ -62,13 +61,7 @@ class Predictor(object): | |||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| """Forward through network.""" | """Forward through network.""" | ||||
| if self._task == "seq_label": | |||||
| y = network(x["word_seq"], x["word_seq_origin_len"]) | |||||
| y = network.prediction(y) | |||||
| elif self._task == "text_classify": | |||||
| y = network(x["word_seq"]) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| y = network(**x) | |||||
| return y | return y | ||||
| def prepare_input(self, data): | def prepare_input(self, data): | ||||
| @@ -88,39 +81,32 @@ class Predictor(object): | |||||
| assert isinstance(data, list) | assert isinstance(data, list) | ||||
| return create_dataset_from_lists(data, self.word_vocab, has_target=False) | return create_dataset_from_lists(data, self.word_vocab, has_target=False) | ||||
| def prepare_output(self, data): | |||||
| """Transform list of batch outputs into strings.""" | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_prepare_output(data) | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_prepare_output(data) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}".format(self._task)) | |||||
| def _seq_label_prepare_output(self, batch_outputs): | |||||
| results = [] | |||||
| for batch in batch_outputs: | |||||
| for example in np.array(batch): | |||||
| results.append([self.label_vocab.to_word(int(x)) for x in example]) | |||||
| return results | |||||
| def _text_classify_prepare_output(self, batch_outputs): | |||||
| results = [] | |||||
| for batch_out in batch_outputs: | |||||
| idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||||
| results.extend([self.label_vocab.to_word(i) for i in idx]) | |||||
| return results | |||||
| class SeqLabelInfer(Predictor): | class SeqLabelInfer(Predictor): | ||||
| def __init__(self, pickle_path): | def __init__(self, pickle_path): | ||||
| print( | print( | ||||
| "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.") | |||||
| super(SeqLabelInfer, self).__init__(pickle_path, "seq_label") | |||||
| "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.") | |||||
| super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor) | |||||
| class ClassificationInfer(Predictor): | class ClassificationInfer(Predictor): | ||||
| def __init__(self, pickle_path): | def __init__(self, pickle_path): | ||||
| print( | print( | ||||
| "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.") | |||||
| super(ClassificationInfer, self).__init__(pickle_path, "text_classify") | |||||
| "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.") | |||||
| super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor) | |||||
| def seq_label_post_processor(batch_outputs, label_vocab): | |||||
| results = [] | |||||
| for batch in batch_outputs: | |||||
| for example in np.array(batch): | |||||
| results.append([label_vocab.to_word(int(x)) for x in example]) | |||||
| return results | |||||
| def text_classify_post_processor(batch_outputs, label_vocab): | |||||
| results = [] | |||||
| for batch_out in batch_outputs: | |||||
| idx = np.argmax(batch_out.detach().numpy(), axis=-1) | |||||
| results.extend([label_vocab.to_word(i) for i in idx]) | |||||
| return results | |||||
| @@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name): | |||||
| :param pickle_path: str, the directory where the pickle file is to be saved | :param pickle_path: str, the directory where the pickle file is to be saved | ||||
| :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | ||||
| """ | """ | ||||
| if not os.path.exists(pickle_path): | |||||
| os.mkdir(pickle_path) | |||||
| print("make dir {} before saving pickle file".format(pickle_path)) | |||||
| with open(os.path.join(pickle_path, file_name), "wb") as f: | with open(os.path.join(pickle_path, file_name), "wb") as f: | ||||
| _pickle.dump(obj, f) | _pickle.dump(obj, f) | ||||
| print("{} saved in {}".format(file_name, pickle_path)) | print("{} saved in {}".format(file_name, pickle_path)) | ||||
| @@ -66,14 +69,27 @@ class Preprocessor(object): | |||||
| Preprocessors will check if those files are already in the directory and will reuse them in future calls. | Preprocessors will check if those files are already in the directory and will reuse them in future calls. | ||||
| """ | """ | ||||
| def __init__(self, label_is_seq=False): | |||||
| def __init__(self, label_is_seq=False, share_vocab=False, add_char_field=False): | |||||
| """ | """ | ||||
| :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | :param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve | ||||
| several special tokens for sequence processing. | several special tokens for sequence processing. | ||||
| :param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this | |||||
| is only available when label_is_seq is True. Default: False. | |||||
| :param add_char_field: bool, whether to add character representations to all TextFields. Default: False. | |||||
| """ | """ | ||||
| print("Preprocessor is about to deprecate. Please use DataSet class.") | |||||
| self.data_vocab = Vocabulary() | self.data_vocab = Vocabulary() | ||||
| self.label_vocab = Vocabulary(need_default=label_is_seq) | |||||
| if label_is_seq is True: | |||||
| if share_vocab is True: | |||||
| self.label_vocab = self.data_vocab | |||||
| else: | |||||
| self.label_vocab = Vocabulary() | |||||
| else: | |||||
| self.label_vocab = Vocabulary(need_default=False) | |||||
| self.character_vocab = Vocabulary(need_default=False) | |||||
| self.add_char_field = add_char_field | |||||
| @property | @property | ||||
| def vocab_size(self): | def vocab_size(self): | ||||
| @@ -83,6 +99,12 @@ class Preprocessor(object): | |||||
| def num_classes(self): | def num_classes(self): | ||||
| return len(self.label_vocab) | return len(self.label_vocab) | ||||
| @property | |||||
| def char_vocab_size(self): | |||||
| if self.character_vocab is None: | |||||
| self.build_char_dict() | |||||
| return len(self.character_vocab) | |||||
| def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | ||||
| """Main pre-processing pipeline. | """Main pre-processing pipeline. | ||||
| @@ -96,7 +118,6 @@ class Preprocessor(object): | |||||
| If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | ||||
| is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | ||||
| """ | """ | ||||
| if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | ||||
| self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | self.data_vocab = load_pickle(pickle_path, "word2id.pkl") | ||||
| self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | self.label_vocab = load_pickle(pickle_path, "class2id.pkl") | ||||
| @@ -176,6 +197,16 @@ class Preprocessor(object): | |||||
| self.label_vocab.update(label) | self.label_vocab.update(label) | ||||
| return self.data_vocab, self.label_vocab | return self.data_vocab, self.label_vocab | ||||
| def build_char_dict(self): | |||||
| char_collection = set() | |||||
| for word in self.data_vocab.word2idx: | |||||
| if len(word) == 0: | |||||
| continue | |||||
| for ch in word: | |||||
| if ch not in char_collection: | |||||
| char_collection.add(ch) | |||||
| self.character_vocab.update(list(char_collection)) | |||||
| def build_reverse_dict(self): | def build_reverse_dict(self): | ||||
| self.data_vocab.build_reverse_vocab() | self.data_vocab.build_reverse_vocab() | ||||
| self.label_vocab.build_reverse_vocab() | self.label_vocab.build_reverse_vocab() | ||||
| @@ -277,11 +308,3 @@ class ClassPreprocess(Preprocessor): | |||||
| print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.") | ||||
| super(ClassPreprocess, self).__init__() | super(ClassPreprocess, self).__init__() | ||||
| if __name__ == "__main__": | |||||
| p = Preprocessor() | |||||
| train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||||
| [["You", "are", "pretty", "."], "1"] | |||||
| ] | |||||
| training_set = p.run(train_dev_data) | |||||
| print(training_set) | |||||
| @@ -1,7 +1,7 @@ | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.metrics import Evaluator | |||||
| from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
| from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
| @@ -22,28 +22,23 @@ class Tester(object): | |||||
| "kwargs" must have the same type as "default_args" on corresponding keys. | "kwargs" must have the same type as "default_args" on corresponding keys. | ||||
| Otherwise, error will raise. | Otherwise, error will raise. | ||||
| """ | """ | ||||
| default_args = {"save_output": True, # collect outputs of validation set | |||||
| "save_loss": True, # collect losses in validation | |||||
| "save_best_dev": False, # save best model during validation | |||||
| "batch_size": 8, | |||||
| default_args = {"batch_size": 8, | |||||
| "use_cuda": False, | "use_cuda": False, | ||||
| "pickle_path": "./save/", | "pickle_path": "./save/", | ||||
| "model_name": "dev_best_model.pkl", | "model_name": "dev_best_model.pkl", | ||||
| "print_every_step": 1, | |||||
| "evaluator": Evaluator() | |||||
| } | } | ||||
| """ | """ | ||||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
| This is used to warn users of essential settings in the training. | This is used to warn users of essential settings in the training. | ||||
| Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
| """ | """ | ||||
| required_args = {"task" # one of ("seq_label", "text_classify") | |||||
| } | |||||
| required_args = {} | |||||
| for req_key in required_args: | for req_key in required_args: | ||||
| if req_key not in kwargs: | if req_key not in kwargs: | ||||
| logger.error("Tester lacks argument {}".format(req_key)) | logger.error("Tester lacks argument {}".format(req_key)) | ||||
| raise ValueError("Tester lacks argument {}".format(req_key)) | raise ValueError("Tester lacks argument {}".format(req_key)) | ||||
| self._task = kwargs["task"] | |||||
| for key in default_args: | for key in default_args: | ||||
| if key in kwargs: | if key in kwargs: | ||||
| @@ -59,17 +54,13 @@ class Tester(object): | |||||
| pass | pass | ||||
| print(default_args) | print(default_args) | ||||
| self.save_output = default_args["save_output"] | |||||
| self.save_best_dev = default_args["save_best_dev"] | |||||
| self.save_loss = default_args["save_loss"] | |||||
| self.batch_size = default_args["batch_size"] | self.batch_size = default_args["batch_size"] | ||||
| self.pickle_path = default_args["pickle_path"] | self.pickle_path = default_args["pickle_path"] | ||||
| self.use_cuda = default_args["use_cuda"] | self.use_cuda = default_args["use_cuda"] | ||||
| self.print_every_step = default_args["print_every_step"] | |||||
| self._evaluator = default_args["evaluator"] | |||||
| self._model = None | self._model = None | ||||
| self.eval_history = [] # evaluation results of all batches | self.eval_history = [] # evaluation results of all batches | ||||
| self.batch_output = [] # outputs of all batches | |||||
| def test(self, network, dev_data): | def test(self, network, dev_data): | ||||
| if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
| @@ -80,26 +71,18 @@ class Tester(object): | |||||
| # turn on the testing mode; clean up the history | # turn on the testing mode; clean up the history | ||||
| self.mode(network, is_test=True) | self.mode(network, is_test=True) | ||||
| self.eval_history.clear() | self.eval_history.clear() | ||||
| self.batch_output.clear() | |||||
| output_list = [] | |||||
| truth_list = [] | |||||
| data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | ||||
| step = 0 | |||||
| for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| prediction = self.data_forward(network, batch_x) | prediction = self.data_forward(network, batch_x) | ||||
| eval_results = self.evaluate(prediction, batch_y) | |||||
| if self.save_output: | |||||
| self.batch_output.append(prediction) | |||||
| if self.save_loss: | |||||
| self.eval_history.append(eval_results) | |||||
| print_output = "[test step {}] {}".format(step, eval_results) | |||||
| logger.info(print_output) | |||||
| if self.print_every_step > 0 and step % self.print_every_step == 0: | |||||
| print(self.make_eval_output(prediction, eval_results)) | |||||
| step += 1 | |||||
| output_list.append(prediction) | |||||
| truth_list.append(batch_y) | |||||
| eval_results = self.evaluate(output_list, truth_list) | |||||
| print("[tester] {}".format(self.print_eval_results(eval_results))) | |||||
| def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
| """Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
| @@ -121,104 +104,30 @@ class Tester(object): | |||||
| def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
| """Compute evaluation metrics. | """Compute evaluation metrics. | ||||
| :param predict: Tensor | |||||
| :param truth: Tensor | |||||
| :param predict: list of Tensor | |||||
| :param truth: list of dict | |||||
| :return eval_results: can be anything. It will be stored in self.eval_history | :return eval_results: can be anything. It will be stored in self.eval_history | ||||
| """ | """ | ||||
| if "label_seq" in truth: | |||||
| truth = truth["label_seq"] | |||||
| elif "label" in truth: | |||||
| truth = truth["label"] | |||||
| else: | |||||
| raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||||
| return self._evaluator(predict, truth) | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_evaluate(predict, truth) | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_evaluate(predict, truth) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| def _seq_label_evaluate(self, predict, truth): | |||||
| batch_size, max_len = predict.size(0), predict.size(1) | |||||
| loss = self._model.loss(predict, truth) / batch_size | |||||
| prediction = self._model.prediction(predict) | |||||
| # pad prediction to equal length | |||||
| for pred in prediction: | |||||
| if len(pred) < max_len: | |||||
| pred += [0] * (max_len - len(pred)) | |||||
| results = torch.Tensor(prediction).view(-1, ) | |||||
| # make sure "results" is in the same device as "truth" | |||||
| results = results.to(truth) | |||||
| accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||||
| return [float(loss), float(accuracy)] | |||||
| def _text_classify_evaluate(self, y_logit, y_true): | |||||
| y_prob = torch.nn.functional.softmax(y_logit, dim=-1) | |||||
| return [y_prob, y_true] | |||||
| @property | |||||
| def metrics(self): | |||||
| """Compute and return metrics. | |||||
| Use self.eval_history to compute metrics over the whole dev set. | |||||
| Please refer to metrics.py for common metric functions. | |||||
| :return : variable number of outputs | |||||
| """ | |||||
| if self._task == "seq_label": | |||||
| return self._seq_label_metrics | |||||
| elif self._task == "text_classify": | |||||
| return self._text_classify_metrics | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| @property | |||||
| def _seq_label_metrics(self): | |||||
| batch_loss = np.mean([x[0] for x in self.eval_history]) | |||||
| batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||||
| return batch_loss, batch_accuracy | |||||
| @property | |||||
| def _text_classify_metrics(self): | |||||
| y_prob, y_true = zip(*self.eval_history) | |||||
| y_prob = torch.cat(y_prob, dim=0) | |||||
| y_pred = torch.argmax(y_prob, dim=-1) | |||||
| y_true = torch.cat(y_true, dim=0) | |||||
| acc = float(torch.sum(y_pred == y_true)) / len(y_true) | |||||
| return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||||
| def show_metrics(self): | |||||
| """Customize evaluation outputs in Trainer. | |||||
| Called by Trainer to print evaluation results on dev set during training. | |||||
| Use self.metrics to fetch available metrics. | |||||
| :return print_str: str | |||||
| """ | |||||
| loss, accuracy = self.metrics | |||||
| return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||||
| def print_eval_results(self, results): | |||||
| """Override this method to support more print formats. | |||||
| def make_eval_output(self, predictions, eval_results): | |||||
| """Customize Tester outputs. | |||||
| :param results: dict, (str: float) is (metrics name: value) | |||||
| :param predictions: Tensor | |||||
| :param eval_results: Tensor | |||||
| :return: str, to be printed. | |||||
| """ | """ | ||||
| return self.show_metrics() | |||||
| return ", ".join([str(key) + "=" + str(value) for key, value in results.items()]) | |||||
| class SeqLabelTester(Tester): | class SeqLabelTester(Tester): | ||||
| def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
| test_args.update({"task": "seq_label"}) | |||||
| print( | print( | ||||
| "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") | |||||
| "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.") | |||||
| super(SeqLabelTester, self).__init__(**test_args) | super(SeqLabelTester, self).__init__(**test_args) | ||||
| class ClassificationTester(Tester): | class ClassificationTester(Tester): | ||||
| def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
| test_args.update({"task": "text_classify"}) | |||||
| print( | print( | ||||
| "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.") | |||||
| "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.") | |||||
| super(ClassificationTester, self).__init__(**test_args) | super(ClassificationTester, self).__init__(**test_args) | ||||
| @@ -1,4 +1,3 @@ | |||||
| import copy | |||||
| import os | import os | ||||
| import time | import time | ||||
| from datetime import timedelta | from datetime import timedelta | ||||
| @@ -8,6 +7,7 @@ from tensorboardX import SummaryWriter | |||||
| from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
| from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
| from fastNLP.core.metrics import Evaluator | |||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
| from fastNLP.core.tester import SeqLabelTester, ClassificationTester | from fastNLP.core.tester import SeqLabelTester, ClassificationTester | ||||
| @@ -43,21 +43,20 @@ class Trainer(object): | |||||
| default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | ||||
| "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | "save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1, | ||||
| "loss": Loss(None), # used to pass type check | "loss": Loss(None), # used to pass type check | ||||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | |||||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | |||||
| "evaluator": Evaluator() | |||||
| } | } | ||||
| """ | """ | ||||
| "required_args" is the collection of arguments that users must pass to Trainer explicitly. | "required_args" is the collection of arguments that users must pass to Trainer explicitly. | ||||
| This is used to warn users of essential settings in the training. | This is used to warn users of essential settings in the training. | ||||
| Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | ||||
| """ | """ | ||||
| required_args = {"task" # one of ("seq_label", "text_classify") | |||||
| } | |||||
| required_args = {} | |||||
| for req_key in required_args: | for req_key in required_args: | ||||
| if req_key not in kwargs: | if req_key not in kwargs: | ||||
| logger.error("Trainer lacks argument {}".format(req_key)) | logger.error("Trainer lacks argument {}".format(req_key)) | ||||
| raise ValueError("Trainer lacks argument {}".format(req_key)) | raise ValueError("Trainer lacks argument {}".format(req_key)) | ||||
| self._task = kwargs["task"] | |||||
| for key in default_args: | for key in default_args: | ||||
| if key in kwargs: | if key in kwargs: | ||||
| @@ -86,6 +85,7 @@ class Trainer(object): | |||||
| self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | self._loss_func = default_args["loss"].get() # return a pytorch loss function or None | ||||
| self._optimizer = None | self._optimizer = None | ||||
| self._optimizer_proto = default_args["optimizer"] | self._optimizer_proto = default_args["optimizer"] | ||||
| self._evaluator = default_args["evaluator"] | |||||
| self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | ||||
| self._graph_summaried = False | self._graph_summaried = False | ||||
| self._best_accuracy = 0.0 | self._best_accuracy = 0.0 | ||||
| @@ -106,9 +106,8 @@ class Trainer(object): | |||||
| # define Tester over dev data | # define Tester over dev data | ||||
| if self.validate: | if self.validate: | ||||
| default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||||
| "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
| "use_cuda": self.use_cuda, "print_every_step": 0} | |||||
| default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||||
| "use_cuda": self.use_cuda, "evaluator": self._evaluator} | |||||
| validator = self._create_validator(default_valid_args) | validator = self._create_validator(default_valid_args) | ||||
| logger.info("validator defined as {}".format(str(validator))) | logger.info("validator defined as {}".format(str(validator))) | ||||
| @@ -142,15 +141,6 @@ class Trainer(object): | |||||
| logger.info("validation started") | logger.info("validation started") | ||||
| validator.test(network, dev_data) | validator.test(network, dev_data) | ||||
| if self.save_best_dev and self.best_eval_result(validator): | |||||
| self.save_model(network, self.model_name) | |||||
| print("Saved better model selected by validation.") | |||||
| logger.info("Saved better model selected by validation.") | |||||
| valid_results = validator.show_metrics() | |||||
| print("[epoch {}] {}".format(epoch, valid_results)) | |||||
| logger.info("[epoch {}] {}".format(epoch, valid_results)) | |||||
| def _train_step(self, data_iterator, network, **kwargs): | def _train_step(self, data_iterator, network, **kwargs): | ||||
| """Training process in one epoch. | """Training process in one epoch. | ||||
| @@ -178,31 +168,6 @@ class Trainer(object): | |||||
| logger.info(print_output) | logger.info(print_output) | ||||
| step += 1 | step += 1 | ||||
| def cross_validate(self, network, train_data_cv, dev_data_cv): | |||||
| """Training with cross validation. | |||||
| :param network: the model | |||||
| :param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||||
| :param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?] | |||||
| """ | |||||
| if len(train_data_cv) != len(dev_data_cv): | |||||
| logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv), | |||||
| len(dev_data_cv))) | |||||
| raise RuntimeError("the number of folds in train and dev data unequals") | |||||
| if self.validate is False: | |||||
| logger.warn("Cross validation requires self.validate to be True. Please turn it on. ") | |||||
| print("[warning] Cross validation requires self.validate to be True. Please turn it on. ") | |||||
| self.validate = True | |||||
| n_fold = len(train_data_cv) | |||||
| logger.info("perform {} folds cross validation.".format(n_fold)) | |||||
| for i in range(n_fold): | |||||
| print("CV:", i) | |||||
| logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold)) | |||||
| network_copy = copy.deepcopy(network) | |||||
| self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||||
| def mode(self, model, is_test=False): | def mode(self, model, is_test=False): | ||||
| """Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
| @@ -229,18 +194,9 @@ class Trainer(object): | |||||
| self._optimizer.step() | self._optimizer.step() | ||||
| def data_forward(self, network, x): | def data_forward(self, network, x): | ||||
| if self._task == "seq_label": | |||||
| y = network(x["word_seq"], x["word_seq_origin_len"]) | |||||
| elif self._task == "text_classify": | |||||
| y = network(x["word_seq"]) | |||||
| else: | |||||
| raise NotImplementedError("Unknown task type {}.".format(self._task)) | |||||
| y = network(**x) | |||||
| if not self._graph_summaried: | if not self._graph_summaried: | ||||
| if self._task == "seq_label": | |||||
| self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | |||||
| elif self._task == "text_classify": | |||||
| self._summary_writer.add_graph(network, x["word_seq"], verbose=False) | |||||
| # self._summary_writer.add_graph(network, x, verbose=False) | |||||
| self._graph_summaried = True | self._graph_summaried = True | ||||
| return y | return y | ||||
| @@ -261,13 +217,9 @@ class Trainer(object): | |||||
| :param truth: ground truth label vector | :param truth: ground truth label vector | ||||
| :return: a scalar | :return: a scalar | ||||
| """ | """ | ||||
| if "label_seq" in truth: | |||||
| truth = truth["label_seq"] | |||||
| elif "label" in truth: | |||||
| truth = truth["label"] | |||||
| truth = truth.view((-1,)) | |||||
| else: | |||||
| raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||||
| if len(truth) > 1: | |||||
| raise NotImplementedError("Not ready to handle multi-labels.") | |||||
| truth = list(truth.values())[0] if len(truth) > 0 else None | |||||
| return self._loss_func(predict, truth) | return self._loss_func(predict, truth) | ||||
| def define_loss(self): | def define_loss(self): | ||||
| @@ -278,8 +230,8 @@ class Trainer(object): | |||||
| These two losses cannot be defined at the same time. | These two losses cannot be defined at the same time. | ||||
| Trainer does not handle loss definition or choose default losses. | Trainer does not handle loss definition or choose default losses. | ||||
| """ | """ | ||||
| if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
| raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
| # if hasattr(self._model, "loss") and self._loss_func is not None: | |||||
| # raise ValueError("Both the model and Trainer define loss. Please take out your loss.") | |||||
| if hasattr(self._model, "loss"): | if hasattr(self._model, "loss"): | ||||
| self._loss_func = self._model.loss | self._loss_func = self._model.loss | ||||
| @@ -322,9 +274,8 @@ class SeqLabelTrainer(Trainer): | |||||
| """ | """ | ||||
| def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
| kwargs.update({"task": "seq_label"}) | |||||
| print( | print( | ||||
| "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.") | |||||
| "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.") | |||||
| super(SeqLabelTrainer, self).__init__(**kwargs) | super(SeqLabelTrainer, self).__init__(**kwargs) | ||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| @@ -335,9 +286,8 @@ class ClassificationTrainer(Trainer): | |||||
| """Trainer for text classification.""" | """Trainer for text classification.""" | ||||
| def __init__(self, **train_args): | def __init__(self, **train_args): | ||||
| train_args.update({"task": "text_classify"}) | |||||
| print( | print( | ||||
| "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.") | |||||
| "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.") | |||||
| super(ClassificationTrainer, self).__init__(**train_args) | super(ClassificationTrainer, self).__init__(**train_args) | ||||
| def _create_validator(self, valid_args): | def _create_validator(self, valid_args): | ||||
| @@ -10,13 +10,15 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1, | |||||
| DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3, | ||||
| DEFAULT_RESERVED_LABEL[2]: 4} | DEFAULT_RESERVED_LABEL[2]: 4} | ||||
| def isiterable(p_object): | def isiterable(p_object): | ||||
| try: | try: | ||||
| it = iter(p_object) | it = iter(p_object) | ||||
| except TypeError: | |||||
| except TypeError: | |||||
| return False | return False | ||||
| return True | return True | ||||
| class Vocabulary(object): | class Vocabulary(object): | ||||
| """Use for word and index one to one mapping | """Use for word and index one to one mapping | ||||
| @@ -28,9 +30,11 @@ class Vocabulary(object): | |||||
| vocab["word"] | vocab["word"] | ||||
| vocab.to_word(5) | vocab.to_word(5) | ||||
| """ | """ | ||||
| def __init__(self, need_default=True): | def __init__(self, need_default=True): | ||||
| """ | """ | ||||
| :param bool need_default: set if the Vocabulary has default labels reserved. | |||||
| :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||||
| """ | """ | ||||
| if need_default: | if need_default: | ||||
| self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | ||||
| @@ -50,20 +54,19 @@ class Vocabulary(object): | |||||
| def update(self, word): | def update(self, word): | ||||
| """add word or list of words into Vocabulary | """add word or list of words into Vocabulary | ||||
| :param word: a list of str or str | |||||
| :param word: a list of string or a single string | |||||
| """ | """ | ||||
| if not isinstance(word, str) and isiterable(word): | if not isinstance(word, str) and isiterable(word): | ||||
| # it's a nested list | |||||
| # it's a nested list | |||||
| for w in word: | for w in word: | ||||
| self.update(w) | self.update(w) | ||||
| else: | else: | ||||
| # it's a word to be added | |||||
| # it's a word to be added | |||||
| if word not in self.word2idx: | if word not in self.word2idx: | ||||
| self.word2idx[word] = len(self) | self.word2idx[word] = len(self) | ||||
| if self.idx2word is not None: | if self.idx2word is not None: | ||||
| self.idx2word = None | self.idx2word = None | ||||
| def __getitem__(self, w): | def __getitem__(self, w): | ||||
| """To support usage like:: | """To support usage like:: | ||||
| @@ -81,12 +84,12 @@ class Vocabulary(object): | |||||
| :param str w: | :param str w: | ||||
| """ | """ | ||||
| return self[w] | return self[w] | ||||
| def unknown_idx(self): | def unknown_idx(self): | ||||
| if self.unknown_label is None: | |||||
| if self.unknown_label is None: | |||||
| return None | return None | ||||
| return self.word2idx[self.unknown_label] | return self.word2idx[self.unknown_label] | ||||
| def padding_idx(self): | def padding_idx(self): | ||||
| if self.padding_label is None: | if self.padding_label is None: | ||||
| return None | return None | ||||
| @@ -95,8 +98,8 @@ class Vocabulary(object): | |||||
| def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
| """build 'index to word' dict based on 'word to index' dict | """build 'index to word' dict based on 'word to index' dict | ||||
| """ | """ | ||||
| self.idx2word = {self.word2idx[w] : w for w in self.word2idx} | |||||
| self.idx2word = {self.word2idx[w]: w for w in self.word2idx} | |||||
| def to_word(self, idx): | def to_word(self, idx): | ||||
| """given a word's index, return the word itself | """given a word's index, return the word itself | ||||
| @@ -105,7 +108,7 @@ class Vocabulary(object): | |||||
| if self.idx2word is None: | if self.idx2word is None: | ||||
| self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
| return self.idx2word[idx] | return self.idx2word[idx] | ||||
| def __getstate__(self): | def __getstate__(self): | ||||
| """use to prepare data for pickle | """use to prepare data for pickle | ||||
| """ | """ | ||||
| @@ -113,12 +116,9 @@ class Vocabulary(object): | |||||
| # no need to pickle idx2word as it can be constructed from word2idx | # no need to pickle idx2word as it can be constructed from word2idx | ||||
| del state['idx2word'] | del state['idx2word'] | ||||
| return state | return state | ||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| """use to restore state from pickle | """use to restore state from pickle | ||||
| """ | """ | ||||
| self.__dict__.update(state) | self.__dict__.update(state) | ||||
| self.idx2word = None | self.idx2word = None | ||||
| @@ -1,5 +1,6 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
| from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer | ||||
| from fastNLP.core.preprocess import load_pickle | from fastNLP.core.preprocess import load_pickle | ||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| @@ -71,11 +72,13 @@ class FastNLP(object): | |||||
| :param model_dir: this directory should contain the following files: | :param model_dir: this directory should contain the following files: | ||||
| 1. a trained model | 1. a trained model | ||||
| 2. a config file, which is a fastNLP's configuration. | 2. a config file, which is a fastNLP's configuration. | ||||
| 3. a Vocab file, which is a pickle object of a Vocab instance. | |||||
| 3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs. | |||||
| """ | """ | ||||
| self.model_dir = model_dir | self.model_dir = model_dir | ||||
| self.model = None | self.model = None | ||||
| self.infer_type = None # "seq_label"/"text_class" | self.infer_type = None # "seq_label"/"text_class" | ||||
| self.word_vocab = None | |||||
| self.label_vocab = None | |||||
| def load(self, model_name, config_file="config", section_name="model"): | def load(self, model_name, config_file="config", section_name="model"): | ||||
| """ | """ | ||||
| @@ -100,10 +103,10 @@ class FastNLP(object): | |||||
| print("Restore model hyper-parameters {}".format(str(model_args.data))) | print("Restore model hyper-parameters {}".format(str(model_args.data))) | ||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
| model_args["vocab_size"] = len(word_vocab) | |||||
| label_vocab = load_pickle(self.model_dir, "class2id.pkl") | |||||
| model_args["num_classes"] = len(label_vocab) | |||||
| self.word_vocab = load_pickle(self.model_dir, "word2id.pkl") | |||||
| model_args["vocab_size"] = len(self.word_vocab) | |||||
| self.label_vocab = load_pickle(self.model_dir, "label2id.pkl") | |||||
| model_args["num_classes"] = len(self.label_vocab) | |||||
| # Construct the model | # Construct the model | ||||
| model = model_class(model_args) | model = model_class(model_args) | ||||
| @@ -130,8 +133,11 @@ class FastNLP(object): | |||||
| # tokenize: list of string ---> 2-D list of string | # tokenize: list of string ---> 2-D list of string | ||||
| infer_input = self.tokenize(raw_input, language="zh") | infer_input = self.tokenize(raw_input, language="zh") | ||||
| # 2-D list of string ---> 2-D list of tags | |||||
| results = infer.predict(self.model, infer_input) | |||||
| # create DataSet: 2-D list of strings ----> DataSet | |||||
| infer_data = self._create_data_set(infer_input) | |||||
| # DataSet ---> 2-D list of tags | |||||
| results = infer.predict(self.model, infer_data) | |||||
| # 2-D list of tags ---> list of final answers | # 2-D list of tags ---> list of final answers | ||||
| outputs = self._make_output(results, infer_input) | outputs = self._make_output(results, infer_input) | ||||
| @@ -154,6 +160,11 @@ class FastNLP(object): | |||||
| return module | return module | ||||
| def _create_inference(self, model_dir): | def _create_inference(self, model_dir): | ||||
| """Specify which task to perform. | |||||
| :param model_dir: | |||||
| :return: | |||||
| """ | |||||
| if self.infer_type == "seq_label": | if self.infer_type == "seq_label": | ||||
| return SeqLabelInfer(model_dir) | return SeqLabelInfer(model_dir) | ||||
| elif self.infer_type == "text_class": | elif self.infer_type == "text_class": | ||||
| @@ -161,8 +172,26 @@ class FastNLP(object): | |||||
| else: | else: | ||||
| raise ValueError("fail to create inference instance") | raise ValueError("fail to create inference instance") | ||||
| def _create_data_set(self, infer_input): | |||||
| """Create a DataSet object given the raw inputs. | |||||
| :param infer_input: 2-D lists of strings | |||||
| :return data_set: a DataSet object | |||||
| """ | |||||
| if self.infer_type == "seq_label": | |||||
| data_set = SeqLabelDataSet() | |||||
| data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) | |||||
| return data_set | |||||
| elif self.infer_type == "text_class": | |||||
| data_set = TextClassifyDataSet() | |||||
| data_set.load_raw(infer_input, {"word_vocab": self.word_vocab}) | |||||
| return data_set | |||||
| else: | |||||
| raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type)) | |||||
| def _load(self, model_dir, model_name): | def _load(self, model_dir, model_name): | ||||
| # To do | |||||
| return 0 | return 0 | ||||
| def _download(self, model_name, url): | def _download(self, model_name, url): | ||||
| @@ -172,7 +201,7 @@ class FastNLP(object): | |||||
| :param url: | :param url: | ||||
| """ | """ | ||||
| print("Downloading {} from {}".format(model_name, url)) | print("Downloading {} from {}".format(model_name, url)) | ||||
| # To do | |||||
| # TODO: download model via url | |||||
| def model_exist(self, model_dir): | def model_exist(self, model_dir): | ||||
| """ | """ | ||||
| @@ -1,27 +1,24 @@ | |||||
| class BaseLoader(object): | class BaseLoader(object): | ||||
| """docstring for BaseLoader""" | |||||
| def __init__(self, data_path): | |||||
| def __init__(self): | |||||
| super(BaseLoader, self).__init__() | super(BaseLoader, self).__init__() | ||||
| self.data_path = data_path | |||||
| def load(self): | |||||
| """ | |||||
| :return: string | |||||
| """ | |||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| text = f.read() | |||||
| return text | |||||
| def load_lines(self): | |||||
| with open(self.data_path, "r", encoding="utf=8") as f: | |||||
| @staticmethod | |||||
| def load_lines(data_path): | |||||
| with open(data_path, "r", encoding="utf=8") as f: | |||||
| text = f.readlines() | text = f.readlines() | ||||
| return [line.strip() for line in text] | return [line.strip() for line in text] | ||||
| @staticmethod | |||||
| def load(data_path): | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| text = f.readlines() | |||||
| return [[word for word in sent.strip()] for sent in text] | |||||
| class ToyLoader0(BaseLoader): | class ToyLoader0(BaseLoader): | ||||
| """ | """ | ||||
| For charLM | |||||
| For CharLM | |||||
| """ | """ | ||||
| def __init__(self, data_path): | def __init__(self, data_path): | ||||
| @@ -8,9 +8,9 @@ from fastNLP.loader.base_loader import BaseLoader | |||||
| class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): | ||||
| """loader for configuration files""" | """loader for configuration files""" | ||||
| def __int__(self, data_name, data_path): | |||||
| super(ConfigLoader, self).__init__(data_path) | |||||
| self.config = self.parse(super(ConfigLoader, self).load()) | |||||
| def __int__(self, data_path): | |||||
| super(ConfigLoader, self).__init__() | |||||
| self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||||
| @staticmethod | @staticmethod | ||||
| def parse(string): | def parse(string): | ||||
| @@ -3,14 +3,17 @@ import os | |||||
| from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
| class DatasetLoader(BaseLoader): | |||||
| class DataSetLoader(BaseLoader): | |||||
| """"loader for data sets""" | """"loader for data sets""" | ||||
| def __init__(self, data_path): | |||||
| super(DatasetLoader, self).__init__(data_path) | |||||
| def __init__(self): | |||||
| super(DataSetLoader, self).__init__() | |||||
| def load(self, path): | |||||
| raise NotImplementedError | |||||
| class POSDatasetLoader(DatasetLoader): | |||||
| class POSDataSetLoader(DataSetLoader): | |||||
| """Dataset Loader for POS Tag datasets. | """Dataset Loader for POS Tag datasets. | ||||
| In these datasets, each line are divided by '\t' | In these datasets, each line are divided by '\t' | ||||
| @@ -31,16 +34,10 @@ class POSDatasetLoader(DatasetLoader): | |||||
| to label5. | to label5. | ||||
| """ | """ | ||||
| def __init__(self, data_path): | |||||
| super(POSDatasetLoader, self).__init__(data_path) | |||||
| def load(self): | |||||
| assert os.path.exists(self.data_path) | |||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| line = f.read() | |||||
| return line | |||||
| def __init__(self): | |||||
| super(POSDataSetLoader, self).__init__() | |||||
| def load_lines(self): | |||||
| def load(self, data_path): | |||||
| """ | """ | ||||
| :return data: three-level list | :return data: three-level list | ||||
| [ | [ | ||||
| @@ -49,7 +46,7 @@ class POSDatasetLoader(DatasetLoader): | |||||
| ... | ... | ||||
| ] | ] | ||||
| """ | """ | ||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| lines = f.readlines() | lines = f.readlines() | ||||
| return self.parse(lines) | return self.parse(lines) | ||||
| @@ -79,15 +76,16 @@ class POSDatasetLoader(DatasetLoader): | |||||
| return data | return data | ||||
| class TokenizeDatasetLoader(DatasetLoader): | |||||
| class TokenizeDataSetLoader(DataSetLoader): | |||||
| """ | """ | ||||
| Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
| """ | """ | ||||
| def __init__(self, data_path): | |||||
| super(TokenizeDatasetLoader, self).__init__(data_path) | |||||
| def __init__(self): | |||||
| super(TokenizeDataSetLoader, self).__init__() | |||||
| def load_pku(self, max_seq_len=32): | |||||
| @staticmethod | |||||
| def load(data_path, max_seq_len=32): | |||||
| """ | """ | ||||
| load pku dataset for Chinese word segmentation | load pku dataset for Chinese word segmentation | ||||
| CWS (Chinese Word Segmentation) pku training dataset format: | CWS (Chinese Word Segmentation) pku training dataset format: | ||||
| @@ -104,7 +102,7 @@ class TokenizeDatasetLoader(DatasetLoader): | |||||
| :return: three-level lists | :return: three-level lists | ||||
| """ | """ | ||||
| assert isinstance(max_seq_len, int) and max_seq_len > 0 | assert isinstance(max_seq_len, int) and max_seq_len > 0 | ||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| sentences = f.readlines() | sentences = f.readlines() | ||||
| data = [] | data = [] | ||||
| for sent in sentences: | for sent in sentences: | ||||
| @@ -135,15 +133,15 @@ class TokenizeDatasetLoader(DatasetLoader): | |||||
| return data | return data | ||||
| class ClassDatasetLoader(DatasetLoader): | |||||
| class ClassDataSetLoader(DataSetLoader): | |||||
| """Loader for classification data sets""" | """Loader for classification data sets""" | ||||
| def __init__(self, data_path): | |||||
| super(ClassDatasetLoader, self).__init__(data_path) | |||||
| def __init__(self): | |||||
| super(ClassDataSetLoader, self).__init__() | |||||
| def load(self): | |||||
| assert os.path.exists(self.data_path) | |||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| def load(self, data_path): | |||||
| assert os.path.exists(data_path) | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| lines = f.readlines() | lines = f.readlines() | ||||
| return self.parse(lines) | return self.parse(lines) | ||||
| @@ -169,21 +167,21 @@ class ClassDatasetLoader(DatasetLoader): | |||||
| return dataset | return dataset | ||||
| class ConllLoader(DatasetLoader): | |||||
| class ConllLoader(DataSetLoader): | |||||
| """loader for conll format files""" | """loader for conll format files""" | ||||
| def __int__(self, data_path): | def __int__(self, data_path): | ||||
| """ | """ | ||||
| :param str data_path: the path to the conll data set | :param str data_path: the path to the conll data set | ||||
| """ | """ | ||||
| super(ConllLoader, self).__init__(data_path) | |||||
| self.data_set = self.parse(self.load()) | |||||
| super(ConllLoader, self).__init__() | |||||
| self.data_set = self.parse(self.load(data_path)) | |||||
| def load(self): | |||||
| def load(self, data_path): | |||||
| """ | """ | ||||
| :return: list lines: all lines in a conll file | :return: list lines: all lines in a conll file | ||||
| """ | """ | ||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| lines = f.readlines() | lines = f.readlines() | ||||
| return lines | return lines | ||||
| @@ -207,28 +205,48 @@ class ConllLoader(DatasetLoader): | |||||
| return sentences | return sentences | ||||
| class LMDatasetLoader(DatasetLoader): | |||||
| def __init__(self, data_path): | |||||
| super(LMDatasetLoader, self).__init__(data_path) | |||||
| class LMDataSetLoader(DataSetLoader): | |||||
| """Language Model Dataset Loader | |||||
| def load(self): | |||||
| if not os.path.exists(self.data_path): | |||||
| raise FileNotFoundError("file {} not found.".format(self.data_path)) | |||||
| with open(self.data_path, "r", encoding="utf=8") as f: | |||||
| text = " ".join(f.readlines()) | |||||
| return text.strip().split() | |||||
| This loader produces data for language model training in a supervised way. | |||||
| That means it has X and Y. | |||||
| """ | |||||
| def __init__(self): | |||||
| super(LMDataSetLoader, self).__init__() | |||||
| class PeopleDailyCorpusLoader(DatasetLoader): | |||||
| def load(self, data_path): | |||||
| if not os.path.exists(data_path): | |||||
| raise FileNotFoundError("file {} not found.".format(data_path)) | |||||
| with open(data_path, "r", encoding="utf=8") as f: | |||||
| text = " ".join(f.readlines()) | |||||
| tokens = text.strip().split() | |||||
| return self.sentence_cut(tokens) | |||||
| def sentence_cut(self, tokens, sentence_length=15): | |||||
| start_idx = 0 | |||||
| data_set = [] | |||||
| for idx in range(len(tokens) // sentence_length): | |||||
| x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||||
| y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||||
| if start_idx * idx + sentence_length + 1 >= len(tokens): | |||||
| # ad hoc | |||||
| y.extend(["<unk>"]) | |||||
| data_set.append([x, y]) | |||||
| return data_set | |||||
| class PeopleDailyCorpusLoader(DataSetLoader): | |||||
| """ | """ | ||||
| People Daily Corpus: Chinese word segmentation, POS tag, NER | People Daily Corpus: Chinese word segmentation, POS tag, NER | ||||
| """ | """ | ||||
| def __init__(self, data_path): | |||||
| super(PeopleDailyCorpusLoader, self).__init__(data_path) | |||||
| def __init__(self): | |||||
| super(PeopleDailyCorpusLoader, self).__init__() | |||||
| def load(self): | |||||
| with open(self.data_path, "r", encoding="utf-8") as f: | |||||
| def load(self, data_path): | |||||
| with open(data_path, "r", encoding="utf-8") as f: | |||||
| sents = f.readlines() | sents = f.readlines() | ||||
| pos_tag_examples = [] | pos_tag_examples = [] | ||||
| @@ -1,215 +1,8 @@ | |||||
| import os | |||||
| import numpy as np | |||||
| import torch | import torch | ||||
| import torch.nn as nn | import torch.nn as nn | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| import torch.optim as optim | |||||
| from torch.autograd import Variable | |||||
| from fastNLP.models.base_model import BaseModel | |||||
| USE_GPU = True | |||||
| """ | |||||
| To be deprecated. | |||||
| """ | |||||
| class CharLM(BaseModel): | |||||
| """ | |||||
| Controller of the Character-level Neural Language Model | |||||
| """ | |||||
| def __init__(self, lstm_batch_size, lstm_seq_len): | |||||
| super(CharLM, self).__init__() | |||||
| """ | |||||
| Settings: should come from config loader or pre-processing | |||||
| """ | |||||
| self.word_embed_dim = 300 | |||||
| self.char_embedding_dim = 15 | |||||
| self.cnn_batch_size = lstm_batch_size * lstm_seq_len | |||||
| self.lstm_seq_len = lstm_seq_len | |||||
| self.lstm_batch_size = lstm_batch_size | |||||
| self.num_epoch = 10 | |||||
| self.old_PPL = 100000 | |||||
| self.best_PPL = 100000 | |||||
| """ | |||||
| These parameters are set by pre-processing. | |||||
| """ | |||||
| self.max_word_len = None | |||||
| self.num_char = None | |||||
| self.vocab_size = None | |||||
| self.preprocess("./data_for_tests/charlm.txt") | |||||
| self.data = None # named tuple to store all data set | |||||
| self.data_ready = False | |||||
| self.criterion = nn.CrossEntropyLoss() | |||||
| self._loss = None | |||||
| self.use_gpu = USE_GPU | |||||
| # word_emb_dim == hidden_size / num of hidden units | |||||
| self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)), | |||||
| to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim))) | |||||
| self.model = charLM(self.char_embedding_dim, | |||||
| self.word_embed_dim, | |||||
| self.vocab_size, | |||||
| self.num_char, | |||||
| use_gpu=self.use_gpu) | |||||
| for param in self.model.parameters(): | |||||
| nn.init.uniform(param.data, -0.05, 0.05) | |||||
| self.learning_rate = 0.1 | |||||
| self.optimizer = None | |||||
| def prepare_input(self, raw_text): | |||||
| """ | |||||
| :param raw_text: raw input text consisting of words | |||||
| :return: torch.Tensor, torch.Tensor | |||||
| feature matrix, label vector | |||||
| This function is only called once in Trainer.train, but may called multiple times in Tester.test | |||||
| So Tester will save test input for frequent calls. | |||||
| """ | |||||
| if os.path.exists("cache/prep.pt") is False: | |||||
| self.preprocess("./data_for_tests/charlm.txt") # To do: This is not good. Need to fix.. | |||||
| objects = torch.load("cache/prep.pt") | |||||
| word_dict = objects["word_dict"] | |||||
| char_dict = objects["char_dict"] | |||||
| max_word_len = self.max_word_len | |||||
| print("word/char dictionary built. Start making inputs.") | |||||
| words = raw_text | |||||
| input_vec = np.array(text2vec(words, char_dict, max_word_len)) | |||||
| # Labels are next-word index in word_dict with the same length as inputs | |||||
| input_label = np.array([word_dict[w] for w in words[1:]] + [word_dict[words[-1]]]) | |||||
| feature_input = torch.from_numpy(input_vec) | |||||
| label_input = torch.from_numpy(input_label) | |||||
| return feature_input, label_input | |||||
| def mode(self, test=False): | |||||
| if test: | |||||
| self.model.eval() | |||||
| else: | |||||
| self.model.train() | |||||
| def data_forward(self, x): | |||||
| """ | |||||
| :param x: Tensor of size [lstm_batch_size, lstm_seq_len, max_word_len+2] | |||||
| :return: Tensor of size [num_words, ?] | |||||
| """ | |||||
| # additional processing of inputs after batching | |||||
| num_seq = x.size()[0] // self.lstm_seq_len | |||||
| x = x[:num_seq * self.lstm_seq_len, :] | |||||
| x = x.view(-1, self.lstm_seq_len, self.max_word_len + 2) | |||||
| # detach hidden state of LSTM from last batch | |||||
| hidden = [state.detach() for state in self.hidden] | |||||
| output, self.hidden = self.model(to_var(x), hidden) | |||||
| return output | |||||
| def grad_backward(self): | |||||
| self.model.zero_grad() | |||||
| self._loss.backward() | |||||
| torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2) | |||||
| self.optimizer.step() | |||||
| def get_loss(self, predict, truth): | |||||
| self._loss = self.criterion(predict, to_var(truth)) | |||||
| return self._loss.data # No pytorch data structure exposed outsides | |||||
| def define_optimizer(self): | |||||
| # redefine optimizer for every new epoch | |||||
| self.optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.85) | |||||
| def save(self): | |||||
| print("network saved") | |||||
| # torch.save(self.models, "cache/models.pkl") | |||||
| def preprocess(self, all_text_files): | |||||
| word_dict, char_dict = create_word_char_dict(all_text_files) | |||||
| num_char = len(char_dict) | |||||
| self.vocab_size = len(word_dict) | |||||
| char_dict["BOW"] = num_char + 1 | |||||
| char_dict["EOW"] = num_char + 2 | |||||
| char_dict["PAD"] = 0 | |||||
| self.num_char = num_char + 3 | |||||
| # char_dict is a dict of (int, string), int counting from 0 to 47 | |||||
| reverse_word_dict = {value: key for key, value in word_dict.items()} | |||||
| self.max_word_len = max([len(word) for word in word_dict]) | |||||
| objects = { | |||||
| "word_dict": word_dict, | |||||
| "char_dict": char_dict, | |||||
| "reverse_word_dict": reverse_word_dict, | |||||
| } | |||||
| if not os.path.exists("cache"): | |||||
| os.mkdir("cache") | |||||
| torch.save(objects, "cache/prep.pt") | |||||
| print("Preprocess done.") | |||||
| """ | |||||
| Global Functions | |||||
| """ | |||||
| def batch_generator(x, batch_size): | |||||
| # x: [num_words, in_channel, height, width] | |||||
| # partitions x into batches | |||||
| num_step = x.size()[0] // batch_size | |||||
| for t in range(num_step): | |||||
| yield x[t * batch_size:(t + 1) * batch_size] | |||||
| def text2vec(words, char_dict, max_word_len): | |||||
| """ Return list of list of int """ | |||||
| word_vec = [] | |||||
| for word in words: | |||||
| vec = [char_dict[ch] for ch in word] | |||||
| if len(vec) < max_word_len: | |||||
| vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))] | |||||
| vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]] | |||||
| word_vec.append(vec) | |||||
| return word_vec | |||||
| def read_data(file_name): | |||||
| with open(file_name, 'r') as f: | |||||
| corpus = f.read().lower() | |||||
| import re | |||||
| corpus = re.sub(r"<unk>", "unk", corpus) | |||||
| return corpus.split() | |||||
| def get_char_dict(vocabulary): | |||||
| char_dict = dict() | |||||
| count = 1 | |||||
| for word in vocabulary: | |||||
| for ch in word: | |||||
| if ch not in char_dict: | |||||
| char_dict[ch] = count | |||||
| count += 1 | |||||
| return char_dict | |||||
| def create_word_char_dict(*file_name): | |||||
| text = [] | |||||
| for file in file_name: | |||||
| text += read_data(file) | |||||
| word_dict = {word: ix for ix, word in enumerate(set(text))} | |||||
| char_dict = get_char_dict(word_dict) | |||||
| return word_dict, char_dict | |||||
| def to_var(x): | |||||
| if torch.cuda.is_available() and USE_GPU: | |||||
| x = x.cuda() | |||||
| return Variable(x) | |||||
| """ | |||||
| Neural Network | |||||
| """ | |||||
| from fastNLP.modules.encoder.lstm import LSTM | |||||
| class Highway(nn.Module): | class Highway(nn.Module): | ||||
| @@ -225,9 +18,8 @@ class Highway(nn.Module): | |||||
| return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x) | return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x) | ||||
| class charLM(nn.Module): | |||||
| """Character-level Neural Language Model | |||||
| CNN + highway network + LSTM | |||||
| class CharLM(nn.Module): | |||||
| """CNN + highway network + LSTM | |||||
| # Input: | # Input: | ||||
| 4D tensor with shape [batch_size, in_channel, height, width] | 4D tensor with shape [batch_size, in_channel, height, width] | ||||
| # Output: | # Output: | ||||
| @@ -241,8 +33,8 @@ class charLM(nn.Module): | |||||
| """ | """ | ||||
| def __init__(self, char_emb_dim, word_emb_dim, | def __init__(self, char_emb_dim, word_emb_dim, | ||||
| vocab_size, num_char, use_gpu): | |||||
| super(charLM, self).__init__() | |||||
| vocab_size, num_char): | |||||
| super(CharLM, self).__init__() | |||||
| self.char_emb_dim = char_emb_dim | self.char_emb_dim = char_emb_dim | ||||
| self.word_emb_dim = word_emb_dim | self.word_emb_dim = word_emb_dim | ||||
| self.vocab_size = vocab_size | self.vocab_size = vocab_size | ||||
| @@ -254,8 +46,7 @@ class charLM(nn.Module): | |||||
| self.convolutions = [] | self.convolutions = [] | ||||
| # list of tuples: (the number of filter, width) | # list of tuples: (the number of filter, width) | ||||
| # self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)] | |||||
| self.filter_num_width = [(25, 1), (50, 2), (75, 3)] | |||||
| self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)] | |||||
| for out_channel, filter_width in self.filter_num_width: | for out_channel, filter_width in self.filter_num_width: | ||||
| self.convolutions.append( | self.convolutions.append( | ||||
| @@ -278,29 +69,13 @@ class charLM(nn.Module): | |||||
| # LSTM | # LSTM | ||||
| self.lstm_num_layers = 2 | self.lstm_num_layers = 2 | ||||
| self.lstm = nn.LSTM(input_size=self.highway_input_dim, | |||||
| hidden_size=self.word_emb_dim, | |||||
| num_layers=self.lstm_num_layers, | |||||
| bias=True, | |||||
| dropout=0.5, | |||||
| batch_first=True) | |||||
| self.lstm = LSTM(self.highway_input_dim, hidden_size=self.word_emb_dim, num_layers=self.lstm_num_layers, | |||||
| dropout=0.5) | |||||
| # output layer | # output layer | ||||
| self.dropout = nn.Dropout(p=0.5) | self.dropout = nn.Dropout(p=0.5) | ||||
| self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | self.linear = nn.Linear(self.word_emb_dim, self.vocab_size) | ||||
| if use_gpu is True: | |||||
| for x in range(len(self.convolutions)): | |||||
| self.convolutions[x] = self.convolutions[x].cuda() | |||||
| self.highway1 = self.highway1.cuda() | |||||
| self.highway2 = self.highway2.cuda() | |||||
| self.lstm = self.lstm.cuda() | |||||
| self.dropout = self.dropout.cuda() | |||||
| self.char_embed = self.char_embed.cuda() | |||||
| self.linear = self.linear.cuda() | |||||
| self.batch_norm = self.batch_norm.cuda() | |||||
| def forward(self, x, hidden): | |||||
| def forward(self, x): | |||||
| # Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2] | # Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2] | ||||
| # Return: Variable of Tensor with shape [num_words, len(word_dict)] | # Return: Variable of Tensor with shape [num_words, len(word_dict)] | ||||
| lstm_batch_size = x.size()[0] | lstm_batch_size = x.size()[0] | ||||
| @@ -313,7 +88,7 @@ class charLM(nn.Module): | |||||
| # [num_seq*seq_len, max_word_len+2, char_emb_dim] | # [num_seq*seq_len, max_word_len+2, char_emb_dim] | ||||
| x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3) | x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3) | ||||
| # [num_seq*seq_len, 1, char_emb_dim, max_word_len+2] | |||||
| # [num_seq*seq_len, 1, max_word_len+2, char_emb_dim] | |||||
| x = self.conv_layers(x) | x = self.conv_layers(x) | ||||
| # [num_seq*seq_len, total_num_filters] | # [num_seq*seq_len, total_num_filters] | ||||
| @@ -328,7 +103,7 @@ class charLM(nn.Module): | |||||
| x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1) | ||||
| # [num_seq, seq_len, total_num_filters] | # [num_seq, seq_len, total_num_filters] | ||||
| x, hidden = self.lstm(x, hidden) | |||||
| x, hidden = self.lstm(x) | |||||
| # [seq_len, num_seq, hidden_size] | # [seq_len, num_seq, hidden_size] | ||||
| x = self.dropout(x) | x = self.dropout(x) | ||||
| @@ -339,7 +114,7 @@ class charLM(nn.Module): | |||||
| x = self.linear(x) | x = self.linear(x) | ||||
| # [num_seq*seq_len, vocab_size] | # [num_seq*seq_len, vocab_size] | ||||
| return x, hidden | |||||
| return x | |||||
| def conv_layers(self, x): | def conv_layers(self, x): | ||||
| chosen_list = list() | chosen_list = list() | ||||
| @@ -31,16 +31,18 @@ class SeqLabeling(BaseModel): | |||||
| num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim) | ||||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | |||||
| self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim) | |||||
| self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | ||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
| self.mask = None | self.mask = None | ||||
| def forward(self, word_seq, word_seq_origin_len): | |||||
| def forward(self, word_seq, word_seq_origin_len, truth=None): | |||||
| """ | """ | ||||
| :param word_seq: LongTensor, [batch_size, mex_len] | :param word_seq: LongTensor, [batch_size, mex_len] | ||||
| :param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | :param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | ||||
| :return y: [batch_size, mex_len, tag_size] | |||||
| :param truth: LongTensor, [batch_size, max_len] | |||||
| :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. | |||||
| If truth is not None, return loss, a scalar. Used in training. | |||||
| """ | """ | ||||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
| @@ -50,9 +52,16 @@ class SeqLabeling(BaseModel): | |||||
| # [batch_size, max_len, hidden_size * direction] | # [batch_size, max_len, hidden_size * direction] | ||||
| x = self.Linear(x) | x = self.Linear(x) | ||||
| # [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
| return x | |||||
| if truth is not None: | |||||
| return self._internal_loss(x, truth) | |||||
| else: | |||||
| return self.decode(x) | |||||
| def loss(self, x, y): | def loss(self, x, y): | ||||
| """ Since the loss has been computed in forward(), this function simply returns x.""" | |||||
| return x | |||||
| def _internal_loss(self, x, y): | |||||
| """ | """ | ||||
| Negative log likelihood loss. | Negative log likelihood loss. | ||||
| :param x: Tensor, [batch_size, max_len, tag_size] | :param x: Tensor, [batch_size, max_len, tag_size] | ||||
| @@ -74,12 +83,19 @@ class SeqLabeling(BaseModel): | |||||
| mask = mask.to(x) | mask = mask.to(x) | ||||
| return mask | return mask | ||||
| def prediction(self, x): | |||||
| def decode(self, x, pad=True): | |||||
| """ | """ | ||||
| :param x: FloatTensor, [batch_size, max_len, tag_size] | :param x: FloatTensor, [batch_size, max_len, tag_size] | ||||
| :param pad: pad the output sequence to equal lengths | |||||
| :return prediction: list of [decode path(list)] | :return prediction: list of [decode path(list)] | ||||
| """ | """ | ||||
| max_len = x.shape[1] | |||||
| tag_seq = self.Crf.viterbi_decode(x, self.mask) | tag_seq = self.Crf.viterbi_decode(x, self.mask) | ||||
| # pad prediction to equal length | |||||
| if pad is True: | |||||
| for pred in tag_seq: | |||||
| if len(pred) < max_len: | |||||
| pred += [0] * (max_len - len(pred)) | |||||
| return tag_seq | return tag_seq | ||||
| @@ -97,7 +113,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
| num_classes = args["num_classes"] | num_classes = args["num_classes"] | ||||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
| self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||||
| self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True) | |||||
| self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
| self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | ||||
| self.relu = torch.nn.ReLU() | self.relu = torch.nn.ReLU() | ||||
| @@ -106,11 +122,12 @@ class AdvSeqLabel(SeqLabeling): | |||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | ||||
| def forward(self, word_seq, word_seq_origin_len): | |||||
| def forward(self, word_seq, word_seq_origin_len, truth=None): | |||||
| """ | """ | ||||
| :param word_seq: LongTensor, [batch_size, mex_len] | :param word_seq: LongTensor, [batch_size, mex_len] | ||||
| :param word_seq_origin_len: list of int. | :param word_seq_origin_len: list of int. | ||||
| :return y: [batch_size, mex_len, tag_size] | |||||
| :param truth: LongTensor, [batch_size, max_len] | |||||
| :return y: | |||||
| """ | """ | ||||
| self.mask = self.make_mask(word_seq, word_seq_origin_len) | self.mask = self.make_mask(word_seq, word_seq_origin_len) | ||||
| @@ -129,4 +146,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
| x = self.Linear2(x) | x = self.Linear2(x) | ||||
| x = x.view(batch_size, max_len, -1) | x = x.view(batch_size, max_len, -1) | ||||
| # [batch_size, max_len, num_classes] | # [batch_size, max_len, num_classes] | ||||
| return x | |||||
| if truth is not None: | |||||
| return self._internal_loss(x, truth) | |||||
| else: | |||||
| return self.decode(x) | |||||
| @@ -55,14 +55,13 @@ class SelfAttention(nn.Module): | |||||
| input = input.contiguous() | input = input.contiguous() | ||||
| size = input.size() # [bsz, len, nhid] | size = input.size() # [bsz, len, nhid] | ||||
| input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len] | ||||
| input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||||
| input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len] | |||||
| y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||||
| attention = self.ws2(y1).transpose(1,2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||||
| y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit] | |||||
| attention = self.ws2(y1).transpose(1, | |||||
| 2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len] | |||||
| attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token. | ||||
| attention = F.softmax(attention,2) # [baz ,hop, len] | |||||
| return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid] | |||||
| attention = F.softmax(attention, 2) # [baz ,hop, len] | |||||
| return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid] | |||||
| @@ -1,10 +1,10 @@ | |||||
| from .embedding import Embedding | |||||
| from .linear import Linear | |||||
| from .lstm import Lstm | |||||
| from .conv import Conv | from .conv import Conv | ||||
| from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
| from .embedding import Embedding | |||||
| from .linear import Linear | |||||
| from .lstm import LSTM | |||||
| __all__ = ["Lstm", | |||||
| __all__ = ["LSTM", | |||||
| "Embedding", | "Embedding", | ||||
| "Linear", | "Linear", | ||||
| "Conv", | "Conv", | ||||
| @@ -1,9 +1,10 @@ | |||||
| import torch.nn as nn | import torch.nn as nn | ||||
| from fastNLP.modules.utils import initial_parameter | from fastNLP.modules.utils import initial_parameter | ||||
| class Lstm(nn.Module): | |||||
| """ | |||||
| LSTM module | |||||
| class LSTM(nn.Module): | |||||
| """Long Short Term Memory | |||||
| Args: | Args: | ||||
| input_size : input size | input_size : input size | ||||
| @@ -13,13 +14,17 @@ class Lstm(nn.Module): | |||||
| bidirectional : If True, becomes a bidirectional RNN. Default: False. | bidirectional : If True, becomes a bidirectional RNN. Default: False. | ||||
| """ | """ | ||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None): | |||||
| super(Lstm, self).__init__() | |||||
| def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False, | |||||
| initial_method=None): | |||||
| super(LSTM, self).__init__() | |||||
| self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True, | ||||
| dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
| initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
| def forward(self, x): | def forward(self, x): | ||||
| x, _ = self.lstm(x) | x, _ = self.lstm(x) | ||||
| return x | return x | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| lstm = Lstm(10) | |||||
| lstm = LSTM(10) | |||||
| @@ -196,30 +196,3 @@ class BiAffine(nn.Module): | |||||
| output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) | output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2) | ||||
| return output | return output | ||||
| class Transpose(nn.Module): | |||||
| def __init__(self, x, y): | |||||
| super(Transpose, self).__init__() | |||||
| self.x = x | |||||
| self.y = y | |||||
| def forward(self, x): | |||||
| return x.transpose(self.x, self.y) | |||||
| class WordDropout(nn.Module): | |||||
| def __init__(self, dropout_rate, drop_to_token): | |||||
| super(WordDropout, self).__init__() | |||||
| self.dropout_rate = dropout_rate | |||||
| self.drop_to_token = drop_to_token | |||||
| def forward(self, word_idx): | |||||
| if not self.training: | |||||
| return word_idx | |||||
| drop_mask = torch.rand(word_idx.shape) < self.dropout_rate | |||||
| if word_idx.device.type == 'cuda': | |||||
| drop_mask = drop_mask.cuda() | |||||
| drop_mask = drop_mask.long() | |||||
| output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx | |||||
| return output | |||||
| @@ -18,7 +18,7 @@ class ConfigSaver(object): | |||||
| :return: The section. | :return: The section. | ||||
| """ | """ | ||||
| sect = ConfigSection() | sect = ConfigSection() | ||||
| ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect}) | |||||
| ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||||
| return sect | return sect | ||||
| def _read_section(self): | def _read_section(self): | ||||
| @@ -104,7 +104,8 @@ class ConfigSaver(object): | |||||
| :return: | :return: | ||||
| """ | """ | ||||
| section_file = self._get_section(section_name) | section_file = self._get_section(section_name) | ||||
| if len(section_file.__dict__.keys()) == 0:#the section not in file before | |||||
| if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||||
| # append this section to config file | |||||
| with open(self.file_path, 'a') as f: | with open(self.file_path, 'a') as f: | ||||
| f.write('[' + section_name + ']\n') | f.write('[' + section_name + ']\n') | ||||
| for k in section.__dict__.keys(): | for k in section.__dict__.keys(): | ||||
| @@ -114,9 +115,11 @@ class ConfigSaver(object): | |||||
| else: | else: | ||||
| f.write(str(section[k]) + '\n\n') | f.write(str(section[k]) + '\n\n') | ||||
| else: | else: | ||||
| # the section exists | |||||
| change_file = False | change_file = False | ||||
| for k in section.__dict__.keys(): | for k in section.__dict__.keys(): | ||||
| if k not in section_file: | if k not in section_file: | ||||
| # find a new key in this section | |||||
| change_file = True | change_file = True | ||||
| break | break | ||||
| if section_file[k] != section[k]: | if section_file[k] != section[k]: | ||||
| @@ -0,0 +1,25 @@ | |||||
| from fastNLP.core.loss import Loss | |||||
| from fastNLP.core.preprocess import Preprocessor | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.loader.dataset_loader import LMDataSetLoader | |||||
| from fastNLP.models.char_language_model import CharLM | |||||
| PICKLE = "./save/" | |||||
| def train(): | |||||
| loader = LMDataSetLoader() | |||||
| train_data = loader.load() | |||||
| pre = Preprocessor(label_is_seq=True, share_vocab=True) | |||||
| train_set = pre.run(train_data, pickle_path=PICKLE) | |||||
| model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size) | |||||
| trainer = Trainer(task="language_model", loss=Loss("cross_entropy")) | |||||
| trainer.train(model, train_set) | |||||
| if __name__ == "__main__": | |||||
| train() | |||||
| @@ -4,12 +4,12 @@ from fastNLP.core.preprocess import ClassPreprocess as Preprocess | |||||
| from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
| from fastNLP.loader.config_loader import ConfigLoader | from fastNLP.loader.config_loader import ConfigLoader | ||||
| from fastNLP.loader.config_loader import ConfigSection | from fastNLP.loader.config_loader import ConfigSection | ||||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader as Dataset_loader | |||||
| from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader | |||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.modules.aggregator.self_attention import SelfAttention | from fastNLP.modules.aggregator.self_attention import SelfAttention | ||||
| from fastNLP.modules.decoder.MLP import MLP | from fastNLP.modules.decoder.MLP import MLP | ||||
| from fastNLP.modules.encoder.embedding import Embedding as Embedding | from fastNLP.modules.encoder.embedding import Embedding as Embedding | ||||
| from fastNLP.modules.encoder.lstm import Lstm | |||||
| from fastNLP.modules.encoder.lstm import LSTM | |||||
| train_data_path = 'small_train_data.txt' | train_data_path = 'small_train_data.txt' | ||||
| dev_data_path = 'small_dev_data.txt' | dev_data_path = 'small_dev_data.txt' | ||||
| @@ -43,7 +43,7 @@ class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel): | |||||
| def __init__(self, args=None): | def __init__(self, args=None): | ||||
| super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() | super(SELF_ATTENTION_YELP_CLASSIFICATION,self).__init__() | ||||
| self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) | self.embedding = Embedding(len(word2index) ,embeding_size , init_emb= None ) | ||||
| self.lstm = Lstm(input_size = embeding_size,hidden_size = lstm_hidden_size ,bidirectional = True) | |||||
| self.lstm = LSTM(input_size=embeding_size, hidden_size=lstm_hidden_size, bidirectional=True) | |||||
| self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) | self.attention = SelfAttention(lstm_hidden_size * 2 ,dim =attention_unit ,num_vec=attention_hops) | ||||
| self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ]) | self.mlp = MLP(size_layer=[lstm_hidden_size * 2*attention_hops ,nfc ,class_num ]) | ||||
| def forward(self,x): | def forward(self,x): | ||||
| @@ -5,50 +5,52 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
| from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
| from fastNLP.loader.dataset_loader import BaseLoader, TokenizeDataSetLoader | |||||
| from fastNLP.core.preprocess import load_pickle | |||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | from fastNLP.models.sequence_modeling import AdvSeqLabel | ||||
| from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.preprocess import save_pickle | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| # not in the file's dir | # not in the file's dir | ||||
| if len(os.path.dirname(__file__)) != 0: | if len(os.path.dirname(__file__)) != 0: | ||||
| os.chdir(os.path.dirname(__file__)) | os.chdir(os.path.dirname(__file__)) | ||||
| datadir = "/home/zyfeng/data/" | datadir = "/home/zyfeng/data/" | ||||
| cfgfile = './cws.cfg' | cfgfile = './cws.cfg' | ||||
| data_name = "pku_training.utf8" | |||||
| cws_data_path = os.path.join(datadir, "pku_training.utf8") | cws_data_path = os.path.join(datadir, "pku_training.utf8") | ||||
| pickle_path = "save" | pickle_path = "save" | ||||
| data_infer_path = os.path.join(datadir, "infer.utf8") | data_infer_path = os.path.join(datadir, "infer.utf8") | ||||
| def infer(): | def infer(): | ||||
| # Config Loader | # Config Loader | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
| ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # Define the same model | # Define the same model | ||||
| model = AdvSeqLabel(test_args) | model = AdvSeqLabel(test_args) | ||||
| try: | try: | ||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
| ModelLoader.load_pytorch(model, "./save/trained_model.pkl") | |||||
| print('model loaded!') | print('model loaded!') | ||||
| except Exception as e: | except Exception as e: | ||||
| print('cannot load model!') | print('cannot load model!') | ||||
| raise | raise | ||||
| # Data Loader | # Data Loader | ||||
| raw_data_loader = BaseLoader(data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | |||||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load_lines) | |||||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||||
| print('data loaded') | print('data loaded') | ||||
| # Inference interface | # Inference interface | ||||
| @@ -63,20 +65,27 @@ def train(): | |||||
| # Config Loader | # Config Loader | ||||
| train_args = ConfigSection() | train_args = ConfigSection() | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("good_path").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
| ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
| # Data Loader | |||||
| loader = TokenizeDatasetLoader(cws_data_path) | |||||
| train_data = loader.load_pku() | |||||
| print("loading data set...") | |||||
| data = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||||
| data.load(cws_data_path) | |||||
| data_train, data_dev = data.split(ratio=0.3) | |||||
| train_args["vocab_size"] = len(data.word_vocab) | |||||
| train_args["num_classes"] = len(data.label_vocab) | |||||
| print("vocab size={}, num_classes={}".format(len(data.word_vocab), len(data.label_vocab))) | |||||
| # Preprocessor | |||||
| preprocessor = SeqLabelPreprocess() | |||||
| data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3) | |||||
| train_args["vocab_size"] = preprocessor.vocab_size | |||||
| train_args["num_classes"] = preprocessor.num_classes | |||||
| change_field_is_target(data_dev, "truth", True) | |||||
| save_pickle(data_dev, "./save/", "data_dev.pkl") | |||||
| save_pickle(data.word_vocab, "./save/", "word2id.pkl") | |||||
| save_pickle(data.label_vocab, "./save/", "label2id.pkl") | |||||
| # Trainer | # Trainer | ||||
| trainer = SeqLabelTrainer(**train_args.data) | |||||
| trainer = SeqLabelTrainer(epochs=train_args["epochs"], batch_size=train_args["batch_size"], | |||||
| validate=train_args["validate"], | |||||
| use_cuda=train_args["use_cuda"], pickle_path=train_args["pickle_path"], | |||||
| save_best_dev=True, print_every_step=10, model_name="trained_model.pkl", | |||||
| evaluator=SeqLabelEvaluator()) | |||||
| # Model | # Model | ||||
| model = AdvSeqLabel(train_args) | model = AdvSeqLabel(train_args) | ||||
| @@ -86,26 +95,26 @@ def train(): | |||||
| except Exception as e: | except Exception as e: | ||||
| print("No saved model. Continue.") | print("No saved model. Continue.") | ||||
| pass | pass | ||||
| # Start training | # Start training | ||||
| trainer.train(model, data_train, data_dev) | trainer.train(model, data_train, data_dev) | ||||
| print("Training finished!") | print("Training finished!") | ||||
| # Saver | # Saver | ||||
| saver = ModelSaver("./save/saved_model.pkl") | |||||
| saver = ModelSaver("./save/trained_model.pkl") | |||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| print("Model saved!") | print("Model saved!") | ||||
| def test(): | |||||
| def predict(): | |||||
| # Config Loader | # Config Loader | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
| ConfigLoader().load_config(cfgfile, {"POS_test": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # load dev data | # load dev data | ||||
| @@ -115,29 +124,28 @@ def test(): | |||||
| model = AdvSeqLabel(test_args) | model = AdvSeqLabel(test_args) | ||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | |||||
| ModelLoader.load_pytorch(model, "./save/trained_model.pkl") | |||||
| print("model loaded!") | print("model loaded!") | ||||
| # Tester | # Tester | ||||
| test_args["evaluator"] = SeqLabelEvaluator() | |||||
| tester = SeqLabelTester(**test_args.data) | tester = SeqLabelTester(**test_args.data) | ||||
| # Start testing | # Start testing | ||||
| tester.test(model, dev_data) | tester.test(model, dev_data) | ||||
| # print test results | |||||
| print(tester.show_metrics()) | |||||
| print("model tested!") | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| import argparse | import argparse | ||||
| parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | ||||
| parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | ||||
| args = parser.parse_args() | args = parser.parse_args() | ||||
| if args.mode == 'train': | if args.mode == 'train': | ||||
| train() | train() | ||||
| elif args.mode == 'test': | elif args.mode == 'test': | ||||
| test() | |||||
| predict() | |||||
| elif args.mode == 'infer': | elif args.mode == 'infer': | ||||
| infer() | infer() | ||||
| else: | else: | ||||
| @@ -66,7 +66,7 @@ def train(): | |||||
| ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | ||||
| # Data Loader | # Data Loader | ||||
| loader = PeopleDailyCorpusLoader(pos_tag_data_path) | |||||
| loader = PeopleDailyCorpusLoader() | |||||
| train_data, _ = loader.load() | train_data, _ = loader.load() | ||||
| # Preprocessor | # Preprocessor | ||||
| @@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f: | |||||
| setup( | setup( | ||||
| name='fastNLP', | name='fastNLP', | ||||
| version='0.0.3', | |||||
| version='0.1.0', | |||||
| description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', | ||||
| long_description=readme, | long_description=readme, | ||||
| license=license, | license=license, | ||||
| @@ -43,8 +43,10 @@ class TestCase1(unittest.TestCase): | |||||
| # use batch to iterate dataset | # use batch to iterate dataset | ||||
| data_iterator = Batch(data, 2, SeqSampler(), False) | data_iterator = Batch(data, 2, SeqSampler(), False) | ||||
| total_data = 0 | |||||
| for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
| self.assertEqual(len(batch_x), 2) | |||||
| total_data += batch_x["text"].size(0) | |||||
| self.assertTrue(batch_x["text"].size(0) == 2 or total_data == len(raw_texts)) | |||||
| self.assertTrue(isinstance(batch_x, dict)) | self.assertTrue(isinstance(batch_x, dict)) | ||||
| self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | self.assertTrue(isinstance(batch_x["text"], torch.LongTensor)) | ||||
| self.assertTrue(isinstance(batch_y, dict)) | self.assertTrue(isinstance(batch_y, dict)) | ||||
| @@ -0,0 +1,243 @@ | |||||
| import unittest | |||||
| from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet | |||||
| from fastNLP.core.dataset import create_dataset_from_lists | |||||
| class TestDataSet(unittest.TestCase): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
| def test_case_1(self): | |||||
| data_set = create_dataset_from_lists(self.labeled_data_list, self.word_vocab, has_target=True, | |||||
| label_vocab=self.label_vocab) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("label_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["label_seq"].text, self.labeled_data_list[0][1]) | |||||
| self.assertEqual(data_set[0].fields["label_seq"]._index, | |||||
| [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
| def test_case_2(self): | |||||
| data_set = create_dataset_from_lists(self.unlabeled_data_list, self.word_vocab, has_target=False) | |||||
| self.assertEqual(len(data_set), len(self.unlabeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.unlabeled_data_list[0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.unlabeled_data_list[0]]) | |||||
| class TestDataSetConvertion(unittest.TestCase): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
| label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||||
| def test_case_1(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = SeqLabelDataSet(load_func=loader) | |||||
| data_set.load("any_path") | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertTrue("truth" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
| def test_case_2(self): | |||||
| def loader(path): | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| return unlabeled_data_list | |||||
| data_set = SeqLabelDataSet(load_func=loader) | |||||
| data_set.load("any_path", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
| def test_case_3(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| [["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = SeqLabelDataSet(load_func=loader) | |||||
| data_set.load("any_path", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("truth" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["truth"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["truth"].text, self.labeled_data_list[0][1]) | |||||
| self.assertEqual(data_set[0].fields["truth"]._index, | |||||
| [self.label_vocab[c] for c in self.labeled_data_list[0][1]]) | |||||
| self.assertTrue("word_seq_origin_len" in data_set[0].fields) | |||||
| class TestDataSetConvertionHHH(unittest.TestCase): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], "A"], | |||||
| [["a", "b", "e", "d"], "C"], | |||||
| [["a", "b", "e", "d"], "B"], | |||||
| ] | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||||
| label_vocab = {"A": 1, "B": 2, "C": 3} | |||||
| def test_case_1(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], "A"], | |||||
| [["a", "b", "e", "d"], "C"], | |||||
| [["a", "b", "e", "d"], "B"], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = TextClassifyDataSet(load_func=loader) | |||||
| data_set.load("xxx") | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertTrue("label" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||||
| def test_case_2(self): | |||||
| def loader(path): | |||||
| labeled_data_list = [ | |||||
| [["a", "b", "e", "d"], "A"], | |||||
| [["a", "b", "e", "d"], "C"], | |||||
| [["a", "b", "e", "d"], "B"], | |||||
| ] | |||||
| return labeled_data_list | |||||
| data_set = TextClassifyDataSet(load_func=loader) | |||||
| data_set.load("xxx", vocabs={"word_vocab": self.word_vocab, "label_vocab": self.label_vocab}) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| self.assertTrue("label" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "label")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["label"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["label"].label, self.labeled_data_list[0][1]) | |||||
| self.assertEqual(data_set[0].fields["label"]._index, self.label_vocab[self.labeled_data_list[0][1]]) | |||||
| def test_case_3(self): | |||||
| def loader(path): | |||||
| unlabeled_data_list = [ | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"], | |||||
| ["a", "b", "e", "d"] | |||||
| ] | |||||
| return unlabeled_data_list | |||||
| data_set = TextClassifyDataSet(load_func=loader) | |||||
| data_set.load("xxx", vocabs={"word_vocab": self.word_vocab}, infer=True) | |||||
| self.assertEqual(len(data_set), len(self.labeled_data_list)) | |||||
| self.assertTrue(len(data_set) > 0) | |||||
| self.assertTrue(hasattr(data_set[0], "fields")) | |||||
| self.assertTrue("word_seq" in data_set[0].fields) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "text")) | |||||
| self.assertTrue(hasattr(data_set[0].fields["word_seq"], "_index")) | |||||
| self.assertEqual(data_set[0].fields["word_seq"].text, self.labeled_data_list[0][0]) | |||||
| self.assertEqual(data_set[0].fields["word_seq"]._index, | |||||
| [self.word_vocab[c] for c in self.labeled_data_list[0][0]]) | |||||
| @@ -1,20 +1,42 @@ | |||||
| import sys, os | |||||
| import os | |||||
| import sys | |||||
| sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | sys.path = [os.path.join(os.path.dirname(__file__), '..')] + sys.path | ||||
| from fastNLP.core import metrics | from fastNLP.core import metrics | ||||
| # from sklearn import metrics as skmetrics | # from sklearn import metrics as skmetrics | ||||
| import unittest | import unittest | ||||
| import numpy as np | |||||
| from numpy import random | from numpy import random | ||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| import torch | |||||
| def generate_fake_label(low, high, size): | def generate_fake_label(low, high, size): | ||||
| return random.randint(low, high, size), random.randint(low, high, size) | return random.randint(low, high, size), random.randint(low, high, size) | ||||
| class TestEvaluator(unittest.TestCase): | |||||
| def test_a(self): | |||||
| evaluator = SeqLabelEvaluator() | |||||
| pred = [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]] | |||||
| truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4])}] | |||||
| ans = evaluator(pred, truth) | |||||
| print(ans) | |||||
| def test_b(self): | |||||
| evaluator = SeqLabelEvaluator() | |||||
| pred = [[1, 2, 3, 4, 5, 0, 0], [1, 2, 3, 4, 5, 0, 0]] | |||||
| truth = [{"truth": torch.LongTensor([1, 2, 3, 3, 3, 0, 0])}, {"truth": torch.LongTensor([1, 2, 3, 3, 4, 0, 0])}] | |||||
| ans = evaluator(pred, truth) | |||||
| print(ans) | |||||
| class TestMetrics(unittest.TestCase): | class TestMetrics(unittest.TestCase): | ||||
| delta = 1e-5 | delta = 1e-5 | ||||
| # test for binary, multiclass, multilabel | # test for binary, multiclass, multilabel | ||||
| data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | data_types = [((1000,), 2), ((1000,), 10), ((1000, 10), 2)] | ||||
| fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | fake_data = [generate_fake_label(0, high, shape) for shape, high in data_types] | ||||
| def test_accuracy_score(self): | def test_accuracy_score(self): | ||||
| for y_true, y_pred in self.fake_data: | for y_true, y_pred in self.fake_data: | ||||
| for normalize in [True, False]: | for normalize in [True, False]: | ||||
| @@ -22,7 +44,7 @@ class TestMetrics(unittest.TestCase): | |||||
| test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | test = metrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | ||||
| # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | # ans = skmetrics.accuracy_score(y_true, y_pred, normalize=normalize, sample_weight=sample_weight) | ||||
| # self.assertAlmostEqual(test, ans, delta=self.delta) | # self.assertAlmostEqual(test, ans, delta=self.delta) | ||||
| def test_recall_score(self): | def test_recall_score(self): | ||||
| for y_true, y_pred in self.fake_data: | for y_true, y_pred in self.fake_data: | ||||
| # print(y_true.shape) | # print(y_true.shape) | ||||
| @@ -73,5 +95,6 @@ class TestMetrics(unittest.TestCase): | |||||
| # ans = skmetrics.f1_score(y_true, y_pred) | # ans = skmetrics.f1_score(y_true, y_pred) | ||||
| # self.assertAlmostEqual(ans, test, delta=self.delta) | # self.assertAlmostEqual(ans, test, delta=self.delta) | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -1,10 +1,13 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import TextClassifyDataSet, SeqLabelDataSet | |||||
| from fastNLP.core.predictor import Predictor | from fastNLP.core.predictor import Predictor | ||||
| from fastNLP.core.preprocess import save_pickle | from fastNLP.core.preprocess import save_pickle | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
| from fastNLP.loader.base_loader import BaseLoader | |||||
| from fastNLP.models.cnn_text_classification import CNNText | |||||
| from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| class TestPredictor(unittest.TestCase): | class TestPredictor(unittest.TestCase): | ||||
| @@ -28,23 +31,44 @@ class TestPredictor(unittest.TestCase): | |||||
| vocab = Vocabulary() | vocab = Vocabulary() | ||||
| vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab.word2idx = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
| class_vocab = Vocabulary() | class_vocab = Vocabulary() | ||||
| class_vocab.word2idx = {"0":0, "1":1, "2":2, "3":3, "4":4} | |||||
| class_vocab.word2idx = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} | |||||
| os.system("mkdir save") | os.system("mkdir save") | ||||
| save_pickle(class_vocab, "./save/", "class2id.pkl") | |||||
| save_pickle(class_vocab, "./save/", "label2id.pkl") | |||||
| save_pickle(vocab, "./save/", "word2id.pkl") | save_pickle(vocab, "./save/", "word2id.pkl") | ||||
| model = SeqLabeling(model_args) | |||||
| predictor = Predictor("./save/", task="seq_label") | |||||
| model = CNNText(model_args) | |||||
| import fastNLP.core.predictor as pre | |||||
| predictor = Predictor("./save/", pre.text_classify_post_processor) | |||||
| results = predictor.predict(network=model, data=infer_data) | |||||
| # Load infer data | |||||
| infer_data_set = TextClassifyDataSet(load_func=BaseLoader.load) | |||||
| infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||||
| results = predictor.predict(network=model, data=infer_data_set) | |||||
| self.assertTrue(isinstance(results, list)) | self.assertTrue(isinstance(results, list)) | ||||
| self.assertGreater(len(results), 0) | self.assertGreater(len(results), 0) | ||||
| self.assertEqual(len(results), len(infer_data)) | |||||
| for res in results: | for res in results: | ||||
| self.assertTrue(isinstance(res, str)) | |||||
| self.assertTrue(res in class_vocab.word2idx) | |||||
| del model, predictor, infer_data_set | |||||
| model = SeqLabeling(model_args) | |||||
| predictor = Predictor("./save/", pre.seq_label_post_processor) | |||||
| infer_data_set = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
| infer_data_set.convert_for_infer(infer_data, vocabs={"word_vocab": vocab.word2idx}) | |||||
| results = predictor.predict(network=model, data=infer_data_set) | |||||
| self.assertTrue(isinstance(results, list)) | |||||
| self.assertEqual(len(results), len(infer_data)) | |||||
| for i in range(len(infer_data)): | |||||
| res = results[i] | |||||
| self.assertTrue(isinstance(res, list)) | self.assertTrue(isinstance(res, list)) | ||||
| self.assertEqual(len(res), 5) | |||||
| self.assertTrue(isinstance(res[0], str)) | |||||
| self.assertEqual(len(res), len(infer_data[i])) | |||||
| os.system("rm -rf save") | os.system("rm -rf save") | ||||
| print("pickle path deleted") | print("pickle path deleted") | ||||
| @@ -1,8 +1,9 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField | |||||
| from fastNLP.core.dataset import SeqLabelDataSet | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| @@ -21,7 +22,7 @@ class TestTester(unittest.TestCase): | |||||
| } | } | ||||
| valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
| "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | "save_loss": True, "batch_size": 2, "pickle_path": "./save/", | ||||
| "use_cuda": False, "print_every_step": 1} | |||||
| "use_cuda": False, "print_every_step": 1, "evaluator": SeqLabelEvaluator()} | |||||
| train_data = [ | train_data = [ | ||||
| [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | [['a', 'b', 'c', 'd', 'e'], ['a', '@', 'c', 'd', 'e']], | ||||
| @@ -34,16 +35,17 @@ class TestTester(unittest.TestCase): | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | ||||
| data_set = DataSet() | |||||
| data_set = SeqLabelDataSet() | |||||
| for example in train_data: | for example in train_data: | ||||
| text, label = example[0], example[1] | text, label = example[0], example[1] | ||||
| x = TextField(text, False) | x = TextField(text, False) | ||||
| x_len = LabelField(len(text), is_target=False) | |||||
| y = TextField(label, is_target=True) | y = TextField(label, is_target=True) | ||||
| ins = Instance(word_seq=x, label_seq=y) | |||||
| ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
| data_set.append(ins) | data_set.append(ins) | ||||
| data_set.index_field("word_seq", vocab) | data_set.index_field("word_seq", vocab) | ||||
| data_set.index_field("label_seq", label_vocab) | |||||
| data_set.index_field("truth", label_vocab) | |||||
| model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
| @@ -1,8 +1,9 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.core.field import TextField | |||||
| from fastNLP.core.dataset import SeqLabelDataSet | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.field import TextField, LabelField | |||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| @@ -12,14 +13,15 @@ from fastNLP.models.sequence_modeling import SeqLabeling | |||||
| class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| args = {"epochs": 3, "batch_size": 2, "validate": True, "use_cuda": False, "pickle_path": "./save/", | |||||
| args = {"epochs": 3, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/", | |||||
| "save_best_dev": True, "model_name": "default_model_name.pkl", | "save_best_dev": True, "model_name": "default_model_name.pkl", | ||||
| "loss": Loss(None), | |||||
| "loss": Loss("cross_entropy"), | |||||
| "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0), | ||||
| "vocab_size": 10, | "vocab_size": 10, | ||||
| "word_emb_dim": 100, | "word_emb_dim": 100, | ||||
| "rnn_hidden_units": 100, | "rnn_hidden_units": 100, | ||||
| "num_classes": 5 | |||||
| "num_classes": 5, | |||||
| "evaluator": SeqLabelEvaluator() | |||||
| } | } | ||||
| trainer = SeqLabelTrainer(**args) | trainer = SeqLabelTrainer(**args) | ||||
| @@ -34,16 +36,17 @@ class TestTrainer(unittest.TestCase): | |||||
| vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | vocab = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, '!': 5, '@': 6, '#': 7, '$': 8, '?': 9} | ||||
| label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | label_vocab = {'a': 0, '@': 1, 'c': 2, 'd': 3, 'e': 4} | ||||
| data_set = DataSet() | |||||
| data_set = SeqLabelDataSet() | |||||
| for example in train_data: | for example in train_data: | ||||
| text, label = example[0], example[1] | text, label = example[0], example[1] | ||||
| x = TextField(text, False) | x = TextField(text, False) | ||||
| y = TextField(label, is_target=True) | |||||
| ins = Instance(word_seq=x, label_seq=y) | |||||
| x_len = LabelField(len(text), is_target=False) | |||||
| y = TextField(label, is_target=False) | |||||
| ins = Instance(word_seq=x, truth=y, word_seq_origin_len=x_len) | |||||
| data_set.append(ins) | data_set.append(ins) | ||||
| data_set.index_field("word_seq", vocab) | data_set.index_field("word_seq", vocab) | ||||
| data_set.index_field("label_seq", label_vocab) | |||||
| data_set.index_field("truth", label_vocab) | |||||
| model = SeqLabeling(args) | model = SeqLabeling(args) | ||||
| @@ -9,10 +9,54 @@ input = [1,2,3] | |||||
| text = "this is text" | text = "this is text" | ||||
| doubles = 0.5 | |||||
| doubles = 0.8 | |||||
| tt = 0.5 | |||||
| test = 105 | |||||
| str = "this is a str" | |||||
| double = 0.5 | |||||
| [t] | [t] | ||||
| x = "this is an test section" | x = "this is an test section" | ||||
| [test-case-2] | [test-case-2] | ||||
| double = 0.5 | double = 0.5 | ||||
| doubles = 0.8 | |||||
| tt = 0.5 | |||||
| test = 105 | |||||
| str = "this is a str" | |||||
| [another-test] | |||||
| doubles = 0.8 | |||||
| tt = 0.5 | |||||
| test = 105 | |||||
| str = "this is a str" | |||||
| double = 0.5 | |||||
| [one-another-test] | |||||
| doubles = 0.8 | |||||
| tt = 0.5 | |||||
| test = 105 | |||||
| str = "this is a str" | |||||
| double = 0.5 | |||||
| @@ -31,7 +31,7 @@ class TestConfigLoader(unittest.TestCase): | |||||
| return dict | return dict | ||||
| test_arg = ConfigSection() | test_arg = ConfigSection() | ||||
| ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
| ConfigLoader().load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
| section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | section = read_section_from_config(os.path.join("./test/loader", "config"), "test") | ||||
| @@ -1,6 +1,7 @@ | |||||
| import os | |||||
| import unittest | import unittest | ||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader, LMDatasetLoader, TokenizeDatasetLoader, \ | |||||
| from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \ | |||||
| PeopleDailyCorpusLoader, ConllLoader | PeopleDailyCorpusLoader, ConllLoader | ||||
| @@ -8,34 +9,34 @@ class TestDatasetLoader(unittest.TestCase): | |||||
| def test_case_1(self): | def test_case_1(self): | ||||
| data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | data = """Tom\tT\nand\tF\nJerry\tT\n.\tF\n\nHello\tT\nworld\tF\n!\tF""" | ||||
| lines = data.split("\n") | lines = data.split("\n") | ||||
| answer = POSDatasetLoader.parse(lines) | |||||
| answer = POSDataSetLoader.parse(lines) | |||||
| truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | truth = [[["Tom", "and", "Jerry", "."], ["T", "F", "T", "F"]], [["Hello", "world", "!"], ["T", "F", "F"]]] | ||||
| self.assertListEqual(answer, truth, "POS Dataset Loader") | self.assertListEqual(answer, truth, "POS Dataset Loader") | ||||
| def test_case_TokenizeDatasetLoader(self): | def test_case_TokenizeDatasetLoader(self): | ||||
| loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||||
| data = loader.load_pku(max_seq_len=32) | |||||
| print("pass TokenizeDatasetLoader test!") | |||||
| loader = TokenizeDataSetLoader() | |||||
| data = loader.load("./test/data_for_tests/cws_pku_utf_8", max_seq_len=32) | |||||
| print("pass TokenizeDataSetLoader test!") | |||||
| def test_case_POSDatasetLoader(self): | def test_case_POSDatasetLoader(self): | ||||
| loader = POSDatasetLoader("./test/data_for_tests/people.txt") | |||||
| data = loader.load() | |||||
| datas = loader.load_lines() | |||||
| print("pass POSDatasetLoader test!") | |||||
| loader = POSDataSetLoader() | |||||
| data = loader.load("./test/data_for_tests/people.txt") | |||||
| datas = loader.load_lines("./test/data_for_tests/people.txt") | |||||
| print("pass POSDataSetLoader test!") | |||||
| def test_case_LMDatasetLoader(self): | def test_case_LMDatasetLoader(self): | ||||
| loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||||
| data = loader.load() | |||||
| datas = loader.load_lines() | |||||
| print("pass TokenizeDatasetLoader test!") | |||||
| loader = LMDataSetLoader() | |||||
| data = loader.load("./test/data_for_tests/charlm.txt") | |||||
| datas = loader.load_lines("./test/data_for_tests/charlm.txt") | |||||
| print("pass TokenizeDataSetLoader test!") | |||||
| def test_PeopleDailyCorpusLoader(self): | def test_PeopleDailyCorpusLoader(self): | ||||
| loader = PeopleDailyCorpusLoader("./test/data_for_tests/people_daily_raw.txt") | |||||
| _, _ = loader.load() | |||||
| loader = PeopleDailyCorpusLoader() | |||||
| _, _ = loader.load("./test/data_for_tests/people_daily_raw.txt") | |||||
| def test_ConllLoader(self): | def test_ConllLoader(self): | ||||
| loader = ConllLoader("./test/data_for_tests/conll_example.txt") | |||||
| _ = loader.load() | |||||
| loader = ConllLoader() | |||||
| _ = loader.load("./test/data_for_tests/conll_example.txt") | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -4,14 +4,16 @@ sys.path.append("..") | |||||
| import argparse | import argparse | ||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||||
| from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle | |||||
| from fastNLP.loader.dataset_loader import BaseLoader | |||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| from fastNLP.core.predictor import SeqLabelInfer | from fastNLP.core.predictor import SeqLabelInfer | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | ||||
| @@ -33,24 +35,27 @@ data_infer_path = args.infer | |||||
| def infer(): | def infer(): | ||||
| # Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_dir, {"POS_infer": test_args}) | |||||
| ConfigLoader().load_config(config_dir, {"POS_infer": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | |||||
| test_args["vocab_size"] = len(word2index) | |||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | |||||
| word_vocab = load_pickle(pickle_path, "word2id.pkl") | |||||
| label_vocab = load_pickle(pickle_path, "label2id.pkl") | |||||
| test_args["vocab_size"] = len(word_vocab) | |||||
| test_args["num_classes"] = len(label_vocab) | |||||
| print("vocabularies loaded") | |||||
| # Define the same model | # Define the same model | ||||
| model = SeqLabeling(test_args) | model = SeqLabeling(test_args) | ||||
| print("model defined") | |||||
| # Dump trained parameters into the model | # Dump trained parameters into the model | ||||
| ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | ModelLoader.load_pytorch(model, os.path.join(pickle_path, model_name)) | ||||
| print("model loaded!") | print("model loaded!") | ||||
| # Data Loader | # Data Loader | ||||
| raw_data_loader = BaseLoader(data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | |||||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}, infer=True) | |||||
| print("data set prepared") | |||||
| # Inference interface | # Inference interface | ||||
| infer = SeqLabelInfer(pickle_path) | infer = SeqLabelInfer(pickle_path) | ||||
| @@ -65,24 +70,18 @@ def train_and_test(): | |||||
| # Config Loader | # Config Loader | ||||
| trainer_args = ConfigSection() | trainer_args = ConfigSection() | ||||
| model_args = ConfigSection() | model_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_dir, { | |||||
| ConfigLoader().load_config(config_dir, { | |||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
| # Data Loader | |||||
| pos_loader = POSDatasetLoader(data_path) | |||||
| train_data = pos_loader.load_lines() | |||||
| # Preprocessor | |||||
| p = SeqLabelPreprocess() | |||||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
| model_args["vocab_size"] = p.vocab_size | |||||
| model_args["num_classes"] = p.num_classes | |||||
| data_set = SeqLabelDataSet() | |||||
| data_set.load(data_path) | |||||
| train_set, dev_set = data_set.split(0.3, shuffle=True) | |||||
| model_args["vocab_size"] = len(data_set.word_vocab) | |||||
| model_args["num_classes"] = len(data_set.label_vocab) | |||||
| # Trainer: two definition styles | |||||
| # 1 | |||||
| # trainer = SeqLabelTrainer(trainer_args.data) | |||||
| save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
| # 2 | |||||
| trainer = SeqLabelTrainer( | trainer = SeqLabelTrainer( | ||||
| epochs=trainer_args["epochs"], | epochs=trainer_args["epochs"], | ||||
| batch_size=trainer_args["batch_size"], | batch_size=trainer_args["batch_size"], | ||||
| @@ -98,7 +97,7 @@ def train_and_test(): | |||||
| model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
| # Start training | # Start training | ||||
| trainer.train(model, data_train, data_dev) | |||||
| trainer.train(model, train_set, dev_set) | |||||
| print("Training finished!") | print("Training finished!") | ||||
| # Saver | # Saver | ||||
| @@ -106,7 +105,9 @@ def train_and_test(): | |||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| print("Model saved!") | print("Model saved!") | ||||
| del model, trainer, pos_loader | |||||
| del model, trainer | |||||
| change_field_is_target(dev_set, "truth", True) | |||||
| # Define the same model | # Define the same model | ||||
| model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
| @@ -117,27 +118,21 @@ def train_and_test(): | |||||
| # Load test configuration | # Load test configuration | ||||
| tester_args = ConfigSection() | tester_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| # Tester | # Tester | ||||
| tester = SeqLabelTester(save_output=False, | |||||
| save_loss=True, | |||||
| save_best_dev=False, | |||||
| batch_size=4, | |||||
| tester = SeqLabelTester(batch_size=4, | |||||
| use_cuda=False, | use_cuda=False, | ||||
| pickle_path=pickle_path, | pickle_path=pickle_path, | ||||
| model_name="seq_label_in_test.pkl", | model_name="seq_label_in_test.pkl", | ||||
| print_every_step=1 | |||||
| evaluator=SeqLabelEvaluator() | |||||
| ) | ) | ||||
| # Start testing with validation data | # Start testing with validation data | ||||
| tester.test(model, data_dev) | |||||
| # print test results | |||||
| print(tester.show_metrics()) | |||||
| tester.test(model, dev_set) | |||||
| print("model tested!") | print("model tested!") | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| train_and_test() | train_and_test() | ||||
| # infer() | |||||
| infer() | |||||
| @@ -1,30 +1,32 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.predictor import Predictor | |||||
| from fastNLP.core.preprocess import Preprocessor, load_pickle | |||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.predictor import SeqLabelInfer | |||||
| from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||||
| from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader | |||||
| from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
| cws_data_path = "test/data_for_tests/cws_pku_utf_8" | |||||
| cws_data_path = "./test/data_for_tests/cws_pku_utf_8" | |||||
| pickle_path = "./save/" | pickle_path = "./save/" | ||||
| data_infer_path = "test/data_for_tests/people_infer.txt" | |||||
| config_path = "test/data_for_tests/config" | |||||
| data_infer_path = "./test/data_for_tests/people_infer.txt" | |||||
| config_path = "./test/data_for_tests/config" | |||||
| def infer(): | def infer(): | ||||
| # Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args}) | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
| # fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
| word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
| test_args["vocab_size"] = len(word2index) | test_args["vocab_size"] = len(word2index) | ||||
| index2label = load_pickle(pickle_path, "class2id.pkl") | |||||
| index2label = load_pickle(pickle_path, "label2id.pkl") | |||||
| test_args["num_classes"] = len(index2label) | test_args["num_classes"] = len(index2label) | ||||
| # Define the same model | # Define the same model | ||||
| @@ -34,31 +36,29 @@ def infer(): | |||||
| ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | ModelLoader.load_pytorch(model, "./save/saved_model.pkl") | ||||
| print("model loaded!") | print("model loaded!") | ||||
| # Data Loader | |||||
| raw_data_loader = BaseLoader(data_infer_path) | |||||
| infer_data = raw_data_loader.load_lines() | |||||
| # Load infer data | |||||
| infer_data = SeqLabelDataSet(load_func=BaseLoader.load) | |||||
| infer_data.load(data_infer_path, vocabs={"word_vocab": word2index}, infer=True) | |||||
| # Inference interface | |||||
| infer = Predictor(pickle_path, "seq_label") | |||||
| # inference | |||||
| infer = SeqLabelInfer(pickle_path) | |||||
| results = infer.predict(model, infer_data) | results = infer.predict(model, infer_data) | ||||
| print(results) | print(results) | ||||
| def train_test(): | def train_test(): | ||||
| # Config Loader | # Config Loader | ||||
| train_args = ConfigSection() | train_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": train_args}) | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": train_args}) | |||||
| # Data Loader | |||||
| loader = TokenizeDatasetLoader(cws_data_path) | |||||
| train_data = loader.load_pku() | |||||
| # define dataset | |||||
| data_train = SeqLabelDataSet(load_func=TokenizeDataSetLoader.load) | |||||
| data_train.load(cws_data_path) | |||||
| train_args["vocab_size"] = len(data_train.word_vocab) | |||||
| train_args["num_classes"] = len(data_train.label_vocab) | |||||
| # Preprocessor | |||||
| p = Preprocessor(label_is_seq=True) | |||||
| data_train = p.run(train_data, pickle_path=pickle_path) | |||||
| train_args["vocab_size"] = p.vocab_size | |||||
| train_args["num_classes"] = p.num_classes | |||||
| save_pickle(data_train.word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(data_train.label_vocab, pickle_path, "label2id.pkl") | |||||
| # Trainer | # Trainer | ||||
| trainer = SeqLabelTrainer(**train_args.data) | trainer = SeqLabelTrainer(**train_args.data) | ||||
| @@ -73,7 +73,7 @@ def train_test(): | |||||
| saver = ModelSaver("./save/saved_model.pkl") | saver = ModelSaver("./save/saved_model.pkl") | ||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| del model, trainer, loader | |||||
| del model, trainer | |||||
| # Define the same model | # Define the same model | ||||
| model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
| @@ -83,17 +83,16 @@ def train_test(): | |||||
| # Load test configuration | # Load test configuration | ||||
| test_args = ConfigSection() | test_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_path, {"POS_infer": test_args}) | |||||
| ConfigLoader().load_config(config_path, {"POS_infer": test_args}) | |||||
| test_args["evaluator"] = SeqLabelEvaluator() | |||||
| # Tester | # Tester | ||||
| tester = SeqLabelTester(**test_args.data) | tester = SeqLabelTester(**test_args.data) | ||||
| # Start testing | # Start testing | ||||
| change_field_is_target(data_train, "truth", True) | |||||
| tester.test(model, data_train) | tester.test(model, data_train) | ||||
| # print test results | |||||
| print(tester.show_metrics()) | |||||
| def test(): | def test(): | ||||
| os.makedirs("save", exist_ok=True) | os.makedirs("save", exist_ok=True) | ||||
| @@ -1,11 +1,12 @@ | |||||
| import os | import os | ||||
| from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target | |||||
| from fastNLP.core.metrics import SeqLabelEvaluator | |||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
| from fastNLP.core.preprocess import save_pickle | |||||
| from fastNLP.core.tester import SeqLabelTester | from fastNLP.core.tester import SeqLabelTester | ||||
| from fastNLP.core.trainer import SeqLabelTrainer | from fastNLP.core.trainer import SeqLabelTrainer | ||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| from fastNLP.loader.dataset_loader import POSDatasetLoader | |||||
| from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
| from fastNLP.models.sequence_modeling import SeqLabeling | from fastNLP.models.sequence_modeling import SeqLabeling | ||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| @@ -21,18 +22,17 @@ def test_training(): | |||||
| # Config Loader | # Config Loader | ||||
| trainer_args = ConfigSection() | trainer_args = ConfigSection() | ||||
| model_args = ConfigSection() | model_args = ConfigSection() | ||||
| ConfigLoader("_").load_config(config_dir, { | |||||
| ConfigLoader().load_config(config_dir, { | |||||
| "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
| # Data Loader | |||||
| pos_loader = POSDatasetLoader(data_path) | |||||
| train_data = pos_loader.load_lines() | |||||
| data_set = SeqLabelDataSet() | |||||
| data_set.load(data_path) | |||||
| data_train, data_dev = data_set.split(0.3, shuffle=True) | |||||
| model_args["vocab_size"] = len(data_set.word_vocab) | |||||
| model_args["num_classes"] = len(data_set.label_vocab) | |||||
| # Preprocessor | |||||
| p = SeqLabelPreprocess() | |||||
| data_train, data_dev = p.run(train_data, pickle_path=pickle_path, train_dev_split=0.5) | |||||
| model_args["vocab_size"] = p.vocab_size | |||||
| model_args["num_classes"] = p.num_classes | |||||
| save_pickle(data_set.word_vocab, pickle_path, "word2id.pkl") | |||||
| save_pickle(data_set.label_vocab, pickle_path, "label2id.pkl") | |||||
| trainer = SeqLabelTrainer( | trainer = SeqLabelTrainer( | ||||
| epochs=trainer_args["epochs"], | epochs=trainer_args["epochs"], | ||||
| @@ -55,7 +55,7 @@ def test_training(): | |||||
| saver = ModelSaver(os.path.join(pickle_path, model_name)) | saver = ModelSaver(os.path.join(pickle_path, model_name)) | ||||
| saver.save_pytorch(model) | saver.save_pytorch(model) | ||||
| del model, trainer, pos_loader | |||||
| del model, trainer | |||||
| # Define the same model | # Define the same model | ||||
| model = SeqLabeling(model_args) | model = SeqLabeling(model_args) | ||||
| @@ -65,21 +65,16 @@ def test_training(): | |||||
| # Load test configuration | # Load test configuration | ||||
| tester_args = ConfigSection() | tester_args = ConfigSection() | ||||
| ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| ConfigLoader().load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
| # Tester | # Tester | ||||
| tester = SeqLabelTester(save_output=False, | |||||
| save_loss=True, | |||||
| save_best_dev=False, | |||||
| batch_size=4, | |||||
| tester = SeqLabelTester(batch_size=4, | |||||
| use_cuda=False, | use_cuda=False, | ||||
| pickle_path=pickle_path, | pickle_path=pickle_path, | ||||
| model_name="seq_label_in_test.pkl", | model_name="seq_label_in_test.pkl", | ||||
| print_every_step=1 | |||||
| evaluator=SeqLabelEvaluator() | |||||
| ) | ) | ||||
| # Start testing with validation data | # Start testing with validation data | ||||
| change_field_is_target(data_dev, "truth", True) | |||||
| tester.test(model, data_dev) | tester.test(model, data_dev) | ||||
| loss, accuracy = tester.metrics | |||||
| assert 0 < accuracy < 1 | |||||
| @@ -9,13 +9,14 @@ sys.path.append("..") | |||||
| from fastNLP.core.predictor import ClassificationInfer | from fastNLP.core.predictor import ClassificationInfer | ||||
| from fastNLP.core.trainer import ClassificationTrainer | from fastNLP.core.trainer import ClassificationTrainer | ||||
| from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | ||||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||||
| from fastNLP.loader.dataset_loader import ClassDataSetLoader | |||||
| from fastNLP.loader.model_loader import ModelLoader | from fastNLP.loader.model_loader import ModelLoader | ||||
| from fastNLP.core.preprocess import ClassPreprocess | |||||
| from fastNLP.models.cnn_text_classification import CNNText | from fastNLP.models.cnn_text_classification import CNNText | ||||
| from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
| from fastNLP.core.optimizer import Optimizer | from fastNLP.core.optimizer import Optimizer | ||||
| from fastNLP.core.loss import Loss | from fastNLP.core.loss import Loss | ||||
| from fastNLP.core.dataset import TextClassifyDataSet | |||||
| from fastNLP.core.preprocess import save_pickle, load_pickle | |||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files") | ||||
| @@ -34,21 +35,18 @@ config_dir = args.config | |||||
| def infer(): | def infer(): | ||||
| # load dataset | # load dataset | ||||
| print("Loading data...") | print("Loading data...") | ||||
| ds_loader = ClassDatasetLoader(train_data_dir) | |||||
| data = ds_loader.load() | |||||
| unlabeled_data = [x[0] for x in data] | |||||
| word_vocab = load_pickle(save_dir, "word2id.pkl") | |||||
| label_vocab = load_pickle(save_dir, "label2id.pkl") | |||||
| print("vocabulary size:", len(word_vocab)) | |||||
| print("number of classes:", len(label_vocab)) | |||||
| # pre-process data | |||||
| pre = ClassPreprocess() | |||||
| data = pre.run(data, pickle_path=save_dir) | |||||
| print("vocabulary size:", pre.vocab_size) | |||||
| print("number of classes:", pre.num_classes) | |||||
| infer_data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
| infer_data.load(train_data_dir, vocabs={"word_vocab": word_vocab, "label_vocab": label_vocab}) | |||||
| model_args = ConfigSection() | model_args = ConfigSection() | ||||
| # TODO: load from config file | |||||
| model_args["vocab_size"] = pre.vocab_size | |||||
| model_args["num_classes"] = pre.num_classes | |||||
| # ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
| model_args["vocab_size"] = len(word_vocab) | |||||
| model_args["num_classes"] = len(label_vocab) | |||||
| ConfigLoader.load_config(config_dir, {"text_class_model": model_args}) | |||||
| # construct model | # construct model | ||||
| print("Building model...") | print("Building model...") | ||||
| @@ -59,7 +57,7 @@ def infer(): | |||||
| print("model loaded!") | print("model loaded!") | ||||
| infer = ClassificationInfer(pickle_path=save_dir) | infer = ClassificationInfer(pickle_path=save_dir) | ||||
| results = infer.predict(cnn, unlabeled_data) | |||||
| results = infer.predict(cnn, infer_data) | |||||
| print(results) | print(results) | ||||
| @@ -69,32 +67,23 @@ def train(): | |||||
| # load dataset | # load dataset | ||||
| print("Loading data...") | print("Loading data...") | ||||
| ds_loader = ClassDatasetLoader(train_data_dir) | |||||
| data = ds_loader.load() | |||||
| print(data[0]) | |||||
| data = TextClassifyDataSet(load_func=ClassDataSetLoader.load) | |||||
| data.load(train_data_dir) | |||||
| # pre-process data | |||||
| pre = ClassPreprocess() | |||||
| data_train = pre.run(data, pickle_path=save_dir) | |||||
| print("vocabulary size:", pre.vocab_size) | |||||
| print("number of classes:", pre.num_classes) | |||||
| print("vocabulary size:", len(data.word_vocab)) | |||||
| print("number of classes:", len(data.label_vocab)) | |||||
| save_pickle(data.word_vocab, save_dir, "word2id.pkl") | |||||
| save_pickle(data.label_vocab, save_dir, "label2id.pkl") | |||||
| model_args["num_classes"] = pre.num_classes | |||||
| model_args["vocab_size"] = pre.vocab_size | |||||
| model_args["num_classes"] = len(data.label_vocab) | |||||
| model_args["vocab_size"] = len(data.word_vocab) | |||||
| # construct model | # construct model | ||||
| print("Building model...") | print("Building model...") | ||||
| model = CNNText(model_args) | model = CNNText(model_args) | ||||
| # ConfigSaver().save_config(config_dir, {"text_class_model": model_args}) | |||||
| # train | # train | ||||
| print("Training...") | print("Training...") | ||||
| # 1 | |||||
| # trainer = ClassificationTrainer(train_args) | |||||
| # 2 | |||||
| trainer = ClassificationTrainer(epochs=train_args["epochs"], | trainer = ClassificationTrainer(epochs=train_args["epochs"], | ||||
| batch_size=train_args["batch_size"], | batch_size=train_args["batch_size"], | ||||
| validate=train_args["validate"], | validate=train_args["validate"], | ||||
| @@ -104,7 +93,7 @@ def train(): | |||||
| model_name=model_name, | model_name=model_name, | ||||
| loss=Loss("cross_entropy"), | loss=Loss("cross_entropy"), | ||||
| optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | optimizer=Optimizer("SGD", lr=0.001, momentum=0.9)) | ||||
| trainer.train(model, data_train) | |||||
| trainer.train(model, data) | |||||
| print("Training finished!") | print("Training finished!") | ||||
| @@ -115,4 +104,4 @@ def train(): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| train() | train() | ||||
| # infer() | |||||
| infer() | |||||
| @@ -2,7 +2,7 @@ import unittest | |||||
| import torch | import torch | ||||
| from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear | |||||
| from fastNLP.modules.other_modules import GroupNorm, LayerNormalization, BiLinear, BiAffine | |||||
| class TestGroupNorm(unittest.TestCase): | class TestGroupNorm(unittest.TestCase): | ||||
| @@ -27,3 +27,25 @@ class TestBiLinear(unittest.TestCase): | |||||
| y = bl(x_left, x_right) | y = bl(x_left, x_right) | ||||
| print(bl) | print(bl) | ||||
| bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) | bl2 = BiLinear(n_left=15, n_right=15, n_out=10, bias=True) | ||||
| class TestBiAffine(unittest.TestCase): | |||||
| def test_case_1(self): | |||||
| batch_size = 16 | |||||
| encoder_length = 21 | |||||
| decoder_length = 32 | |||||
| layer = BiAffine(10, 10, 25, biaffine=True) | |||||
| decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||||
| encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||||
| y = layer(decoder_input, encoder_input) | |||||
| self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, decoder_length)) | |||||
| def test_case_2(self): | |||||
| batch_size = 16 | |||||
| encoder_length = 21 | |||||
| decoder_length = 32 | |||||
| layer = BiAffine(10, 10, 25, biaffine=False) | |||||
| decoder_input = torch.randn((batch_size, encoder_length, 10)) | |||||
| encoder_input = torch.randn((batch_size, decoder_length, 10)) | |||||
| y = layer(decoder_input, encoder_input) | |||||
| self.assertEqual(tuple(y.shape), (batch_size, 25, encoder_length, 1)) | |||||
| @@ -1,8 +1,5 @@ | |||||
| import os | import os | ||||
| import unittest | import unittest | ||||
| import configparser | |||||
| import json | |||||
| from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | ||||
| from fastNLP.saver.config_saver import ConfigSaver | from fastNLP.saver.config_saver import ConfigSaver | ||||
| @@ -10,7 +7,7 @@ from fastNLP.saver.config_saver import ConfigSaver | |||||
| class TestConfigSaver(unittest.TestCase): | class TestConfigSaver(unittest.TestCase): | ||||
| def test_case_1(self): | def test_case_1(self): | ||||
| config_file_dir = "./test/loader/" | |||||
| config_file_dir = "test/loader/" | |||||
| config_file_name = "config" | config_file_name = "config" | ||||
| config_file_path = os.path.join(config_file_dir, config_file_name) | config_file_path = os.path.join(config_file_dir, config_file_name) | ||||
| @@ -21,7 +18,7 @@ class TestConfigSaver(unittest.TestCase): | |||||
| standard_section = ConfigSection() | standard_section = ConfigSection() | ||||
| t_section = ConfigSection() | t_section = ConfigSection() | ||||
| ConfigLoader(config_file_path).load_config(config_file_path, {"test": standard_section, "t": t_section}) | |||||
| ConfigLoader().load_config(config_file_path, {"test": standard_section, "t": t_section}) | |||||
| config_saver = ConfigSaver(config_file_path) | config_saver = ConfigSaver(config_file_path) | ||||
| @@ -48,11 +45,11 @@ class TestConfigSaver(unittest.TestCase): | |||||
| one_another_test_section = ConfigSection() | one_another_test_section = ConfigSection() | ||||
| a_test_case_2_section = ConfigSection() | a_test_case_2_section = ConfigSection() | ||||
| ConfigLoader(config_file_path).load_config(config_file_path, {"test": test_section, | |||||
| "another-test": another_test_section, | |||||
| "t": at_section, | |||||
| "one-another-test": one_another_test_section, | |||||
| "test-case-2": a_test_case_2_section}) | |||||
| ConfigLoader().load_config(config_file_path, {"test": test_section, | |||||
| "another-test": another_test_section, | |||||
| "t": at_section, | |||||
| "one-another-test": one_another_test_section, | |||||
| "test-case-2": a_test_case_2_section}) | |||||
| assert test_section == standard_section | assert test_section == standard_section | ||||
| assert at_section == t_section | assert at_section == t_section | ||||
| @@ -80,3 +77,37 @@ class TestConfigSaver(unittest.TestCase): | |||||
| tmp_config_saver = ConfigSaver("file-NOT-exist") | tmp_config_saver = ConfigSaver("file-NOT-exist") | ||||
| except Exception as e: | except Exception as e: | ||||
| pass | pass | ||||
| def test_case_2(self): | |||||
| config = "[section_A]\n[section_B]\n" | |||||
| with open("./test.cfg", "w", encoding="utf-8") as f: | |||||
| f.write(config) | |||||
| saver = ConfigSaver("./test.cfg") | |||||
| section = ConfigSection() | |||||
| section["doubles"] = 0.8 | |||||
| section["tt"] = [1, 2, 3] | |||||
| section["test"] = 105 | |||||
| section["str"] = "this is a str" | |||||
| saver.save_config_file("section_A", section) | |||||
| os.system("rm ./test.cfg") | |||||
| def test_case_3(self): | |||||
| config = "[section_A]\ndoubles = 0.9\ntt = [1, 2, 3]\n[section_B]\n" | |||||
| with open("./test.cfg", "w", encoding="utf-8") as f: | |||||
| f.write(config) | |||||
| saver = ConfigSaver("./test.cfg") | |||||
| section = ConfigSection() | |||||
| section["doubles"] = 0.8 | |||||
| section["tt"] = [1, 2, 3] | |||||
| section["test"] = 105 | |||||
| section["str"] = "this is a str" | |||||
| saver.save_config_file("section_A", section) | |||||
| os.system("rm ./test.cfg") | |||||
| @@ -54,7 +54,7 @@ def mock_cws(): | |||||
| class2id = Vocabulary(need_default=False) | class2id = Vocabulary(need_default=False) | ||||
| label_list = ['B', 'M', 'E', 'S'] | label_list = ['B', 'M', 'E', 'S'] | ||||
| class2id.update(label_list) | class2id.update(label_list) | ||||
| save_pickle(class2id, "./mock/", "class2id.pkl") | |||||
| save_pickle(class2id, "./mock/", "label2id.pkl") | |||||
| model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} | model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)} | ||||
| config_file = """ | config_file = """ | ||||
| @@ -115,7 +115,7 @@ def mock_pos_tag(): | |||||
| idx2label = Vocabulary(need_default=False) | idx2label = Vocabulary(need_default=False) | ||||
| label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] | label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv'] | ||||
| idx2label.update(label_list) | idx2label.update(label_list) | ||||
| save_pickle(idx2label, "./mock/", "class2id.pkl") | |||||
| save_pickle(idx2label, "./mock/", "label2id.pkl") | |||||
| model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | ||||
| config_file = """ | config_file = """ | ||||
| @@ -163,7 +163,7 @@ def mock_text_classify(): | |||||
| idx2label = Vocabulary(need_default=False) | idx2label = Vocabulary(need_default=False) | ||||
| label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] | label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F'] | ||||
| idx2label.update(label_list) | idx2label.update(label_list) | ||||
| save_pickle(idx2label, "./mock/", "class2id.pkl") | |||||
| save_pickle(idx2label, "./mock/", "label2id.pkl") | |||||
| model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)} | ||||
| config_file = """ | config_file = """ | ||||