@@ -10,7 +10,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.model_zoo import load_url | from fastNLP.api.model_zoo import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | ||||
from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader | |||||
from reproduction.pos_tag_model.pos_reader import ConllPOSReader | |||||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
@@ -250,7 +250,7 @@ class LossInForward(LossBase): | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | |||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | ||||
return loss | return loss | ||||
@@ -436,15 +436,14 @@ class SpanFPreRecMetric(MetricBase): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
num_classes = pred.size(-1) | |||||
if (target >= num_classes).any(): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
"id >= {}, the number of classes.".format(num_classes)) | |||||
if pred.size() == target.size() and len(target.size()) == 2: | if pred.size() == target.size() and len(target.size()) == 2: | ||||
pass | pass | ||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
num_classes = pred.size(-1) | |||||
if (target >= num_classes).any(): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
"id >= {}, the number of classes.".format(num_classes)) | |||||
else: | else: | ||||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | ||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
@@ -1,8 +1,8 @@ | |||||
import torch | import torch | ||||
import numpy as np | |||||
from fastNLP.models.base_model import BaseModel | from fastNLP.models.base_model import BaseModel | ||||
from fastNLP.modules import decoder, encoder | from fastNLP.modules import decoder, encoder | ||||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||||
from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
@@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
Advanced Sequence Labeling Model | Advanced Sequence Labeling Model | ||||
""" | """ | ||||
def __init__(self, args, emb=None): | |||||
def __init__(self, args, emb=None, id2words=None): | |||||
super(AdvSeqLabel, self).__init__(args) | super(AdvSeqLabel, self).__init__(args) | ||||
vocab_size = args["vocab_size"] | vocab_size = args["vocab_size"] | ||||
@@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): | |||||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | ||||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | self.norm1 = torch.nn.LayerNorm(word_emb_dim) | ||||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | # self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | ||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||||
bidirectional=True, batch_first=True) | |||||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | ||||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | ||||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | # self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | ||||
@@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): | |||||
self.drop = torch.nn.Dropout(dropout) | self.drop = torch.nn.Dropout(dropout) | ||||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | ||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
if id2words is None: | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||||
else: | |||||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||||
allowed_transitions=allowed_transitions(id2words, | |||||
encoding_type="bmes")) | |||||
def forward(self, word_seq, word_seq_origin_len, truth=None): | def forward(self, word_seq, word_seq_origin_len, truth=None): | ||||
""" | """ | ||||
@@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): | |||||
assert 'loss' in kwargs | assert 'loss' in kwargs | ||||
return kwargs['loss'] | return kwargs['loss'] | ||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||
args = { | args = { | ||||
'vocab_size': 20, | 'vocab_size': 20, | ||||
@@ -208,11 +215,11 @@ if __name__ == '__main__': | |||||
res = model(word_seq, word_seq_len, truth) | res = model(word_seq, word_seq_len, truth) | ||||
loss = res['loss'] | loss = res['loss'] | ||||
pred = res['predict'] | pred = res['predict'] | ||||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
print('loss: {} acc {}'.format(loss.item(), | |||||
((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
optimizer.step() | optimizer.step() | ||||
curidx = endidx | curidx = endidx | ||||
if curidx == len(data): | if curidx == len(data): | ||||
curidx = 0 | curidx = 0 | ||||
@@ -4,6 +4,7 @@ from collections import Counter | |||||
from fastNLP.api.processor import Processor | from fastNLP.api.processor import Processor | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
class CombineWordAndPosProcessor(Processor): | class CombineWordAndPosProcessor(Processor): | ||||
def __init__(self, word_field_name, pos_field_name): | def __init__(self, word_field_name, pos_field_name): | ||||
super(CombineWordAndPosProcessor, self).__init__(None, None) | super(CombineWordAndPosProcessor, self).__init__(None, None) | ||||
@@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): | |||||
return dataset | return dataset | ||||
class PosOutputStrProcessor(Processor): | class PosOutputStrProcessor(Processor): | ||||
def __init__(self, word_field_name, pos_field_name): | def __init__(self, word_field_name, pos_field_name): | ||||
super(PosOutputStrProcessor, self).__init__(None, None) | super(PosOutputStrProcessor, self).__init__(None, None) |
@@ -0,0 +1,71 @@ | |||||
import torch | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import SeqLenProcessor | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||||
cfgfile = './pos_tag.cfg' | |||||
pickle_path = "save" | |||||
def train(): | |||||
# load config | |||||
train_param = ConfigSection() | |||||
model_param = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
print("config loaded") | |||||
# Data Loader | |||||
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||||
print(dataset) | |||||
print("dataset transformed") | |||||
vocab_proc = VocabIndexerProcessor("words") | |||||
tag_proc = VocabIndexerProcessor("tag") | |||||
seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") | |||||
vocab_proc(dataset) | |||||
tag_proc(dataset) | |||||
seq_len_proc(dataset) | |||||
dataset.rename_field("words", "word_seq") | |||||
dataset.rename_field("tag", "truth") | |||||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||||
dataset.set_target("truth", "word_seq_origin_len") | |||||
print("processors defined") | |||||
# dataset.set_is_target(tag_ids=True) | |||||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||||
model_param["num_classes"] = tag_proc.get_vocab_size() | |||||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||||
# define a model | |||||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) | |||||
# call trainer to train | |||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), | |||||
dev_data=dataset, metric_key="f", | |||||
use_tqdm=False, use_cuda=True, print_every=20) | |||||
trainer.train() | |||||
# save model & pipeline | |||||
pp = Pipeline([vocab_proc, seq_len_proc]) | |||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||||
torch.save(save_dict, "model_pp.pkl") | |||||
print("pipeline saved") | |||||
def infer(): | |||||
pass | |||||
if __name__ == "__main__": | |||||
train() |