@@ -10,7 +10,7 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.model_zoo import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||
from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader | |||
from reproduction.pos_tag_model.pos_reader import ConllPOSReader | |||
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.sampler import SequentialSampler | |||
@@ -250,7 +250,7 @@ class LossInForward(LossBase): | |||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}") | |||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||
return loss | |||
@@ -436,15 +436,14 @@ class SpanFPreRecMetric(MetricBase): | |||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_lens)}.") | |||
num_classes = pred.size(-1) | |||
if (target >= num_classes).any(): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
if pred.size() == target.size() and len(target.size()) == 2: | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
pred = pred.argmax(dim=-1) | |||
num_classes = pred.size(-1) | |||
if (target >= num_classes).any(): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
else: | |||
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
@@ -1,8 +1,8 @@ | |||
import torch | |||
import numpy as np | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder, encoder | |||
from fastNLP.modules.decoder.CRF import allowed_transitions | |||
from fastNLP.modules.utils import seq_mask | |||
@@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling): | |||
Advanced Sequence Labeling Model | |||
""" | |||
def __init__(self, args, emb=None): | |||
def __init__(self, args, emb=None, id2words=None): | |||
super(AdvSeqLabel, self).__init__(args) | |||
vocab_size = args["vocab_size"] | |||
@@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling): | |||
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb) | |||
self.norm1 = torch.nn.LayerNorm(word_emb_dim) | |||
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True) | |||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True) | |||
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, | |||
bidirectional=True, batch_first=True) | |||
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3) | |||
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3) | |||
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3) | |||
@@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling): | |||
self.drop = torch.nn.Dropout(dropout) | |||
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
if id2words is None: | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False) | |||
else: | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False, | |||
allowed_transitions=allowed_transitions(id2words, | |||
encoding_type="bmes")) | |||
def forward(self, word_seq, word_seq_origin_len, truth=None): | |||
""" | |||
@@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling): | |||
assert 'loss' in kwargs | |||
return kwargs['loss'] | |||
if __name__ == '__main__': | |||
args = { | |||
'vocab_size': 20, | |||
@@ -208,11 +215,11 @@ if __name__ == '__main__': | |||
res = model(word_seq, word_seq_len, truth) | |||
loss = res['loss'] | |||
pred = res['predict'] | |||
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
print('loss: {} acc {}'.format(loss.item(), | |||
((pred.data == truth).long().sum().float() / word_seq_len.sum().float()))) | |||
optimizer.zero_grad() | |||
loss.backward() | |||
optimizer.step() | |||
curidx = endidx | |||
if curidx == len(data): | |||
curidx = 0 | |||
@@ -4,6 +4,7 @@ from collections import Counter | |||
from fastNLP.api.processor import Processor | |||
from fastNLP.core.dataset import DataSet | |||
class CombineWordAndPosProcessor(Processor): | |||
def __init__(self, word_field_name, pos_field_name): | |||
super(CombineWordAndPosProcessor, self).__init__(None, None) | |||
@@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor): | |||
return dataset | |||
class PosOutputStrProcessor(Processor): | |||
def __init__(self, word_field_name, pos_field_name): | |||
super(PosOutputStrProcessor, self).__init__(None, None) |
@@ -0,0 +1,71 @@ | |||
import torch | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import SeqLenProcessor | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader | |||
cfgfile = './pos_tag.cfg' | |||
pickle_path = "save" | |||
def train(): | |||
# load config | |||
train_param = ConfigSection() | |||
model_param = ConfigSection() | |||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||
print("config loaded") | |||
# Data Loader | |||
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||
print(dataset) | |||
print("dataset transformed") | |||
vocab_proc = VocabIndexerProcessor("words") | |||
tag_proc = VocabIndexerProcessor("tag") | |||
seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len") | |||
vocab_proc(dataset) | |||
tag_proc(dataset) | |||
seq_len_proc(dataset) | |||
dataset.rename_field("words", "word_seq") | |||
dataset.rename_field("tag", "truth") | |||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||
dataset.set_target("truth", "word_seq_origin_len") | |||
print("processors defined") | |||
# dataset.set_is_target(tag_ids=True) | |||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||
model_param["num_classes"] = tag_proc.get_vocab_size() | |||
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"])) | |||
# define a model | |||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word) | |||
# call trainer to train | |||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), | |||
dev_data=dataset, metric_key="f", | |||
use_tqdm=False, use_cuda=True, print_every=20) | |||
trainer.train() | |||
# save model & pipeline | |||
pp = Pipeline([vocab_proc, seq_len_proc]) | |||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||
torch.save(save_dict, "model_pp.pkl") | |||
print("pipeline saved") | |||
def infer(): | |||
pass | |||
if __name__ == "__main__": | |||
train() |