@@ -0,0 +1,252 @@ | |||
import torch.nn as nn | |||
from pathes import * | |||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner | |||
from fastNLP.embeddings import StaticEmbedding | |||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||
import torch.optim as optim | |||
import argparse | |||
import torch | |||
import sys | |||
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ | |||
from fastNLP import Tester | |||
import fitlog | |||
from fastNLP.core.callback import FitlogCallback | |||
from utils import set_seed | |||
import os | |||
from fastNLP import LRScheduler | |||
from torch.optim.lr_scheduler import LambdaLR | |||
# sys.path.append('.') | |||
# sys.path.append('..') | |||
# for p in sys.path: | |||
# print(p) | |||
# fitlog.add_hyper_in_file (__file__) # record your hyperparameters | |||
########hyper | |||
########hyper | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--device',default='cpu') | |||
parser.add_argument('--debug',default=True) | |||
parser.add_argument('--batch',default=1) | |||
parser.add_argument('--test_batch',default=1024) | |||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||
parser.add_argument('--lr',default=0.015) | |||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||
parser.add_argument('--hidden',default=100) | |||
parser.add_argument('--momentum',default=0) | |||
parser.add_argument('--bi',default=True) | |||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||
parser.add_argument('--use_bigram',default=False) | |||
parser.add_argument('--embed_dropout',default=0) | |||
parser.add_argument('--output_dropout',default=0) | |||
parser.add_argument('--epoch',default=100) | |||
parser.add_argument('--seed',default=100) | |||
args = parser.parse_args() | |||
set_seed(args.seed) | |||
fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] | |||
if args.model == 'lattice': | |||
fit_msg_list.append(str(args.skip_before_head)) | |||
fit_msg = ' '.join(fit_msg_list) | |||
# fitlog.commit(__file__,fit_msg=fit_msg) | |||
# fitlog.add_hyper(args) | |||
device = torch.device(args.device) | |||
for k,v in args.__dict__.items(): | |||
print(k,v) | |||
refresh_data = False | |||
if args.dataset == 'ontonote': | |||
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False) | |||
elif args.dataset == 'resume': | |||
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False) | |||
# exit() | |||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data) | |||
cache_name = os.path.join('cache',args.dataset+'_lattice') | |||
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data,_cache_fp=cache_name) | |||
print('中:embedding:{}'.format(embeddings['char'](24))) | |||
print('embed lookup dropout:{}'.format(embeddings['word'].word_dropout)) | |||
# for k, v in datasets.items(): | |||
# # v.apply_field(lambda x: list(map(len, x)), 'skips_l2r_word', 'lexicon_count') | |||
# # v.apply_field(lambda x: | |||
# # list(map(lambda y: | |||
# # list(map(lambda z: vocabs['word'].to_index(z), y)), x)), | |||
# # 'skips_l2r_word') | |||
print(datasets['train'][0]) | |||
print('vocab info:') | |||
for k,v in vocabs.items(): | |||
print('{}:{}'.format(k,len(v))) | |||
# print(datasets['dev'][0]) | |||
# print(datasets['test'][0]) | |||
# print(datasets['train'].get_all_fields().keys()) | |||
for k,v in datasets.items(): | |||
if args.model == 'lattice': | |||
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') | |||
if args.skip_before_head: | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) | |||
else: | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_word', LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) | |||
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) | |||
if args.bi: | |||
v.set_input('chars','bigrams','seq_len', | |||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', | |||
'target', | |||
use_1st_ins_infer_dim_type=True) | |||
else: | |||
v.set_input('chars','bigrams','seq_len', | |||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||
'target', | |||
use_1st_ins_infer_dim_type=True) | |||
v.set_target('target','seq_len') | |||
v['target'].set_pad_val(0) | |||
elif args.model == 'lstm': | |||
v.set_ignore_type('skips_l2r_word','skips_l2r_source') | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||
v.set_input('chars','bigrams','seq_len','target', | |||
use_1st_ins_infer_dim_type=True) | |||
v.set_target('target','seq_len') | |||
v['target'].set_pad_val(0) | |||
print(datasets['dev']['skips_l2r_word'][100]) | |||
if args.model =='lattice': | |||
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram, | |||
vocabs=vocabs | |||
) | |||
elif args.model == 'lstm': | |||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
bidirectional=args.bi, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
use_bigram=args.use_bigram) | |||
for k,v in model.state_dict().items(): | |||
print('{}:{}'.format(k,v.size())) | |||
# exit() | |||
weight_dict = torch.load(open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')) | |||
# print(weight_dict.keys()) | |||
# for k,v in weight_dict.items(): | |||
# print('{}:{}'.format(k,v.size())) | |||
def state_dict_param(model): | |||
param_list = list(model.named_parameters()) | |||
print(len(param_list)) | |||
param_dict = {} | |||
for i in range(len(param_list)): | |||
param_dict[param_list[i][0]] = param_list[i][1] | |||
return param_dict | |||
def copy_yangjie_lattice_weight(target,source_dict): | |||
t = state_dict_param(target) | |||
with torch.no_grad(): | |||
t['encoder.char_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.weight_ih']) | |||
t['encoder.char_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.weight_hh']) | |||
t['encoder.char_cell.alpha_weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_ih']) | |||
t['encoder.char_cell.alpha_weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_hh']) | |||
t['encoder.char_cell.bias'].set_(source_dict['lstm.forward_lstm.rnn.bias']) | |||
t['encoder.char_cell.alpha_bias'].set_(source_dict['lstm.forward_lstm.rnn.alpha_bias']) | |||
t['encoder.word_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_ih']) | |||
t['encoder.word_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_hh']) | |||
t['encoder.word_cell.bias'].set_(source_dict['lstm.forward_lstm.word_rnn.bias']) | |||
t['encoder_back.char_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.weight_ih']) | |||
t['encoder_back.char_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.weight_hh']) | |||
t['encoder_back.char_cell.alpha_weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_ih']) | |||
t['encoder_back.char_cell.alpha_weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_hh']) | |||
t['encoder_back.char_cell.bias'].set_(source_dict['lstm.backward_lstm.rnn.bias']) | |||
t['encoder_back.char_cell.alpha_bias'].set_(source_dict['lstm.backward_lstm.rnn.alpha_bias']) | |||
t['encoder_back.word_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_ih']) | |||
t['encoder_back.word_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_hh']) | |||
t['encoder_back.word_cell.bias'].set_(source_dict['lstm.backward_lstm.word_rnn.bias']) | |||
for k,v in t.items(): | |||
print('{}:{}'.format(k,v)) | |||
copy_yangjie_lattice_weight(model,weight_dict) | |||
# print(vocabs['label'].word2idx.keys()) | |||
loss = LossInForward() | |||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||
if args.optim == 'adam': | |||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||
elif args.optim == 'sgd': | |||
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) | |||
# tester = Tester(datasets['dev'],model,metrics=metrics,batch_size=args.test_batch,device=device) | |||
# test_result = tester.test() | |||
# print(test_result) | |||
callbacks = [ | |||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05)**ep)) | |||
] | |||
print(datasets['train'][:2]) | |||
print(vocabs['char'].to_index(':')) | |||
# StaticEmbedding | |||
# datasets['train'] = datasets['train'][1:] | |||
from fastNLP import SequentialSampler | |||
trainer = Trainer(datasets['train'],model, | |||
optimizer=optimizer, | |||
loss=loss, | |||
metrics=metrics, | |||
dev_data=datasets['dev'], | |||
device=device, | |||
batch_size=args.batch, | |||
n_epochs=args.epoch, | |||
dev_batch_size=args.test_batch, | |||
callbacks=callbacks, | |||
check_code_level=-1, | |||
sampler=SequentialSampler()) | |||
trainer.train() |
@@ -0,0 +1,772 @@ | |||
from fastNLP.io import CSVLoader | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
import numpy as np | |||
import fitlog | |||
import pickle | |||
import os | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP import cache_results | |||
@cache_results(_cache_fp='mtl16', _refresh=False) | |||
def load_16_task(dict_path): | |||
''' | |||
:param dict_path: /remote-home/txsun/fnlp/MTL-LT/data | |||
:return: | |||
''' | |||
task_path = os.path.join(dict_path,'data.pkl') | |||
embedding_path = os.path.join(dict_path,'word_embedding.npy') | |||
embedding = np.load(embedding_path).astype(np.float32) | |||
task_list = pickle.load(open(task_path, 'rb'))['task_lst'] | |||
for t in task_list: | |||
t.train_set.rename_field('words_idx', 'words') | |||
t.dev_set.rename_field('words_idx', 'words') | |||
t.test_set.rename_field('words_idx', 'words') | |||
t.train_set.rename_field('label', 'target') | |||
t.dev_set.rename_field('label', 'target') | |||
t.test_set.rename_field('label', 'target') | |||
t.train_set.add_seq_len('words') | |||
t.dev_set.add_seq_len('words') | |||
t.test_set.add_seq_len('words') | |||
t.train_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||
t.dev_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||
t.test_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||
return task_list,embedding | |||
@cache_results(_cache_fp='SST2', _refresh=False) | |||
def load_sst2(dict_path,embedding_path=None): | |||
''' | |||
:param dict_path: /remote-home/xnli/data/corpus/text_classification/SST-2/ | |||
:param embedding_path: glove 300d txt | |||
:return: | |||
''' | |||
train_path = os.path.join(dict_path,'train.tsv') | |||
dev_path = os.path.join(dict_path,'dev.tsv') | |||
loader = CSVLoader(headers=('words', 'target'), sep='\t') | |||
train_data = loader.load(train_path).datasets['train'] | |||
dev_data = loader.load(dev_path).datasets['train'] | |||
train_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words') | |||
dev_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words') | |||
train_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len') | |||
dev_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len') | |||
vocab = Vocabulary(min_freq=2) | |||
vocab.from_dataset(train_data, field_name='words') | |||
vocab.from_dataset(dev_data, field_name='words') | |||
# pretrained_embedding = load_word_emb(embedding_path, 300, vocab) | |||
label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(train_data, field_name='target') | |||
label_vocab.index_dataset(train_data, field_name='target') | |||
label_vocab.index_dataset(dev_data, field_name='target') | |||
vocab.index_dataset(train_data, field_name='words', new_field_name='words') | |||
vocab.index_dataset(dev_data, field_name='words', new_field_name='words') | |||
train_data.set_input(Const.INPUT, Const.INPUT_LEN) | |||
train_data.set_target(Const.TARGET) | |||
dev_data.set_input(Const.INPUT, Const.INPUT_LEN) | |||
dev_data.set_target(Const.TARGET) | |||
if embedding_path is not None: | |||
pretrained_embedding = load_word_emb(embedding_path, 300, vocab) | |||
return (train_data,dev_data),(vocab,label_vocab),pretrained_embedding | |||
else: | |||
return (train_data,dev_data),(vocab,label_vocab) | |||
@cache_results(_cache_fp='OntonotesPOS', _refresh=False) | |||
def load_conllized_ontonote_POS(path,embedding_path=None): | |||
from fastNLP.io.data_loader import ConllLoader | |||
header2index = {'words':3,'POS':4,'NER':10} | |||
headers = ['words','POS'] | |||
if 'NER' in headers: | |||
print('警告!通过 load_conllized_ontonote 函数读出来的NER标签不是BIOS,是纯粹的conll格式,是错误的!') | |||
indexes = list(map(lambda x:header2index[x],headers)) | |||
loader = ConllLoader(headers,indexes) | |||
bundle = loader.load(path) | |||
# print(bundle.datasets) | |||
train_set = bundle.datasets['train'] | |||
dev_set = bundle.datasets['dev'] | |||
test_set = bundle.datasets['test'] | |||
# train_set = loader.load(os.path.join(path,'train.txt')) | |||
# dev_set = loader.load(os.path.join(path, 'dev.txt')) | |||
# test_set = loader.load(os.path.join(path, 'test.txt')) | |||
# print(len(train_set)) | |||
train_set.add_seq_len('words','seq_len') | |||
dev_set.add_seq_len('words','seq_len') | |||
test_set.add_seq_len('words','seq_len') | |||
# print(dataset['POS']) | |||
vocab = Vocabulary(min_freq=1) | |||
vocab.from_dataset(train_set,field_name='words') | |||
vocab.from_dataset(dev_set, field_name='words') | |||
vocab.from_dataset(test_set, field_name='words') | |||
vocab.index_dataset(train_set,field_name='words') | |||
vocab.index_dataset(dev_set, field_name='words') | |||
vocab.index_dataset(test_set, field_name='words') | |||
label_vocab_dict = {} | |||
for i,h in enumerate(headers): | |||
if h == 'words': | |||
continue | |||
label_vocab_dict[h] = Vocabulary(min_freq=1,padding=None,unknown=None) | |||
label_vocab_dict[h].from_dataset(train_set,field_name=h) | |||
label_vocab_dict[h].index_dataset(train_set,field_name=h) | |||
label_vocab_dict[h].index_dataset(dev_set,field_name=h) | |||
label_vocab_dict[h].index_dataset(test_set,field_name=h) | |||
train_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||
train_set.set_target(headers[1]) | |||
dev_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||
dev_set.set_target(headers[1]) | |||
test_set.set_input(Const.INPUT, Const.INPUT_LEN) | |||
test_set.set_target(headers[1]) | |||
if len(headers) > 2: | |||
print('警告:由于任务数量大于1,所以需要每次手动设置target!') | |||
print('train:',len(train_set),'dev:',len(dev_set),'test:',len(test_set)) | |||
if embedding_path is not None: | |||
pretrained_embedding = load_word_emb(embedding_path, 300, vocab) | |||
return (train_set,dev_set,test_set),(vocab,label_vocab_dict),pretrained_embedding | |||
else: | |||
return (train_set, dev_set, test_set), (vocab, label_vocab_dict) | |||
@cache_results(_cache_fp='OntonotesNER', _refresh=False) | |||
def load_conllized_ontonote_NER(path,embedding_path=None): | |||
from fastNLP.io.pipe.conll import OntoNotesNERPipe | |||
ontoNotesNERPipe = OntoNotesNERPipe(lower=True,target_pad_val=-100) | |||
bundle_NER = ontoNotesNERPipe.process_from_file(path) | |||
train_set_NER = bundle_NER.datasets['train'] | |||
dev_set_NER = bundle_NER.datasets['dev'] | |||
test_set_NER = bundle_NER.datasets['test'] | |||
train_set_NER.add_seq_len('words','seq_len') | |||
dev_set_NER.add_seq_len('words','seq_len') | |||
test_set_NER.add_seq_len('words','seq_len') | |||
NER_vocab = bundle_NER.get_vocab('target') | |||
word_vocab = bundle_NER.get_vocab('words') | |||
if embedding_path is not None: | |||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||
dropout=0.5,lower=True) | |||
# pretrained_embedding = load_word_emb(embedding_path, 300, word_vocab) | |||
return (train_set_NER,dev_set_NER,test_set_NER),\ | |||
(word_vocab,NER_vocab),embed | |||
else: | |||
return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab) | |||
@cache_results(_cache_fp='OntonotesPOSNER', _refresh=False) | |||
def load_conllized_ontonote_NER_POS(path,embedding_path=None): | |||
from fastNLP.io.pipe.conll import OntoNotesNERPipe | |||
ontoNotesNERPipe = OntoNotesNERPipe(lower=True) | |||
bundle_NER = ontoNotesNERPipe.process_from_file(path) | |||
train_set_NER = bundle_NER.datasets['train'] | |||
dev_set_NER = bundle_NER.datasets['dev'] | |||
test_set_NER = bundle_NER.datasets['test'] | |||
NER_vocab = bundle_NER.get_vocab('target') | |||
word_vocab = bundle_NER.get_vocab('words') | |||
(train_set_POS,dev_set_POS,test_set_POS),(_,POS_vocab) = load_conllized_ontonote_POS(path) | |||
POS_vocab = POS_vocab['POS'] | |||
train_set_NER.add_field('pos',train_set_POS['POS'],is_target=True) | |||
dev_set_NER.add_field('pos', dev_set_POS['POS'], is_target=True) | |||
test_set_NER.add_field('pos', test_set_POS['POS'], is_target=True) | |||
if train_set_NER.has_field('target'): | |||
train_set_NER.rename_field('target','ner') | |||
if dev_set_NER.has_field('target'): | |||
dev_set_NER.rename_field('target','ner') | |||
if test_set_NER.has_field('target'): | |||
test_set_NER.rename_field('target','ner') | |||
if train_set_NER.has_field('pos'): | |||
train_set_NER.rename_field('pos','posid') | |||
if dev_set_NER.has_field('pos'): | |||
dev_set_NER.rename_field('pos','posid') | |||
if test_set_NER.has_field('pos'): | |||
test_set_NER.rename_field('pos','posid') | |||
if train_set_NER.has_field('ner'): | |||
train_set_NER.rename_field('ner','nerid') | |||
if dev_set_NER.has_field('ner'): | |||
dev_set_NER.rename_field('ner','nerid') | |||
if test_set_NER.has_field('ner'): | |||
test_set_NER.rename_field('ner','nerid') | |||
if embedding_path is not None: | |||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||
dropout=0.5,lower=True) | |||
return (train_set_NER,dev_set_NER,test_set_NER),\ | |||
(word_vocab,POS_vocab,NER_vocab),embed | |||
else: | |||
return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab) | |||
@cache_results(_cache_fp='Ontonotes3', _refresh=True) | |||
def load_conllized_ontonote_pkl(path,embedding_path=None): | |||
data_bundle = pickle.load(open(path,'rb')) | |||
train_set = data_bundle.datasets['train'] | |||
dev_set = data_bundle.datasets['dev'] | |||
test_set = data_bundle.datasets['test'] | |||
train_set.rename_field('pos','posid') | |||
train_set.rename_field('ner','nerid') | |||
train_set.rename_field('chunk','chunkid') | |||
dev_set.rename_field('pos','posid') | |||
dev_set.rename_field('ner','nerid') | |||
dev_set.rename_field('chunk','chunkid') | |||
test_set.rename_field('pos','posid') | |||
test_set.rename_field('ner','nerid') | |||
test_set.rename_field('chunk','chunkid') | |||
word_vocab = data_bundle.vocabs['words'] | |||
pos_vocab = data_bundle.vocabs['pos'] | |||
ner_vocab = data_bundle.vocabs['ner'] | |||
chunk_vocab = data_bundle.vocabs['chunk'] | |||
if embedding_path is not None: | |||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||
dropout=0.5,lower=True) | |||
return (train_set,dev_set,test_set),\ | |||
(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed | |||
else: | |||
return (train_set, dev_set, test_set), (word_vocab,ner_vocab) | |||
# print(data_bundle) | |||
# @cache_results(_cache_fp='Conll2003', _refresh=False) | |||
# def load_conll_2003(path,embedding_path=None): | |||
# f = open(path, 'rb') | |||
# data_pkl = pickle.load(f) | |||
# | |||
# task_lst = data_pkl['task_lst'] | |||
# vocabs = data_pkl['vocabs'] | |||
# # word_vocab = vocabs['words'] | |||
# # pos_vocab = vocabs['pos'] | |||
# # chunk_vocab = vocabs['chunk'] | |||
# # ner_vocab = vocabs['ner'] | |||
# | |||
# if embedding_path is not None: | |||
# embed = StaticEmbedding(vocab=vocabs['words'], model_dir_or_name=embedding_path, word_dropout=0.01, | |||
# dropout=0.5) | |||
# return task_lst,vocabs,embed | |||
# else: | |||
# return task_lst,vocabs | |||
# @cache_results(_cache_fp='Conll2003_mine', _refresh=False) | |||
@cache_results(_cache_fp='Conll2003_mine_embed_100', _refresh=True) | |||
def load_conll_2003_mine(path,embedding_path=None,pad_val=-100): | |||
f = open(path, 'rb') | |||
data_pkl = pickle.load(f) | |||
# print(data_pkl) | |||
# print(data_pkl) | |||
train_set = data_pkl[0]['train'] | |||
dev_set = data_pkl[0]['dev'] | |||
test_set = data_pkl[0]['test'] | |||
train_set.set_pad_val('posid',pad_val) | |||
train_set.set_pad_val('nerid', pad_val) | |||
train_set.set_pad_val('chunkid', pad_val) | |||
dev_set.set_pad_val('posid',pad_val) | |||
dev_set.set_pad_val('nerid', pad_val) | |||
dev_set.set_pad_val('chunkid', pad_val) | |||
test_set.set_pad_val('posid',pad_val) | |||
test_set.set_pad_val('nerid', pad_val) | |||
test_set.set_pad_val('chunkid', pad_val) | |||
if train_set.has_field('task_id'): | |||
train_set.delete_field('task_id') | |||
if dev_set.has_field('task_id'): | |||
dev_set.delete_field('task_id') | |||
if test_set.has_field('task_id'): | |||
test_set.delete_field('task_id') | |||
if train_set.has_field('words_idx'): | |||
train_set.rename_field('words_idx','words') | |||
if dev_set.has_field('words_idx'): | |||
dev_set.rename_field('words_idx','words') | |||
if test_set.has_field('words_idx'): | |||
test_set.rename_field('words_idx','words') | |||
word_vocab = data_pkl[1]['words'] | |||
pos_vocab = data_pkl[1]['pos'] | |||
ner_vocab = data_pkl[1]['ner'] | |||
chunk_vocab = data_pkl[1]['chunk'] | |||
if embedding_path is not None: | |||
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01, | |||
dropout=0.5,lower=True) | |||
return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed | |||
else: | |||
return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab) | |||
def load_conllized_ontonote_pkl_yf(path): | |||
def init_task(task): | |||
task_name = task.task_name | |||
for ds in [task.train_set, task.dev_set, task.test_set]: | |||
if ds.has_field('words'): | |||
ds.rename_field('words', 'x') | |||
else: | |||
ds.rename_field('words_idx', 'x') | |||
if ds.has_field('label'): | |||
ds.rename_field('label', 'y') | |||
else: | |||
ds.rename_field(task_name, 'y') | |||
ds.set_input('x', 'y', 'task_id') | |||
ds.set_target('y') | |||
if task_name in ['ner', 'chunk'] or 'pos' in task_name: | |||
ds.set_input('seq_len') | |||
ds.set_target('seq_len') | |||
return task | |||
#/remote-home/yfshao/workdir/datasets/conll03/data.pkl | |||
def pload(fn): | |||
with open(fn, 'rb') as f: | |||
return pickle.load(f) | |||
DB = pload(path) | |||
task_lst = DB['task_lst'] | |||
vocabs = DB['vocabs'] | |||
task_lst = [init_task(task) for task in task_lst] | |||
return task_lst, vocabs | |||
@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False) | |||
def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||
normlize={'char':True,'bigram':True,'word':False}): | |||
from fastNLP.io.data_loader import ConllLoader | |||
from utils import get_bigrams | |||
loader = ConllLoader(['chars','target']) | |||
bundle = loader.load(path) | |||
datasets = bundle.datasets | |||
for k,v in datasets.items(): | |||
print('{}:{}'.format(k,len(v))) | |||
# print(*list(datasets.keys())) | |||
vocabs = {} | |||
word_vocab = Vocabulary() | |||
bigram_vocab = Vocabulary() | |||
label_vocab = Vocabulary(padding=None,unknown=None) | |||
for k,v in datasets.items(): | |||
# ignore the word segmentation tag | |||
v.apply_field(lambda x: [w[0] for w in x],'chars','chars') | |||
v.apply_field(get_bigrams,'chars','bigrams') | |||
word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||
print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word)) | |||
for k,v in datasets.items(): | |||
# v.set_pad_val('target',-100) | |||
v.add_seq_len('chars',new_field_name='seq_len') | |||
vocabs['char'] = word_vocab | |||
vocabs['label'] = label_vocab | |||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||
if index_token: | |||
word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words') | |||
bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams') | |||
label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target') | |||
# for k,v in datasets.items(): | |||
# v.set_input('chars','bigrams','seq_len','target') | |||
# v.set_target('target','seq_len') | |||
vocabs['bigram'] = bigram_vocab | |||
embeddings = {} | |||
if unigram_embedding_path is not None: | |||
unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path, | |||
word_dropout=0.01,normalize=normlize['char']) | |||
embeddings['char'] = unigram_embedding | |||
if bigram_embedding_path is not None: | |||
bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path, | |||
word_dropout=0.01,normalize=normlize['bigram']) | |||
embeddings['bigram'] = bigram_embedding | |||
return datasets, vocabs, embeddings | |||
# datasets,vocabs = load_weibo_ner('/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo') | |||
# | |||
# print(datasets['train'][:5]) | |||
# print(vocabs['word'].idx2word) | |||
# print(vocabs['target'].idx2word) | |||
@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False) | |||
def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||
normalize={'char':True,'bigram':True,'word':False}): | |||
from fastNLP.io.data_loader import ConllLoader | |||
from utils import get_bigrams | |||
train_path = os.path.join(path,'train.char.bmes') | |||
dev_path = os.path.join(path,'dev.char.bmes') | |||
test_path = os.path.join(path,'test.char.bmes') | |||
loader = ConllLoader(['chars','target']) | |||
train_bundle = loader.load(train_path) | |||
dev_bundle = loader.load(dev_path) | |||
test_bundle = loader.load(test_path) | |||
datasets = dict() | |||
datasets['train'] = train_bundle.datasets['train'] | |||
datasets['dev'] = dev_bundle.datasets['train'] | |||
datasets['test'] = test_bundle.datasets['train'] | |||
datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams') | |||
datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||
datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||
datasets['train'].add_seq_len('chars') | |||
datasets['dev'].add_seq_len('chars') | |||
datasets['test'].add_seq_len('chars') | |||
char_vocab = Vocabulary() | |||
bigram_vocab = Vocabulary() | |||
label_vocab = Vocabulary(padding=None,unknown=None) | |||
print(datasets.keys()) | |||
print(len(datasets['dev'])) | |||
print(len(datasets['test'])) | |||
print(len(datasets['train'])) | |||
char_vocab.from_dataset(datasets['train'],field_name='chars', | |||
no_create_entry_dataset=[datasets['dev'],datasets['test']] ) | |||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams', | |||
no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||
if index_token: | |||
char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||
field_name='chars',new_field_name='chars') | |||
bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||
field_name='bigrams',new_field_name='bigrams') | |||
label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||
field_name='target',new_field_name='target') | |||
vocabs = {} | |||
vocabs['char'] = char_vocab | |||
vocabs['label'] = label_vocab | |||
vocabs['bigram'] = bigram_vocab | |||
vocabs['label'] = label_vocab | |||
embeddings = {} | |||
if char_embedding_path is not None: | |||
char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01, | |||
normalize=normalize['char']) | |||
embeddings['char'] = char_embedding | |||
if bigram_embedding_path is not None: | |||
bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01, | |||
normalize=normalize['bigram']) | |||
embeddings['bigram'] = bigram_embedding | |||
return datasets,vocabs,embeddings | |||
@cache_results(_cache_fp='cache/resume_ner',_refresh=False) | |||
def load_resume_ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True, | |||
normalize={'char':True,'bigram':True,'word':False}): | |||
from fastNLP.io.data_loader import ConllLoader | |||
from utils import get_bigrams | |||
train_path = os.path.join(path,'train.char.bmes') | |||
dev_path = os.path.join(path,'dev.char.bmes') | |||
test_path = os.path.join(path,'test.char.bmes') | |||
loader = ConllLoader(['chars','target']) | |||
train_bundle = loader.load(train_path) | |||
dev_bundle = loader.load(dev_path) | |||
test_bundle = loader.load(test_path) | |||
datasets = dict() | |||
datasets['train'] = train_bundle.datasets['train'] | |||
datasets['dev'] = dev_bundle.datasets['train'] | |||
datasets['test'] = test_bundle.datasets['train'] | |||
datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams') | |||
datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||
datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams') | |||
datasets['train'].add_seq_len('chars') | |||
datasets['dev'].add_seq_len('chars') | |||
datasets['test'].add_seq_len('chars') | |||
char_vocab = Vocabulary() | |||
bigram_vocab = Vocabulary() | |||
label_vocab = Vocabulary(padding=None,unknown=None) | |||
print(datasets.keys()) | |||
print(len(datasets['dev'])) | |||
print(len(datasets['test'])) | |||
print(len(datasets['train'])) | |||
char_vocab.from_dataset(datasets['train'],field_name='chars', | |||
no_create_entry_dataset=[datasets['dev'],datasets['test']] ) | |||
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams', | |||
no_create_entry_dataset=[datasets['dev'],datasets['test']]) | |||
label_vocab.from_dataset(datasets['train'],field_name='target') | |||
if index_token: | |||
char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||
field_name='chars',new_field_name='chars') | |||
bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||
field_name='bigrams',new_field_name='bigrams') | |||
label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'], | |||
field_name='target',new_field_name='target') | |||
vocabs = {} | |||
vocabs['char'] = char_vocab | |||
vocabs['label'] = label_vocab | |||
vocabs['bigram'] = bigram_vocab | |||
embeddings = {} | |||
if char_embedding_path is not None: | |||
char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,normalize=normalize['char']) | |||
embeddings['char'] = char_embedding | |||
if bigram_embedding_path is not None: | |||
bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,normalize=normalize['bigram']) | |||
embeddings['bigram'] = bigram_embedding | |||
return datasets,vocabs,embeddings | |||
@cache_results(_cache_fp='need_to_defined_fp',_refresh=False) | |||
def equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,word_embedding_path=None, | |||
normalize={'char':True,'bigram':True,'word':False}): | |||
from utils_ import Trie,get_skip_path | |||
from functools import partial | |||
w_trie = Trie() | |||
for w in w_list: | |||
w_trie.insert(w) | |||
# for k,v in datasets.items(): | |||
# v.apply_field(partial(get_skip_path,w_trie=w_trie),'chars','skips') | |||
def skips2skips_l2r(chars,w_trie): | |||
''' | |||
:param lexicons: list[[int,int,str]] | |||
:return: skips_l2r | |||
''' | |||
# print(lexicons) | |||
# print('******') | |||
lexicons = get_skip_path(chars,w_trie=w_trie) | |||
# max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0 | |||
result = [[] for _ in range(len(chars))] | |||
for lex in lexicons: | |||
s = lex[0] | |||
e = lex[1] | |||
w = lex[2] | |||
result[e].append([s,w]) | |||
return result | |||
def skips2skips_r2l(chars,w_trie): | |||
''' | |||
:param lexicons: list[[int,int,str]] | |||
:return: skips_l2r | |||
''' | |||
# print(lexicons) | |||
# print('******') | |||
lexicons = get_skip_path(chars,w_trie=w_trie) | |||
# max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0 | |||
result = [[] for _ in range(len(chars))] | |||
for lex in lexicons: | |||
s = lex[0] | |||
e = lex[1] | |||
w = lex[2] | |||
result[s].append([e,w]) | |||
return result | |||
for k,v in datasets.items(): | |||
v.apply_field(partial(skips2skips_l2r,w_trie=w_trie),'chars','skips_l2r') | |||
for k,v in datasets.items(): | |||
v.apply_field(partial(skips2skips_r2l,w_trie=w_trie),'chars','skips_r2l') | |||
# print(v['skips_l2r'][0]) | |||
word_vocab = Vocabulary() | |||
word_vocab.add_word_lst(w_list) | |||
vocabs['word'] = word_vocab | |||
for k,v in datasets.items(): | |||
v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_l2r','skips_l2r_source') | |||
v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_l2r', 'skips_l2r_word') | |||
for k,v in datasets.items(): | |||
v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_r2l','skips_r2l_source') | |||
v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_r2l', 'skips_r2l_word') | |||
for k,v in datasets.items(): | |||
v.apply_field(lambda x:list(map(len,x)), 'skips_l2r_word', 'lexicon_count') | |||
v.apply_field(lambda x: | |||
list(map(lambda y: | |||
list(map(lambda z:word_vocab.to_index(z),y)),x)), | |||
'skips_l2r_word',new_field_name='skips_l2r_word') | |||
v.apply_field(lambda x:list(map(len,x)), 'skips_r2l_word', 'lexicon_count_back') | |||
v.apply_field(lambda x: | |||
list(map(lambda y: | |||
list(map(lambda z:word_vocab.to_index(z),y)),x)), | |||
'skips_r2l_word',new_field_name='skips_r2l_word') | |||
if word_embedding_path is not None: | |||
word_embedding = StaticEmbedding(word_vocab,word_embedding_path,word_dropout=0,normalize=normalize['word']) | |||
embeddings['word'] = word_embedding | |||
vocabs['char'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], | |||
field_name='chars', new_field_name='chars') | |||
vocabs['bigram'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], | |||
field_name='bigrams', new_field_name='bigrams') | |||
vocabs['label'].index_dataset(datasets['train'], datasets['dev'], datasets['test'], | |||
field_name='target', new_field_name='target') | |||
return datasets,vocabs,embeddings | |||
@cache_results(_cache_fp='cache/load_yangjie_rich_pretrain_word_list',_refresh=False) | |||
def load_yangjie_rich_pretrain_word_list(embedding_path,drop_characters=True): | |||
f = open(embedding_path,'r') | |||
lines = f.readlines() | |||
w_list = [] | |||
for line in lines: | |||
splited = line.strip().split(' ') | |||
w = splited[0] | |||
w_list.append(w) | |||
if drop_characters: | |||
w_list = list(filter(lambda x:len(x) != 1, w_list)) | |||
return w_list | |||
# from pathes import * | |||
# | |||
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path, | |||
# yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path) | |||
# print(datasets.keys()) | |||
# print(vocabs.keys()) | |||
# print(embeddings) | |||
# yangjie_rich_pretrain_word_path | |||
# datasets['train'].set_pad_val |
@@ -0,0 +1,189 @@ | |||
import torch.nn as nn | |||
# from pathes import * | |||
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner | |||
from fastNLP.embeddings import StaticEmbedding | |||
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1 | |||
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward | |||
import torch.optim as optim | |||
import argparse | |||
import torch | |||
import sys | |||
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ | |||
from fastNLP import Tester | |||
import fitlog | |||
from fastNLP.core.callback import FitlogCallback | |||
from utils import set_seed | |||
import os | |||
from fastNLP import LRScheduler | |||
from torch.optim.lr_scheduler import LambdaLR | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--device',default='cuda:4') | |||
parser.add_argument('--debug',default=False) | |||
parser.add_argument('--norm_embed',default=True) | |||
parser.add_argument('--batch',default=10) | |||
parser.add_argument('--test_batch',default=1024) | |||
parser.add_argument('--optim',default='sgd',help='adam|sgd') | |||
parser.add_argument('--lr',default=0.045) | |||
parser.add_argument('--model',default='lattice',help='lattice|lstm') | |||
parser.add_argument('--skip_before_head',default=False)#in paper it's false | |||
parser.add_argument('--hidden',default=100) | |||
parser.add_argument('--momentum',default=0) | |||
parser.add_argument('--bi',default=True) | |||
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra') | |||
parser.add_argument('--use_bigram',default=True) | |||
parser.add_argument('--embed_dropout',default=0.5) | |||
parser.add_argument('--output_dropout',default=0.5) | |||
parser.add_argument('--epoch',default=100) | |||
parser.add_argument('--seed',default=100) | |||
args = parser.parse_args() | |||
set_seed(args.seed) | |||
fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)] | |||
if args.model == 'lattice': | |||
fit_msg_list.append(str(args.skip_before_head)) | |||
fit_msg = ' '.join(fit_msg_list) | |||
fitlog.commit(__file__,fit_msg=fit_msg) | |||
fitlog.add_hyper(args) | |||
device = torch.device(args.device) | |||
for k,v in args.__dict__.items(): | |||
print(k,v) | |||
refresh_data = False | |||
from pathes import * | |||
# ontonote4ner_cn_path = 0 | |||
# yangjie_rich_pretrain_unigram_path = 0 | |||
# yangjie_rich_pretrain_bigram_path = 0 | |||
# resume_ner_path = 0 | |||
# weibo_ner_path = 0 | |||
if args.dataset == 'ontonote': | |||
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False, | |||
) | |||
elif args.dataset == 'resume': | |||
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False, | |||
) | |||
elif args.dataset == 'weibo': | |||
datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
_refresh=refresh_data,index_token=False, | |||
) | |||
if args.dataset == 'ontonote': | |||
args.batch = 10 | |||
args.lr = 0.045 | |||
elif args.dataset == 'resume': | |||
args.batch = 1 | |||
args.lr = 0.015 | |||
elif args.dataset == 'weibo': | |||
args.embed_dropout = 0.1 | |||
args.output_dropout = 0.1 | |||
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data) | |||
cache_name = os.path.join('cache',args.dataset+'_lattice') | |||
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||
_refresh=refresh_data,_cache_fp=cache_name) | |||
print(datasets['train'][0]) | |||
print('vocab info:') | |||
for k,v in vocabs.items(): | |||
print('{}:{}'.format(k,len(v))) | |||
for k,v in datasets.items(): | |||
if args.model == 'lattice': | |||
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source') | |||
if args.skip_before_head: | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True)) | |||
else: | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_r2l_word', LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1)) | |||
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1)) | |||
if args.bi: | |||
v.set_input('chars','bigrams','seq_len', | |||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back', | |||
'target', | |||
use_1st_ins_infer_dim_type=True) | |||
else: | |||
v.set_input('chars','bigrams','seq_len', | |||
'skips_l2r_word','skips_l2r_source','lexicon_count', | |||
'target', | |||
use_1st_ins_infer_dim_type=True) | |||
v.set_target('target','seq_len') | |||
v['target'].set_pad_val(0) | |||
elif args.model == 'lstm': | |||
v.set_ignore_type('skips_l2r_word','skips_l2r_source') | |||
v.set_padder('skips_l2r_word',LatticeLexiconPadder()) | |||
v.set_padder('skips_l2r_source',LatticeLexiconPadder()) | |||
v.set_input('chars','bigrams','seq_len','target', | |||
use_1st_ins_infer_dim_type=True) | |||
v.set_target('target','seq_len') | |||
v['target'].set_pad_val(0) | |||
print(datasets['dev']['skips_l2r_word'][100]) | |||
if args.model =='lattice': | |||
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
skip_batch_first=True,bidirectional=args.bi,debug=args.debug, | |||
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram | |||
) | |||
elif args.model == 'lstm': | |||
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'], | |||
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device, | |||
bidirectional=args.bi, | |||
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout, | |||
use_bigram=args.use_bigram) | |||
loss = LossInForward() | |||
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso') | |||
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj') | |||
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len') | |||
metrics = [f1_metric,f1_metric_yj,acc_metric] | |||
if args.optim == 'adam': | |||
optimizer = optim.Adam(model.parameters(),lr=args.lr) | |||
elif args.optim == 'sgd': | |||
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum) | |||
callbacks = [ | |||
FitlogCallback({'test':datasets['test'],'train':datasets['train']}), | |||
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep)) | |||
] | |||
trainer = Trainer(datasets['train'],model, | |||
optimizer=optimizer, | |||
loss=loss, | |||
metrics=metrics, | |||
dev_data=datasets['dev'], | |||
device=device, | |||
batch_size=args.batch, | |||
n_epochs=args.epoch, | |||
dev_batch_size=args.test_batch, | |||
callbacks=callbacks) | |||
trainer.train() |
@@ -0,0 +1,299 @@ | |||
import torch.nn as nn | |||
from fastNLP.embeddings import StaticEmbedding | |||
from fastNLP.modules import LSTM, ConditionalRandomField | |||
import torch | |||
from fastNLP import seq_len_to_mask | |||
from utils import better_init_rnn | |||
class LatticeLSTM_SeqLabel(nn.Module): | |||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | |||
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | |||
skip_before_head=False,use_bigram=True,vocabs=None): | |||
if device is None: | |||
self.device = torch.device('cpu') | |||
else: | |||
self.device = torch.device(device) | |||
from modules import LatticeLSTMLayer_sup_back_V0 | |||
super().__init__() | |||
self.debug = debug | |||
self.skip_batch_first = skip_batch_first | |||
self.char_embed_size = char_embed.embedding.weight.size(1) | |||
self.bigram_embed_size = bigram_embed.embedding.weight.size(1) | |||
self.word_embed_size = word_embed.embedding.weight.size(1) | |||
self.hidden_size = hidden_size | |||
self.label_size = label_size | |||
self.bidirectional = bidirectional | |||
self.use_bigram = use_bigram | |||
self.vocabs = vocabs | |||
if self.use_bigram: | |||
self.input_size = self.char_embed_size + self.bigram_embed_size | |||
else: | |||
self.input_size = self.char_embed_size | |||
self.char_embed = char_embed | |||
self.bigram_embed = bigram_embed | |||
self.word_embed = word_embed | |||
self.encoder = LatticeLSTMLayer_sup_back_V0(self.input_size,self.word_embed_size, | |||
self.hidden_size, | |||
left2right=True, | |||
bias=bias, | |||
device=self.device, | |||
debug=self.debug, | |||
skip_before_head=skip_before_head) | |||
if self.bidirectional: | |||
self.encoder_back = LatticeLSTMLayer_sup_back_V0(self.input_size, | |||
self.word_embed_size, self.hidden_size, | |||
left2right=False, | |||
bias=bias, | |||
device=self.device, | |||
debug=self.debug, | |||
skip_before_head=skip_before_head) | |||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||
self.crf = ConditionalRandomField(label_size, True) | |||
self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True)) | |||
if self.crf.include_start_end_trans: | |||
self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||
self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||
self.loss_func = nn.CrossEntropyLoss() | |||
self.embed_dropout = nn.Dropout(embed_dropout) | |||
self.output_dropout = nn.Dropout(output_dropout) | |||
def forward(self, chars, bigrams, seq_len, target, | |||
skips_l2r_source, skips_l2r_word, lexicon_count, | |||
skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): | |||
# print('skips_l2r_word_id:{}'.format(skips_l2r_word.size())) | |||
batch = chars.size(0) | |||
max_seq_len = chars.size(1) | |||
# max_lexicon_count = skips_l2r_word.size(2) | |||
embed_char = self.char_embed(chars) | |||
if self.use_bigram: | |||
embed_bigram = self.bigram_embed(bigrams) | |||
embedding = torch.cat([embed_char, embed_bigram], dim=-1) | |||
else: | |||
embedding = embed_char | |||
embed_nonword = self.embed_dropout(embedding) | |||
# skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) | |||
embed_word = self.word_embed(skips_l2r_word) | |||
embed_word = self.embed_dropout(embed_word) | |||
# embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1]) | |||
encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) | |||
if self.bidirectional: | |||
embed_word_back = self.word_embed(skips_r2l_word) | |||
embed_word_back = self.embed_dropout(embed_word_back) | |||
encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source, | |||
embed_word_back, lexicon_count_back) | |||
encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) | |||
encoded_h = self.output_dropout(encoded_h) | |||
pred = self.output(encoded_h) | |||
mask = seq_len_to_mask(seq_len) | |||
if self.training: | |||
loss = self.crf(pred, target, mask) | |||
return {'loss': loss} | |||
else: | |||
pred, path = self.crf.viterbi_decode(pred, mask) | |||
return {'pred': pred} | |||
# batch_size, sent_len = pred.shape[0], pred.shape[1] | |||
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | |||
# return {'pred':pred,'loss':loss} | |||
class LatticeLSTM_SeqLabel_V1(nn.Module): | |||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False, | |||
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False, | |||
skip_before_head=False,use_bigram=True,vocabs=None): | |||
if device is None: | |||
self.device = torch.device('cpu') | |||
else: | |||
self.device = torch.device(device) | |||
from modules import LatticeLSTMLayer_sup_back_V1 | |||
super().__init__() | |||
self.count = 0 | |||
self.debug = debug | |||
self.skip_batch_first = skip_batch_first | |||
self.char_embed_size = char_embed.embedding.weight.size(1) | |||
self.bigram_embed_size = bigram_embed.embedding.weight.size(1) | |||
self.word_embed_size = word_embed.embedding.weight.size(1) | |||
self.hidden_size = hidden_size | |||
self.label_size = label_size | |||
self.bidirectional = bidirectional | |||
self.use_bigram = use_bigram | |||
self.vocabs = vocabs | |||
if self.use_bigram: | |||
self.input_size = self.char_embed_size + self.bigram_embed_size | |||
else: | |||
self.input_size = self.char_embed_size | |||
self.char_embed = char_embed | |||
self.bigram_embed = bigram_embed | |||
self.word_embed = word_embed | |||
self.encoder = LatticeLSTMLayer_sup_back_V1(self.input_size,self.word_embed_size, | |||
self.hidden_size, | |||
left2right=True, | |||
bias=bias, | |||
device=self.device, | |||
debug=self.debug, | |||
skip_before_head=skip_before_head) | |||
if self.bidirectional: | |||
self.encoder_back = LatticeLSTMLayer_sup_back_V1(self.input_size, | |||
self.word_embed_size, self.hidden_size, | |||
left2right=False, | |||
bias=bias, | |||
device=self.device, | |||
debug=self.debug, | |||
skip_before_head=skip_before_head) | |||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||
self.crf = ConditionalRandomField(label_size, True) | |||
self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True)) | |||
if self.crf.include_start_end_trans: | |||
self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||
self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True)) | |||
self.loss_func = nn.CrossEntropyLoss() | |||
self.embed_dropout = nn.Dropout(embed_dropout) | |||
self.output_dropout = nn.Dropout(output_dropout) | |||
def forward(self, chars, bigrams, seq_len, target, | |||
skips_l2r_source, skips_l2r_word, lexicon_count, | |||
skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None): | |||
batch = chars.size(0) | |||
max_seq_len = chars.size(1) | |||
embed_char = self.char_embed(chars) | |||
if self.use_bigram: | |||
embed_bigram = self.bigram_embed(bigrams) | |||
embedding = torch.cat([embed_char, embed_bigram], dim=-1) | |||
else: | |||
embedding = embed_char | |||
embed_nonword = self.embed_dropout(embedding) | |||
# skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1]) | |||
embed_word = self.word_embed(skips_l2r_word) | |||
embed_word = self.embed_dropout(embed_word) | |||
encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count) | |||
if self.bidirectional: | |||
embed_word_back = self.word_embed(skips_r2l_word) | |||
embed_word_back = self.embed_dropout(embed_word_back) | |||
encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source, | |||
embed_word_back, lexicon_count_back) | |||
encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1) | |||
encoded_h = self.output_dropout(encoded_h) | |||
pred = self.output(encoded_h) | |||
mask = seq_len_to_mask(seq_len) | |||
if self.training: | |||
loss = self.crf(pred, target, mask) | |||
return {'loss': loss} | |||
else: | |||
pred, path = self.crf.viterbi_decode(pred, mask) | |||
return {'pred': pred} | |||
class LSTM_SeqLabel(nn.Module): | |||
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, | |||
bidirectional=False, device=None, embed_dropout=0, output_dropout=0,use_bigram=True): | |||
if device is None: | |||
self.device = torch.device('cpu') | |||
else: | |||
self.device = torch.device(device) | |||
super().__init__() | |||
self.char_embed_size = char_embed.embedding.weight.size(1) | |||
self.bigram_embed_size = bigram_embed.embedding.weight.size(1) | |||
self.word_embed_size = word_embed.embedding.weight.size(1) | |||
self.hidden_size = hidden_size | |||
self.label_size = label_size | |||
self.bidirectional = bidirectional | |||
self.use_bigram = use_bigram | |||
self.char_embed = char_embed | |||
self.bigram_embed = bigram_embed | |||
self.word_embed = word_embed | |||
if self.use_bigram: | |||
self.input_size = self.char_embed_size + self.bigram_embed_size | |||
else: | |||
self.input_size = self.char_embed_size | |||
self.encoder = LSTM(self.input_size, self.hidden_size, | |||
bidirectional=self.bidirectional) | |||
better_init_rnn(self.encoder.lstm) | |||
self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size) | |||
self.debug = False | |||
self.loss_func = nn.CrossEntropyLoss() | |||
self.embed_dropout = nn.Dropout(embed_dropout) | |||
self.output_dropout = nn.Dropout(output_dropout) | |||
self.crf = ConditionalRandomField(label_size, True) | |||
def forward(self, chars, bigrams, seq_len, target): | |||
embed_char = self.char_embed(chars) | |||
if self.use_bigram: | |||
embed_bigram = self.bigram_embed(bigrams) | |||
embedding = torch.cat([embed_char, embed_bigram], dim=-1) | |||
else: | |||
embedding = embed_char | |||
embedding = self.embed_dropout(embedding) | |||
encoded_h, encoded_c = self.encoder(embedding, seq_len) | |||
encoded_h = self.output_dropout(encoded_h) | |||
pred = self.output(encoded_h) | |||
mask = seq_len_to_mask(seq_len) | |||
# pred = self.crf(pred) | |||
# batch_size, sent_len = pred.shape[0], pred.shape[1] | |||
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len)) | |||
if self.training: | |||
loss = self.crf(pred, target, mask) | |||
return {'loss': loss} | |||
else: | |||
pred, path = self.crf.viterbi_decode(pred, mask) | |||
return {'pred': pred} |
@@ -0,0 +1,638 @@ | |||
import torch.nn as nn | |||
import torch | |||
from fastNLP.core.utils import seq_len_to_mask | |||
from utils import better_init_rnn | |||
import numpy as np | |||
class WordLSTMCell_yangjie(nn.Module): | |||
"""A basic LSTM cell.""" | |||
def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True): | |||
""" | |||
Most parts are copied from torch.nn.LSTMCell. | |||
""" | |||
super().__init__() | |||
self.left2right = left2right | |||
self.debug = debug | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.use_bias = use_bias | |||
self.weight_ih = nn.Parameter( | |||
torch.FloatTensor(input_size, 3 * hidden_size)) | |||
self.weight_hh = nn.Parameter( | |||
torch.FloatTensor(hidden_size, 3 * hidden_size)) | |||
if use_bias: | |||
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) | |||
else: | |||
self.register_parameter('bias', None) | |||
self.reset_parameters() | |||
def reset_parameters(self): | |||
""" | |||
Initialize parameters following the way proposed in the paper. | |||
""" | |||
nn.init.orthogonal(self.weight_ih.data) | |||
weight_hh_data = torch.eye(self.hidden_size) | |||
weight_hh_data = weight_hh_data.repeat(1, 3) | |||
with torch.no_grad(): | |||
self.weight_hh.set_(weight_hh_data) | |||
# The bias is just set to zero vectors. | |||
if self.use_bias: | |||
nn.init.constant(self.bias.data, val=0) | |||
def forward(self, input_, hx): | |||
""" | |||
Args: | |||
input_: A (batch, input_size) tensor containing input | |||
features. | |||
hx: A tuple (h_0, c_0), which contains the initial hidden | |||
and cell state, where the size of both states is | |||
(batch, hidden_size). | |||
Returns: | |||
h_1, c_1: Tensors containing the next hidden and cell state. | |||
""" | |||
h_0, c_0 = hx | |||
batch_size = h_0.size(0) | |||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh) | |||
wi = torch.mm(input_, self.weight_ih) | |||
f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1) | |||
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g) | |||
return c_1 | |||
def __repr__(self): | |||
s = '{name}({input_size}, {hidden_size})' | |||
return s.format(name=self.__class__.__name__, **self.__dict__) | |||
class MultiInputLSTMCell_V0(nn.Module): | |||
def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False): | |||
super().__init__() | |||
self.char_input_size = char_input_size | |||
self.hidden_size = hidden_size | |||
self.use_bias = use_bias | |||
self.weight_ih = nn.Parameter( | |||
torch.FloatTensor(char_input_size, 3 * hidden_size) | |||
) | |||
self.weight_hh = nn.Parameter( | |||
torch.FloatTensor(hidden_size, 3 * hidden_size) | |||
) | |||
self.alpha_weight_ih = nn.Parameter( | |||
torch.FloatTensor(char_input_size, hidden_size) | |||
) | |||
self.alpha_weight_hh = nn.Parameter( | |||
torch.FloatTensor(hidden_size, hidden_size) | |||
) | |||
if self.use_bias: | |||
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) | |||
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size)) | |||
else: | |||
self.register_parameter('bias', None) | |||
self.register_parameter('alpha_bias', None) | |||
self.debug = debug | |||
self.reset_parameters() | |||
def reset_parameters(self): | |||
""" | |||
Initialize parameters following the way proposed in the paper. | |||
""" | |||
nn.init.orthogonal(self.weight_ih.data) | |||
nn.init.orthogonal(self.alpha_weight_ih.data) | |||
weight_hh_data = torch.eye(self.hidden_size) | |||
weight_hh_data = weight_hh_data.repeat(1, 3) | |||
with torch.no_grad(): | |||
self.weight_hh.set_(weight_hh_data) | |||
alpha_weight_hh_data = torch.eye(self.hidden_size) | |||
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1) | |||
with torch.no_grad(): | |||
self.alpha_weight_hh.set_(alpha_weight_hh_data) | |||
# The bias is just set to zero vectors. | |||
if self.use_bias: | |||
nn.init.constant_(self.bias.data, val=0) | |||
nn.init.constant_(self.alpha_bias.data, val=0) | |||
def forward(self, inp, skip_c, skip_count, hx): | |||
''' | |||
:param inp: chars B * hidden | |||
:param skip_c: 由跳边得到的c, B * X * hidden | |||
:param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask | |||
:param hx: | |||
:return: | |||
''' | |||
max_skip_count = torch.max(skip_count).item() | |||
if True: | |||
h_0, c_0 = hx | |||
batch_size = h_0.size(0) | |||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||
wi = torch.matmul(inp, self.weight_ih) | |||
wh = torch.matmul(h_0, self.weight_hh) | |||
i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) | |||
i = torch.sigmoid(i).unsqueeze(1) | |||
o = torch.sigmoid(o).unsqueeze(1) | |||
g = torch.tanh(g).unsqueeze(1) | |||
alpha_wi = torch.matmul(inp, self.alpha_weight_ih) | |||
alpha_wi.unsqueeze_(1) | |||
# alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size) | |||
alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) | |||
alpha_bias_batch = self.alpha_bias.unsqueeze(0) | |||
alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | |||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) | |||
skip_mask = 1 - skip_mask | |||
skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) | |||
skip_mask = (skip_mask).float()*1e20 | |||
alpha = alpha - skip_mask | |||
alpha = torch.exp(torch.cat([i, alpha], dim=1)) | |||
alpha_sum = torch.sum(alpha, dim=1, keepdim=True) | |||
alpha = torch.div(alpha, alpha_sum) | |||
merge_i_c = torch.cat([g, skip_c], dim=1) | |||
c_1 = merge_i_c * alpha | |||
c_1 = c_1.sum(1, keepdim=True) | |||
# h_1 = o * c_1 | |||
h_1 = o * torch.tanh(c_1) | |||
return h_1.squeeze(1), c_1.squeeze(1) | |||
else: | |||
h_0, c_0 = hx | |||
batch_size = h_0.size(0) | |||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||
wi = torch.matmul(inp, self.weight_ih) | |||
wh = torch.matmul(h_0, self.weight_hh) | |||
i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) | |||
i = torch.sigmoid(i).unsqueeze(1) | |||
o = torch.sigmoid(o).unsqueeze(1) | |||
g = torch.tanh(g).unsqueeze(1) | |||
c_1 = g | |||
h_1 = o * c_1 | |||
return h_1,c_1 | |||
class MultiInputLSTMCell_V1(nn.Module): | |||
def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False): | |||
super().__init__() | |||
self.char_input_size = char_input_size | |||
self.hidden_size = hidden_size | |||
self.use_bias = use_bias | |||
self.weight_ih = nn.Parameter( | |||
torch.FloatTensor(char_input_size, 3 * hidden_size) | |||
) | |||
self.weight_hh = nn.Parameter( | |||
torch.FloatTensor(hidden_size, 3 * hidden_size) | |||
) | |||
self.alpha_weight_ih = nn.Parameter( | |||
torch.FloatTensor(char_input_size, hidden_size) | |||
) | |||
self.alpha_weight_hh = nn.Parameter( | |||
torch.FloatTensor(hidden_size, hidden_size) | |||
) | |||
if self.use_bias: | |||
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size)) | |||
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size)) | |||
else: | |||
self.register_parameter('bias', None) | |||
self.register_parameter('alpha_bias', None) | |||
self.debug = debug | |||
self.reset_parameters() | |||
def reset_parameters(self): | |||
""" | |||
Initialize parameters following the way proposed in the paper. | |||
""" | |||
nn.init.orthogonal(self.weight_ih.data) | |||
nn.init.orthogonal(self.alpha_weight_ih.data) | |||
weight_hh_data = torch.eye(self.hidden_size) | |||
weight_hh_data = weight_hh_data.repeat(1, 3) | |||
with torch.no_grad(): | |||
self.weight_hh.set_(weight_hh_data) | |||
alpha_weight_hh_data = torch.eye(self.hidden_size) | |||
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1) | |||
with torch.no_grad(): | |||
self.alpha_weight_hh.set_(alpha_weight_hh_data) | |||
# The bias is just set to zero vectors. | |||
if self.use_bias: | |||
nn.init.constant_(self.bias.data, val=0) | |||
nn.init.constant_(self.alpha_bias.data, val=0) | |||
def forward(self, inp, skip_c, skip_count, hx): | |||
''' | |||
:param inp: chars B * hidden | |||
:param skip_c: 由跳边得到的c, B * X * hidden | |||
:param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask | |||
:param hx: | |||
:return: | |||
''' | |||
max_skip_count = torch.max(skip_count).item() | |||
if True: | |||
h_0, c_0 = hx | |||
batch_size = h_0.size(0) | |||
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size())) | |||
wi = torch.matmul(inp, self.weight_ih) | |||
wh = torch.matmul(h_0, self.weight_hh) | |||
i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1) | |||
i = torch.sigmoid(i).unsqueeze(1) | |||
o = torch.sigmoid(o).unsqueeze(1) | |||
g = torch.tanh(g).unsqueeze(1) | |||
##basic lstm start | |||
f = 1 - i | |||
c_1_basic = f*c_0.unsqueeze(1) + i*g | |||
c_1_basic = c_1_basic.squeeze(1) | |||
alpha_wi = torch.matmul(inp, self.alpha_weight_ih) | |||
alpha_wi.unsqueeze_(1) | |||
alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh) | |||
alpha_bias_batch = self.alpha_bias.unsqueeze(0) | |||
alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch) | |||
skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]) | |||
skip_mask = 1 - skip_mask | |||
skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size) | |||
skip_mask = (skip_mask).float()*1e20 | |||
alpha = alpha - skip_mask | |||
alpha = torch.exp(torch.cat([i, alpha], dim=1)) | |||
alpha_sum = torch.sum(alpha, dim=1, keepdim=True) | |||
alpha = torch.div(alpha, alpha_sum) | |||
merge_i_c = torch.cat([g, skip_c], dim=1) | |||
c_1 = merge_i_c * alpha | |||
c_1 = c_1.sum(1, keepdim=True) | |||
# h_1 = o * c_1 | |||
c_1 = c_1.squeeze(1) | |||
count_select = (skip_count != 0).float().unsqueeze(-1) | |||
c_1 = c_1*count_select + c_1_basic*(1-count_select) | |||
o = o.squeeze(1) | |||
h_1 = o * torch.tanh(c_1) | |||
return h_1, c_1 | |||
class LatticeLSTMLayer_sup_back_V0(nn.Module): | |||
def __init__(self, char_input_size, word_input_size, hidden_size, left2right, | |||
bias=True,device=None,debug=False,skip_before_head=False): | |||
super().__init__() | |||
self.skip_before_head = skip_before_head | |||
self.hidden_size = hidden_size | |||
self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug) | |||
self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug) | |||
self.word_input_size = word_input_size | |||
self.left2right = left2right | |||
self.bias = bias | |||
self.device = device | |||
self.debug = debug | |||
def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): | |||
''' | |||
:param inp: batch * seq_len * embedding, chars | |||
:param seq_len: batch, length of chars | |||
:param skip_sources: batch * seq_len * X, 跳边的起点 | |||
:param skip_words: batch * seq_len * X * embedding, 跳边的词 | |||
:param lexicon_count: batch * seq_len, count of lexicon per example per position | |||
:param init_state: the hx of rnn | |||
:return: | |||
''' | |||
if self.left2right: | |||
max_seq_len = max(seq_len) | |||
batch_size = inp.size(0) | |||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
for i in range(max_seq_len): | |||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||
h_0, c_0 = h_[:, i, :], c_[:, i, :] | |||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||
index_1 = skip_source_flat | |||
if not self.skip_before_head: | |||
c_x = c_[[index_0, index_1+1]] | |||
h_x = h_[[index_0, index_1+1]] | |||
else: | |||
c_x = c_[[index_0,index_1]] | |||
h_x = h_[[index_0,index_1]] | |||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||
h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1) | |||
c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) | |||
return h_[:,1:],c_[:,1:] | |||
else: | |||
mask_for_seq_len = seq_len_to_mask(seq_len) | |||
max_seq_len = max(seq_len) | |||
batch_size = inp.size(0) | |||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
for i in reversed(range(max_seq_len)): | |||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||
h_0, c_0 = h_[:, 0, :], c_[:, 0, :] | |||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||
index_1 = skip_source_flat-i | |||
if not self.skip_before_head: | |||
c_x = c_[[index_0, index_1-1]] | |||
h_x = h_[[index_0, index_1-1]] | |||
else: | |||
c_x = c_[[index_0,index_1]] | |||
h_x = h_[[index_0,index_1]] | |||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||
h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) | |||
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||
h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | |||
c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1) | |||
return h_[:,:-1],c_[:,:-1] | |||
class LatticeLSTMLayer_sup_back_V1(nn.Module): | |||
# V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式, | |||
# 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同 | |||
def __init__(self, char_input_size, word_input_size, hidden_size, left2right, | |||
bias=True,device=None,debug=False,skip_before_head=False): | |||
super().__init__() | |||
self.debug = debug | |||
self.skip_before_head = skip_before_head | |||
self.hidden_size = hidden_size | |||
self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug) | |||
self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug) | |||
self.word_input_size = word_input_size | |||
self.left2right = left2right | |||
self.bias = bias | |||
self.device = device | |||
def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None): | |||
''' | |||
:param inp: batch * seq_len * embedding, chars | |||
:param seq_len: batch, length of chars | |||
:param skip_sources: batch * seq_len * X, 跳边的起点 | |||
:param skip_words: batch * seq_len * X * embedding_size, 跳边的词 | |||
:param lexicon_count: batch * seq_len, | |||
lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量 | |||
:param init_state: the hx of rnn | |||
:return: | |||
''' | |||
if self.left2right: | |||
max_seq_len = max(seq_len) | |||
batch_size = inp.size(0) | |||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
for i in range(max_seq_len): | |||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||
h_0, c_0 = h_[:, i, :], c_[:, i, :] | |||
#为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量 | |||
#为了匹配pytorch的[]取址方式,需要将reshape成二维张量 | |||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||
index_1 = skip_source_flat | |||
if not self.skip_before_head: | |||
c_x = c_[[index_0, index_1+1]] | |||
h_x = h_[[index_0, index_1+1]] | |||
else: | |||
c_x = c_[[index_0,index_1]] | |||
h_x = h_[[index_0,index_1]] | |||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||
h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1) | |||
c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1) | |||
return h_[:,1:],c_[:,1:] | |||
else: | |||
mask_for_seq_len = seq_len_to_mask(seq_len) | |||
max_seq_len = max(seq_len) | |||
batch_size = inp.size(0) | |||
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device) | |||
for i in reversed(range(max_seq_len)): | |||
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1) | |||
h_0, c_0 = h_[:, 0, :], c_[:, 0, :] | |||
skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous() | |||
skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size) | |||
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count) | |||
index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count) | |||
index_1 = skip_source_flat-i | |||
if not self.skip_before_head: | |||
c_x = c_[[index_0, index_1-1]] | |||
h_x = h_[[index_0, index_1-1]] | |||
else: | |||
c_x = c_[[index_0,index_1]] | |||
h_x = h_[[index_0,index_1]] | |||
c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size) | |||
c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat)) | |||
c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size) | |||
h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0)) | |||
h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0) | |||
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0) | |||
h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1) | |||
c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1) | |||
return h_[:,:-1],c_[:,:-1] | |||
@@ -0,0 +1,23 @@ | |||
glove_100_path = 'en-glove-6b-100d' | |||
glove_50_path = 'en-glove-6b-50d' | |||
glove_200_path = '' | |||
glove_300_path = 'en-glove-840b-300' | |||
fasttext_path = 'en-fasttext' #300 | |||
tencent_chinese_word_path = 'cn' # tencent 200 | |||
fasttext_cn_path = 'cn-fasttext' # 300 | |||
yangjie_rich_pretrain_unigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.uni.ite50.vec' | |||
yangjie_rich_pretrain_bigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.bi.ite50.vec' | |||
yangjie_rich_pretrain_word_path = '/remote-home/xnli/data/pretrain/chinese/ctb.50d.vec' | |||
conll_2003_path = '/remote-home/xnli/data/corpus/multi_task/conll_2013/data_mine.pkl' | |||
conllized_ontonote_path = '/remote-home/txsun/data/OntoNotes-5.0-NER-master/v12/english' | |||
conllized_ontonote_pkl_path = '/remote-home/txsun/data/ontonotes5.pkl' | |||
sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/' | |||
# weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo' | |||
ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER' | |||
msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER' | |||
resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER' | |||
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER' |
@@ -0,0 +1,126 @@ | |||
from utils_ import get_skip_path_trivial, Trie, get_skip_path | |||
from load_data import load_yangjie_rich_pretrain_word_list, load_ontonotes4ner, equip_chinese_ner_with_skip | |||
from pathes import * | |||
from functools import partial | |||
from fastNLP import cache_results | |||
from fastNLP.embeddings.static_embedding import StaticEmbedding | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from fastNLP.core.metrics import _bmes_tag_to_spans,_bmeso_tag_to_spans | |||
from load_data import load_resume_ner | |||
# embed = StaticEmbedding(None,embedding_dim=2) | |||
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path, | |||
# _refresh=True,index_token=False) | |||
# | |||
# w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path, | |||
# _refresh=False) | |||
# | |||
# datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path, | |||
# _refresh=True) | |||
# | |||
def reverse_style(input_string): | |||
target_position = input_string.index('[') | |||
input_len = len(input_string) | |||
output_string = input_string[target_position:input_len] + input_string[0:target_position] | |||
# print('in:{}.out:{}'.format(input_string, output_string)) | |||
return output_string | |||
def get_yangjie_bmeso(label_list): | |||
def get_ner_BMESO_yj(label_list): | |||
# list_len = len(word_list) | |||
# assert(list_len == len(label_list)), "word list size unmatch with label list" | |||
list_len = len(label_list) | |||
begin_label = 'b-' | |||
end_label = 'e-' | |||
single_label = 's-' | |||
whole_tag = '' | |||
index_tag = '' | |||
tag_list = [] | |||
stand_matrix = [] | |||
for i in range(0, list_len): | |||
# wordlabel = word_list[i] | |||
current_label = label_list[i].lower() | |||
if begin_label in current_label: | |||
if index_tag != '': | |||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||
whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) | |||
index_tag = current_label.replace(begin_label, "", 1) | |||
elif single_label in current_label: | |||
if index_tag != '': | |||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||
whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) | |||
tag_list.append(whole_tag) | |||
whole_tag = "" | |||
index_tag = "" | |||
elif end_label in current_label: | |||
if index_tag != '': | |||
tag_list.append(whole_tag + ',' + str(i)) | |||
whole_tag = '' | |||
index_tag = '' | |||
else: | |||
continue | |||
if (whole_tag != '') & (index_tag != ''): | |||
tag_list.append(whole_tag) | |||
tag_list_len = len(tag_list) | |||
for i in range(0, tag_list_len): | |||
if len(tag_list[i]) > 0: | |||
tag_list[i] = tag_list[i] + ']' | |||
insert_list = reverse_style(tag_list[i]) | |||
stand_matrix.append(insert_list) | |||
# print stand_matrix | |||
return stand_matrix | |||
def transform_YJ_to_fastNLP(span): | |||
span = span[1:] | |||
span_split = span.split(']') | |||
# print('span_list:{}'.format(span_split)) | |||
span_type = span_split[1] | |||
# print('span_split[0].split(','):{}'.format(span_split[0].split(','))) | |||
if ',' in span_split[0]: | |||
b, e = span_split[0].split(',') | |||
else: | |||
b = span_split[0] | |||
e = b | |||
b = int(b) | |||
e = int(e) | |||
e += 1 | |||
return (span_type, (b, e)) | |||
yj_form = get_ner_BMESO_yj(label_list) | |||
# print('label_list:{}'.format(label_list)) | |||
# print('yj_from:{}'.format(yj_form)) | |||
fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form)) | |||
return fastNLP_form | |||
# tag_list = ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O'] | |||
# span_list = get_ner_BMES(tag_list) | |||
# print(span_list) | |||
# yangjie_label_list = ['B-NAME', 'E-NAME', 'O', 'B-CONT', 'M-CONT', 'E-CONT', 'B-RACE', 'E-RACE', 'B-TITLE', 'M-TITLE', 'E-TITLE', 'B-EDU', 'M-EDU', 'E-EDU', 'B-ORG', 'M-ORG', 'E-ORG', 'M-NAME', 'B-PRO', 'M-PRO', 'E-PRO', 'S-RACE', 'S-NAME', 'B-LOC', 'M-LOC', 'E-LOC', 'M-RACE', 'S-ORG'] | |||
# my_label_list = ['O', 'M-ORG', 'M-TITLE', 'B-TITLE', 'E-TITLE', 'B-ORG', 'E-ORG', 'M-EDU', 'B-NAME', 'E-NAME', 'B-EDU', 'E-EDU', 'M-NAME', 'M-PRO', 'M-CONT', 'B-PRO', 'E-PRO', 'B-CONT', 'E-CONT', 'M-LOC', 'B-RACE', 'E-RACE', 'S-NAME', 'B-LOC', 'E-LOC', 'M-RACE', 'S-RACE', 'S-ORG'] | |||
# yangjie_label = set(yangjie_label_list) | |||
# my_label = set(my_label_list) | |||
a = torch.tensor([0,2,0,3]) | |||
b = (a==0) | |||
print(b) | |||
print(b.float()) | |||
from fastNLP import RandomSampler | |||
# f = open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb') | |||
# weight_dict = torch.load(f) | |||
# print(weight_dict.keys()) | |||
# for k,v in weight_dict.items(): | |||
# print("{}:{}".format(k,v.size())) |
@@ -0,0 +1,361 @@ | |||
import torch.nn.functional as F | |||
import torch | |||
import random | |||
import numpy as np | |||
from fastNLP import Const | |||
from fastNLP import CrossEntropyLoss | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import Tester | |||
import os | |||
from fastNLP import logger | |||
def should_mask(name, t=''): | |||
if 'bias' in name: | |||
return False | |||
if 'embedding' in name: | |||
splited = name.split('.') | |||
if splited[-1]!='weight': | |||
return False | |||
if 'embedding' in splited[-2]: | |||
return False | |||
if 'c0' in name: | |||
return False | |||
if 'h0' in name: | |||
return False | |||
if 'output' in name and t not in name: | |||
return False | |||
return True | |||
def get_init_mask(model): | |||
init_masks = {} | |||
for name, param in model.named_parameters(): | |||
if should_mask(name): | |||
init_masks[name+'.mask'] = torch.ones_like(param) | |||
# logger.info(init_masks[name+'.mask'].requires_grad) | |||
return init_masks | |||
def set_seed(seed): | |||
random.seed(seed) | |||
np.random.seed(seed+100) | |||
torch.manual_seed(seed+200) | |||
torch.cuda.manual_seed_all(seed+300) | |||
def get_parameters_size(model): | |||
result = {} | |||
for name,p in model.state_dict().items(): | |||
result[name] = p.size() | |||
return result | |||
def prune_by_proportion_model(model,proportion,task): | |||
# print('this time prune to ',proportion*100,'%') | |||
for name, p in model.named_parameters(): | |||
# print(name) | |||
if not should_mask(name,task): | |||
continue | |||
tensor = p.data.cpu().numpy() | |||
index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) | |||
# print(name,'alive count',len(index[0])) | |||
alive = tensor[index] | |||
# print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size())) | |||
percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) | |||
# tensor = p | |||
# index = torch.nonzero(model.mask[task][name+'.mask']) | |||
# # print('nonzero len',index) | |||
# alive = tensor[index] | |||
# print('alive size:',alive.shape) | |||
# prune_by_proportion_model() | |||
# percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values | |||
# print('the',(1-proportion)*len(index[0]),'th big') | |||
# print('threshold:',percentile_value) | |||
prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) | |||
# for | |||
def prune_by_proportion_model_global(model,proportion,task): | |||
# print('this time prune to ',proportion*100,'%') | |||
alive = None | |||
for name, p in model.named_parameters(): | |||
# print(name) | |||
if not should_mask(name,task): | |||
continue | |||
tensor = p.data.cpu().numpy() | |||
index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy()) | |||
# print(name,'alive count',len(index[0])) | |||
if alive is None: | |||
alive = tensor[index] | |||
else: | |||
alive = np.concatenate([alive,tensor[index]],axis=0) | |||
percentile_value = np.percentile(abs(alive), (1 - proportion) * 100) | |||
for name, p in model.named_parameters(): | |||
if should_mask(name,task): | |||
prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value) | |||
def prune_by_threshold_parameter(p, mask, threshold): | |||
p_abs = torch.abs(p) | |||
new_mask = (p_abs > threshold).float() | |||
# print(mask) | |||
mask[:]*=new_mask | |||
def one_time_train_and_prune_single_task(trainer,PRUNE_PER, | |||
optimizer_init_state_dict=None, | |||
model_init_state_dict=None, | |||
is_global=None, | |||
): | |||
from fastNLP import Trainer | |||
trainer.optimizer.load_state_dict(optimizer_init_state_dict) | |||
trainer.model.load_state_dict(model_init_state_dict) | |||
# print('metrics:',metrics.__dict__) | |||
# print('loss:',loss.__dict__) | |||
# print('trainer input:',task.train_set.get_input_name()) | |||
# trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics, | |||
# optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks) | |||
trainer.train(load_best_model=True) | |||
# tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False) | |||
# print('FOR DEBUG: test train_set:',tester.test()) | |||
# print('**'*20) | |||
# if task.test_set: | |||
# tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1) | |||
# tester.test() | |||
if is_global: | |||
prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task) | |||
else: | |||
prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task) | |||
# def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None): | |||
def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None): | |||
''' | |||
:param trainer: | |||
:param ITER: | |||
:param PRUNE: | |||
:param is_global: | |||
:param save_path: should be a dictionary which will be filled with mask and state dict | |||
:return: | |||
''' | |||
from fastNLP import Trainer | |||
import torch | |||
import math | |||
import copy | |||
PRUNE = args.prune | |||
ITER = args.iter | |||
trainer = get_trainer(args,model,train_set,dev_set,test_set,device) | |||
optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict()) | |||
model_init_state_dict = copy.deepcopy(trainer.model.state_dict()) | |||
if save_path is not None: | |||
if not os.path.exists(save_path): | |||
os.makedirs(save_path) | |||
# if not os.path.exists(os.path.join(save_path, 'model_init.pkl')): | |||
# f = open(os.path.join(save_path, 'model_init.pkl'), 'wb') | |||
# torch.save(trainer.model.state_dict(),f) | |||
mask_count = 0 | |||
model = trainer.model | |||
task = trainer.model.now_task | |||
for name, p in model.mask[task].items(): | |||
mask_count += torch.sum(p).item() | |||
init_mask_count = mask_count | |||
logger.info('init mask count:{}'.format(mask_count)) | |||
# logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count, | |||
# mask_count / init_mask_count * 100)) | |||
prune_per_iter = math.pow(PRUNE, 1 / ITER) | |||
for i in range(ITER): | |||
trainer = get_trainer(args,model,train_set,dev_set,test_set,device) | |||
one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict) | |||
if save_path is not None: | |||
f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb') | |||
torch.save(model.mask[task],f) | |||
mask_count = 0 | |||
for name, p in model.mask[task].items(): | |||
mask_count += torch.sum(p).item() | |||
logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100)) | |||
def get_appropriate_cuda(task_scale='s'): | |||
if task_scale not in {'s','m','l'}: | |||
logger.info('task scale wrong!') | |||
exit(2) | |||
import pynvml | |||
pynvml.nvmlInit() | |||
total_cuda_num = pynvml.nvmlDeviceGetCount() | |||
for i in range(total_cuda_num): | |||
logger.info(i) | |||
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # 这里的0是GPU id | |||
memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) | |||
utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) | |||
logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu) | |||
if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2: | |||
logger.info(i,memInfo.used / memInfo.total) | |||
return 'cuda:'+str(i) | |||
if task_scale=='s': | |||
max_memory=2000 | |||
elif task_scale=='m': | |||
max_memory=6000 | |||
else: | |||
max_memory = 9000 | |||
max_id = -1 | |||
for i in range(total_cuda_num): | |||
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 这里的0是GPU id | |||
memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle) | |||
utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle) | |||
if max_memory < memInfo.free: | |||
max_memory = memInfo.free | |||
max_id = i | |||
if id == -1: | |||
logger.info('no appropriate gpu, wait!') | |||
exit(2) | |||
return 'cuda:'+str(max_id) | |||
# if memInfo.used / memInfo.total < 0.5: | |||
# return | |||
def print_mask(mask_dict): | |||
def seq_mul(*X): | |||
res = 1 | |||
for x in X: | |||
res*=x | |||
return res | |||
for name,p in mask_dict.items(): | |||
total_size = seq_mul(*p.size()) | |||
unmasked_size = len(np.nonzero(p)) | |||
print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%') | |||
print() | |||
def check_words_same(dataset_1,dataset_2,field_1,field_2): | |||
if len(dataset_1[field_1]) != len(dataset_2[field_2]): | |||
logger.info('CHECK: example num not same!') | |||
return False | |||
for i, words in enumerate(dataset_1[field_1]): | |||
if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]): | |||
logger.info('CHECK {} th example length not same'.format(i)) | |||
logger.info('1:{}'.format(dataset_1[field_1][i])) | |||
logger.info('2:'.format(dataset_2[field_2][i])) | |||
return False | |||
# for j,w in enumerate(words): | |||
# if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]: | |||
# print('CHECK', i, 'th example has words different!') | |||
# print('1:',dataset_1[field_1][i]) | |||
# print('2:',dataset_2[field_2][i]) | |||
# return False | |||
logger.info('CHECK: totally same!') | |||
return True | |||
def get_now_time(): | |||
import time | |||
from datetime import datetime, timezone, timedelta | |||
dt = datetime.utcnow() | |||
# print(dt) | |||
tzutc_8 = timezone(timedelta(hours=8)) | |||
local_dt = dt.astimezone(tzutc_8) | |||
result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute, | |||
local_dt.second)) | |||
return result | |||
def get_bigrams(words): | |||
result = [] | |||
for i,w in enumerate(words): | |||
if i!=len(words)-1: | |||
result.append(words[i]+words[i+1]) | |||
else: | |||
result.append(words[i]+'<end>') | |||
return result | |||
def print_info(*inp,islog=False,sep=' '): | |||
from fastNLP import logger | |||
if islog: | |||
print(*inp,sep=sep) | |||
else: | |||
inp = sep.join(map(str,inp)) | |||
logger.info(inp) | |||
def better_init_rnn(rnn,coupled=False): | |||
import torch.nn as nn | |||
if coupled: | |||
repeat_size = 3 | |||
else: | |||
repeat_size = 4 | |||
# print(list(rnn.named_parameters())) | |||
if hasattr(rnn,'num_layers'): | |||
for i in range(rnn.num_layers): | |||
nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data) | |||
weight_hh_data = torch.eye(rnn.hidden_size) | |||
weight_hh_data = weight_hh_data.repeat(1, repeat_size) | |||
with torch.no_grad(): | |||
getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data) | |||
nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0) | |||
nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0) | |||
if rnn.bidirectional: | |||
for i in range(rnn.num_layers): | |||
nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data) | |||
weight_hh_data = torch.eye(rnn.hidden_size) | |||
weight_hh_data = weight_hh_data.repeat(1, repeat_size) | |||
with torch.no_grad(): | |||
getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data) | |||
nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0) | |||
nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0) | |||
else: | |||
nn.init.orthogonal(rnn.weight_ih.data) | |||
weight_hh_data = torch.eye(rnn.hidden_size) | |||
weight_hh_data = weight_hh_data.repeat(repeat_size,1) | |||
with torch.no_grad(): | |||
rnn.weight_hh.set_(weight_hh_data) | |||
# The bias is just set to zero vectors. | |||
print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn))) | |||
if rnn.bias: | |||
nn.init.constant(rnn.bias_ih.data, val=0) | |||
nn.init.constant(rnn.bias_hh.data, val=0) | |||
# print(list(rnn.named_parameters())) | |||
@@ -0,0 +1,405 @@ | |||
import collections | |||
from fastNLP import cache_results | |||
def get_skip_path(chars,w_trie): | |||
sentence = ''.join(chars) | |||
result = w_trie.get_lexicon(sentence) | |||
return result | |||
# @cache_results(_cache_fp='cache/get_skip_path_trivial',_refresh=True) | |||
def get_skip_path_trivial(chars,w_list): | |||
chars = ''.join(chars) | |||
w_set = set(w_list) | |||
result = [] | |||
# for i in range(len(chars)): | |||
# result.append([]) | |||
for i in range(len(chars)-1): | |||
for j in range(i+2,len(chars)+1): | |||
if chars[i:j] in w_set: | |||
result.append([i,j-1,chars[i:j]]) | |||
return result | |||
class TrieNode: | |||
def __init__(self): | |||
self.children = collections.defaultdict(TrieNode) | |||
self.is_w = False | |||
class Trie: | |||
def __init__(self): | |||
self.root = TrieNode() | |||
def insert(self,w): | |||
current = self.root | |||
for c in w: | |||
current = current.children[c] | |||
current.is_w = True | |||
def search(self,w): | |||
''' | |||
:param w: | |||
:return: | |||
-1:not w route | |||
0:subroute but not word | |||
1:subroute and word | |||
''' | |||
current = self.root | |||
for c in w: | |||
current = current.children.get(c) | |||
if current is None: | |||
return -1 | |||
if current.is_w: | |||
return 1 | |||
else: | |||
return 0 | |||
def get_lexicon(self,sentence): | |||
result = [] | |||
for i in range(len(sentence)): | |||
current = self.root | |||
for j in range(i, len(sentence)): | |||
current = current.children.get(sentence[j]) | |||
if current is None: | |||
break | |||
if current.is_w: | |||
result.append([i,j,sentence[i:j+1]]) | |||
return result | |||
from fastNLP.core.field import Padder | |||
import numpy as np | |||
import torch | |||
from collections import defaultdict | |||
class LatticeLexiconPadder(Padder): | |||
def __init__(self, pad_val=0, pad_val_dynamic=False,dynamic_offset=0, **kwargs): | |||
''' | |||
:param pad_val: | |||
:param pad_val_dynamic: if True, pad_val is the seq_len | |||
:param kwargs: | |||
''' | |||
self.pad_val = pad_val | |||
self.pad_val_dynamic = pad_val_dynamic | |||
self.dynamic_offset = dynamic_offset | |||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||
# 与autoPadder中 dim=2 的情况一样 | |||
max_len = max(map(len, contents)) | |||
max_len = max(max_len,1)#avoid 0 size dim which causes cuda wrong | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in contents]) | |||
max_word_len = max(max_word_len,1) | |||
if self.pad_val_dynamic: | |||
# print('pad_val_dynamic:{}'.format(max_len-1)) | |||
array = np.full((len(contents), max_len, max_word_len), max_len-1+self.dynamic_offset, | |||
dtype=field_ele_dtype) | |||
else: | |||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
for j, content_ii in enumerate(content_i): | |||
array[i, j, :len(content_ii)] = content_ii | |||
array = torch.tensor(array) | |||
return array | |||
from fastNLP.core.metrics import MetricBase | |||
def get_yangjie_bmeso(label_list,ignore_labels=None): | |||
def get_ner_BMESO_yj(label_list): | |||
def reverse_style(input_string): | |||
target_position = input_string.index('[') | |||
input_len = len(input_string) | |||
output_string = input_string[target_position:input_len] + input_string[0:target_position] | |||
# print('in:{}.out:{}'.format(input_string, output_string)) | |||
return output_string | |||
# list_len = len(word_list) | |||
# assert(list_len == len(label_list)), "word list size unmatch with label list" | |||
list_len = len(label_list) | |||
begin_label = 'b-' | |||
end_label = 'e-' | |||
single_label = 's-' | |||
whole_tag = '' | |||
index_tag = '' | |||
tag_list = [] | |||
stand_matrix = [] | |||
for i in range(0, list_len): | |||
# wordlabel = word_list[i] | |||
current_label = label_list[i].lower() | |||
if begin_label in current_label: | |||
if index_tag != '': | |||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||
whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i) | |||
index_tag = current_label.replace(begin_label, "", 1) | |||
elif single_label in current_label: | |||
if index_tag != '': | |||
tag_list.append(whole_tag + ',' + str(i - 1)) | |||
whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i) | |||
tag_list.append(whole_tag) | |||
whole_tag = "" | |||
index_tag = "" | |||
elif end_label in current_label: | |||
if index_tag != '': | |||
tag_list.append(whole_tag + ',' + str(i)) | |||
whole_tag = '' | |||
index_tag = '' | |||
else: | |||
continue | |||
if (whole_tag != '') & (index_tag != ''): | |||
tag_list.append(whole_tag) | |||
tag_list_len = len(tag_list) | |||
for i in range(0, tag_list_len): | |||
if len(tag_list[i]) > 0: | |||
tag_list[i] = tag_list[i] + ']' | |||
insert_list = reverse_style(tag_list[i]) | |||
stand_matrix.append(insert_list) | |||
# print stand_matrix | |||
return stand_matrix | |||
def transform_YJ_to_fastNLP(span): | |||
span = span[1:] | |||
span_split = span.split(']') | |||
# print('span_list:{}'.format(span_split)) | |||
span_type = span_split[1] | |||
# print('span_split[0].split(','):{}'.format(span_split[0].split(','))) | |||
if ',' in span_split[0]: | |||
b, e = span_split[0].split(',') | |||
else: | |||
b = span_split[0] | |||
e = b | |||
b = int(b) | |||
e = int(e) | |||
e += 1 | |||
return (span_type, (b, e)) | |||
yj_form = get_ner_BMESO_yj(label_list) | |||
# print('label_list:{}'.format(label_list)) | |||
# print('yj_from:{}'.format(yj_form)) | |||
fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form)) | |||
return fastNLP_form | |||
class SpanFPreRecMetric_YJ(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例) | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
最后得到的metric结果为:: | |||
{ | |||
'f': xxx, # 这里使用f考虑以后可以计算f_beta值 | |||
'pre': xxx, | |||
'rec':xxx | |||
} | |||
若only_gross=False, 即还会返回各个label的metric统计值:: | |||
{ | |||
'f': xxx, | |||
'pre': xxx, | |||
'rec':xxx, | |||
'f-label': xxx, | |||
'pre-label': xxx, | |||
'rec-label':xxx, | |||
... | |||
} | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
label的f1, pre, rec | |||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : | |||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||
only_gross=True, f_type='micro', beta=1): | |||
from fastNLP.core import Vocabulary | |||
from fastNLP.core.metrics import _bmes_tag_to_spans,_bio_tag_to_spans,\ | |||
_bioes_tag_to_spans,_bmeso_tag_to_spans | |||
from collections import defaultdict | |||
encoding_type = encoding_type.lower() | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.encoding_type = encoding_type | |||
# print('encoding_type:{}'self.encoding_type) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
self.tag_to_span_func = _bio_tag_to_spans | |||
elif self.encoding_type == 'bmeso': | |||
self.tag_to_span_func = _bmeso_tag_to_spans | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
elif self.encoding_type == 'bmesoyj': | |||
self.tag_to_span_func = get_yangjie_bmeso | |||
# self.tag_to_span_func = | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
self.beta_square = self.beta ** 2 | |||
self.only_gross = only_gross | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
self.tag_vocab = tag_vocab | |||
self._true_positives = defaultdict(int) | |||
self._false_positives = defaultdict(int) | |||
self._false_negatives = defaultdict(int) | |||
def evaluate(self, pred, target, seq_len): | |||
from fastNLP.core.utils import _get_func_signature | |||
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | |||
:param target: [batch, seq_len], 真实值 | |||
:param seq_len: [batch] 文本长度标记 | |||
:return: | |||
""" | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if not isinstance(seq_len, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if pred.size() == target.size() and len(target.size()) == 2: | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: | |||
num_classes = pred.size(-1) | |||
pred = pred.argmax(dim=-1) | |||
if (target >= num_classes).any(): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
else: | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
batch_size = pred.size(0) | |||
pred = pred.tolist() | |||
target = target.tolist() | |||
for i in range(batch_size): | |||
pred_tags = pred[i][:int(seq_len[i])] | |||
gold_tags = target[i][:int(seq_len[i])] | |||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||
for span in pred_spans: | |||
if span in gold_spans: | |||
self._true_positives[span[0]] += 1 | |||
gold_spans.remove(span) | |||
else: | |||
self._false_positives[span[0]] += 1 | |||
for span in gold_spans: | |||
self._false_negatives[span[0]] += 1 | |||
def get_metric(self, reset=True): | |||
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._false_negatives.keys()) | |||
tags.update(set(self._false_positives.keys())) | |||
tags.update(set(self._true_positives.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
tp = self._true_positives[tag] | |||
fn = self._false_negatives[tag] | |||
fp = self._false_positives[tag] | |||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag) | |||
pre_key = 'pre-{}'.format(tag) | |||
rec_key = 'rec-{}'.format(tag) | |||
evaluate_result[f_key] = f | |||
evaluate_result[pre_key] = pre | |||
evaluate_result[rec_key] = rec | |||
if self.f_type == 'macro': | |||
evaluate_result['f'] = f_sum / len(tags) | |||
evaluate_result['pre'] = pre_sum / len(tags) | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), | |||
sum(self._false_negatives.values()), | |||
sum(self._false_positives.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
if reset: | |||
self._true_positives = defaultdict(int) | |||
self._false_positives = defaultdict(int) | |||
self._false_negatives = defaultdict(int) | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def _compute_f_pre_rec(self, tp, fn, fp): | |||
""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||