@@ -0,0 +1,252 @@ | |||||
import torch.nn as nn | |||||
from pathes import * | |||||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||||
import torch.optim as optim | |||||
import argparse | |||||
import torch | |||||
import sys | |||||
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ | |||||
from fastNLP import Tester | |||||
import fitlog | |||||
from fastNLP.core.callback import FitlogCallback | |||||
from utils import set_seed | |||||
import os | |||||
from fastNLP import LRScheduler | |||||
from torch.optim.lr_scheduler import LambdaLR | |||||
# sys.path.append('.') | |||||
# sys.path.append('..') | |||||
# for p in sys.path: | |||||
# print(p) | |||||
# fitlog.add_hyper_in_file (__file__) # record your hyperparameters | |||||
########hyper | |||||
########hyper | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--device',default='cpu') | |||||
parser.add_argument('--debug',default=True) | |||||
parser.add_argument('--batch',default=1) | |||||
parser.add_argument('--test_batch',default=1024) | |||||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||||
parser.add_argument('--lr',default=0.015) | |||||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||||
parser.add_argument('--hidden',default=100) | |||||
parser.add_argument('--momentum',default=0) | |||||
parser.add_argument('--bi',default=True) | |||||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||||
parser.add_argument('--use_bigram',default=False) | |||||
parser.add_argument('--embed_dropout',default=0) | |||||
parser.add_argument('--output_dropout',default=0) | |||||
parser.add_argument('--epoch',default=100) | |||||
parser.add_argument('--seed',default=100) | |||||
args = parser.parse_args() | |||||
set_seed(args.seed) | |||||
fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] | |||||
if args.model == 'lattice': | |||||
fit_msg_list.append(str(args.skip_before_head)) | |||||
fit_msg = ' '.join(fit_msg_list) | |||||
# fitlog.commit(__file__,fit_msg=fit_msg) | |||||
# fitlog.add_hyper(args) | |||||
device = torch.device(args.device) | |||||
for k,v in args.__dict__.items(): | |||||
print(k,v) | |||||
refresh_data = False | |||||
if args.dataset == 'ontonote': | |||||
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False) | |||||
elif args.dataset == 'resume': | |||||
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False) | |||||
# exit() | |||||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||||
_refresh=refresh_data) | |||||
cache_name = os.path.join('cache',args.dataset+'_lattice') | |||||
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||||
_refresh=refresh_data,_cache_fp=cache_name) | |||||
print('中:embedding:{}'.format(embeddings['char'](24))) | |||||
print('embed lookup dropout:{}'.format(embeddings['word'].word_dropout)) | |||||
# for k, v in datasets.items(): | |||||
# # v.apply_field(lambda x: list(map(len, x)), 'skips_l2r_word', 'lexicon_count') | |||||
# # v.apply_field(lambda x: | |||||
# # list(map(lambda y: | |||||
# # list(map(lambda z: vocabs['word'].to_index(z), y)), x)), | |||||
# # 'skips_l2r_word') | |||||
print(datasets['train'][0]) | |||||
print('vocab info:') | |||||
for k,v in vocabs.items(): | |||||
print('{}:{}'.format(k,len(v))) | |||||
# print(datasets['dev'][0]) | |||||
# print(datasets['test'][0]) | |||||
# print(datasets['train'].get_all_fields().keys()) | |||||
for k,v in datasets.items(): | |||||
if args.model == 'lattice': | |||||
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') | |||||
if args.skip_before_head: | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) | |||||
else: | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_word', LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) | |||||
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) | |||||
if args.bi: | |||||
v.set_input('chars','bigrams','seq_len', | |||||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||||
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', | |||||
'target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
else: | |||||
v.set_input('chars','bigrams','seq_len', | |||||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||||
'target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
v.set_target('target','seq_len') | |||||
v['target'].set_pad_val(0) | |||||
elif args.model == 'lstm': | |||||
v.set_ignore_type('skips_l2r_word','skips_l2r_source') | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||||
v.set_input('chars','bigrams','seq_len','target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
v.set_target('target','seq_len') | |||||
v['target'].set_pad_val(0) | |||||
print(datasets['dev']['skips_l2r_word'][100]) | |||||
if args.model =='lattice': | |||||
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, | |||||
vocabs=vocabs | |||||
) | |||||
elif args.model == 'lstm': | |||||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||||
bidirectional=args.bi, | |||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||||
use_bigram=args.use_bigram) | |||||
for k,v in model.state_dict().items(): | |||||
print('{}:{}'.format(k,v.size())) | |||||
# exit() | |||||
weight_dict = torch.load(open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')) | |||||
# print(weight_dict.keys()) | |||||
# for k,v in weight_dict.items(): | |||||
# print('{}:{}'.format(k,v.size())) | |||||
def state_dict_param(model): | |||||
param_list = list(model.named_parameters()) | |||||
print(len(param_list)) | |||||
param_dict = {} | |||||
for i in range(len(param_list)): | |||||
param_dict[param_list[i][0]] = param_list[i][1] | |||||
return param_dict | |||||
def copy_yangjie_lattice_weight(target,source_dict): | |||||
t = state_dict_param(target) | |||||
with torch.no_grad(): | |||||
t['encoder.char_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.weight_ih']) | |||||
t['encoder.char_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.weight_hh']) | |||||
t['encoder.char_cell.alpha_weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_ih']) | |||||
t['encoder.char_cell.alpha_weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_hh']) | |||||
t['encoder.char_cell.bias'].set_(source_dict['lstm.forward_lstm.rnn.bias']) | |||||
t['encoder.char_cell.alpha_bias'].set_(source_dict['lstm.forward_lstm.rnn.alpha_bias']) | |||||
t['encoder.word_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_ih']) | |||||
t['encoder.word_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_hh']) | |||||
t['encoder.word_cell.bias'].set_(source_dict['lstm.forward_lstm.word_rnn.bias']) | |||||
t['encoder_back.char_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.weight_ih']) | |||||
t['encoder_back.char_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.weight_hh']) | |||||
t['encoder_back.char_cell.alpha_weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_ih']) | |||||
t['encoder_back.char_cell.alpha_weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_hh']) | |||||
t['encoder_back.char_cell.bias'].set_(source_dict['lstm.backward_lstm.rnn.bias']) | |||||
t['encoder_back.char_cell.alpha_bias'].set_(source_dict['lstm.backward_lstm.rnn.alpha_bias']) | |||||
t['encoder_back.word_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_ih']) | |||||
t['encoder_back.word_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_hh']) | |||||
t['encoder_back.word_cell.bias'].set_(source_dict['lstm.backward_lstm.word_rnn.bias']) | |||||
for k,v in t.items(): | |||||
print('{}:{}'.format(k,v)) | |||||
copy_yangjie_lattice_weight(model,weight_dict) | |||||
# print(vocabs['label'].word2idx.keys()) | |||||
loss = LossInForward() | |||||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||||
if args.optim == 'adam': | |||||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||||
elif args.optim == 'sgd': | |||||
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) | |||||
# tester = Tester(datasets['dev'],model,metrics=metrics,batch_size=args.test_batch,device=device) | |||||
# test_result = tester.test() | |||||
# print(test_result) | |||||
callbacks = [ | |||||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05)**ep)) | |||||
] | |||||
print(datasets['train'][:2]) | |||||
print(vocabs['char'].to_index(':')) | |||||
# StaticEmbedding | |||||
# datasets['train'] = datasets['train'][1:] | |||||
from fastNLP import SequentialSampler | |||||
trainer = Trainer(datasets['train'],model, | |||||
optimizer=optimizer, | |||||
loss=loss, | |||||
metrics=metrics, | |||||
dev_data=datasets['dev'], | |||||
device=device, | |||||
batch_size=args.batch, | |||||
n_epochs=args.epoch, | |||||
dev_batch_size=args.test_batch, | |||||
callbacks=callbacks, | |||||
check_code_level=-1, | |||||
sampler=SequentialSampler()) | |||||
trainer.train() |
@@ -0,0 +1,772 @@ | |||||
from fastNLP.io import CSVLoader | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import Const | |||||
import numpy as np | |||||
import fitlog | |||||
import pickle | |||||
import os | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP import cache_results | |||||
@cache_results(_cache_fp='mtl16', _refresh=False) | |||||
def load_16_task(dict_path): | |||||
''' | |||||
:param dict_path: /remote-home/txsun/fnlp/MTL-LT/data | |||||
:return: | |||||
''' | |||||
task_path = os.path.join(dict_path,'data.pkl') | |||||
embedding_path = os.path.join(dict_path,'word_embedding.npy') | |||||
embedding = np.load(embedding_path).astype(np.float32) | |||||
task_list = pickle.load(open(task_path, 'rb'))['task_lst'] | |||||
for t in task_list: | |||||
t.train_set.rename_field('words_idx', 'words') | |||||
t.dev_set.rename_field('words_idx', 'words') | |||||
t.test_set.rename_field('words_idx', 'words') | |||||
t.train_set.rename_field('label', 'target') | |||||
t.dev_set.rename_field('label', 'target') | |||||
t.test_set.rename_field('label', 'target') | |||||
t.train_set.add_seq_len('words') | |||||
t.dev_set.add_seq_len('words') | |||||
t.test_set.add_seq_len('words') | |||||
t.train_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
t.dev_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
t.test_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
return task_list,embedding | |||||
@cache_results(_cache_fp='SST2', _refresh=False) | |||||
def load_sst2(dict_path,embedding_path=None): | |||||
''' | |||||
:param dict_path: /remote-home/xnli/data/corpus/text_classification/SST-2/ | |||||
:param embedding_path: glove 300d txt | |||||
:return: | |||||
''' | |||||
train_path = os.path.join(dict_path,'train.tsv') | |||||
dev_path = os.path.join(dict_path,'dev.tsv') | |||||
loader = CSVLoader(headers=('words', 'target'), sep='\t') | |||||
train_data = loader.load(train_path).datasets['train'] | |||||
dev_data = loader.load(dev_path).datasets['train'] | |||||
train_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words') | |||||
dev_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words') | |||||
train_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len') | |||||
dev_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len') | |||||
vocab = Vocabulary(min_freq=2) | |||||
vocab.from_dataset(train_data, field_name='words') | |||||
vocab.from_dataset(dev_data, field_name='words') | |||||
# pretrained_embedding = load_word_emb(embedding_path, 300, vocab) | |||||
label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(train_data, field_name='target') | |||||
label_vocab.index_dataset(train_data, field_name='target') | |||||
label_vocab.index_dataset(dev_data, field_name='target') | |||||
vocab.index_dataset(train_data, field_name='words', new_field_name='words') | |||||
vocab.index_dataset(dev_data, field_name='words', new_field_name='words') | |||||
train_data.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
train_data.set_target(Const.TARGET) | |||||
dev_data.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
dev_data.set_target(Const.TARGET) | |||||
if embedding_path is not None: | |||||
pretrained_embedding = load_word_emb(embedding_path, 300, vocab) | |||||
return (train_data,dev_data),(vocab,label_vocab),pretrained_embedding | |||||
else: | |||||
return (train_data,dev_data),(vocab,label_vocab) | |||||
@cache_results(_cache_fp='OntonotesPOS', _refresh=False) | |||||
def load_conllized_ontonote_POS(path,embedding_path=None): | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
header2index = {'words':3,'POS':4,'NER':10} | |||||
headers = ['words','POS'] | |||||
if 'NER' in headers: | |||||
print('警告!通过 load_conllized_ontonote 函数读出来的NER标签不是BIOS,是纯粹的conll格式,是错误的!') | |||||
indexes = list(map(lambda x:header2index[x],headers)) | |||||
loader = ConllLoader(headers,indexes) | |||||
bundle = loader.load(path) | |||||
# print(bundle.datasets) | |||||
train_set = bundle.datasets['train'] | |||||
dev_set = bundle.datasets['dev'] | |||||
test_set = bundle.datasets['test'] | |||||
# train_set = loader.load(os.path.join(path,'train.txt')) | |||||
# dev_set = loader.load(os.path.join(path, 'dev.txt')) | |||||
# test_set = loader.load(os.path.join(path, 'test.txt')) | |||||
# print(len(train_set)) | |||||
train_set.add_seq_len('words','seq_len') | |||||
dev_set.add_seq_len('words','seq_len') | |||||
test_set.add_seq_len('words','seq_len') | |||||
# print(dataset['POS']) | |||||
vocab = Vocabulary(min_freq=1) | |||||
vocab.from_dataset(train_set,field_name='words') | |||||
vocab.from_dataset(dev_set, field_name='words') | |||||
vocab.from_dataset(test_set, field_name='words') | |||||
vocab.index_dataset(train_set,field_name='words') | |||||
vocab.index_dataset(dev_set, field_name='words') | |||||
vocab.index_dataset(test_set, field_name='words') | |||||
label_vocab_dict = {} | |||||
for i,h in enumerate(headers): | |||||
if h == 'words': | |||||
continue | |||||
label_vocab_dict[h] = Vocabulary(min_freq=1,padding=None,unknown=None) | |||||
label_vocab_dict[h].from_dataset(train_set,field_name=h) | |||||
label_vocab_dict[h].index_dataset(train_set,field_name=h) | |||||
label_vocab_dict[h].index_dataset(dev_set,field_name=h) | |||||
label_vocab_dict[h].index_dataset(test_set,field_name=h) | |||||
train_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
train_set.set_target(headers[1]) | |||||
dev_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
dev_set.set_target(headers[1]) | |||||
test_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||||
test_set.set_target(headers[1]) | |||||
if len(headers) > 2: | |||||
print('警告:由于任务数量大于1,所以需要每次手动设置target!') | |||||
print('train:',len(train_set),'dev:',len(dev_set),'test:',len(test_set)) | |||||
if embedding_path is not None: | |||||
pretrained_embedding = load_word_emb(embedding_path, 300, vocab) | |||||
return (train_set,dev_set,test_set),(vocab,label_vocab_dict),pretrained_embedding | |||||
else: | |||||
return (train_set, dev_set, test_set), (vocab, label_vocab_dict) | |||||
@cache_results(_cache_fp='OntonotesNER', _refresh=False) | |||||
def load_conllized_ontonote_NER(path,embedding_path=None): | |||||
from fastNLP.io.pipe.conll import OntoNotesNERPipe | |||||
ontoNotesNERPipe = OntoNotesNERPipe(lower=True,target_pad_val=-100) | |||||
bundle_NER = ontoNotesNERPipe.process_from_file(path) | |||||
train_set_NER = bundle_NER.datasets['train'] | |||||
dev_set_NER = bundle_NER.datasets['dev'] | |||||
test_set_NER = bundle_NER.datasets['test'] | |||||
train_set_NER.add_seq_len('words','seq_len') | |||||
dev_set_NER.add_seq_len('words','seq_len') | |||||
test_set_NER.add_seq_len('words','seq_len') | |||||
NER_vocab = bundle_NER.get_vocab('target') | |||||
word_vocab = bundle_NER.get_vocab('words') | |||||
if embedding_path is not None: | |||||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||||
dropout=0.5,lower=True) | |||||
# pretrained_embedding = load_word_emb(embedding_path, 300, word_vocab) | |||||
return (train_set_NER,dev_set_NER,test_set_NER),\ | |||||
(word_vocab,NER_vocab),embed | |||||
else: | |||||
return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab) | |||||
@cache_results(_cache_fp='OntonotesPOSNER', _refresh=False) | |||||
def load_conllized_ontonote_NER_POS(path,embedding_path=None): | |||||
from fastNLP.io.pipe.conll import OntoNotesNERPipe | |||||
ontoNotesNERPipe = OntoNotesNERPipe(lower=True) | |||||
bundle_NER = ontoNotesNERPipe.process_from_file(path) | |||||
train_set_NER = bundle_NER.datasets['train'] | |||||
dev_set_NER = bundle_NER.datasets['dev'] | |||||
test_set_NER = bundle_NER.datasets['test'] | |||||
NER_vocab = bundle_NER.get_vocab('target') | |||||
word_vocab = bundle_NER.get_vocab('words') | |||||
(train_set_POS,dev_set_POS,test_set_POS),(_,POS_vocab) = load_conllized_ontonote_POS(path) | |||||
POS_vocab = POS_vocab['POS'] | |||||
train_set_NER.add_field('pos',train_set_POS['POS'],is_target=True) | |||||
dev_set_NER.add_field('pos', dev_set_POS['POS'], is_target=True) | |||||
test_set_NER.add_field('pos', test_set_POS['POS'], is_target=True) | |||||
if train_set_NER.has_field('target'): | |||||
train_set_NER.rename_field('target','ner') | |||||
if dev_set_NER.has_field('target'): | |||||
dev_set_NER.rename_field('target','ner') | |||||
if test_set_NER.has_field('target'): | |||||
test_set_NER.rename_field('target','ner') | |||||
if train_set_NER.has_field('pos'): | |||||
train_set_NER.rename_field('pos','posid') | |||||
if dev_set_NER.has_field('pos'): | |||||
dev_set_NER.rename_field('pos','posid') | |||||
if test_set_NER.has_field('pos'): | |||||
test_set_NER.rename_field('pos','posid') | |||||
if train_set_NER.has_field('ner'): | |||||
train_set_NER.rename_field('ner','nerid') | |||||
if dev_set_NER.has_field('ner'): | |||||
dev_set_NER.rename_field('ner','nerid') | |||||
if test_set_NER.has_field('ner'): | |||||
test_set_NER.rename_field('ner','nerid') | |||||
if embedding_path is not None: | |||||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||||
dropout=0.5,lower=True) | |||||
return (train_set_NER,dev_set_NER,test_set_NER),\ | |||||
(word_vocab,POS_vocab,NER_vocab),embed | |||||
else: | |||||
return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab) | |||||
@cache_results(_cache_fp='Ontonotes3', _refresh=True) | |||||
def load_conllized_ontonote_pkl(path,embedding_path=None): | |||||
data_bundle = pickle.load(open(path,'rb')) | |||||
train_set = data_bundle.datasets['train'] | |||||
dev_set = data_bundle.datasets['dev'] | |||||
test_set = data_bundle.datasets['test'] | |||||
train_set.rename_field('pos','posid') | |||||
train_set.rename_field('ner','nerid') | |||||
train_set.rename_field('chunk','chunkid') | |||||
dev_set.rename_field('pos','posid') | |||||
dev_set.rename_field('ner','nerid') | |||||
dev_set.rename_field('chunk','chunkid') | |||||
test_set.rename_field('pos','posid') | |||||
test_set.rename_field('ner','nerid') | |||||
test_set.rename_field('chunk','chunkid') | |||||
word_vocab = data_bundle.vocabs['words'] | |||||
pos_vocab = data_bundle.vocabs['pos'] | |||||
ner_vocab = data_bundle.vocabs['ner'] | |||||
chunk_vocab = data_bundle.vocabs['chunk'] | |||||
if embedding_path is not None: | |||||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||||
dropout=0.5,lower=True) | |||||
return (train_set,dev_set,test_set),\ | |||||
(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed | |||||
else: | |||||
return (train_set, dev_set, test_set), (word_vocab,ner_vocab) | |||||
# print(data_bundle) | |||||
# @cache_results(_cache_fp='Conll2003', _refresh=False) | |||||
# def load_conll_2003(path,embedding_path=None): | |||||
# f = open(path, 'rb') | |||||
# data_pkl = pickle.load(f) | |||||
# | |||||
# task_lst = data_pkl['task_lst'] | |||||
# vocabs = data_pkl['vocabs'] | |||||
# # word_vocab = vocabs['words'] | |||||
# # pos_vocab = vocabs['pos'] | |||||
# # chunk_vocab = vocabs['chunk'] | |||||
# # ner_vocab = vocabs['ner'] | |||||
# | |||||
# if embedding_path is not None: | |||||
# embed = StaticEmbedding(vocab=vocabs['words'], model_dir_or_name=embedding_path, word_dropout=0.01, | |||||
# dropout=0.5) | |||||
# return task_lst,vocabs,embed | |||||
# else: | |||||
# return task_lst,vocabs | |||||
# @cache_results(_cache_fp='Conll2003_mine', _refresh=False) | |||||
@cache_results(_cache_fp='Conll2003_mine_embed_100', _refresh=True) | |||||
def load_conll_2003_mine(path,embedding_path=None,pad_val=-100): | |||||
f = open(path, 'rb') | |||||
data_pkl = pickle.load(f) | |||||
# print(data_pkl) | |||||
# print(data_pkl) | |||||
train_set = data_pkl[0]['train'] | |||||
dev_set = data_pkl[0]['dev'] | |||||
test_set = data_pkl[0]['test'] | |||||
train_set.set_pad_val('posid',pad_val) | |||||
train_set.set_pad_val('nerid', pad_val) | |||||
train_set.set_pad_val('chunkid', pad_val) | |||||
dev_set.set_pad_val('posid',pad_val) | |||||
dev_set.set_pad_val('nerid', pad_val) | |||||
dev_set.set_pad_val('chunkid', pad_val) | |||||
test_set.set_pad_val('posid',pad_val) | |||||
test_set.set_pad_val('nerid', pad_val) | |||||
test_set.set_pad_val('chunkid', pad_val) | |||||
if train_set.has_field('task_id'): | |||||
train_set.delete_field('task_id') | |||||
if dev_set.has_field('task_id'): | |||||
dev_set.delete_field('task_id') | |||||
if test_set.has_field('task_id'): | |||||
test_set.delete_field('task_id') | |||||
if train_set.has_field('words_idx'): | |||||
train_set.rename_field('words_idx','words') | |||||
if dev_set.has_field('words_idx'): | |||||
dev_set.rename_field('words_idx','words') | |||||
if test_set.has_field('words_idx'): | |||||
test_set.rename_field('words_idx','words') | |||||
word_vocab = data_pkl[1]['words'] | |||||
pos_vocab = data_pkl[1]['pos'] | |||||
ner_vocab = data_pkl[1]['ner'] | |||||
chunk_vocab = data_pkl[1]['chunk'] | |||||
if embedding_path is not None: | |||||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||||
dropout=0.5,lower=True) | |||||
return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed | |||||
else: | |||||
return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab) | |||||
def load_conllized_ontonote_pkl_yf(path): | |||||
def init_task(task): | |||||
task_name = task.task_name | |||||
for ds in [task.train_set, task.dev_set, task.test_set]: | |||||
if ds.has_field('words'): | |||||
ds.rename_field('words', 'x') | |||||
else: | |||||
ds.rename_field('words_idx', 'x') | |||||
if ds.has_field('label'): | |||||
ds.rename_field('label', 'y') | |||||
else: | |||||
ds.rename_field(task_name, 'y') | |||||
ds.set_input('x', 'y', 'task_id') | |||||
ds.set_target('y') | |||||
if task_name in ['ner', 'chunk'] or 'pos' in task_name: | |||||
ds.set_input('seq_len') | |||||
ds.set_target('seq_len') | |||||
return task | |||||
#/remote-home/yfshao/workdir/datasets/conll03/data.pkl | |||||
def pload(fn): | |||||
with open(fn, 'rb') as f: | |||||
return pickle.load(f) | |||||
DB = pload(path) | |||||
task_lst = DB['task_lst'] | |||||
vocabs = DB['vocabs'] | |||||
task_lst = [init_task(task) for task in task_lst] | |||||
return task_lst, vocabs | |||||
@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False) | |||||
def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||||
normlize={'char':True,'bigram':True,'word':False}): | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
from utils import get_bigrams | |||||
loader = ConllLoader(['chars','target']) | |||||
bundle = loader.load(path) | |||||
datasets = bundle.datasets | |||||
for k,v in datasets.items(): | |||||
print('{}:{}'.format(k,len(v))) | |||||
# print(*list(datasets.keys())) | |||||
vocabs = {} | |||||
word_vocab = Vocabulary() | |||||
bigram_vocab = Vocabulary() | |||||
label_vocab = Vocabulary(padding=None,unknown=None) | |||||
for k,v in datasets.items(): | |||||
# ignore the word segmentation tag | |||||
v.apply_field(lambda x: [w[0] for w in x],'chars','chars') | |||||
v.apply_field(get_bigrams,'chars','bigrams') | |||||
word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||||
print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word)) | |||||
for k,v in datasets.items(): | |||||
# v.set_pad_val('target',-100) | |||||
v.add_seq_len('chars',new_field_name='seq_len') | |||||
vocabs['char'] = word_vocab | |||||
vocabs['label'] = label_vocab | |||||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||||
if index_token: | |||||
word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words') | |||||
bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams') | |||||
label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target') | |||||
# for k,v in datasets.items(): | |||||
# v.set_input('chars','bigrams','seq_len','target') | |||||
# v.set_target('target','seq_len') | |||||
vocabs['bigram'] = bigram_vocab | |||||
embeddings = {} | |||||
if unigram_embedding_path is not None: | |||||
unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path, | |||||
word_dropout=0.01,normalize=normlize['char']) | |||||
embeddings['char'] = unigram_embedding | |||||
if bigram_embedding_path is not None: | |||||
bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path, | |||||
word_dropout=0.01,normalize=normlize['bigram']) | |||||
embeddings['bigram'] = bigram_embedding | |||||
return datasets, vocabs, embeddings | |||||
# datasets,vocabs = load_weibo_ner('/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo') | |||||
# | |||||
# print(datasets['train'][:5]) | |||||
# print(vocabs['word'].idx2word) | |||||
# print(vocabs['target'].idx2word) | |||||
@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False) | |||||
def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||||
normalize={'char':True,'bigram':True,'word':False}): | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
from utils import get_bigrams | |||||
train_path = os.path.join(path,'train.char.bmes') | |||||
dev_path = os.path.join(path,'dev.char.bmes') | |||||
test_path = os.path.join(path,'test.char.bmes') | |||||
loader = ConllLoader(['chars','target']) | |||||
train_bundle = loader.load(train_path) | |||||
dev_bundle = loader.load(dev_path) | |||||
test_bundle = loader.load(test_path) | |||||
datasets = dict() | |||||
datasets['train'] = train_bundle.datasets['train'] | |||||
datasets['dev'] = dev_bundle.datasets['train'] | |||||
datasets['test'] = test_bundle.datasets['train'] | |||||
datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams') | |||||
datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||||
datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||||
datasets['train'].add_seq_len('chars') | |||||
datasets['dev'].add_seq_len('chars') | |||||
datasets['test'].add_seq_len('chars') | |||||
char_vocab = Vocabulary() | |||||
bigram_vocab = Vocabulary() | |||||
label_vocab = Vocabulary(padding=None,unknown=None) | |||||
print(datasets.keys()) | |||||
print(len(datasets['dev'])) | |||||
print(len(datasets['test'])) | |||||
print(len(datasets['train'])) | |||||
char_vocab.from_dataset(datasets['train'],field_name='chars', | |||||
no_create_entry_dataset=[datasets['dev'],datasets['test']] ) | |||||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams', | |||||
no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||||
if index_token: | |||||
char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||||
field_name='chars',new_field_name='chars') | |||||
bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||||
field_name='bigrams',new_field_name='bigrams') | |||||
label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||||
field_name='target',new_field_name='target') | |||||
vocabs = {} | |||||
vocabs['char'] = char_vocab | |||||
vocabs['label'] = label_vocab | |||||
vocabs['bigram'] = bigram_vocab | |||||
vocabs['label'] = label_vocab | |||||
embeddings = {} | |||||
if char_embedding_path is not None: | |||||
char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01, | |||||
normalize=normalize['char']) | |||||
embeddings['char'] = char_embedding | |||||
if bigram_embedding_path is not None: | |||||
bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01, | |||||
normalize=normalize['bigram']) | |||||
embeddings['bigram'] = bigram_embedding | |||||
return datasets,vocabs,embeddings | |||||
@cache_results(_cache_fp='cache/resume_ner',_refresh=False) | |||||
def load_resume_ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||||
normalize={'char':True,'bigram':True,'word':False}): | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
from utils import get_bigrams | |||||
train_path = os.path.join(path,'train.char.bmes') | |||||
dev_path = os.path.join(path,'dev.char.bmes') | |||||
test_path = os.path.join(path,'test.char.bmes') | |||||
loader = ConllLoader(['chars','target']) | |||||
train_bundle = loader.load(train_path) | |||||
dev_bundle = loader.load(dev_path) | |||||
test_bundle = loader.load(test_path) | |||||
datasets = dict() | |||||
datasets['train'] = train_bundle.datasets['train'] | |||||
datasets['dev'] = dev_bundle.datasets['train'] | |||||
datasets['test'] = test_bundle.datasets['train'] | |||||
datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams') | |||||
datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||||
datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||||
datasets['train'].add_seq_len('chars') | |||||
datasets['dev'].add_seq_len('chars') | |||||
datasets['test'].add_seq_len('chars') | |||||
char_vocab = Vocabulary() | |||||
bigram_vocab = Vocabulary() | |||||
label_vocab = Vocabulary(padding=None,unknown=None) | |||||
print(datasets.keys()) | |||||
print(len(datasets['dev'])) | |||||
print(len(datasets['test'])) | |||||
print(len(datasets['train'])) | |||||
char_vocab.from_dataset(datasets['train'],field_name='chars', | |||||
no_create_entry_dataset=[datasets['dev'],datasets['test']] ) | |||||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams', | |||||
no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||||
if index_token: | |||||
char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||||
field_name='chars',new_field_name='chars') | |||||
bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||||
field_name='bigrams',new_field_name='bigrams') | |||||
label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||||
field_name='target',new_field_name='target') | |||||
vocabs = {} | |||||
vocabs['char'] = char_vocab | |||||
vocabs['label'] = label_vocab | |||||
vocabs['bigram'] = bigram_vocab | |||||
embeddings = {} | |||||
if char_embedding_path is not None: | |||||
char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,normalize=normalize['char']) | |||||
embeddings['char'] = char_embedding | |||||
if bigram_embedding_path is not None: | |||||
bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,normalize=normalize['bigram']) | |||||
embeddings['bigram'] = bigram_embedding | |||||
return datasets,vocabs,embeddings | |||||
@cache_results(_cache_fp='need_to_defined_fp',_refresh=False) | |||||
def equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,word_embedding_path=None, | |||||
normalize={'char':True,'bigram':True,'word':False}): | |||||
from utils_ import Trie,get_skip_path | |||||
from functools import partial | |||||
w_trie = Trie() | |||||
for w in w_list: | |||||
w_trie.insert(w) | |||||
# for k,v in datasets.items(): | |||||
# v.apply_field(partial(get_skip_path,w_trie=w_trie),'chars','skips') | |||||
def skips2skips_l2r(chars,w_trie): | |||||
''' | |||||
:param lexicons: list[[int,int,str]] | |||||
:return: skips_l2r | |||||
''' | |||||
# print(lexicons) | |||||
# print('******') | |||||
lexicons = get_skip_path(chars,w_trie=w_trie) | |||||
# max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0 | |||||
result = [[] for _ in range(len(chars))] | |||||
for lex in lexicons: | |||||
s = lex[0] | |||||
e = lex[1] | |||||
w = lex[2] | |||||
result[e].append([s,w]) | |||||
return result | |||||
def skips2skips_r2l(chars,w_trie): | |||||
''' | |||||
:param lexicons: list[[int,int,str]] | |||||
:return: skips_l2r | |||||
''' | |||||
# print(lexicons) | |||||
# print('******') | |||||
lexicons = get_skip_path(chars,w_trie=w_trie) | |||||
# max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0 | |||||
result = [[] for _ in range(len(chars))] | |||||
for lex in lexicons: | |||||
s = lex[0] | |||||
e = lex[1] | |||||
w = lex[2] | |||||
result[s].append([e,w]) | |||||
return result | |||||
for k,v in datasets.items(): | |||||
v.apply_field(partial(skips2skips_l2r,w_trie=w_trie),'chars','skips_l2r') | |||||
for k,v in datasets.items(): | |||||
v.apply_field(partial(skips2skips_r2l,w_trie=w_trie),'chars','skips_r2l') | |||||
# print(v['skips_l2r'][0]) | |||||
word_vocab = Vocabulary() | |||||
word_vocab.add_word_lst(w_list) | |||||
vocabs['word'] = word_vocab | |||||
for k,v in datasets.items(): | |||||
v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_l2r','skips_l2r_source') | |||||
v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_l2r', 'skips_l2r_word') | |||||
for k,v in datasets.items(): | |||||
v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_r2l','skips_r2l_source') | |||||
v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_r2l', 'skips_r2l_word') | |||||
for k,v in datasets.items(): | |||||
v.apply_field(lambda x:list(map(len,x)), 'skips_l2r_word', 'lexicon_count') | |||||
v.apply_field(lambda x: | |||||
list(map(lambda y: | |||||
list(map(lambda z:word_vocab.to_index(z),y)),x)), | |||||
'skips_l2r_word',new_field_name='skips_l2r_word') | |||||
v.apply_field(lambda x:list(map(len,x)), 'skips_r2l_word', 'lexicon_count_back') | |||||
v.apply_field(lambda x: | |||||
list(map(lambda y: | |||||
list(map(lambda z:word_vocab.to_index(z),y)),x)), | |||||
'skips_r2l_word',new_field_name='skips_r2l_word') | |||||
if word_embedding_path is not None: | |||||
word_embedding = StaticEmbedding(word_vocab,word_embedding_path,word_dropout=0,normalize=normalize['word']) | |||||
embeddings['word'] = word_embedding | |||||
vocabs['char'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], | |||||
field_name='chars', new_field_name='chars') | |||||
vocabs['bigram'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], | |||||
field_name='bigrams', new_field_name='bigrams') | |||||
vocabs['label'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], | |||||
field_name='target', new_field_name='target') | |||||
return datasets,vocabs,embeddings | |||||
@cache_results(_cache_fp='cache/load_yangjie_rich_pretrain_word_list',_refresh=False) | |||||
def load_yangjie_rich_pretrain_word_list(embedding_path,drop_characters=True): | |||||
f = open(embedding_path,'r') | |||||
lines = f.readlines() | |||||
w_list = [] | |||||
for line in lines: | |||||
splited = line.strip().split(' ') | |||||
w = splited[0] | |||||
w_list.append(w) | |||||
if drop_characters: | |||||
w_list = list(filter(lambda x:len(x) != 1, w_list)) | |||||
return w_list | |||||
# from pathes import * | |||||
# | |||||
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path, | |||||
# yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path) | |||||
# print(datasets.keys()) | |||||
# print(vocabs.keys()) | |||||
# print(embeddings) | |||||
# yangjie_rich_pretrain_word_path | |||||
# datasets['train'].set_pad_val |
@@ -0,0 +1,189 @@ | |||||
import torch.nn as nn | |||||
# from pathes import * | |||||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||||
import torch.optim as optim | |||||
import argparse | |||||
import torch | |||||
import sys | |||||
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ | |||||
from fastNLP import Tester | |||||
import fitlog | |||||
from fastNLP.core.callback import FitlogCallback | |||||
from utils import set_seed | |||||
import os | |||||
from fastNLP import LRScheduler | |||||
from torch.optim.lr_scheduler import LambdaLR | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument('--device',default='cuda:4') | |||||
parser.add_argument('--debug',default=False) | |||||
parser.add_argument('--norm_embed',default=True) | |||||
parser.add_argument('--batch',default=10) | |||||
parser.add_argument('--test_batch',default=1024) | |||||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||||
parser.add_argument('--lr',default=0.045) | |||||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||||
parser.add_argument('--hidden',default=100) | |||||
parser.add_argument('--momentum',default=0) | |||||
parser.add_argument('--bi',default=True) | |||||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||||
parser.add_argument('--use_bigram',default=True) | |||||
parser.add_argument('--embed_dropout',default=0.5) | |||||
parser.add_argument('--output_dropout',default=0.5) | |||||
parser.add_argument('--epoch',default=100) | |||||
parser.add_argument('--seed',default=100) | |||||
args = parser.parse_args() | |||||
set_seed(args.seed) | |||||
fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] | |||||
if args.model == 'lattice': | |||||
fit_msg_list.append(str(args.skip_before_head)) | |||||
fit_msg = ' '.join(fit_msg_list) | |||||
fitlog.commit(__file__,fit_msg=fit_msg) | |||||
fitlog.add_hyper(args) | |||||
device = torch.device(args.device) | |||||
for k,v in args.__dict__.items(): | |||||
print(k,v) | |||||
refresh_data = False | |||||
from pathes import * | |||||
# ontonote4ner_cn_path = 0 | |||||
# yangjie_rich_pretrain_unigram_path = 0 | |||||
# yangjie_rich_pretrain_bigram_path = 0 | |||||
# resume_ner_path = 0 | |||||
# weibo_ner_path = 0 | |||||
if args.dataset == 'ontonote': | |||||
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False, | |||||
) | |||||
elif args.dataset == 'resume': | |||||
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False, | |||||
) | |||||
elif args.dataset == 'weibo': | |||||
datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
_refresh=refresh_data,index_token=False, | |||||
) | |||||
if args.dataset == 'ontonote': | |||||
args.batch = 10 | |||||
args.lr = 0.045 | |||||
elif args.dataset == 'resume': | |||||
args.batch = 1 | |||||
args.lr = 0.015 | |||||
elif args.dataset == 'weibo': | |||||
args.embed_dropout = 0.1 | |||||
args.output_dropout = 0.1 | |||||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||||
_refresh=refresh_data) | |||||
cache_name = os.path.join('cache',args.dataset+'_lattice') | |||||
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||||
_refresh=refresh_data,_cache_fp=cache_name) | |||||
print(datasets['train'][0]) | |||||
print('vocab info:') | |||||
for k,v in vocabs.items(): | |||||
print('{}:{}'.format(k,len(v))) | |||||
for k,v in datasets.items(): | |||||
if args.model == 'lattice': | |||||
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') | |||||
if args.skip_before_head: | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) | |||||
else: | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_r2l_word', LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) | |||||
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) | |||||
if args.bi: | |||||
v.set_input('chars','bigrams','seq_len', | |||||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||||
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', | |||||
'target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
else: | |||||
v.set_input('chars','bigrams','seq_len', | |||||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||||
'target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
v.set_target('target','seq_len') | |||||
v['target'].set_pad_val(0) | |||||
elif args.model == 'lstm': | |||||
v.set_ignore_type('skips_l2r_word','skips_l2r_source') | |||||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||||
v.set_input('chars','bigrams','seq_len','target', | |||||
use_1st_ins_infer_dim_type=True) | |||||
v.set_target('target','seq_len') | |||||
v['target'].set_pad_val(0) | |||||
print(datasets['dev']['skips_l2r_word'][100]) | |||||
if args.model =='lattice': | |||||
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram | |||||
) | |||||
elif args.model == 'lstm': | |||||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||||
bidirectional=args.bi, | |||||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||||
use_bigram=args.use_bigram) | |||||
loss = LossInForward() | |||||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||||
if args.optim == 'adam': | |||||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||||
elif args.optim == 'sgd': | |||||
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) | |||||
callbacks = [ | |||||
FitlogCallback({'test':datasets['test'],'train':datasets['train']}), | |||||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) | |||||
] | |||||
trainer = Trainer(datasets['train'],model, | |||||
optimizer=optimizer, | |||||
loss=loss, | |||||
metrics=metrics, | |||||
dev_data=datasets['dev'], | |||||
device=device, | |||||
batch_size=args.batch, | |||||
n_epochs=args.epoch, | |||||
dev_batch_size=args.test_batch, | |||||
callbacks=callbacks) | |||||
trainer.train() |
@@ -0,0 +1,299 @@ | |||||
import torch.nn as nn | |||||
from fastNLP.embeddings import StaticEmbedding | |||||
from fastNLP.modules import LSTM, ConditionalRandomField | |||||
import torch | |||||
from fastNLP import seq_len_to_mask | |||||
from utils import better_init_rnn | |||||
class LatticeLSTM_SeqLabel(nn.Module): | |||||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | |||||
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | |||||
skip_before_head=False,use_bigram=True,vocabs=None): | |||||
if device is None: | |||||
self.device = torch.device('cpu') | |||||
else: | |||||
self.device = torch.device(device) | |||||
from modules import LatticeLSTMLayer_sup_back_V0 | |||||
super().__init__() | |||||
self.debug = debug | |||||
self.skip_batch_first = skip_batch_first | |||||
self.char_embed_size = char_embed.embedding.weight.size(1) | |||||
self.bigram_embed_size = bigram_embed.embedding.weight.size(1) | |||||
self.word_embed_size = word_embed.embedding.weight.size(1) | |||||
self.hidden_size = hidden_size | |||||
self.label_size = label_size | |||||
self.bidirectional = bidirectional | |||||
self.use_bigram = use_bigram | |||||
self.vocabs = vocabs | |||||
if self.use_bigram: | |||||
self.input_size = self.char_embed_size + self.bigram_embed_size | |||||
else: | |||||
self.input_size = self.char_embed_size | |||||
self.char_embed = char_embed | |||||
self.bigram_embed = bigram_embed | |||||
self.word_embed = word_embed | |||||
self.encoder = LatticeLSTMLayer_sup_back_V0(self.input_size,self.word_embed_size, | |||||
self.hidden_size, | |||||
left2right=True, | |||||
bias=bias, | |||||
device=self.device, | |||||
debug=self.debug, | |||||
skip_before_head=skip_before_head) | |||||
if self.bidirectional: | |||||
self.encoder_back = LatticeLSTMLayer_sup_back_V0(self.input_size, | |||||
self.word_embed_size, self.hidden_size, | |||||
left2right=False, | |||||
bias=bias, | |||||
device=self.device, | |||||
debug=self.debug, | |||||
skip_before_head=skip_before_head) | |||||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||||
self.crf = ConditionalRandomField(label_size, True) | |||||
self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True)) | |||||
if self.crf.include_start_end_trans: | |||||
self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||||
self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
self.embed_dropout = nn.Dropout(embed_dropout) | |||||
self.output_dropout = nn.Dropout(output_dropout) | |||||
def forward(self, chars, bigrams, seq_len, target, | |||||
skips_l2r_source, skips_l2r_word, lexicon_count, | |||||
skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): | |||||
# print('skips_l2r_word_id:{}'.format(skips_l2r_word.size())) | |||||
batch = chars.size(0) | |||||
max_seq_len = chars.size(1) | |||||
# max_lexicon_count = skips_l2r_word.size(2) | |||||
embed_char = self.char_embed(chars) | |||||
if self.use_bigram: | |||||
embed_bigram = self.bigram_embed(bigrams) | |||||
embedding = torch.cat([embed_char, embed_bigram], dim=-1) | |||||
else: | |||||
embedding = embed_char | |||||
embed_nonword = self.embed_dropout(embedding) | |||||
# skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) | |||||
embed_word = self.word_embed(skips_l2r_word) | |||||
embed_word = self.embed_dropout(embed_word) | |||||
# embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1]) | |||||
encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) | |||||
if self.bidirectional: | |||||
embed_word_back = self.word_embed(skips_r2l_word) | |||||
embed_word_back = self.embed_dropout(embed_word_back) | |||||
encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source, | |||||
embed_word_back, lexicon_count_back) | |||||
encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) | |||||
encoded_h = self.output_dropout(encoded_h) | |||||
pred = self.output(encoded_h) | |||||
mask = seq_len_to_mask(seq_len) | |||||
if self.training: | |||||
loss = self.crf(pred, target, mask) | |||||
return {'loss': loss} | |||||
else: | |||||
pred, path = self.crf.viterbi_decode(pred, mask) | |||||
return {'pred': pred} | |||||
# batch_size, sent_len = pred.shape[0], pred.shape[1] | |||||
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | |||||
# return {'pred':pred,'loss':loss} | |||||
class LatticeLSTM_SeqLabel_V1(nn.Module): | |||||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | |||||
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | |||||
skip_before_head=False,use_bigram=True,vocabs=None): | |||||
if device is None: | |||||
self.device = torch.device('cpu') | |||||
else: | |||||
self.device = torch.device(device) | |||||
from modules import LatticeLSTMLayer_sup_back_V1 | |||||
super().__init__() | |||||
self.count = 0 | |||||
self.debug = debug | |||||
self.skip_batch_first = skip_batch_first | |||||
self.char_embed_size = char_embed.embedding.weight.size(1) | |||||
self.bigram_embed_size = bigram_embed.embedding.weight.size(1) | |||||
self.word_embed_size = word_embed.embedding.weight.size(1) | |||||
self.hidden_size = hidden_size | |||||
self.label_size = label_size | |||||
self.bidirectional = bidirectional | |||||
self.use_bigram = use_bigram | |||||
self.vocabs = vocabs | |||||
if self.use_bigram: | |||||
self.input_size = self.char_embed_size + self.bigram_embed_size | |||||
else: | |||||
self.input_size = self.char_embed_size | |||||
self.char_embed = char_embed | |||||
self.bigram_embed = bigram_embed | |||||
self.word_embed = word_embed | |||||
self.encoder = LatticeLSTMLayer_sup_back_V1(self.input_size,self.word_embed_size, | |||||
self.hidden_size, | |||||
left2right=True, | |||||
bias=bias, | |||||
device=self.device, | |||||
debug=self.debug, | |||||
skip_before_head=skip_before_head) | |||||
if self.bidirectional: | |||||
self.encoder_back = LatticeLSTMLayer_sup_back_V1(self.input_size, | |||||
self.word_embed_size, self.hidden_size, | |||||
left2right=False, | |||||
bias=bias, | |||||
device=self.device, | |||||
debug=self.debug, | |||||
skip_before_head=skip_before_head) | |||||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||||
self.crf = ConditionalRandomField(label_size, True) | |||||
self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True)) | |||||
if self.crf.include_start_end_trans: | |||||
self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||||
self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
self.embed_dropout = nn.Dropout(embed_dropout) | |||||
self.output_dropout = nn.Dropout(output_dropout) | |||||
def forward(self, chars, bigrams, seq_len, target, | |||||
skips_l2r_source, skips_l2r_word, lexicon_count, | |||||
skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): | |||||
batch = chars.size(0) | |||||
max_seq_len = chars.size(1) | |||||
embed_char = self.char_embed(chars) | |||||
if self.use_bigram: | |||||
embed_bigram = self.bigram_embed(bigrams) | |||||
embedding = torch.cat([embed_char, embed_bigram], dim=-1) | |||||
else: | |||||
embedding = embed_char | |||||
embed_nonword = self.embed_dropout(embedding) | |||||
# skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) | |||||
embed_word = self.word_embed(skips_l2r_word) | |||||
embed_word = self.embed_dropout(embed_word) | |||||
encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) | |||||
if self.bidirectional: | |||||
embed_word_back = self.word_embed(skips_r2l_word) | |||||
embed_word_back = self.embed_dropout(embed_word_back) | |||||
encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source, | |||||
embed_word_back, lexicon_count_back) | |||||
encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) | |||||
encoded_h = self.output_dropout(encoded_h) | |||||
pred = self.output(encoded_h) | |||||
mask = seq_len_to_mask(seq_len) | |||||
if self.training: | |||||
loss = self.crf(pred, target, mask) | |||||
return {'loss': loss} | |||||
else: | |||||
pred, path = self.crf.viterbi_decode(pred, mask) | |||||
return {'pred': pred} | |||||
class LSTM_SeqLabel(nn.Module): | |||||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, | |||||
bidirectional=False, device=None, embed_dropout=0, output_dropout=0,use_bigram=True): | |||||
if device is None: | |||||
self.device = torch.device('cpu') | |||||
else: | |||||
self.device = torch.device(device) | |||||
super().__init__() | |||||
self.char_embed_size = char_embed.embedding.weight.size(1) | |||||
self.bigram_embed_size = bigram_embed.embedding.weight.size(1) | |||||
self.word_embed_size = word_embed.embedding.weight.size(1) | |||||
self.hidden_size = hidden_size | |||||
self.label_size = label_size | |||||
self.bidirectional = bidirectional | |||||
self.use_bigram = use_bigram | |||||
self.char_embed = char_embed | |||||
self.bigram_embed = bigram_embed | |||||
self.word_embed = word_embed | |||||
if self.use_bigram: | |||||
self.input_size = self.char_embed_size + self.bigram_embed_size | |||||
else: | |||||
self.input_size = self.char_embed_size | |||||
self.encoder = LSTM(self.input_size, self.hidden_size, | |||||
bidirectional=self.bidirectional) | |||||
better_init_rnn(self.encoder.lstm) | |||||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||||
self.debug = False | |||||
self.loss_func = nn.CrossEntropyLoss() | |||||
self.embed_dropout = nn.Dropout(embed_dropout) | |||||
self.output_dropout = nn.Dropout(output_dropout) | |||||
self.crf = ConditionalRandomField(label_size, True) | |||||
def forward(self, chars, bigrams, seq_len, target): | |||||
embed_char = self.char_embed(chars) | |||||
if self.use_bigram: | |||||
embed_bigram = self.bigram_embed(bigrams) | |||||
embedding = torch.cat([embed_char, embed_bigram], dim=-1) | |||||
else: | |||||
embedding = embed_char | |||||
embedding = self.embed_dropout(embedding) | |||||
encoded_h, encoded_c = self.encoder(embedding, seq_len) | |||||
encoded_h = self.output_dropout(encoded_h) | |||||
pred = self.output(encoded_h) | |||||
mask = seq_len_to_mask(seq_len) | |||||
# pred = self.crf(pred) | |||||
# batch_size, sent_len = pred.shape[0], pred.shape[1] | |||||
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | |||||
if self.training: | |||||
loss = self.crf(pred, target, mask) | |||||
return {'loss': loss} | |||||
else: | |||||
pred, path = self.crf.viterbi_decode(pred, mask) | |||||
return {'pred': pred} |
@@ -0,0 +1,638 @@ | |||||
import torch.nn as nn | |||||
import torch | |||||
from fastNLP.core.utils import seq_len_to_mask | |||||
from utils import better_init_rnn | |||||
import numpy as np | |||||
class WordLSTMCell_yangjie(nn.Module): | |||||
"""A basic LSTM cell.""" | |||||
def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True): | |||||
""" | |||||
Most parts are copied from torch.nn.LSTMCell. | |||||
""" | |||||
super().__init__() | |||||
self.left2right = left2right | |||||
self.debug = debug | |||||
self.input_size = input_size | |||||
self.hidden_size = hidden_size | |||||
self.use_bias = use_bias | |||||
self.weight_ih = nn.Parameter( | |||||
torch.FloatTensor(input_size, 3 * hidden_size)) | |||||
self.weight_hh = nn.Parameter( | |||||
torch.FloatTensor(hidden_size, 3 * hidden_size)) | |||||
if use_bias: | |||||
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) | |||||
else: | |||||
self.register_parameter('bias', None) | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
""" | |||||
Initialize parameters following the way proposed in the paper. | |||||
""" | |||||
nn.init.orthogonal(self.weight_ih.data) | |||||
weight_hh_data = torch.eye(self.hidden_size) | |||||
weight_hh_data = weight_hh_data.repeat(1, 3) | |||||
with torch.no_grad(): | |||||
self.weight_hh.set_(weight_hh_data) | |||||
# The bias is just set to zero vectors. | |||||
if self.use_bias: | |||||
nn.init.constant(self.bias.data, val=0) | |||||
def forward(self, input_, hx): | |||||
""" | |||||
Args: | |||||
input_: A (batch, input_size) tensor containing input | |||||
features. | |||||
hx: A tuple (h_0, c_0), which contains the initial hidden | |||||
and cell state, where the size of both states is | |||||
(batch, hidden_size). | |||||
Returns: | |||||
h_1, c_1: Tensors containing the next hidden and cell state. | |||||
""" | |||||
h_0, c_0 = hx | |||||
batch_size = h_0.size(0) | |||||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||||
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) | |||||
wi = torch.mm(input_, self.weight_ih) | |||||
f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1) | |||||
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) | |||||
return c_1 | |||||
def __repr__(self): | |||||
s = '{name}({input_size}, {hidden_size})' | |||||
return s.format(name=self.__class__.__name__, **self.__dict__) | |||||
class MultiInputLSTMCell_V0(nn.Module): | |||||
def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False): | |||||
super().__init__() | |||||
self.char_input_size = char_input_size | |||||
self.hidden_size = hidden_size | |||||
self.use_bias = use_bias | |||||
self.weight_ih = nn.Parameter( | |||||
torch.FloatTensor(char_input_size, 3 * hidden_size) | |||||
) | |||||
self.weight_hh = nn.Parameter( | |||||
torch.FloatTensor(hidden_size, 3 * hidden_size) | |||||
) | |||||
self.alpha_weight_ih = nn.Parameter( | |||||
torch.FloatTensor(char_input_size, hidden_size) | |||||
) | |||||
self.alpha_weight_hh = nn.Parameter( | |||||
torch.FloatTensor(hidden_size, hidden_size) | |||||
) | |||||
if self.use_bias: | |||||
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) | |||||
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size)) | |||||
else: | |||||
self.register_parameter('bias', None) | |||||
self.register_parameter('alpha_bias', None) | |||||
self.debug = debug | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
""" | |||||
Initialize parameters following the way proposed in the paper. | |||||
""" | |||||
nn.init.orthogonal(self.weight_ih.data) | |||||
nn.init.orthogonal(self.alpha_weight_ih.data) | |||||
weight_hh_data = torch.eye(self.hidden_size) | |||||
weight_hh_data = weight_hh_data.repeat(1, 3) | |||||
with torch.no_grad(): | |||||
self.weight_hh.set_(weight_hh_data) | |||||
alpha_weight_hh_data = torch.eye(self.hidden_size) | |||||
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1) | |||||
with torch.no_grad(): | |||||
self.alpha_weight_hh.set_(alpha_weight_hh_data) | |||||
# The bias is just set to zero vectors. | |||||
if self.use_bias: | |||||
nn.init.constant_(self.bias.data, val=0) | |||||
nn.init.constant_(self.alpha_bias.data, val=0) | |||||
def forward(self, inp, skip_c, skip_count, hx): | |||||
''' | |||||
:param inp: chars B * hidden | |||||
:param skip_c: 由跳边得到的c, B * X * hidden | |||||
:param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask | |||||
:param hx: | |||||
:return: | |||||
''' | |||||
max_skip_count = torch.max(skip_count).item() | |||||
if True: | |||||
h_0, c_0 = hx | |||||
batch_size = h_0.size(0) | |||||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||||
wi = torch.matmul(inp, self.weight_ih) | |||||
wh = torch.matmul(h_0, self.weight_hh) | |||||
i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) | |||||
i = torch.sigmoid(i).unsqueeze(1) | |||||
o = torch.sigmoid(o).unsqueeze(1) | |||||
g = torch.tanh(g).unsqueeze(1) | |||||
alpha_wi = torch.matmul(inp, self.alpha_weight_ih) | |||||
alpha_wi.unsqueeze_(1) | |||||
# alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size) | |||||
alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) | |||||
alpha_bias_batch = self.alpha_bias.unsqueeze(0) | |||||
alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | |||||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) | |||||
skip_mask = 1 - skip_mask | |||||
skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) | |||||
skip_mask = (skip_mask).float()*1e20 | |||||
alpha = alpha - skip_mask | |||||
alpha = torch.exp(torch.cat([i, alpha], dim=1)) | |||||
alpha_sum = torch.sum(alpha, dim=1, keepdim=True) | |||||
alpha = torch.div(alpha, alpha_sum) | |||||
merge_i_c = torch.cat([g, skip_c], dim=1) | |||||
c_1 = merge_i_c * alpha | |||||
c_1 = c_1.sum(1, keepdim=True) | |||||
# h_1 = o * c_1 | |||||
h_1 = o * torch.tanh(c_1) | |||||
return h_1.squeeze(1), c_1.squeeze(1) | |||||
else: | |||||
h_0, c_0 = hx | |||||
batch_size = h_0.size(0) | |||||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||||
wi = torch.matmul(inp, self.weight_ih) | |||||
wh = torch.matmul(h_0, self.weight_hh) | |||||
i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) | |||||
i = torch.sigmoid(i).unsqueeze(1) | |||||
o = torch.sigmoid(o).unsqueeze(1) | |||||
g = torch.tanh(g).unsqueeze(1) | |||||
c_1 = g | |||||
h_1 = o * c_1 | |||||
return h_1,c_1 | |||||
class MultiInputLSTMCell_V1(nn.Module): | |||||
def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False): | |||||
super().__init__() | |||||
self.char_input_size = char_input_size | |||||
self.hidden_size = hidden_size | |||||
self.use_bias = use_bias | |||||
self.weight_ih = nn.Parameter( | |||||
torch.FloatTensor(char_input_size, 3 * hidden_size) | |||||
) | |||||
self.weight_hh = nn.Parameter( | |||||
torch.FloatTensor(hidden_size, 3 * hidden_size) | |||||
) | |||||
self.alpha_weight_ih = nn.Parameter( | |||||
torch.FloatTensor(char_input_size, hidden_size) | |||||
) | |||||
self.alpha_weight_hh = nn.Parameter( | |||||
torch.FloatTensor(hidden_size, hidden_size) | |||||
) | |||||
if self.use_bias: | |||||
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) | |||||
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size)) | |||||
else: | |||||
self.register_parameter('bias', None) | |||||
self.register_parameter('alpha_bias', None) | |||||
self.debug = debug | |||||
self.reset_parameters() | |||||
def reset_parameters(self): | |||||
""" | |||||
Initialize parameters following the way proposed in the paper. | |||||
""" | |||||
nn.init.orthogonal(self.weight_ih.data) | |||||
nn.init.orthogonal(self.alpha_weight_ih.data) | |||||
weight_hh_data = torch.eye(self.hidden_size) | |||||
weight_hh_data = weight_hh_data.repeat(1, 3) | |||||
with torch.no_grad(): | |||||
self.weight_hh.set_(weight_hh_data) | |||||
alpha_weight_hh_data = torch.eye(self.hidden_size) | |||||
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1) | |||||
with torch.no_grad(): | |||||
self.alpha_weight_hh.set_(alpha_weight_hh_data) | |||||
# The bias is just set to zero vectors. | |||||
if self.use_bias: | |||||
nn.init.constant_(self.bias.data, val=0) | |||||
nn.init.constant_(self.alpha_bias.data, val=0) | |||||
def forward(self, inp, skip_c, skip_count, hx): | |||||
''' | |||||
:param inp: chars B * hidden | |||||
:param skip_c: 由跳边得到的c, B * X * hidden | |||||
:param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask | |||||
:param hx: | |||||
:return: | |||||
''' | |||||
max_skip_count = torch.max(skip_count).item() | |||||
if True: | |||||
h_0, c_0 = hx | |||||
batch_size = h_0.size(0) | |||||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||||
wi = torch.matmul(inp, self.weight_ih) | |||||
wh = torch.matmul(h_0, self.weight_hh) | |||||
i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) | |||||
i = torch.sigmoid(i).unsqueeze(1) | |||||
o = torch.sigmoid(o).unsqueeze(1) | |||||
g = torch.tanh(g).unsqueeze(1) | |||||
##basic lstm start | |||||
f = 1 - i | |||||
c_1_basic = f*c_0.unsqueeze(1) + i*g | |||||
c_1_basic = c_1_basic.squeeze(1) | |||||
alpha_wi = torch.matmul(inp, self.alpha_weight_ih) | |||||
alpha_wi.unsqueeze_(1) | |||||
alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) | |||||
alpha_bias_batch = self.alpha_bias.unsqueeze(0) | |||||
alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | |||||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) | |||||
skip_mask = 1 - skip_mask | |||||
skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) | |||||
skip_mask = (skip_mask).float()*1e20 | |||||
alpha = alpha - skip_mask | |||||
alpha = torch.exp(torch.cat([i, alpha], dim=1)) | |||||
alpha_sum = torch.sum(alpha, dim=1, keepdim=True) | |||||
alpha = torch.div(alpha, alpha_sum) | |||||
merge_i_c = torch.cat([g, skip_c], dim=1) | |||||
c_1 = merge_i_c * alpha | |||||
c_1 = c_1.sum(1, keepdim=True) | |||||
# h_1 = o * c_1 | |||||
c_1 = c_1.squeeze(1) | |||||
count_select = (skip_count != 0).float().unsqueeze(-1) | |||||
c_1 = c_1*count_select + c_1_basic*(1-count_select) | |||||
o = o.squeeze(1) | |||||
h_1 = o * torch.tanh(c_1) | |||||
return h_1, c_1 | |||||
class LatticeLSTMLayer_sup_back_V0(nn.Module): | |||||
def __init__(self, char_input_size, word_input_size, hidden_size, left2right, | |||||
bias=True,device=None,debug=False,skip_before_head=False): | |||||
super().__init__() | |||||
self.skip_before_head = skip_before_head | |||||
self.hidden_size = hidden_size | |||||
self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug) | |||||
self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug) | |||||
self.word_input_size = word_input_size | |||||
self.left2right = left2right | |||||
self.bias = bias | |||||
self.device = device | |||||
self.debug = debug | |||||
def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): | |||||
''' | |||||
:param inp: batch * seq_len * embedding, chars | |||||
:param seq_len: batch, length of chars | |||||
:param skip_sources: batch * seq_len * X, 跳边的起点 | |||||
:param skip_words: batch * seq_len * X * embedding, 跳边的词 | |||||
:param lexicon_count: batch * seq_len, count of lexicon per example per position | |||||
:param init_state: the hx of rnn | |||||
:return: | |||||
''' | |||||
if self.left2right: | |||||
max_seq_len = max(seq_len) | |||||
batch_size = inp.size(0) | |||||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
for i in range(max_seq_len): | |||||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||||
h_0, c_0 = h_[:, i, :], c_[:, i, :] | |||||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||||
index_1 = skip_source_flat | |||||
if not self.skip_before_head: | |||||
c_x = c_[[index_0, index_1+1]] | |||||
h_x = h_[[index_0, index_1+1]] | |||||
else: | |||||
c_x = c_[[index_0,index_1]] | |||||
h_x = h_[[index_0,index_1]] | |||||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||||
h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1) | |||||
c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) | |||||
return h_[:,1:],c_[:,1:] | |||||
else: | |||||
mask_for_seq_len = seq_len_to_mask(seq_len) | |||||
max_seq_len = max(seq_len) | |||||
batch_size = inp.size(0) | |||||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
for i in reversed(range(max_seq_len)): | |||||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||||
h_0, c_0 = h_[:, 0, :], c_[:, 0, :] | |||||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||||
index_1 = skip_source_flat-i | |||||
if not self.skip_before_head: | |||||
c_x = c_[[index_0, index_1-1]] | |||||
h_x = h_[[index_0, index_1-1]] | |||||
else: | |||||
c_x = c_[[index_0,index_1]] | |||||
h_x = h_[[index_0,index_1]] | |||||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||||
h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) | |||||
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||||
h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | |||||
c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1) | |||||
return h_[:,:-1],c_[:,:-1] | |||||
class LatticeLSTMLayer_sup_back_V1(nn.Module): | |||||
# V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式, | |||||
# 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同 | |||||
def __init__(self, char_input_size, word_input_size, hidden_size, left2right, | |||||
bias=True,device=None,debug=False,skip_before_head=False): | |||||
super().__init__() | |||||
self.debug = debug | |||||
self.skip_before_head = skip_before_head | |||||
self.hidden_size = hidden_size | |||||
self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug) | |||||
self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug) | |||||
self.word_input_size = word_input_size | |||||
self.left2right = left2right | |||||
self.bias = bias | |||||
self.device = device | |||||
def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): | |||||
''' | |||||
:param inp: batch * seq_len * embedding, chars | |||||
:param seq_len: batch, length of chars | |||||
:param skip_sources: batch * seq_len * X, 跳边的起点 | |||||
:param skip_words: batch * seq_len * X * embedding_size, 跳边的词 | |||||
:param lexicon_count: batch * seq_len, | |||||
lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量 | |||||
:param init_state: the hx of rnn | |||||
:return: | |||||
''' | |||||
if self.left2right: | |||||
max_seq_len = max(seq_len) | |||||
batch_size = inp.size(0) | |||||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
for i in range(max_seq_len): | |||||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||||
h_0, c_0 = h_[:, i, :], c_[:, i, :] | |||||
#为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量 | |||||
#为了匹配pytorch的[]取址方式,需要将reshape成二维张量 | |||||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||||
index_1 = skip_source_flat | |||||
if not self.skip_before_head: | |||||
c_x = c_[[index_0, index_1+1]] | |||||
h_x = h_[[index_0, index_1+1]] | |||||
else: | |||||
c_x = c_[[index_0,index_1]] | |||||
h_x = h_[[index_0,index_1]] | |||||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||||
h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1) | |||||
c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) | |||||
return h_[:,1:],c_[:,1:] | |||||
else: | |||||
mask_for_seq_len = seq_len_to_mask(seq_len) | |||||
max_seq_len = max(seq_len) | |||||
batch_size = inp.size(0) | |||||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||||
for i in reversed(range(max_seq_len)): | |||||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||||
h_0, c_0 = h_[:, 0, :], c_[:, 0, :] | |||||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||||
index_1 = skip_source_flat-i | |||||
if not self.skip_before_head: | |||||
c_x = c_[[index_0, index_1-1]] | |||||
h_x = h_[[index_0, index_1-1]] | |||||
else: | |||||
c_x = c_[[index_0,index_1]] | |||||
h_x = h_[[index_0,index_1]] | |||||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||||
h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) | |||||
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||||
h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | |||||
c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1) | |||||
return h_[:,:-1],c_[:,:-1] | |||||
@@ -0,0 +1,23 @@ | |||||
glove_100_path = 'en-glove-6b-100d' | |||||
glove_50_path = 'en-glove-6b-50d' | |||||
glove_200_path = '' | |||||
glove_300_path = 'en-glove-840b-300' | |||||
fasttext_path = 'en-fasttext' #300 | |||||
tencent_chinese_word_path = 'cn' # tencent 200 | |||||
fasttext_cn_path = 'cn-fasttext' # 300 | |||||
yangjie_rich_pretrain_unigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.uni.ite50.vec' | |||||
yangjie_rich_pretrain_bigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.bi.ite50.vec' | |||||
yangjie_rich_pretrain_word_path = '/remote-home/xnli/data/pretrain/chinese/ctb.50d.vec' | |||||
conll_2003_path = '/remote-home/xnli/data/corpus/multi_task/conll_2013/data_mine.pkl' | |||||
conllized_ontonote_path = '/remote-home/txsun/data/OntoNotes-5.0-NER-master/v12/english' | |||||
conllized_ontonote_pkl_path = '/remote-home/txsun/data/ontonotes5.pkl' | |||||
sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/' | |||||
# weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo' | |||||
ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER' | |||||
msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER' | |||||
resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER' | |||||
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' |
@@ -0,0 +1,126 @@ | |||||
from utils_ import get_skip_path_trivial, Trie, get_skip_path | |||||
from load_data import load_yangjie_rich_pretrain_word_list, load_ontonotes4ner, equip_chinese_ner_with_skip | |||||
from pathes import * | |||||
from functools import partial | |||||
from fastNLP import cache_results | |||||
from fastNLP.embeddings.static_embedding import StaticEmbedding | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.metrics import _bmes_tag_to_spans,_bmeso_tag_to_spans | |||||
from load_data import load_resume_ner | |||||
# embed = StaticEmbedding(None,embedding_dim=2) | |||||
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||||
# _refresh=True,index_token=False) | |||||
# | |||||
# w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||||
# _refresh=False) | |||||
# | |||||
# datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||||
# _refresh=True) | |||||
# | |||||
def reverse_style(input_string): | |||||
target_position = input_string.index('[') | |||||
input_len = len(input_string) | |||||
output_string = input_string[target_position:input_len] + input_string[0:target_position] | |||||
# print('in:{}.out:{}'.format(input_string, output_string)) | |||||
return output_string | |||||
def get_yangjie_bmeso(label_list): | |||||
def get_ner_BMESO_yj(label_list): | |||||
# list_len = len(word_list) | |||||
# assert(list_len == len(label_list)), "word list size unmatch with label list" | |||||
list_len = len(label_list) | |||||
begin_label = 'b-' | |||||
end_label = 'e-' | |||||
single_label = 's-' | |||||
whole_tag = '' | |||||
index_tag = '' | |||||
tag_list = [] | |||||
stand_matrix = [] | |||||
for i in range(0, list_len): | |||||
# wordlabel = word_list[i] | |||||
current_label = label_list[i].lower() | |||||
if begin_label in current_label: | |||||
if index_tag != '': | |||||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||||
whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) | |||||
index_tag = current_label.replace(begin_label, "", 1) | |||||
elif single_label in current_label: | |||||
if index_tag != '': | |||||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||||
whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) | |||||
tag_list.append(whole_tag) | |||||
whole_tag = "" | |||||
index_tag = "" | |||||
elif end_label in current_label: | |||||
if index_tag != '': | |||||
tag_list.append(whole_tag + ',' + str(i)) | |||||
whole_tag = '' | |||||
index_tag = '' | |||||
else: | |||||
continue | |||||
if (whole_tag != '') & (index_tag != ''): | |||||
tag_list.append(whole_tag) | |||||
tag_list_len = len(tag_list) | |||||
for i in range(0, tag_list_len): | |||||
if len(tag_list[i]) > 0: | |||||
tag_list[i] = tag_list[i] + ']' | |||||
insert_list = reverse_style(tag_list[i]) | |||||
stand_matrix.append(insert_list) | |||||
# print stand_matrix | |||||
return stand_matrix | |||||
def transform_YJ_to_fastNLP(span): | |||||
span = span[1:] | |||||
span_split = span.split(']') | |||||
# print('span_list:{}'.format(span_split)) | |||||
span_type = span_split[1] | |||||
# print('span_split[0].split(','):{}'.format(span_split[0].split(','))) | |||||
if ',' in span_split[0]: | |||||
b, e = span_split[0].split(',') | |||||
else: | |||||
b = span_split[0] | |||||
e = b | |||||
b = int(b) | |||||
e = int(e) | |||||
e += 1 | |||||
return (span_type, (b, e)) | |||||
yj_form = get_ner_BMESO_yj(label_list) | |||||
# print('label_list:{}'.format(label_list)) | |||||
# print('yj_from:{}'.format(yj_form)) | |||||
fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form)) | |||||
return fastNLP_form | |||||
# tag_list = ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O'] | |||||
# span_list = get_ner_BMES(tag_list) | |||||
# print(span_list) | |||||
# yangjie_label_list = ['B-NAME', 'E-NAME', 'O', 'B-CONT', 'M-CONT', 'E-CONT', 'B-RACE', 'E-RACE', 'B-TITLE', 'M-TITLE', 'E-TITLE', 'B-EDU', 'M-EDU', 'E-EDU', 'B-ORG', 'M-ORG', 'E-ORG', 'M-NAME', 'B-PRO', 'M-PRO', 'E-PRO', 'S-RACE', 'S-NAME', 'B-LOC', 'M-LOC', 'E-LOC', 'M-RACE', 'S-ORG'] | |||||
# my_label_list = ['O', 'M-ORG', 'M-TITLE', 'B-TITLE', 'E-TITLE', 'B-ORG', 'E-ORG', 'M-EDU', 'B-NAME', 'E-NAME', 'B-EDU', 'E-EDU', 'M-NAME', 'M-PRO', 'M-CONT', 'B-PRO', 'E-PRO', 'B-CONT', 'E-CONT', 'M-LOC', 'B-RACE', 'E-RACE', 'S-NAME', 'B-LOC', 'E-LOC', 'M-RACE', 'S-RACE', 'S-ORG'] | |||||
# yangjie_label = set(yangjie_label_list) | |||||
# my_label = set(my_label_list) | |||||
a = torch.tensor([0,2,0,3]) | |||||
b = (a==0) | |||||
print(b) | |||||
print(b.float()) | |||||
from fastNLP import RandomSampler | |||||
# f = open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb') | |||||
# weight_dict = torch.load(f) | |||||
# print(weight_dict.keys()) | |||||
# for k,v in weight_dict.items(): | |||||
# print("{}:{}".format(k,v.size())) |
@@ -0,0 +1,361 @@ | |||||
import torch.nn.functional as F | |||||
import torch | |||||
import random | |||||
import numpy as np | |||||
from fastNLP import Const | |||||
from fastNLP import CrossEntropyLoss | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import Tester | |||||
import os | |||||
from fastNLP import logger | |||||
def should_mask(name, t=''): | |||||
if 'bias' in name: | |||||
return False | |||||
if 'embedding' in name: | |||||
splited = name.split('.') | |||||
if splited[-1]!='weight': | |||||
return False | |||||
if 'embedding' in splited[-2]: | |||||
return False | |||||
if 'c0' in name: | |||||
return False | |||||
if 'h0' in name: | |||||
return False | |||||
if 'output' in name and t not in name: | |||||
return False | |||||
return True | |||||
def get_init_mask(model): | |||||
init_masks = {} | |||||
for name, param in model.named_parameters(): | |||||
if should_mask(name): | |||||
init_masks[name+'.mask'] = torch.ones_like(param) | |||||
# logger.info(init_masks[name+'.mask'].requires_grad) | |||||
return init_masks | |||||
def set_seed(seed): | |||||
random.seed(seed) | |||||
np.random.seed(seed+100) | |||||
torch.manual_seed(seed+200) | |||||
torch.cuda.manual_seed_all(seed+300) | |||||
def get_parameters_size(model): | |||||
result = {} | |||||
for name,p in model.state_dict().items(): | |||||
result[name] = p.size() | |||||
return result | |||||
def prune_by_proportion_model(model,proportion,task): | |||||
# print('this time prune to ',proportion*100,'%') | |||||
for name, p in model.named_parameters(): | |||||
# print(name) | |||||
if not should_mask(name,task): | |||||
continue | |||||
tensor = p.data.cpu().numpy() | |||||
index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) | |||||
# print(name,'alive count',len(index[0])) | |||||
alive = tensor[index] | |||||
# print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size())) | |||||
percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) | |||||
# tensor = p | |||||
# index = torch.nonzero(model.mask[task][name+'.mask']) | |||||
# # print('nonzero len',index) | |||||
# alive = tensor[index] | |||||
# print('alive size:',alive.shape) | |||||
# prune_by_proportion_model() | |||||
# percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values | |||||
# print('the',(1-proportion)*len(index[0]),'th big') | |||||
# print('threshold:',percentile_value) | |||||
prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) | |||||
# for | |||||
def prune_by_proportion_model_global(model,proportion,task): | |||||
# print('this time prune to ',proportion*100,'%') | |||||
alive = None | |||||
for name, p in model.named_parameters(): | |||||
# print(name) | |||||
if not should_mask(name,task): | |||||
continue | |||||
tensor = p.data.cpu().numpy() | |||||
index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) | |||||
# print(name,'alive count',len(index[0])) | |||||
if alive is None: | |||||
alive = tensor[index] | |||||
else: | |||||
alive = np.concatenate([alive,tensor[index]],axis=0) | |||||
percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) | |||||
for name, p in model.named_parameters(): | |||||
if should_mask(name,task): | |||||
prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) | |||||
def prune_by_threshold_parameter(p, mask, threshold): | |||||
p_abs = torch.abs(p) | |||||
new_mask = (p_abs > threshold).float() | |||||
# print(mask) | |||||
mask[:]*=new_mask | |||||
def one_time_train_and_prune_single_task(trainer,PRUNE_PER, | |||||
optimizer_init_state_dict=None, | |||||
model_init_state_dict=None, | |||||
is_global=None, | |||||
): | |||||
from fastNLP import Trainer | |||||
trainer.optimizer.load_state_dict(optimizer_init_state_dict) | |||||
trainer.model.load_state_dict(model_init_state_dict) | |||||
# print('metrics:',metrics.__dict__) | |||||
# print('loss:',loss.__dict__) | |||||
# print('trainer input:',task.train_set.get_input_name()) | |||||
# trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics, | |||||
# optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks) | |||||
trainer.train(load_best_model=True) | |||||
# tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False) | |||||
# print('FOR DEBUG: test train_set:',tester.test()) | |||||
# print('**'*20) | |||||
# if task.test_set: | |||||
# tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1) | |||||
# tester.test() | |||||
if is_global: | |||||
prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task) | |||||
else: | |||||
prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task) | |||||
# def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None): | |||||
def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None): | |||||
''' | |||||
:param trainer: | |||||
:param ITER: | |||||
:param PRUNE: | |||||
:param is_global: | |||||
:param save_path: should be a dictionary which will be filled with mask and state dict | |||||
:return: | |||||
''' | |||||
from fastNLP import Trainer | |||||
import torch | |||||
import math | |||||
import copy | |||||
PRUNE = args.prune | |||||
ITER = args.iter | |||||
trainer = get_trainer(args,model,train_set,dev_set,test_set,device) | |||||
optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict()) | |||||
model_init_state_dict = copy.deepcopy(trainer.model.state_dict()) | |||||
if save_path is not None: | |||||
if not os.path.exists(save_path): | |||||
os.makedirs(save_path) | |||||
# if not os.path.exists(os.path.join(save_path, 'model_init.pkl')): | |||||
# f = open(os.path.join(save_path, 'model_init.pkl'), 'wb') | |||||
# torch.save(trainer.model.state_dict(),f) | |||||
mask_count = 0 | |||||
model = trainer.model | |||||
task = trainer.model.now_task | |||||
for name, p in model.mask[task].items(): | |||||
mask_count += torch.sum(p).item() | |||||
init_mask_count = mask_count | |||||
logger.info('init mask count:{}'.format(mask_count)) | |||||
# logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count, | |||||
# mask_count / init_mask_count * 100)) | |||||
prune_per_iter = math.pow(PRUNE, 1 / ITER) | |||||
for i in range(ITER): | |||||
trainer = get_trainer(args,model,train_set,dev_set,test_set,device) | |||||
one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict) | |||||
if save_path is not None: | |||||
f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb') | |||||
torch.save(model.mask[task],f) | |||||
mask_count = 0 | |||||
for name, p in model.mask[task].items(): | |||||
mask_count += torch.sum(p).item() | |||||
logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100)) | |||||
def get_appropriate_cuda(task_scale='s'): | |||||
if task_scale not in {'s','m','l'}: | |||||
logger.info('task scale wrong!') | |||||
exit(2) | |||||
import pynvml | |||||
pynvml.nvmlInit() | |||||
total_cuda_num = pynvml.nvmlDeviceGetCount() | |||||
for i in range(total_cuda_num): | |||||
logger.info(i) | |||||
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # 这里的0是GPU id | |||||
memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) | |||||
utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) | |||||
logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu) | |||||
if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2: | |||||
logger.info(i,memInfo.used / memInfo.total) | |||||
return 'cuda:'+str(i) | |||||
if task_scale=='s': | |||||
max_memory=2000 | |||||
elif task_scale=='m': | |||||
max_memory=6000 | |||||
else: | |||||
max_memory = 9000 | |||||
max_id = -1 | |||||
for i in range(total_cuda_num): | |||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 这里的0是GPU id | |||||
memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) | |||||
utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) | |||||
if max_memory < memInfo.free: | |||||
max_memory = memInfo.free | |||||
max_id = i | |||||
if id == -1: | |||||
logger.info('no appropriate gpu, wait!') | |||||
exit(2) | |||||
return 'cuda:'+str(max_id) | |||||
# if memInfo.used / memInfo.total < 0.5: | |||||
# return | |||||
def print_mask(mask_dict): | |||||
def seq_mul(*X): | |||||
res = 1 | |||||
for x in X: | |||||
res*=x | |||||
return res | |||||
for name,p in mask_dict.items(): | |||||
total_size = seq_mul(*p.size()) | |||||
unmasked_size = len(np.nonzero(p)) | |||||
print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%') | |||||
print() | |||||
def check_words_same(dataset_1,dataset_2,field_1,field_2): | |||||
if len(dataset_1[field_1]) != len(dataset_2[field_2]): | |||||
logger.info('CHECK: example num not same!') | |||||
return False | |||||
for i, words in enumerate(dataset_1[field_1]): | |||||
if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]): | |||||
logger.info('CHECK {} th example length not same'.format(i)) | |||||
logger.info('1:{}'.format(dataset_1[field_1][i])) | |||||
logger.info('2:'.format(dataset_2[field_2][i])) | |||||
return False | |||||
# for j,w in enumerate(words): | |||||
# if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]: | |||||
# print('CHECK', i, 'th example has words different!') | |||||
# print('1:',dataset_1[field_1][i]) | |||||
# print('2:',dataset_2[field_2][i]) | |||||
# return False | |||||
logger.info('CHECK: totally same!') | |||||
return True | |||||
def get_now_time(): | |||||
import time | |||||
from datetime import datetime, timezone, timedelta | |||||
dt = datetime.utcnow() | |||||
# print(dt) | |||||
tzutc_8 = timezone(timedelta(hours=8)) | |||||
local_dt = dt.astimezone(tzutc_8) | |||||
result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute, | |||||
local_dt.second)) | |||||
return result | |||||
def get_bigrams(words): | |||||
result = [] | |||||
for i,w in enumerate(words): | |||||
if i!=len(words)-1: | |||||
result.append(words[i]+words[i+1]) | |||||
else: | |||||
result.append(words[i]+'<end>') | |||||
return result | |||||
def print_info(*inp,islog=False,sep=' '): | |||||
from fastNLP import logger | |||||
if islog: | |||||
print(*inp,sep=sep) | |||||
else: | |||||
inp = sep.join(map(str,inp)) | |||||
logger.info(inp) | |||||
def better_init_rnn(rnn,coupled=False): | |||||
import torch.nn as nn | |||||
if coupled: | |||||
repeat_size = 3 | |||||
else: | |||||
repeat_size = 4 | |||||
# print(list(rnn.named_parameters())) | |||||
if hasattr(rnn,'num_layers'): | |||||
for i in range(rnn.num_layers): | |||||
nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data) | |||||
weight_hh_data = torch.eye(rnn.hidden_size) | |||||
weight_hh_data = weight_hh_data.repeat(1, repeat_size) | |||||
with torch.no_grad(): | |||||
getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data) | |||||
nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0) | |||||
nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0) | |||||
if rnn.bidirectional: | |||||
for i in range(rnn.num_layers): | |||||
nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data) | |||||
weight_hh_data = torch.eye(rnn.hidden_size) | |||||
weight_hh_data = weight_hh_data.repeat(1, repeat_size) | |||||
with torch.no_grad(): | |||||
getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data) | |||||
nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0) | |||||
nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0) | |||||
else: | |||||
nn.init.orthogonal(rnn.weight_ih.data) | |||||
weight_hh_data = torch.eye(rnn.hidden_size) | |||||
weight_hh_data = weight_hh_data.repeat(repeat_size,1) | |||||
with torch.no_grad(): | |||||
rnn.weight_hh.set_(weight_hh_data) | |||||
# The bias is just set to zero vectors. | |||||
print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn))) | |||||
if rnn.bias: | |||||
nn.init.constant(rnn.bias_ih.data, val=0) | |||||
nn.init.constant(rnn.bias_hh.data, val=0) | |||||
# print(list(rnn.named_parameters())) | |||||
@@ -0,0 +1,405 @@ | |||||
import collections | |||||
from fastNLP import cache_results | |||||
def get_skip_path(chars,w_trie): | |||||
sentence = ''.join(chars) | |||||
result = w_trie.get_lexicon(sentence) | |||||
return result | |||||
# @cache_results(_cache_fp='cache/get_skip_path_trivial',_refresh=True) | |||||
def get_skip_path_trivial(chars,w_list): | |||||
chars = ''.join(chars) | |||||
w_set = set(w_list) | |||||
result = [] | |||||
# for i in range(len(chars)): | |||||
# result.append([]) | |||||
for i in range(len(chars)-1): | |||||
for j in range(i+2,len(chars)+1): | |||||
if chars[i:j] in w_set: | |||||
result.append([i,j-1,chars[i:j]]) | |||||
return result | |||||
class TrieNode: | |||||
def __init__(self): | |||||
self.children = collections.defaultdict(TrieNode) | |||||
self.is_w = False | |||||
class Trie: | |||||
def __init__(self): | |||||
self.root = TrieNode() | |||||
def insert(self,w): | |||||
current = self.root | |||||
for c in w: | |||||
current = current.children[c] | |||||
current.is_w = True | |||||
def search(self,w): | |||||
''' | |||||
:param w: | |||||
:return: | |||||
-1:not w route | |||||
0:subroute but not word | |||||
1:subroute and word | |||||
''' | |||||
current = self.root | |||||
for c in w: | |||||
current = current.children.get(c) | |||||
if current is None: | |||||
return -1 | |||||
if current.is_w: | |||||
return 1 | |||||
else: | |||||
return 0 | |||||
def get_lexicon(self,sentence): | |||||
result = [] | |||||
for i in range(len(sentence)): | |||||
current = self.root | |||||
for j in range(i, len(sentence)): | |||||
current = current.children.get(sentence[j]) | |||||
if current is None: | |||||
break | |||||
if current.is_w: | |||||
result.append([i,j,sentence[i:j+1]]) | |||||
return result | |||||
from fastNLP.core.field import Padder | |||||
import numpy as np | |||||
import torch | |||||
from collections import defaultdict | |||||
class LatticeLexiconPadder(Padder): | |||||
def __init__(self, pad_val=0, pad_val_dynamic=False,dynamic_offset=0, **kwargs): | |||||
''' | |||||
:param pad_val: | |||||
:param pad_val_dynamic: if True, pad_val is the seq_len | |||||
:param kwargs: | |||||
''' | |||||
self.pad_val = pad_val | |||||
self.pad_val_dynamic = pad_val_dynamic | |||||
self.dynamic_offset = dynamic_offset | |||||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||||
# 与autoPadder中 dim=2 的情况一样 | |||||
max_len = max(map(len, contents)) | |||||
max_len = max(max_len,1)#avoid 0 size dim which causes cuda wrong | |||||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||||
content_i in contents]) | |||||
max_word_len = max(max_word_len,1) | |||||
if self.pad_val_dynamic: | |||||
# print('pad_val_dynamic:{}'.format(max_len-1)) | |||||
array = np.full((len(contents), max_len, max_word_len), max_len-1+self.dynamic_offset, | |||||
dtype=field_ele_dtype) | |||||
else: | |||||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||||
for i, content_i in enumerate(contents): | |||||
for j, content_ii in enumerate(content_i): | |||||
array[i, j, :len(content_ii)] = content_ii | |||||
array = torch.tensor(array) | |||||
return array | |||||
from fastNLP.core.metrics import MetricBase | |||||
def get_yangjie_bmeso(label_list,ignore_labels=None): | |||||
def get_ner_BMESO_yj(label_list): | |||||
def reverse_style(input_string): | |||||
target_position = input_string.index('[') | |||||
input_len = len(input_string) | |||||
output_string = input_string[target_position:input_len] + input_string[0:target_position] | |||||
# print('in:{}.out:{}'.format(input_string, output_string)) | |||||
return output_string | |||||
# list_len = len(word_list) | |||||
# assert(list_len == len(label_list)), "word list size unmatch with label list" | |||||
list_len = len(label_list) | |||||
begin_label = 'b-' | |||||
end_label = 'e-' | |||||
single_label = 's-' | |||||
whole_tag = '' | |||||
index_tag = '' | |||||
tag_list = [] | |||||
stand_matrix = [] | |||||
for i in range(0, list_len): | |||||
# wordlabel = word_list[i] | |||||
current_label = label_list[i].lower() | |||||
if begin_label in current_label: | |||||
if index_tag != '': | |||||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||||
whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) | |||||
index_tag = current_label.replace(begin_label, "", 1) | |||||
elif single_label in current_label: | |||||
if index_tag != '': | |||||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||||
whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) | |||||
tag_list.append(whole_tag) | |||||
whole_tag = "" | |||||
index_tag = "" | |||||
elif end_label in current_label: | |||||
if index_tag != '': | |||||
tag_list.append(whole_tag + ',' + str(i)) | |||||
whole_tag = '' | |||||
index_tag = '' | |||||
else: | |||||
continue | |||||
if (whole_tag != '') & (index_tag != ''): | |||||
tag_list.append(whole_tag) | |||||
tag_list_len = len(tag_list) | |||||
for i in range(0, tag_list_len): | |||||
if len(tag_list[i]) > 0: | |||||
tag_list[i] = tag_list[i] + ']' | |||||
insert_list = reverse_style(tag_list[i]) | |||||
stand_matrix.append(insert_list) | |||||
# print stand_matrix | |||||
return stand_matrix | |||||
def transform_YJ_to_fastNLP(span): | |||||
span = span[1:] | |||||
span_split = span.split(']') | |||||
# print('span_list:{}'.format(span_split)) | |||||
span_type = span_split[1] | |||||
# print('span_split[0].split(','):{}'.format(span_split[0].split(','))) | |||||
if ',' in span_split[0]: | |||||
b, e = span_split[0].split(',') | |||||
else: | |||||
b = span_split[0] | |||||
e = b | |||||
b = int(b) | |||||
e = int(e) | |||||
e += 1 | |||||
return (span_type, (b, e)) | |||||
yj_form = get_ner_BMESO_yj(label_list) | |||||
# print('label_list:{}'.format(label_list)) | |||||
# print('yj_from:{}'.format(yj_form)) | |||||
fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form)) | |||||
return fastNLP_form | |||||
class SpanFPreRecMetric_YJ(MetricBase): | |||||
r""" | |||||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||||
在序列标注问题中,以span的方式计算F, pre, rec. | |||||
比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例) | |||||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||||
最后得到的metric结果为:: | |||||
{ | |||||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||||
'pre': xxx, | |||||
'rec':xxx | |||||
} | |||||
若only_gross=False, 即还会返回各个label的metric统计值:: | |||||
{ | |||||
'f': xxx, | |||||
'pre': xxx, | |||||
'rec':xxx, | |||||
'f-label': xxx, | |||||
'pre-label': xxx, | |||||
'rec-label':xxx, | |||||
... | |||||
} | |||||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||||
个label | |||||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||||
label的f1, pre, rec | |||||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : | |||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
""" | |||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||||
only_gross=True, f_type='micro', beta=1): | |||||
from fastNLP.core import Vocabulary | |||||
from fastNLP.core.metrics import _bmes_tag_to_spans,_bio_tag_to_spans,\ | |||||
_bioes_tag_to_spans,_bmeso_tag_to_spans | |||||
from collections import defaultdict | |||||
encoding_type = encoding_type.lower() | |||||
if not isinstance(tag_vocab, Vocabulary): | |||||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||||
if f_type not in ('micro', 'macro'): | |||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||||
self.encoding_type = encoding_type | |||||
# print('encoding_type:{}'self.encoding_type) | |||||
if self.encoding_type == 'bmes': | |||||
self.tag_to_span_func = _bmes_tag_to_spans | |||||
elif self.encoding_type == 'bio': | |||||
self.tag_to_span_func = _bio_tag_to_spans | |||||
elif self.encoding_type == 'bmeso': | |||||
self.tag_to_span_func = _bmeso_tag_to_spans | |||||
elif self.encoding_type == 'bioes': | |||||
self.tag_to_span_func = _bioes_tag_to_spans | |||||
elif self.encoding_type == 'bmesoyj': | |||||
self.tag_to_span_func = get_yangjie_bmeso | |||||
# self.tag_to_span_func = | |||||
else: | |||||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||||
self.ignore_labels = ignore_labels | |||||
self.f_type = f_type | |||||
self.beta = beta | |||||
self.beta_square = self.beta ** 2 | |||||
self.only_gross = only_gross | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||||
self.tag_vocab = tag_vocab | |||||
self._true_positives = defaultdict(int) | |||||
self._false_positives = defaultdict(int) | |||||
self._false_negatives = defaultdict(int) | |||||
def evaluate(self, pred, target, seq_len): | |||||
from fastNLP.core.utils import _get_func_signature | |||||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | |||||
:param target: [batch, seq_len], 真实值 | |||||
:param seq_len: [batch] 文本长度标记 | |||||
:return: | |||||
""" | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if not isinstance(seq_len, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_len)}.") | |||||
if pred.size() == target.size() and len(target.size()) == 2: | |||||
pass | |||||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||||
num_classes = pred.size(-1) | |||||
pred = pred.argmax(dim=-1) | |||||
if (target >= num_classes).any(): | |||||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||||
"id >= {}, the number of classes.".format(num_classes)) | |||||
else: | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
batch_size = pred.size(0) | |||||
pred = pred.tolist() | |||||
target = target.tolist() | |||||
for i in range(batch_size): | |||||
pred_tags = pred[i][:int(seq_len[i])] | |||||
gold_tags = target[i][:int(seq_len[i])] | |||||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||||
for span in pred_spans: | |||||
if span in gold_spans: | |||||
self._true_positives[span[0]] += 1 | |||||
gold_spans.remove(span) | |||||
else: | |||||
self._false_positives[span[0]] += 1 | |||||
for span in gold_spans: | |||||
self._false_negatives[span[0]] += 1 | |||||
def get_metric(self, reset=True): | |||||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||||
evaluate_result = {} | |||||
if not self.only_gross or self.f_type == 'macro': | |||||
tags = set(self._false_negatives.keys()) | |||||
tags.update(set(self._false_positives.keys())) | |||||
tags.update(set(self._true_positives.keys())) | |||||
f_sum = 0 | |||||
pre_sum = 0 | |||||
rec_sum = 0 | |||||
for tag in tags: | |||||
tp = self._true_positives[tag] | |||||
fn = self._false_negatives[tag] | |||||
fp = self._false_positives[tag] | |||||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||||
f_sum += f | |||||
pre_sum += pre | |||||
rec_sum += rec | |||||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||||
f_key = 'f-{}'.format(tag) | |||||
pre_key = 'pre-{}'.format(tag) | |||||
rec_key = 'rec-{}'.format(tag) | |||||
evaluate_result[f_key] = f | |||||
evaluate_result[pre_key] = pre | |||||
evaluate_result[rec_key] = rec | |||||
if self.f_type == 'macro': | |||||
evaluate_result['f'] = f_sum / len(tags) | |||||
evaluate_result['pre'] = pre_sum / len(tags) | |||||
evaluate_result['rec'] = rec_sum / len(tags) | |||||
if self.f_type == 'micro': | |||||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||||
sum(self._false_negatives.values()), | |||||
sum(self._false_positives.values())) | |||||
evaluate_result['f'] = f | |||||
evaluate_result['pre'] = pre | |||||
evaluate_result['rec'] = rec | |||||
if reset: | |||||
self._true_positives = defaultdict(int) | |||||
self._false_positives = defaultdict(int) | |||||
self._false_negatives = defaultdict(int) | |||||
for key, value in evaluate_result.items(): | |||||
evaluate_result[key] = round(value, 6) | |||||
return evaluate_result | |||||
def _compute_f_pre_rec(self, tp, fn, fp): | |||||
""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||