Browse Source

add pos-tag training script

tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
887c6fec70
7 changed files with 92 additions and 13 deletions
  1. +1
    -1
      fastNLP/api/api.py
  2. +1
    -1
      fastNLP/core/losses.py
  3. +4
    -5
      fastNLP/core/metrics.py
  4. +13
    -6
      fastNLP/models/sequence_modeling.py
  5. +2
    -0
      reproduction/pos_tag_model/pos_processor.py
  6. +0
    -0
      reproduction/pos_tag_model/pos_reader.py
  7. +71
    -0
      reproduction/pos_tag_model/train_pos_tag.py

+ 1
- 1
fastNLP/api/api.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.dataset import DataSet
from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.pos_tag_model.pos_io.pos_reader import ConllPOSReader
from reproduction.pos_tag_model.pos_reader import ConllPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler


+ 1
- 1
fastNLP/core/losses.py View File

@@ -250,7 +250,7 @@ class LossInForward(LossBase):

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}")
raise TypeError(f"Loss excepted to be a torch.Tensor, got {type(loss)}")
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")

return loss


+ 4
- 5
fastNLP/core/metrics.py View File

@@ -436,15 +436,14 @@ class SpanFPreRecMetric(MetricBase):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")

num_classes = pred.size(-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))

if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
num_classes = pred.size(-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "


+ 13
- 6
fastNLP/models/sequence_modeling.py View File

@@ -1,8 +1,8 @@
import torch
import numpy as np

from fastNLP.models.base_model import BaseModel
from fastNLP.modules import decoder, encoder
from fastNLP.modules.decoder.CRF import allowed_transitions
from fastNLP.modules.utils import seq_mask


@@ -93,7 +93,7 @@ class AdvSeqLabel(SeqLabeling):
Advanced Sequence Labeling Model
"""

def __init__(self, args, emb=None):
def __init__(self, args, emb=None, id2words=None):
super(AdvSeqLabel, self).__init__(args)

vocab_size = args["vocab_size"]
@@ -105,7 +105,8 @@ class AdvSeqLabel(SeqLabeling):
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
self.norm1 = torch.nn.LayerNorm(word_emb_dim)
# self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=2, dropout=dropout, bidirectional=True)
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout, bidirectional=True, batch_first=True)
self.Rnn = torch.nn.LSTM(input_size=word_emb_dim, hidden_size=hidden_dim, num_layers=2, dropout=dropout,
bidirectional=True, batch_first=True)
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
self.norm2 = torch.nn.LayerNorm(hidden_dim * 2 // 3)
# self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
@@ -113,7 +114,12 @@ class AdvSeqLabel(SeqLabeling):
self.drop = torch.nn.Dropout(dropout)
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes)

self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
if id2words is None:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False)
else:
self.Crf = decoder.CRF.ConditionalRandomField(num_classes, include_start_end_trans=False,
allowed_transitions=allowed_transitions(id2words,
encoding_type="bmes"))

def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
@@ -178,6 +184,7 @@ class AdvSeqLabel(SeqLabeling):
assert 'loss' in kwargs
return kwargs['loss']


if __name__ == '__main__':
args = {
'vocab_size': 20,
@@ -208,11 +215,11 @@ if __name__ == '__main__':
res = model(word_seq, word_seq_len, truth)
loss = res['loss']
pred = res['predict']
print('loss: {} acc {}'.format(loss.item(), ((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
print('loss: {} acc {}'.format(loss.item(),
((pred.data == truth).long().sum().float() / word_seq_len.sum().float())))
optimizer.zero_grad()
loss.backward()
optimizer.step()
curidx = endidx
if curidx == len(data):
curidx = 0


reproduction/pos_tag_model/process/pos_processor.py → reproduction/pos_tag_model/pos_processor.py View File

@@ -4,6 +4,7 @@ from collections import Counter
from fastNLP.api.processor import Processor
from fastNLP.core.dataset import DataSet


class CombineWordAndPosProcessor(Processor):
def __init__(self, word_field_name, pos_field_name):
super(CombineWordAndPosProcessor, self).__init__(None, None)
@@ -60,6 +61,7 @@ class CombineWordAndPosProcessor(Processor):

return dataset


class PosOutputStrProcessor(Processor):
def __init__(self, word_field_name, pos_field_name):
super(PosOutputStrProcessor, self).__init__(None, None)

reproduction/pos_tag_model/pos_io/pos_reader.py → reproduction/pos_tag_model/pos_reader.py View File


+ 71
- 0
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -0,0 +1,71 @@
import torch

from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import SeqLenProcessor
from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.trainer import Trainer
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.models.sequence_modeling import AdvSeqLabel
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader

cfgfile = './pos_tag.cfg'
pickle_path = "save"


def train():
# load config
train_param = ConfigSection()
model_param = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param})
print("config loaded")

# Data Loader
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx")
print(dataset)
print("dataset transformed")

vocab_proc = VocabIndexerProcessor("words")
tag_proc = VocabIndexerProcessor("tag")
seq_len_proc = SeqLenProcessor(field_name="words", new_added_field_name="word_seq_origin_len")

vocab_proc(dataset)
tag_proc(dataset)
seq_len_proc(dataset)

dataset.rename_field("words", "word_seq")
dataset.rename_field("tag", "truth")
dataset.set_input("word_seq", "word_seq_origin_len", "truth")
dataset.set_target("truth", "word_seq_origin_len")

print("processors defined")

# dataset.set_is_target(tag_ids=True)
model_param["vocab_size"] = vocab_proc.get_vocab_size()
model_param["num_classes"] = tag_proc.get_vocab_size()
print("vocab_size={} num_classes={}".format(model_param["vocab_size"], model_param["num_classes"]))

# define a model
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word)

# call trainer to train
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dataset, metric_key="f",
use_tqdm=False, use_cuda=True, print_every=20)
trainer.train()

# save model & pipeline
pp = Pipeline([vocab_proc, seq_len_proc])
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
torch.save(save_dict, "model_pp.pkl")
print("pipeline saved")


def infer():
pass


if __name__ == "__main__":
train()

Loading…
Cancel
Save