| @@ -10,7 +10,7 @@ from fastNLP.core.dataset import DataSet | |||
| from fastNLP.api.model_zoo import load_url | |||
| from fastNLP.api.processor import ModelProcessor | |||
| from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
| from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader | |||
| from reproduction.pos_tag_model.pos_reader import ConllPOSReader | |||
| from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||
| from fastNLP.core.instance import Instance | |||
| from fastNLP.core.sampler import SequentialSampler | |||
| @@ -250,7 +250,7 @@ class LossInForward(LossBase): | |||
| if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
| if not isinstance(loss, torch.Tensor): | |||
| raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||
| raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | |||
| raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||
| return loss | |||
| @@ -436,15 +436,14 @@ class SpanFPreRecMetric(MetricBase): | |||
| raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(seq_lens)}.") | |||
| num_classes = pred.size(-1) | |||
| if (target >= num_classes).any(): | |||
| raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
| "id >= {}, the number of classes.".format(num_classes)) | |||
| if pred.size() == target.size() and len(target.size()) == 2: | |||
| pass | |||
| elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
| pred = pred.argmax(dim=-1) | |||
| num_classes = pred.size(-1) | |||
| if (target >= num_classes).any(): | |||
| raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
| "id >= {}, the number of classes.".format(num_classes)) | |||
| else: | |||
| raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
| f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
| @@ -1,8 +1,8 @@ | |||
| import torch | |||
| import numpy as np | |||
| from fastNLP.models.base_model import BaseModel | |||
| from fastNLP.modules import decoder, encoder | |||
| from fastNLP.modules.decoder.CRF import allowed_transitions | |||
| from fastNLP.modules.utils import seq_mask | |||
| @@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): | |||
| Advanced Sequence Labeling Model | |||
| """ | |||
| def __init__(self, args, emb=None): | |||
| def __init__(self, args, emb=None, id2words=None): | |||
| super(AdvSeqLabel, self).__init__(args) | |||
| vocab_size = args["vocab_size"] | |||
| @@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): | |||
| self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||
| self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||
| # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||
| self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||
| self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||
| bidirectional=True, batch_first=True) | |||
| self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||
| self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||
| # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
| @@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): | |||
| self.drop = torch.nn.Dropout(dropout) | |||
| self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
| if id2words is None: | |||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
| else: | |||
| self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||
| allowed_transitions=allowed_transitions(id2words, | |||
| encoding_type="bmes")) | |||
| def forward(self, word_seq, word_seq_origin_len, truth=None): | |||
| """ | |||
| @@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): | |||
| assert 'loss' in kwargs | |||
| return kwargs['loss'] | |||
| if __name__ == '__main__': | |||
| args = { | |||
| 'vocab_size': 20, | |||
| @@ -208,11 +215,11 @@ if __name__ == '__main__': | |||
| res = model(word_seq, word_seq_len, truth) | |||
| loss = res['loss'] | |||
| pred = res['predict'] | |||
| print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
| print('loss: {} acc {}'.format(loss.item(), | |||
| ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| optimizer.step() | |||
| curidx = endidx | |||
| if curidx == len(data): | |||
| curidx = 0 | |||
| @@ -4,6 +4,7 @@ from collections import Counter | |||
| from fastNLP.api.processor import Processor | |||
| from fastNLP.core.dataset import DataSet | |||
| class CombineWordAndPosProcessor(Processor): | |||
| def __init__(self, word_field_name, pos_field_name): | |||
| super(CombineWordAndPosProcessor, self).__init__(None, None) | |||
| @@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): | |||
| return dataset | |||
| class PosOutputStrProcessor(Processor): | |||
| def __init__(self, word_field_name, pos_field_name): | |||
| super(PosOutputStrProcessor, self).__init__(None, None) | |||
| @@ -0,0 +1,71 @@ | |||
| import torch | |||
| from fastNLP.api.pipeline import Pipeline | |||
| from fastNLP.api.processor import SeqLenProcessor | |||
| from fastNLP.core.metrics import SpanFPreRecMetric | |||
| from fastNLP.core.trainer import Trainer | |||
| from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
| from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
| from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||
| from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||
| cfgfile = './pos_tag.cfg' | |||
| pickle_path = "save" | |||
| def train(): | |||
| # load config | |||
| train_param = ConfigSection() | |||
| model_param = ConfigSection() | |||
| ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||
| print("config loaded") | |||
| # Data Loader | |||
| dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||
| print(dataset) | |||
| print("dataset transformed") | |||
| vocab_proc = VocabIndexerProcessor("words") | |||
| tag_proc = VocabIndexerProcessor("tag") | |||
| seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") | |||
| vocab_proc(dataset) | |||
| tag_proc(dataset) | |||
| seq_len_proc(dataset) | |||
| dataset.rename_field("words", "word_seq") | |||
| dataset.rename_field("tag", "truth") | |||
| dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||
| dataset.set_target("truth", "word_seq_origin_len") | |||
| print("processors defined") | |||
| # dataset.set_is_target(tag_ids=True) | |||
| model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||
| model_param["num_classes"] = tag_proc.get_vocab_size() | |||
| print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||
| # define a model | |||
| model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) | |||
| # call trainer to train | |||
| trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
| target="truth", | |||
| seq_lens="word_seq_origin_len"), | |||
| dev_data=dataset, metric_key="f", | |||
| use_tqdm=False, use_cuda=True, print_every=20) | |||
| trainer.train() | |||
| # save model & pipeline | |||
| pp = Pipeline([vocab_proc, seq_len_proc]) | |||
| save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||
| torch.save(save_dict, "model_pp.pkl") | |||
| print("pipeline saved") | |||
| def infer(): | |||
| pass | |||
| if __name__ == "__main__": | |||
| train() | |||