| @@ -10,7 +10,7 @@ from fastNLP.core.dataset import DataSet | |||||
| from fastNLP.api.model_zoo import load_url | from fastNLP.api.model_zoo import load_url | ||||
| from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
| from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | ||||
| from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader | |||||
| from reproduction.pos_tag_model.pos_reader import ConllPOSReader | |||||
| from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
| @@ -250,7 +250,7 @@ class LossInForward(LossBase): | |||||
| if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
| if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
| raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
| raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | |||||
| raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | ||||
| return loss | return loss | ||||
| @@ -436,15 +436,14 @@ class SpanFPreRecMetric(MetricBase): | |||||
| raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
| f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
| num_classes = pred.size(-1) | |||||
| if (target >= num_classes).any(): | |||||
| raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
| "id >= {}, the number of classes.".format(num_classes)) | |||||
| if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
| pass | pass | ||||
| elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | ||||
| pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
| num_classes = pred.size(-1) | |||||
| if (target >= num_classes).any(): | |||||
| raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
| "id >= {}, the number of classes.".format(num_classes)) | |||||
| else: | else: | ||||
| raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | ||||
| f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
| @@ -1,8 +1,8 @@ | |||||
| import torch | import torch | ||||
| import numpy as np | |||||
| from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
| from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
| from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
| from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
| @@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
| Advanced Sequence Labeling Model | Advanced Sequence Labeling Model | ||||
| """ | """ | ||||
| def __init__(self, args, emb=None): | |||||
| def __init__(self, args, emb=None, id2words=None): | |||||
| super(AdvSeqLabel, self).__init__(args) | super(AdvSeqLabel, self).__init__(args) | ||||
| vocab_size = args["vocab_size"] | vocab_size = args["vocab_size"] | ||||
| @@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): | |||||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
| self.norm1 = torch.nn.LayerNorm(word_emb_dim) | self.norm1 = torch.nn.LayerNorm(word_emb_dim) | ||||
| # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | ||||
| self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||||
| self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||||
| bidirectional=True, batch_first=True) | |||||
| self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
| self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | ||||
| # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | ||||
| @@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): | |||||
| self.drop = torch.nn.Dropout(dropout) | self.drop = torch.nn.Dropout(dropout) | ||||
| self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | ||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
| if id2words is None: | |||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
| else: | |||||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
| allowed_transitions=allowed_transitions(id2words, | |||||
| encoding_type="bmes")) | |||||
| def forward(self, word_seq, word_seq_origin_len, truth=None): | def forward(self, word_seq, word_seq_origin_len, truth=None): | ||||
| """ | """ | ||||
| @@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
| assert 'loss' in kwargs | assert 'loss' in kwargs | ||||
| return kwargs['loss'] | return kwargs['loss'] | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| args = { | args = { | ||||
| 'vocab_size': 20, | 'vocab_size': 20, | ||||
| @@ -208,11 +215,11 @@ if __name__ == '__main__': | |||||
| res = model(word_seq, word_seq_len, truth) | res = model(word_seq, word_seq_len, truth) | ||||
| loss = res['loss'] | loss = res['loss'] | ||||
| pred = res['predict'] | pred = res['predict'] | ||||
| print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
| print('loss: {} acc {}'.format(loss.item(), | |||||
| ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
| optimizer.zero_grad() | optimizer.zero_grad() | ||||
| loss.backward() | loss.backward() | ||||
| optimizer.step() | optimizer.step() | ||||
| curidx = endidx | curidx = endidx | ||||
| if curidx == len(data): | if curidx == len(data): | ||||
| curidx = 0 | curidx = 0 | ||||
| @@ -4,6 +4,7 @@ from collections import Counter | |||||
| from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| class CombineWordAndPosProcessor(Processor): | class CombineWordAndPosProcessor(Processor): | ||||
| def __init__(self, word_field_name, pos_field_name): | def __init__(self, word_field_name, pos_field_name): | ||||
| super(CombineWordAndPosProcessor, self).__init__(None, None) | super(CombineWordAndPosProcessor, self).__init__(None, None) | ||||
| @@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): | |||||
| return dataset | return dataset | ||||
| class PosOutputStrProcessor(Processor): | class PosOutputStrProcessor(Processor): | ||||
| def __init__(self, word_field_name, pos_field_name): | def __init__(self, word_field_name, pos_field_name): | ||||
| super(PosOutputStrProcessor, self).__init__(None, None) | super(PosOutputStrProcessor, self).__init__(None, None) | ||||
| @@ -0,0 +1,71 @@ | |||||
| import torch | |||||
| from fastNLP.api.pipeline import Pipeline | |||||
| from fastNLP.api.processor import SeqLenProcessor | |||||
| from fastNLP.core.metrics import SpanFPreRecMetric | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
| from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||||
| from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
| cfgfile = './pos_tag.cfg' | |||||
| pickle_path = "save" | |||||
| def train(): | |||||
| # load config | |||||
| train_param = ConfigSection() | |||||
| model_param = ConfigSection() | |||||
| ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
| print("config loaded") | |||||
| # Data Loader | |||||
| dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||||
| print(dataset) | |||||
| print("dataset transformed") | |||||
| vocab_proc = VocabIndexerProcessor("words") | |||||
| tag_proc = VocabIndexerProcessor("tag") | |||||
| seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") | |||||
| vocab_proc(dataset) | |||||
| tag_proc(dataset) | |||||
| seq_len_proc(dataset) | |||||
| dataset.rename_field("words", "word_seq") | |||||
| dataset.rename_field("tag", "truth") | |||||
| dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||||
| dataset.set_target("truth", "word_seq_origin_len") | |||||
| print("processors defined") | |||||
| # dataset.set_is_target(tag_ids=True) | |||||
| model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||||
| model_param["num_classes"] = tag_proc.get_vocab_size() | |||||
| print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||||
| # define a model | |||||
| model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) | |||||
| # call trainer to train | |||||
| trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
| target="truth", | |||||
| seq_lens="word_seq_origin_len"), | |||||
| dev_data=dataset, metric_key="f", | |||||
| use_tqdm=False, use_cuda=True, print_every=20) | |||||
| trainer.train() | |||||
| # save model & pipeline | |||||
| pp = Pipeline([vocab_proc, seq_len_proc]) | |||||
| save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||||
| torch.save(save_dict, "model_pp.pkl") | |||||
| print("pipeline saved") | |||||
| def infer(): | |||||
| pass | |||||
| if __name__ == "__main__": | |||||
| train() | |||||