diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 0ac5c503..bd14197a 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -10,7 +10,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.model_zoo import load_url from fastNLP.api.processor import ModelProcessor from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader -from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader +from reproduction.pos_tag_model.pos_reader import ConllPOSReader from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag from fastNLP.core.instance import Instance from fastNLP.core.sampler import SequentialSampler diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 04f8b73e..c11f538c 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -250,7 +250,7 @@ class LossInForward(LossBase): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not isinstance(loss, torch.Tensor): - raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") + raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") return loss diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 99637463..36f9ab90 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -436,15 +436,14 @@ class SpanFPreRecMetric(MetricBase): raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_lens)}.") - num_classes = pred.size(-1) - if (target >= num_classes).any(): - raise ValueError("A gold label passed to SpanBasedF1Metric contains an " - "id >= {}, the number of classes.".format(num_classes)) - if pred.size() == target.size() and len(target.size()) == 2: pass elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: pred = pred.argmax(dim=-1) + num_classes = pred.size(-1) + if (target >= num_classes).any(): + raise ValueError("A gold label passed to SpanBasedF1Metric contains an " + "id >= {}, the number of classes.".format(num_classes)) else: raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " f"size:{pred.size()}, target should have size: {pred.size()} or " diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index e911598c..cb9e9478 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -1,8 +1,8 @@ import torch -import numpy as np from fastNLP.models.base_model import BaseModel from fastNLP.modules import decoder, encoder +from fastNLP.modules.decoder.CRF import allowed_transitions from fastNLP.modules.utils import seq_mask @@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): Advanced Sequence Labeling Model """ - def __init__(self, args, emb=None): + def __init__(self, args, emb=None, id2words=None): super(AdvSeqLabel, self).__init__(args) vocab_size = args["vocab_size"] @@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) self.norm1 = torch.nn.LayerNorm(word_emb_dim) # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) - self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) + self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, + bidirectional=True, batch_first=True) self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) @@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): self.drop = torch.nn.Dropout(dropout) self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) - self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) + if id2words is None: + self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) + else: + self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, + allowed_transitions=allowed_transitions(id2words, + encoding_type="bmes")) def forward(self, word_seq, word_seq_origin_len, truth=None): """ @@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): assert 'loss' in kwargs return kwargs['loss'] + if __name__ == '__main__': args = { 'vocab_size': 20, @@ -208,11 +215,11 @@ if __name__ == '__main__': res = model(word_seq, word_seq_len, truth) loss = res['loss'] pred = res['predict'] - print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) + print('loss: {} acc {}'.format(loss.item(), + ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) optimizer.zero_grad() loss.backward() optimizer.step() curidx = endidx if curidx == len(data): curidx = 0 - diff --git a/reproduction/pos_tag_model/process/pos_processor.py b/reproduction/pos_tag_model/pos_processor.py similarity index 99% rename from reproduction/pos_tag_model/process/pos_processor.py rename to reproduction/pos_tag_model/pos_processor.py index 5c03f9cd..7a1b8e01 100644 --- a/reproduction/pos_tag_model/process/pos_processor.py +++ b/reproduction/pos_tag_model/pos_processor.py @@ -4,6 +4,7 @@ from collections import Counter from fastNLP.api.processor import Processor from fastNLP.core.dataset import DataSet + class CombineWordAndPosProcessor(Processor): def __init__(self, word_field_name, pos_field_name): super(CombineWordAndPosProcessor, self).__init__(None, None) @@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): return dataset + class PosOutputStrProcessor(Processor): def __init__(self, word_field_name, pos_field_name): super(PosOutputStrProcessor, self).__init__(None, None) diff --git a/reproduction/pos_tag_model/pos_io/pos_reader.py b/reproduction/pos_tag_model/pos_reader.py similarity index 100% rename from reproduction/pos_tag_model/pos_io/pos_reader.py rename to reproduction/pos_tag_model/pos_reader.py diff --git a/reproduction/pos_tag_model/train_pos_tag.py b/reproduction/pos_tag_model/train_pos_tag.py new file mode 100644 index 00000000..e440b542 --- /dev/null +++ b/reproduction/pos_tag_model/train_pos_tag.py @@ -0,0 +1,71 @@ +import torch + +from fastNLP.api.pipeline import Pipeline +from fastNLP.api.processor import SeqLenProcessor +from fastNLP.core.metrics import SpanFPreRecMetric +from fastNLP.core.trainer import Trainer +from fastNLP.io.config_io import ConfigLoader, ConfigSection +from fastNLP.models.sequence_modeling import AdvSeqLabel +from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor +from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader + +cfgfile = './pos_tag.cfg' +pickle_path = "save" + + +def train(): + # load config + train_param = ConfigSection() + model_param = ConfigSection() + ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) + print("config loaded") + + # Data Loader + dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") + print(dataset) + print("dataset transformed") + + vocab_proc = VocabIndexerProcessor("words") + tag_proc = VocabIndexerProcessor("tag") + seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") + + vocab_proc(dataset) + tag_proc(dataset) + seq_len_proc(dataset) + + dataset.rename_field("words", "word_seq") + dataset.rename_field("tag", "truth") + dataset.set_input("word_seq", "word_seq_origin_len", "truth") + dataset.set_target("truth", "word_seq_origin_len") + + print("processors defined") + + # dataset.set_is_target(tag_ids=True) + model_param["vocab_size"] = vocab_proc.get_vocab_size() + model_param["num_classes"] = tag_proc.get_vocab_size() + print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) + + # define a model + model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) + + # call trainer to train + trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", + target="truth", + seq_lens="word_seq_origin_len"), + dev_data=dataset, metric_key="f", + use_tqdm=False, use_cuda=True, print_every=20) + trainer.train() + + # save model & pipeline + pp = Pipeline([vocab_proc, seq_len_proc]) + save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} + torch.save(save_dict, "model_pp.pkl") + print("pipeline saved") + + +def infer(): + pass + + +if __name__ == "__main__": + train()