Browse Source

batch-support LatticeLSTM

tags/v0.5.0
LeeSureman 5 years ago
parent
commit
9cbcd74c58
9 changed files with 3065 additions and 0 deletions
  1. +252
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/check_output.py
  2. +772
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/load_data.py
  3. +189
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/main.py
  4. +299
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/models.py
  5. +638
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/modules.py
  6. +23
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/pathes.py
  7. +126
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/small.py
  8. +361
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils.py
  9. +405
    -0
      reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils_.py

+ 252
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/check_output.py View File

@@ -0,0 +1,252 @@
import torch.nn as nn
from pathes import *
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner
from fastNLP.embeddings import StaticEmbedding
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward
import torch.optim as optim
import argparse
import torch
import sys
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ
from fastNLP import Tester
import fitlog
from fastNLP.core.callback import FitlogCallback
from utils import set_seed
import os
from fastNLP import LRScheduler
from torch.optim.lr_scheduler import LambdaLR



# sys.path.append('.')
# sys.path.append('..')
# for p in sys.path:
# print(p)
# fitlog.add_hyper_in_file (__file__) # record your hyperparameters
########hyper

########hyper

parser = argparse.ArgumentParser()
parser.add_argument('--device',default='cpu')
parser.add_argument('--debug',default=True)

parser.add_argument('--batch',default=1)
parser.add_argument('--test_batch',default=1024)
parser.add_argument('--optim',default='sgd',help='adam|sgd')
parser.add_argument('--lr',default=0.015)
parser.add_argument('--model',default='lattice',help='lattice|lstm')
parser.add_argument('--skip_before_head',default=False)#in paper it's false
parser.add_argument('--hidden',default=100)
parser.add_argument('--momentum',default=0)
parser.add_argument('--bi',default=True)
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra')
parser.add_argument('--use_bigram',default=False)

parser.add_argument('--embed_dropout',default=0)
parser.add_argument('--output_dropout',default=0)
parser.add_argument('--epoch',default=100)
parser.add_argument('--seed',default=100)

args = parser.parse_args()

set_seed(args.seed)

fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)]
if args.model == 'lattice':
fit_msg_list.append(str(args.skip_before_head))
fit_msg = ' '.join(fit_msg_list)
# fitlog.commit(__file__,fit_msg=fit_msg)


# fitlog.add_hyper(args)
device = torch.device(args.device)
for k,v in args.__dict__.items():
print(k,v)

refresh_data = False
if args.dataset == 'ontonote':
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
_refresh=refresh_data,index_token=False)
elif args.dataset == 'resume':
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
_refresh=refresh_data,index_token=False)
# exit()
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
_refresh=refresh_data)



cache_name = os.path.join('cache',args.dataset+'_lattice')
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
_refresh=refresh_data,_cache_fp=cache_name)

print('中:embedding:{}'.format(embeddings['char'](24)))
print('embed lookup dropout:{}'.format(embeddings['word'].word_dropout))

# for k, v in datasets.items():
# # v.apply_field(lambda x: list(map(len, x)), 'skips_l2r_word', 'lexicon_count')
# # v.apply_field(lambda x:
# # list(map(lambda y:
# # list(map(lambda z: vocabs['word'].to_index(z), y)), x)),
# # 'skips_l2r_word')

print(datasets['train'][0])
print('vocab info:')
for k,v in vocabs.items():
print('{}:{}'.format(k,len(v)))
# print(datasets['dev'][0])
# print(datasets['test'][0])
# print(datasets['train'].get_all_fields().keys())
for k,v in datasets.items():
if args.model == 'lattice':
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source')
if args.skip_before_head:
v.set_padder('skips_l2r_word',LatticeLexiconPadder())
v.set_padder('skips_l2r_source',LatticeLexiconPadder())
v.set_padder('skips_r2l_word',LatticeLexiconPadder())
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True))
else:
v.set_padder('skips_l2r_word',LatticeLexiconPadder())
v.set_padder('skips_r2l_word', LatticeLexiconPadder())
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1))
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1))
if args.bi:
v.set_input('chars','bigrams','seq_len',
'skips_l2r_word','skips_l2r_source','lexicon_count',
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back',
'target',
use_1st_ins_infer_dim_type=True)
else:
v.set_input('chars','bigrams','seq_len',
'skips_l2r_word','skips_l2r_source','lexicon_count',
'target',
use_1st_ins_infer_dim_type=True)
v.set_target('target','seq_len')

v['target'].set_pad_val(0)
elif args.model == 'lstm':
v.set_ignore_type('skips_l2r_word','skips_l2r_source')
v.set_padder('skips_l2r_word',LatticeLexiconPadder())
v.set_padder('skips_l2r_source',LatticeLexiconPadder())
v.set_input('chars','bigrams','seq_len','target',
use_1st_ins_infer_dim_type=True)
v.set_target('target','seq_len')

v['target'].set_pad_val(0)

print(datasets['dev']['skips_l2r_word'][100])


if args.model =='lattice':
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'],
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
skip_batch_first=True,bidirectional=args.bi,debug=args.debug,
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram,
vocabs=vocabs
)
elif args.model == 'lstm':
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'],
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
bidirectional=args.bi,
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
use_bigram=args.use_bigram)

for k,v in model.state_dict().items():
print('{}:{}'.format(k,v.size()))



# exit()
weight_dict = torch.load(open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb'))
# print(weight_dict.keys())
# for k,v in weight_dict.items():
# print('{}:{}'.format(k,v.size()))
def state_dict_param(model):
param_list = list(model.named_parameters())
print(len(param_list))
param_dict = {}
for i in range(len(param_list)):
param_dict[param_list[i][0]] = param_list[i][1]

return param_dict


def copy_yangjie_lattice_weight(target,source_dict):
t = state_dict_param(target)
with torch.no_grad():
t['encoder.char_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.weight_ih'])
t['encoder.char_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.weight_hh'])
t['encoder.char_cell.alpha_weight_ih'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_ih'])
t['encoder.char_cell.alpha_weight_hh'].set_(source_dict['lstm.forward_lstm.rnn.alpha_weight_hh'])
t['encoder.char_cell.bias'].set_(source_dict['lstm.forward_lstm.rnn.bias'])
t['encoder.char_cell.alpha_bias'].set_(source_dict['lstm.forward_lstm.rnn.alpha_bias'])
t['encoder.word_cell.weight_ih'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_ih'])
t['encoder.word_cell.weight_hh'].set_(source_dict['lstm.forward_lstm.word_rnn.weight_hh'])
t['encoder.word_cell.bias'].set_(source_dict['lstm.forward_lstm.word_rnn.bias'])

t['encoder_back.char_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.weight_ih'])
t['encoder_back.char_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.weight_hh'])
t['encoder_back.char_cell.alpha_weight_ih'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_ih'])
t['encoder_back.char_cell.alpha_weight_hh'].set_(source_dict['lstm.backward_lstm.rnn.alpha_weight_hh'])
t['encoder_back.char_cell.bias'].set_(source_dict['lstm.backward_lstm.rnn.bias'])
t['encoder_back.char_cell.alpha_bias'].set_(source_dict['lstm.backward_lstm.rnn.alpha_bias'])
t['encoder_back.word_cell.weight_ih'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_ih'])
t['encoder_back.word_cell.weight_hh'].set_(source_dict['lstm.backward_lstm.word_rnn.weight_hh'])
t['encoder_back.word_cell.bias'].set_(source_dict['lstm.backward_lstm.word_rnn.bias'])

for k,v in t.items():
print('{}:{}'.format(k,v))

copy_yangjie_lattice_weight(model,weight_dict)

# print(vocabs['label'].word2idx.keys())








loss = LossInForward()

f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso')
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj')
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len')
metrics = [f1_metric,f1_metric_yj,acc_metric]

if args.optim == 'adam':
optimizer = optim.Adam(model.parameters(),lr=args.lr)
elif args.optim == 'sgd':
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum)



# tester = Tester(datasets['dev'],model,metrics=metrics,batch_size=args.test_batch,device=device)
# test_result = tester.test()
# print(test_result)
callbacks = [
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.05)**ep))
]
print(datasets['train'][:2])
print(vocabs['char'].to_index(':'))
# StaticEmbedding
# datasets['train'] = datasets['train'][1:]
from fastNLP import SequentialSampler
trainer = Trainer(datasets['train'],model,
optimizer=optimizer,
loss=loss,
metrics=metrics,
dev_data=datasets['dev'],
device=device,
batch_size=args.batch,
n_epochs=args.epoch,
dev_batch_size=args.test_batch,
callbacks=callbacks,
check_code_level=-1,
sampler=SequentialSampler())

trainer.train()

+ 772
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/load_data.py View File

@@ -0,0 +1,772 @@
from fastNLP.io import CSVLoader
from fastNLP import Vocabulary
from fastNLP import Const
import numpy as np
import fitlog
import pickle
import os
from fastNLP.embeddings import StaticEmbedding
from fastNLP import cache_results


@cache_results(_cache_fp='mtl16', _refresh=False)
def load_16_task(dict_path):
'''

:param dict_path: /remote-home/txsun/fnlp/MTL-LT/data
:return:
'''
task_path = os.path.join(dict_path,'data.pkl')
embedding_path = os.path.join(dict_path,'word_embedding.npy')

embedding = np.load(embedding_path).astype(np.float32)

task_list = pickle.load(open(task_path, 'rb'))['task_lst']

for t in task_list:
t.train_set.rename_field('words_idx', 'words')
t.dev_set.rename_field('words_idx', 'words')
t.test_set.rename_field('words_idx', 'words')

t.train_set.rename_field('label', 'target')
t.dev_set.rename_field('label', 'target')
t.test_set.rename_field('label', 'target')

t.train_set.add_seq_len('words')
t.dev_set.add_seq_len('words')
t.test_set.add_seq_len('words')

t.train_set.set_input(Const.INPUT, Const.INPUT_LEN)
t.dev_set.set_input(Const.INPUT, Const.INPUT_LEN)
t.test_set.set_input(Const.INPUT, Const.INPUT_LEN)

return task_list,embedding


@cache_results(_cache_fp='SST2', _refresh=False)
def load_sst2(dict_path,embedding_path=None):
'''

:param dict_path: /remote-home/xnli/data/corpus/text_classification/SST-2/
:param embedding_path: glove 300d txt
:return:
'''
train_path = os.path.join(dict_path,'train.tsv')
dev_path = os.path.join(dict_path,'dev.tsv')

loader = CSVLoader(headers=('words', 'target'), sep='\t')
train_data = loader.load(train_path).datasets['train']
dev_data = loader.load(dev_path).datasets['train']

train_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words')
dev_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words')

train_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len')
dev_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len')

vocab = Vocabulary(min_freq=2)
vocab.from_dataset(train_data, field_name='words')
vocab.from_dataset(dev_data, field_name='words')

# pretrained_embedding = load_word_emb(embedding_path, 300, vocab)

label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(train_data, field_name='target')

label_vocab.index_dataset(train_data, field_name='target')
label_vocab.index_dataset(dev_data, field_name='target')

vocab.index_dataset(train_data, field_name='words', new_field_name='words')
vocab.index_dataset(dev_data, field_name='words', new_field_name='words')

train_data.set_input(Const.INPUT, Const.INPUT_LEN)
train_data.set_target(Const.TARGET)

dev_data.set_input(Const.INPUT, Const.INPUT_LEN)
dev_data.set_target(Const.TARGET)

if embedding_path is not None:
pretrained_embedding = load_word_emb(embedding_path, 300, vocab)
return (train_data,dev_data),(vocab,label_vocab),pretrained_embedding

else:
return (train_data,dev_data),(vocab,label_vocab)

@cache_results(_cache_fp='OntonotesPOS', _refresh=False)
def load_conllized_ontonote_POS(path,embedding_path=None):
from fastNLP.io.data_loader import ConllLoader
header2index = {'words':3,'POS':4,'NER':10}
headers = ['words','POS']

if 'NER' in headers:
print('警告!通过 load_conllized_ontonote 函数读出来的NER标签不是BIOS,是纯粹的conll格式,是错误的!')
indexes = list(map(lambda x:header2index[x],headers))

loader = ConllLoader(headers,indexes)

bundle = loader.load(path)

# print(bundle.datasets)

train_set = bundle.datasets['train']
dev_set = bundle.datasets['dev']
test_set = bundle.datasets['test']




# train_set = loader.load(os.path.join(path,'train.txt'))
# dev_set = loader.load(os.path.join(path, 'dev.txt'))
# test_set = loader.load(os.path.join(path, 'test.txt'))

# print(len(train_set))

train_set.add_seq_len('words','seq_len')
dev_set.add_seq_len('words','seq_len')
test_set.add_seq_len('words','seq_len')



# print(dataset['POS'])

vocab = Vocabulary(min_freq=1)
vocab.from_dataset(train_set,field_name='words')
vocab.from_dataset(dev_set, field_name='words')
vocab.from_dataset(test_set, field_name='words')

vocab.index_dataset(train_set,field_name='words')
vocab.index_dataset(dev_set, field_name='words')
vocab.index_dataset(test_set, field_name='words')




label_vocab_dict = {}

for i,h in enumerate(headers):
if h == 'words':
continue
label_vocab_dict[h] = Vocabulary(min_freq=1,padding=None,unknown=None)
label_vocab_dict[h].from_dataset(train_set,field_name=h)

label_vocab_dict[h].index_dataset(train_set,field_name=h)
label_vocab_dict[h].index_dataset(dev_set,field_name=h)
label_vocab_dict[h].index_dataset(test_set,field_name=h)

train_set.set_input(Const.INPUT, Const.INPUT_LEN)
train_set.set_target(headers[1])

dev_set.set_input(Const.INPUT, Const.INPUT_LEN)
dev_set.set_target(headers[1])

test_set.set_input(Const.INPUT, Const.INPUT_LEN)
test_set.set_target(headers[1])

if len(headers) > 2:
print('警告:由于任务数量大于1,所以需要每次手动设置target!')


print('train:',len(train_set),'dev:',len(dev_set),'test:',len(test_set))

if embedding_path is not None:
pretrained_embedding = load_word_emb(embedding_path, 300, vocab)
return (train_set,dev_set,test_set),(vocab,label_vocab_dict),pretrained_embedding
else:
return (train_set, dev_set, test_set), (vocab, label_vocab_dict)


@cache_results(_cache_fp='OntonotesNER', _refresh=False)
def load_conllized_ontonote_NER(path,embedding_path=None):
from fastNLP.io.pipe.conll import OntoNotesNERPipe
ontoNotesNERPipe = OntoNotesNERPipe(lower=True,target_pad_val=-100)
bundle_NER = ontoNotesNERPipe.process_from_file(path)

train_set_NER = bundle_NER.datasets['train']
dev_set_NER = bundle_NER.datasets['dev']
test_set_NER = bundle_NER.datasets['test']

train_set_NER.add_seq_len('words','seq_len')
dev_set_NER.add_seq_len('words','seq_len')
test_set_NER.add_seq_len('words','seq_len')


NER_vocab = bundle_NER.get_vocab('target')
word_vocab = bundle_NER.get_vocab('words')

if embedding_path is not None:

embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
dropout=0.5,lower=True)


# pretrained_embedding = load_word_emb(embedding_path, 300, word_vocab)
return (train_set_NER,dev_set_NER,test_set_NER),\
(word_vocab,NER_vocab),embed
else:
return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab)

@cache_results(_cache_fp='OntonotesPOSNER', _refresh=False)

def load_conllized_ontonote_NER_POS(path,embedding_path=None):
from fastNLP.io.pipe.conll import OntoNotesNERPipe
ontoNotesNERPipe = OntoNotesNERPipe(lower=True)
bundle_NER = ontoNotesNERPipe.process_from_file(path)

train_set_NER = bundle_NER.datasets['train']
dev_set_NER = bundle_NER.datasets['dev']
test_set_NER = bundle_NER.datasets['test']

NER_vocab = bundle_NER.get_vocab('target')
word_vocab = bundle_NER.get_vocab('words')

(train_set_POS,dev_set_POS,test_set_POS),(_,POS_vocab) = load_conllized_ontonote_POS(path)
POS_vocab = POS_vocab['POS']

train_set_NER.add_field('pos',train_set_POS['POS'],is_target=True)
dev_set_NER.add_field('pos', dev_set_POS['POS'], is_target=True)
test_set_NER.add_field('pos', test_set_POS['POS'], is_target=True)

if train_set_NER.has_field('target'):
train_set_NER.rename_field('target','ner')

if dev_set_NER.has_field('target'):
dev_set_NER.rename_field('target','ner')

if test_set_NER.has_field('target'):
test_set_NER.rename_field('target','ner')



if train_set_NER.has_field('pos'):
train_set_NER.rename_field('pos','posid')
if dev_set_NER.has_field('pos'):
dev_set_NER.rename_field('pos','posid')
if test_set_NER.has_field('pos'):
test_set_NER.rename_field('pos','posid')

if train_set_NER.has_field('ner'):
train_set_NER.rename_field('ner','nerid')
if dev_set_NER.has_field('ner'):
dev_set_NER.rename_field('ner','nerid')
if test_set_NER.has_field('ner'):
test_set_NER.rename_field('ner','nerid')

if embedding_path is not None:

embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
dropout=0.5,lower=True)

return (train_set_NER,dev_set_NER,test_set_NER),\
(word_vocab,POS_vocab,NER_vocab),embed
else:
return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab)

@cache_results(_cache_fp='Ontonotes3', _refresh=True)
def load_conllized_ontonote_pkl(path,embedding_path=None):

data_bundle = pickle.load(open(path,'rb'))
train_set = data_bundle.datasets['train']
dev_set = data_bundle.datasets['dev']
test_set = data_bundle.datasets['test']

train_set.rename_field('pos','posid')
train_set.rename_field('ner','nerid')
train_set.rename_field('chunk','chunkid')

dev_set.rename_field('pos','posid')
dev_set.rename_field('ner','nerid')
dev_set.rename_field('chunk','chunkid')

test_set.rename_field('pos','posid')
test_set.rename_field('ner','nerid')
test_set.rename_field('chunk','chunkid')


word_vocab = data_bundle.vocabs['words']
pos_vocab = data_bundle.vocabs['pos']
ner_vocab = data_bundle.vocabs['ner']
chunk_vocab = data_bundle.vocabs['chunk']


if embedding_path is not None:

embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
dropout=0.5,lower=True)

return (train_set,dev_set,test_set),\
(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed
else:
return (train_set, dev_set, test_set), (word_vocab,ner_vocab)
# print(data_bundle)










# @cache_results(_cache_fp='Conll2003', _refresh=False)
# def load_conll_2003(path,embedding_path=None):
# f = open(path, 'rb')
# data_pkl = pickle.load(f)
#
# task_lst = data_pkl['task_lst']
# vocabs = data_pkl['vocabs']
# # word_vocab = vocabs['words']
# # pos_vocab = vocabs['pos']
# # chunk_vocab = vocabs['chunk']
# # ner_vocab = vocabs['ner']
#
# if embedding_path is not None:
# embed = StaticEmbedding(vocab=vocabs['words'], model_dir_or_name=embedding_path, word_dropout=0.01,
# dropout=0.5)
# return task_lst,vocabs,embed
# else:
# return task_lst,vocabs

# @cache_results(_cache_fp='Conll2003_mine', _refresh=False)
@cache_results(_cache_fp='Conll2003_mine_embed_100', _refresh=True)
def load_conll_2003_mine(path,embedding_path=None,pad_val=-100):
f = open(path, 'rb')

data_pkl = pickle.load(f)
# print(data_pkl)
# print(data_pkl)
train_set = data_pkl[0]['train']
dev_set = data_pkl[0]['dev']
test_set = data_pkl[0]['test']

train_set.set_pad_val('posid',pad_val)
train_set.set_pad_val('nerid', pad_val)
train_set.set_pad_val('chunkid', pad_val)

dev_set.set_pad_val('posid',pad_val)
dev_set.set_pad_val('nerid', pad_val)
dev_set.set_pad_val('chunkid', pad_val)

test_set.set_pad_val('posid',pad_val)
test_set.set_pad_val('nerid', pad_val)
test_set.set_pad_val('chunkid', pad_val)

if train_set.has_field('task_id'):

train_set.delete_field('task_id')

if dev_set.has_field('task_id'):
dev_set.delete_field('task_id')

if test_set.has_field('task_id'):
test_set.delete_field('task_id')

if train_set.has_field('words_idx'):
train_set.rename_field('words_idx','words')

if dev_set.has_field('words_idx'):
dev_set.rename_field('words_idx','words')

if test_set.has_field('words_idx'):
test_set.rename_field('words_idx','words')



word_vocab = data_pkl[1]['words']
pos_vocab = data_pkl[1]['pos']
ner_vocab = data_pkl[1]['ner']
chunk_vocab = data_pkl[1]['chunk']

if embedding_path is not None:
embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
dropout=0.5,lower=True)
return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed
else:
return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab)


def load_conllized_ontonote_pkl_yf(path):
def init_task(task):
task_name = task.task_name
for ds in [task.train_set, task.dev_set, task.test_set]:
if ds.has_field('words'):
ds.rename_field('words', 'x')
else:
ds.rename_field('words_idx', 'x')
if ds.has_field('label'):
ds.rename_field('label', 'y')
else:
ds.rename_field(task_name, 'y')
ds.set_input('x', 'y', 'task_id')
ds.set_target('y')

if task_name in ['ner', 'chunk'] or 'pos' in task_name:
ds.set_input('seq_len')
ds.set_target('seq_len')
return task
#/remote-home/yfshao/workdir/datasets/conll03/data.pkl
def pload(fn):
with open(fn, 'rb') as f:
return pickle.load(f)

DB = pload(path)
task_lst = DB['task_lst']
vocabs = DB['vocabs']
task_lst = [init_task(task) for task in task_lst]

return task_lst, vocabs


@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False)
def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True,
normlize={'char':True,'bigram':True,'word':False}):
from fastNLP.io.data_loader import ConllLoader
from utils import get_bigrams

loader = ConllLoader(['chars','target'])
bundle = loader.load(path)

datasets = bundle.datasets
for k,v in datasets.items():
print('{}:{}'.format(k,len(v)))
# print(*list(datasets.keys()))
vocabs = {}
word_vocab = Vocabulary()
bigram_vocab = Vocabulary()
label_vocab = Vocabulary(padding=None,unknown=None)

for k,v in datasets.items():
# ignore the word segmentation tag
v.apply_field(lambda x: [w[0] for w in x],'chars','chars')
v.apply_field(get_bigrams,'chars','bigrams')


word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']])
label_vocab.from_dataset(datasets['train'],field_name='target')
print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word))


for k,v in datasets.items():
# v.set_pad_val('target',-100)
v.add_seq_len('chars',new_field_name='seq_len')


vocabs['char'] = word_vocab
vocabs['label'] = label_vocab


bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']])
if index_token:
word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words')
bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams')
label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target')

# for k,v in datasets.items():
# v.set_input('chars','bigrams','seq_len','target')
# v.set_target('target','seq_len')

vocabs['bigram'] = bigram_vocab

embeddings = {}

if unigram_embedding_path is not None:
unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path,
word_dropout=0.01,normalize=normlize['char'])
embeddings['char'] = unigram_embedding

if bigram_embedding_path is not None:
bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path,
word_dropout=0.01,normalize=normlize['bigram'])
embeddings['bigram'] = bigram_embedding

return datasets, vocabs, embeddings



# datasets,vocabs = load_weibo_ner('/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo')
#
# print(datasets['train'][:5])
# print(vocabs['word'].idx2word)
# print(vocabs['target'].idx2word)


@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False)
def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True,
normalize={'char':True,'bigram':True,'word':False}):
from fastNLP.io.data_loader import ConllLoader
from utils import get_bigrams

train_path = os.path.join(path,'train.char.bmes')
dev_path = os.path.join(path,'dev.char.bmes')
test_path = os.path.join(path,'test.char.bmes')

loader = ConllLoader(['chars','target'])
train_bundle = loader.load(train_path)
dev_bundle = loader.load(dev_path)
test_bundle = loader.load(test_path)


datasets = dict()
datasets['train'] = train_bundle.datasets['train']
datasets['dev'] = dev_bundle.datasets['train']
datasets['test'] = test_bundle.datasets['train']


datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams')
datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')
datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')

datasets['train'].add_seq_len('chars')
datasets['dev'].add_seq_len('chars')
datasets['test'].add_seq_len('chars')



char_vocab = Vocabulary()
bigram_vocab = Vocabulary()
label_vocab = Vocabulary(padding=None,unknown=None)
print(datasets.keys())
print(len(datasets['dev']))
print(len(datasets['test']))
print(len(datasets['train']))
char_vocab.from_dataset(datasets['train'],field_name='chars',
no_create_entry_dataset=[datasets['dev'],datasets['test']] )
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',
no_create_entry_dataset=[datasets['dev'],datasets['test']])
label_vocab.from_dataset(datasets['train'],field_name='target')
if index_token:
char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
field_name='chars',new_field_name='chars')
bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
field_name='bigrams',new_field_name='bigrams')
label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
field_name='target',new_field_name='target')

vocabs = {}
vocabs['char'] = char_vocab
vocabs['label'] = label_vocab
vocabs['bigram'] = bigram_vocab
vocabs['label'] = label_vocab

embeddings = {}
if char_embedding_path is not None:
char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,
normalize=normalize['char'])
embeddings['char'] = char_embedding

if bigram_embedding_path is not None:
bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,
normalize=normalize['bigram'])
embeddings['bigram'] = bigram_embedding

return datasets,vocabs,embeddings



@cache_results(_cache_fp='cache/resume_ner',_refresh=False)
def load_resume_ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True,
normalize={'char':True,'bigram':True,'word':False}):
from fastNLP.io.data_loader import ConllLoader
from utils import get_bigrams

train_path = os.path.join(path,'train.char.bmes')
dev_path = os.path.join(path,'dev.char.bmes')
test_path = os.path.join(path,'test.char.bmes')

loader = ConllLoader(['chars','target'])
train_bundle = loader.load(train_path)
dev_bundle = loader.load(dev_path)
test_bundle = loader.load(test_path)


datasets = dict()
datasets['train'] = train_bundle.datasets['train']
datasets['dev'] = dev_bundle.datasets['train']
datasets['test'] = test_bundle.datasets['train']


datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams')
datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')
datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')

datasets['train'].add_seq_len('chars')
datasets['dev'].add_seq_len('chars')
datasets['test'].add_seq_len('chars')



char_vocab = Vocabulary()
bigram_vocab = Vocabulary()
label_vocab = Vocabulary(padding=None,unknown=None)
print(datasets.keys())
print(len(datasets['dev']))
print(len(datasets['test']))
print(len(datasets['train']))
char_vocab.from_dataset(datasets['train'],field_name='chars',
no_create_entry_dataset=[datasets['dev'],datasets['test']] )
bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',
no_create_entry_dataset=[datasets['dev'],datasets['test']])
label_vocab.from_dataset(datasets['train'],field_name='target')
if index_token:
char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
field_name='chars',new_field_name='chars')
bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
field_name='bigrams',new_field_name='bigrams')
label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
field_name='target',new_field_name='target')

vocabs = {}
vocabs['char'] = char_vocab
vocabs['label'] = label_vocab
vocabs['bigram'] = bigram_vocab

embeddings = {}
if char_embedding_path is not None:
char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,normalize=normalize['char'])
embeddings['char'] = char_embedding

if bigram_embedding_path is not None:
bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,normalize=normalize['bigram'])
embeddings['bigram'] = bigram_embedding

return datasets,vocabs,embeddings


@cache_results(_cache_fp='need_to_defined_fp',_refresh=False)
def equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,word_embedding_path=None,
normalize={'char':True,'bigram':True,'word':False}):
from utils_ import Trie,get_skip_path
from functools import partial
w_trie = Trie()
for w in w_list:
w_trie.insert(w)

# for k,v in datasets.items():
# v.apply_field(partial(get_skip_path,w_trie=w_trie),'chars','skips')

def skips2skips_l2r(chars,w_trie):
'''

:param lexicons: list[[int,int,str]]
:return: skips_l2r
'''
# print(lexicons)
# print('******')

lexicons = get_skip_path(chars,w_trie=w_trie)


# max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0

result = [[] for _ in range(len(chars))]

for lex in lexicons:
s = lex[0]
e = lex[1]
w = lex[2]

result[e].append([s,w])

return result

def skips2skips_r2l(chars,w_trie):
'''

:param lexicons: list[[int,int,str]]
:return: skips_l2r
'''
# print(lexicons)
# print('******')

lexicons = get_skip_path(chars,w_trie=w_trie)


# max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0

result = [[] for _ in range(len(chars))]

for lex in lexicons:
s = lex[0]
e = lex[1]
w = lex[2]

result[s].append([e,w])

return result

for k,v in datasets.items():
v.apply_field(partial(skips2skips_l2r,w_trie=w_trie),'chars','skips_l2r')

for k,v in datasets.items():
v.apply_field(partial(skips2skips_r2l,w_trie=w_trie),'chars','skips_r2l')

# print(v['skips_l2r'][0])
word_vocab = Vocabulary()
word_vocab.add_word_lst(w_list)
vocabs['word'] = word_vocab
for k,v in datasets.items():
v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_l2r','skips_l2r_source')
v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_l2r', 'skips_l2r_word')

for k,v in datasets.items():
v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_r2l','skips_r2l_source')
v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_r2l', 'skips_r2l_word')

for k,v in datasets.items():
v.apply_field(lambda x:list(map(len,x)), 'skips_l2r_word', 'lexicon_count')
v.apply_field(lambda x:
list(map(lambda y:
list(map(lambda z:word_vocab.to_index(z),y)),x)),
'skips_l2r_word',new_field_name='skips_l2r_word')

v.apply_field(lambda x:list(map(len,x)), 'skips_r2l_word', 'lexicon_count_back')

v.apply_field(lambda x:
list(map(lambda y:
list(map(lambda z:word_vocab.to_index(z),y)),x)),
'skips_r2l_word',new_field_name='skips_r2l_word')





if word_embedding_path is not None:
word_embedding = StaticEmbedding(word_vocab,word_embedding_path,word_dropout=0,normalize=normalize['word'])
embeddings['word'] = word_embedding

vocabs['char'].index_dataset(datasets['train'], datasets['dev'], datasets['test'],
field_name='chars', new_field_name='chars')
vocabs['bigram'].index_dataset(datasets['train'], datasets['dev'], datasets['test'],
field_name='bigrams', new_field_name='bigrams')
vocabs['label'].index_dataset(datasets['train'], datasets['dev'], datasets['test'],
field_name='target', new_field_name='target')

return datasets,vocabs,embeddings



@cache_results(_cache_fp='cache/load_yangjie_rich_pretrain_word_list',_refresh=False)
def load_yangjie_rich_pretrain_word_list(embedding_path,drop_characters=True):
f = open(embedding_path,'r')
lines = f.readlines()
w_list = []
for line in lines:
splited = line.strip().split(' ')
w = splited[0]
w_list.append(w)

if drop_characters:
w_list = list(filter(lambda x:len(x) != 1, w_list))

return w_list



# from pathes import *
#
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,
# yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path)
# print(datasets.keys())
# print(vocabs.keys())
# print(embeddings)
# yangjie_rich_pretrain_word_path
# datasets['train'].set_pad_val

+ 189
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/main.py View File

@@ -0,0 +1,189 @@
import torch.nn as nn
# from pathes import *
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,load_resume_ner,load_weibo_ner
from fastNLP.embeddings import StaticEmbedding
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward
import torch.optim as optim
import argparse
import torch
import sys
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ
from fastNLP import Tester
import fitlog
from fastNLP.core.callback import FitlogCallback
from utils import set_seed
import os
from fastNLP import LRScheduler
from torch.optim.lr_scheduler import LambdaLR

parser = argparse.ArgumentParser()
parser.add_argument('--device',default='cuda:4')
parser.add_argument('--debug',default=False)

parser.add_argument('--norm_embed',default=True)
parser.add_argument('--batch',default=10)
parser.add_argument('--test_batch',default=1024)
parser.add_argument('--optim',default='sgd',help='adam|sgd')
parser.add_argument('--lr',default=0.045)
parser.add_argument('--model',default='lattice',help='lattice|lstm')
parser.add_argument('--skip_before_head',default=False)#in paper it's false
parser.add_argument('--hidden',default=100)
parser.add_argument('--momentum',default=0)
parser.add_argument('--bi',default=True)
parser.add_argument('--dataset',default='ontonote',help='resume|ontonote|weibo|msra')
parser.add_argument('--use_bigram',default=True)

parser.add_argument('--embed_dropout',default=0.5)
parser.add_argument('--output_dropout',default=0.5)
parser.add_argument('--epoch',default=100)
parser.add_argument('--seed',default=100)

args = parser.parse_args()

set_seed(args.seed)

fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)]
if args.model == 'lattice':
fit_msg_list.append(str(args.skip_before_head))
fit_msg = ' '.join(fit_msg_list)
fitlog.commit(__file__,fit_msg=fit_msg)


fitlog.add_hyper(args)
device = torch.device(args.device)
for k,v in args.__dict__.items():
print(k,v)

refresh_data = False


from pathes import *
# ontonote4ner_cn_path = 0
# yangjie_rich_pretrain_unigram_path = 0
# yangjie_rich_pretrain_bigram_path = 0
# resume_ner_path = 0
# weibo_ner_path = 0

if args.dataset == 'ontonote':
datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
_refresh=refresh_data,index_token=False,
)
elif args.dataset == 'resume':
datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
_refresh=refresh_data,index_token=False,
)
elif args.dataset == 'weibo':
datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
_refresh=refresh_data,index_token=False,
)

if args.dataset == 'ontonote':
args.batch = 10
args.lr = 0.045
elif args.dataset == 'resume':
args.batch = 1
args.lr = 0.015
elif args.dataset == 'weibo':
args.embed_dropout = 0.1
args.output_dropout = 0.1

w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
_refresh=refresh_data)

cache_name = os.path.join('cache',args.dataset+'_lattice')
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
_refresh=refresh_data,_cache_fp=cache_name)

print(datasets['train'][0])
print('vocab info:')
for k,v in vocabs.items():
print('{}:{}'.format(k,len(v)))

for k,v in datasets.items():
if args.model == 'lattice':
v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source')
if args.skip_before_head:
v.set_padder('skips_l2r_word',LatticeLexiconPadder())
v.set_padder('skips_l2r_source',LatticeLexiconPadder())
v.set_padder('skips_r2l_word',LatticeLexiconPadder())
v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True))
else:
v.set_padder('skips_l2r_word',LatticeLexiconPadder())
v.set_padder('skips_r2l_word', LatticeLexiconPadder())
v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1))
v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1))
if args.bi:
v.set_input('chars','bigrams','seq_len',
'skips_l2r_word','skips_l2r_source','lexicon_count',
'skips_r2l_word', 'skips_r2l_source','lexicon_count_back',
'target',
use_1st_ins_infer_dim_type=True)
else:
v.set_input('chars','bigrams','seq_len',
'skips_l2r_word','skips_l2r_source','lexicon_count',
'target',
use_1st_ins_infer_dim_type=True)
v.set_target('target','seq_len')

v['target'].set_pad_val(0)
elif args.model == 'lstm':
v.set_ignore_type('skips_l2r_word','skips_l2r_source')
v.set_padder('skips_l2r_word',LatticeLexiconPadder())
v.set_padder('skips_l2r_source',LatticeLexiconPadder())
v.set_input('chars','bigrams','seq_len','target',
use_1st_ins_infer_dim_type=True)
v.set_target('target','seq_len')

v['target'].set_pad_val(0)

print(datasets['dev']['skips_l2r_word'][100])


if args.model =='lattice':
model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'],
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
skip_batch_first=True,bidirectional=args.bi,debug=args.debug,
skip_before_head=args.skip_before_head,use_bigram=args.use_bigram
)
elif args.model == 'lstm':
model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'],
hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
bidirectional=args.bi,
embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
use_bigram=args.use_bigram)


loss = LossInForward()

f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmeso')
f1_metric_yj = SpanFPreRecMetric_YJ(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type='bmesoyj')
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len')
metrics = [f1_metric,f1_metric_yj,acc_metric]

if args.optim == 'adam':
optimizer = optim.Adam(model.parameters(),lr=args.lr)
elif args.optim == 'sgd':
optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum)




callbacks = [
FitlogCallback({'test':datasets['test'],'train':datasets['train']}),
LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep))
]

trainer = Trainer(datasets['train'],model,
optimizer=optimizer,
loss=loss,
metrics=metrics,
dev_data=datasets['dev'],
device=device,
batch_size=args.batch,
n_epochs=args.epoch,
dev_batch_size=args.test_batch,
callbacks=callbacks)

trainer.train()

+ 299
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/models.py View File

@@ -0,0 +1,299 @@
import torch.nn as nn
from fastNLP.embeddings import StaticEmbedding
from fastNLP.modules import LSTM, ConditionalRandomField
import torch
from fastNLP import seq_len_to_mask
from utils import better_init_rnn


class LatticeLSTM_SeqLabel(nn.Module):
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False,
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False,
skip_before_head=False,use_bigram=True,vocabs=None):
if device is None:
self.device = torch.device('cpu')
else:
self.device = torch.device(device)
from modules import LatticeLSTMLayer_sup_back_V0
super().__init__()
self.debug = debug
self.skip_batch_first = skip_batch_first
self.char_embed_size = char_embed.embedding.weight.size(1)
self.bigram_embed_size = bigram_embed.embedding.weight.size(1)
self.word_embed_size = word_embed.embedding.weight.size(1)
self.hidden_size = hidden_size
self.label_size = label_size
self.bidirectional = bidirectional
self.use_bigram = use_bigram
self.vocabs = vocabs

if self.use_bigram:
self.input_size = self.char_embed_size + self.bigram_embed_size
else:
self.input_size = self.char_embed_size

self.char_embed = char_embed
self.bigram_embed = bigram_embed
self.word_embed = word_embed
self.encoder = LatticeLSTMLayer_sup_back_V0(self.input_size,self.word_embed_size,
self.hidden_size,
left2right=True,
bias=bias,
device=self.device,
debug=self.debug,
skip_before_head=skip_before_head)
if self.bidirectional:
self.encoder_back = LatticeLSTMLayer_sup_back_V0(self.input_size,
self.word_embed_size, self.hidden_size,
left2right=False,
bias=bias,
device=self.device,
debug=self.debug,
skip_before_head=skip_before_head)

self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size)
self.crf = ConditionalRandomField(label_size, True)

self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True))
if self.crf.include_start_end_trans:
self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))
self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))

self.loss_func = nn.CrossEntropyLoss()
self.embed_dropout = nn.Dropout(embed_dropout)
self.output_dropout = nn.Dropout(output_dropout)

def forward(self, chars, bigrams, seq_len, target,
skips_l2r_source, skips_l2r_word, lexicon_count,
skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None):
# print('skips_l2r_word_id:{}'.format(skips_l2r_word.size()))
batch = chars.size(0)
max_seq_len = chars.size(1)
# max_lexicon_count = skips_l2r_word.size(2)


embed_char = self.char_embed(chars)
if self.use_bigram:

embed_bigram = self.bigram_embed(bigrams)

embedding = torch.cat([embed_char, embed_bigram], dim=-1)
else:

embedding = embed_char


embed_nonword = self.embed_dropout(embedding)

# skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1])
embed_word = self.word_embed(skips_l2r_word)
embed_word = self.embed_dropout(embed_word)
# embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1])


encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count)

if self.bidirectional:
embed_word_back = self.word_embed(skips_r2l_word)
embed_word_back = self.embed_dropout(embed_word_back)
encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source,
embed_word_back, lexicon_count_back)
encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1)

encoded_h = self.output_dropout(encoded_h)

pred = self.output(encoded_h)

mask = seq_len_to_mask(seq_len)

if self.training:
loss = self.crf(pred, target, mask)
return {'loss': loss}
else:
pred, path = self.crf.viterbi_decode(pred, mask)
return {'pred': pred}

# batch_size, sent_len = pred.shape[0], pred.shape[1]
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len))
# return {'pred':pred,'loss':loss}

class LatticeLSTM_SeqLabel_V1(nn.Module):
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False,
device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False,
skip_before_head=False,use_bigram=True,vocabs=None):
if device is None:
self.device = torch.device('cpu')
else:
self.device = torch.device(device)
from modules import LatticeLSTMLayer_sup_back_V1
super().__init__()
self.count = 0
self.debug = debug
self.skip_batch_first = skip_batch_first
self.char_embed_size = char_embed.embedding.weight.size(1)
self.bigram_embed_size = bigram_embed.embedding.weight.size(1)
self.word_embed_size = word_embed.embedding.weight.size(1)
self.hidden_size = hidden_size
self.label_size = label_size
self.bidirectional = bidirectional
self.use_bigram = use_bigram
self.vocabs = vocabs

if self.use_bigram:
self.input_size = self.char_embed_size + self.bigram_embed_size
else:
self.input_size = self.char_embed_size

self.char_embed = char_embed
self.bigram_embed = bigram_embed
self.word_embed = word_embed
self.encoder = LatticeLSTMLayer_sup_back_V1(self.input_size,self.word_embed_size,
self.hidden_size,
left2right=True,
bias=bias,
device=self.device,
debug=self.debug,
skip_before_head=skip_before_head)
if self.bidirectional:
self.encoder_back = LatticeLSTMLayer_sup_back_V1(self.input_size,
self.word_embed_size, self.hidden_size,
left2right=False,
bias=bias,
device=self.device,
debug=self.debug,
skip_before_head=skip_before_head)

self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size)
self.crf = ConditionalRandomField(label_size, True)

self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True))
if self.crf.include_start_end_trans:
self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))
self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))

self.loss_func = nn.CrossEntropyLoss()
self.embed_dropout = nn.Dropout(embed_dropout)
self.output_dropout = nn.Dropout(output_dropout)

def forward(self, chars, bigrams, seq_len, target,
skips_l2r_source, skips_l2r_word, lexicon_count,
skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None):

batch = chars.size(0)
max_seq_len = chars.size(1)



embed_char = self.char_embed(chars)
if self.use_bigram:

embed_bigram = self.bigram_embed(bigrams)

embedding = torch.cat([embed_char, embed_bigram], dim=-1)
else:

embedding = embed_char


embed_nonword = self.embed_dropout(embedding)

# skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1])
embed_word = self.word_embed(skips_l2r_word)
embed_word = self.embed_dropout(embed_word)



encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count)

if self.bidirectional:
embed_word_back = self.word_embed(skips_r2l_word)
embed_word_back = self.embed_dropout(embed_word_back)
encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source,
embed_word_back, lexicon_count_back)
encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1)

encoded_h = self.output_dropout(encoded_h)

pred = self.output(encoded_h)

mask = seq_len_to_mask(seq_len)

if self.training:
loss = self.crf(pred, target, mask)
return {'loss': loss}
else:
pred, path = self.crf.viterbi_decode(pred, mask)
return {'pred': pred}


class LSTM_SeqLabel(nn.Module):
def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True,
bidirectional=False, device=None, embed_dropout=0, output_dropout=0,use_bigram=True):

if device is None:
self.device = torch.device('cpu')
else:
self.device = torch.device(device)
super().__init__()
self.char_embed_size = char_embed.embedding.weight.size(1)
self.bigram_embed_size = bigram_embed.embedding.weight.size(1)
self.word_embed_size = word_embed.embedding.weight.size(1)
self.hidden_size = hidden_size
self.label_size = label_size
self.bidirectional = bidirectional
self.use_bigram = use_bigram

self.char_embed = char_embed
self.bigram_embed = bigram_embed
self.word_embed = word_embed

if self.use_bigram:
self.input_size = self.char_embed_size + self.bigram_embed_size
else:
self.input_size = self.char_embed_size

self.encoder = LSTM(self.input_size, self.hidden_size,
bidirectional=self.bidirectional)

better_init_rnn(self.encoder.lstm)

self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size)

self.debug = False
self.loss_func = nn.CrossEntropyLoss()
self.embed_dropout = nn.Dropout(embed_dropout)
self.output_dropout = nn.Dropout(output_dropout)
self.crf = ConditionalRandomField(label_size, True)

def forward(self, chars, bigrams, seq_len, target):
embed_char = self.char_embed(chars)

if self.use_bigram:

embed_bigram = self.bigram_embed(bigrams)

embedding = torch.cat([embed_char, embed_bigram], dim=-1)
else:

embedding = embed_char

embedding = self.embed_dropout(embedding)

encoded_h, encoded_c = self.encoder(embedding, seq_len)

encoded_h = self.output_dropout(encoded_h)

pred = self.output(encoded_h)

mask = seq_len_to_mask(seq_len)

# pred = self.crf(pred)

# batch_size, sent_len = pred.shape[0], pred.shape[1]
# loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len))
if self.training:
loss = self.crf(pred, target, mask)
return {'loss': loss}
else:
pred, path = self.crf.viterbi_decode(pred, mask)
return {'pred': pred}

+ 638
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/modules.py View File

@@ -0,0 +1,638 @@
import torch.nn as nn
import torch
from fastNLP.core.utils import seq_len_to_mask
from utils import better_init_rnn
import numpy as np


class WordLSTMCell_yangjie(nn.Module):

"""A basic LSTM cell."""

def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True):
"""
Most parts are copied from torch.nn.LSTMCell.
"""

super().__init__()
self.left2right = left2right
self.debug = debug
self.input_size = input_size
self.hidden_size = hidden_size
self.use_bias = use_bias
self.weight_ih = nn.Parameter(
torch.FloatTensor(input_size, 3 * hidden_size))
self.weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, 3 * hidden_size))
if use_bias:
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
else:
self.register_parameter('bias', None)
self.reset_parameters()

def reset_parameters(self):
"""
Initialize parameters following the way proposed in the paper.
"""
nn.init.orthogonal(self.weight_ih.data)
weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 3)
with torch.no_grad():
self.weight_hh.set_(weight_hh_data)
# The bias is just set to zero vectors.
if self.use_bias:
nn.init.constant(self.bias.data, val=0)

def forward(self, input_, hx):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""

h_0, c_0 = hx



batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
wi = torch.mm(input_, self.weight_ih)
f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)

return c_1

def __repr__(self):
s = '{name}({input_size}, {hidden_size})'
return s.format(name=self.__class__.__name__, **self.__dict__)


class MultiInputLSTMCell_V0(nn.Module):
def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
super().__init__()
self.char_input_size = char_input_size
self.hidden_size = hidden_size
self.use_bias = use_bias

self.weight_ih = nn.Parameter(
torch.FloatTensor(char_input_size, 3 * hidden_size)
)

self.weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, 3 * hidden_size)
)

self.alpha_weight_ih = nn.Parameter(
torch.FloatTensor(char_input_size, hidden_size)
)

self.alpha_weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, hidden_size)
)

if self.use_bias:
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
else:
self.register_parameter('bias', None)
self.register_parameter('alpha_bias', None)

self.debug = debug
self.reset_parameters()

def reset_parameters(self):
"""
Initialize parameters following the way proposed in the paper.
"""
nn.init.orthogonal(self.weight_ih.data)
nn.init.orthogonal(self.alpha_weight_ih.data)

weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 3)
with torch.no_grad():
self.weight_hh.set_(weight_hh_data)

alpha_weight_hh_data = torch.eye(self.hidden_size)
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
with torch.no_grad():
self.alpha_weight_hh.set_(alpha_weight_hh_data)

# The bias is just set to zero vectors.
if self.use_bias:
nn.init.constant_(self.bias.data, val=0)
nn.init.constant_(self.alpha_bias.data, val=0)

def forward(self, inp, skip_c, skip_count, hx):
'''

:param inp: chars B * hidden
:param skip_c: 由跳边得到的c, B * X * hidden
:param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
:param hx:
:return:
'''
max_skip_count = torch.max(skip_count).item()



if True:
h_0, c_0 = hx
batch_size = h_0.size(0)

bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

wi = torch.matmul(inp, self.weight_ih)
wh = torch.matmul(h_0, self.weight_hh)



i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)

i = torch.sigmoid(i).unsqueeze(1)
o = torch.sigmoid(o).unsqueeze(1)
g = torch.tanh(g).unsqueeze(1)



alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
alpha_wi.unsqueeze_(1)

# alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size)
alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)

alpha_bias_batch = self.alpha_bias.unsqueeze(0)

alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)

skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1])

skip_mask = 1 - skip_mask


skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)

skip_mask = (skip_mask).float()*1e20

alpha = alpha - skip_mask

alpha = torch.exp(torch.cat([i, alpha], dim=1))



alpha_sum = torch.sum(alpha, dim=1, keepdim=True)

alpha = torch.div(alpha, alpha_sum)

merge_i_c = torch.cat([g, skip_c], dim=1)

c_1 = merge_i_c * alpha

c_1 = c_1.sum(1, keepdim=True)
# h_1 = o * c_1
h_1 = o * torch.tanh(c_1)

return h_1.squeeze(1), c_1.squeeze(1)

else:

h_0, c_0 = hx
batch_size = h_0.size(0)

bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

wi = torch.matmul(inp, self.weight_ih)
wh = torch.matmul(h_0, self.weight_hh)

i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)

i = torch.sigmoid(i).unsqueeze(1)
o = torch.sigmoid(o).unsqueeze(1)
g = torch.tanh(g).unsqueeze(1)

c_1 = g
h_1 = o * c_1

return h_1,c_1

class MultiInputLSTMCell_V1(nn.Module):
def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
super().__init__()
self.char_input_size = char_input_size
self.hidden_size = hidden_size
self.use_bias = use_bias

self.weight_ih = nn.Parameter(
torch.FloatTensor(char_input_size, 3 * hidden_size)
)

self.weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, 3 * hidden_size)
)

self.alpha_weight_ih = nn.Parameter(
torch.FloatTensor(char_input_size, hidden_size)
)

self.alpha_weight_hh = nn.Parameter(
torch.FloatTensor(hidden_size, hidden_size)
)

if self.use_bias:
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
else:
self.register_parameter('bias', None)
self.register_parameter('alpha_bias', None)

self.debug = debug
self.reset_parameters()

def reset_parameters(self):
"""
Initialize parameters following the way proposed in the paper.
"""
nn.init.orthogonal(self.weight_ih.data)
nn.init.orthogonal(self.alpha_weight_ih.data)

weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 3)
with torch.no_grad():
self.weight_hh.set_(weight_hh_data)

alpha_weight_hh_data = torch.eye(self.hidden_size)
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
with torch.no_grad():
self.alpha_weight_hh.set_(alpha_weight_hh_data)

# The bias is just set to zero vectors.
if self.use_bias:
nn.init.constant_(self.bias.data, val=0)
nn.init.constant_(self.alpha_bias.data, val=0)

def forward(self, inp, skip_c, skip_count, hx):
'''

:param inp: chars B * hidden
:param skip_c: 由跳边得到的c, B * X * hidden
:param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
:param hx:
:return:
'''
max_skip_count = torch.max(skip_count).item()



if True:
h_0, c_0 = hx
batch_size = h_0.size(0)

bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

wi = torch.matmul(inp, self.weight_ih)
wh = torch.matmul(h_0, self.weight_hh)


i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)

i = torch.sigmoid(i).unsqueeze(1)
o = torch.sigmoid(o).unsqueeze(1)
g = torch.tanh(g).unsqueeze(1)



##basic lstm start

f = 1 - i
c_1_basic = f*c_0.unsqueeze(1) + i*g
c_1_basic = c_1_basic.squeeze(1)





alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
alpha_wi.unsqueeze_(1)


alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)

alpha_bias_batch = self.alpha_bias.unsqueeze(0)

alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)

skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1])

skip_mask = 1 - skip_mask


skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)

skip_mask = (skip_mask).float()*1e20

alpha = alpha - skip_mask

alpha = torch.exp(torch.cat([i, alpha], dim=1))



alpha_sum = torch.sum(alpha, dim=1, keepdim=True)

alpha = torch.div(alpha, alpha_sum)

merge_i_c = torch.cat([g, skip_c], dim=1)

c_1 = merge_i_c * alpha

c_1 = c_1.sum(1, keepdim=True)
# h_1 = o * c_1
c_1 = c_1.squeeze(1)
count_select = (skip_count != 0).float().unsqueeze(-1)




c_1 = c_1*count_select + c_1_basic*(1-count_select)


o = o.squeeze(1)
h_1 = o * torch.tanh(c_1)

return h_1, c_1

class LatticeLSTMLayer_sup_back_V0(nn.Module):
def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
bias=True,device=None,debug=False,skip_before_head=False):
super().__init__()

self.skip_before_head = skip_before_head

self.hidden_size = hidden_size

self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug)

self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)

self.word_input_size = word_input_size
self.left2right = left2right
self.bias = bias
self.device = device
self.debug = debug

def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
'''

:param inp: batch * seq_len * embedding, chars
:param seq_len: batch, length of chars
:param skip_sources: batch * seq_len * X, 跳边的起点
:param skip_words: batch * seq_len * X * embedding, 跳边的词
:param lexicon_count: batch * seq_len, count of lexicon per example per position
:param init_state: the hx of rnn
:return:
'''


if self.left2right:

max_seq_len = max(seq_len)
batch_size = inp.size(0)
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

for i in range(max_seq_len):
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
h_0, c_0 = h_[:, i, :], c_[:, i, :]

skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
index_1 = skip_source_flat

if not self.skip_before_head:
c_x = c_[[index_0, index_1+1]]
h_x = h_[[index_0, index_1+1]]
else:
c_x = c_[[index_0,index_1]]
h_x = h_[[index_0,index_1]]

c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)




c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))

c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)

return h_[:,1:],c_[:,1:]
else:
mask_for_seq_len = seq_len_to_mask(seq_len)

max_seq_len = max(seq_len)
batch_size = inp.size(0)
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

for i in reversed(range(max_seq_len)):
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)



h_0, c_0 = h_[:, 0, :], c_[:, 0, :]

skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
index_1 = skip_source_flat-i

if not self.skip_before_head:
c_x = c_[[index_0, index_1-1]]
h_x = h_[[index_0, index_1-1]]
else:
c_x = c_[[index_0,index_1]]
h_x = h_[[index_0,index_1]]

c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)




c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))

c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0)
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0)


h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)

return h_[:,:-1],c_[:,:-1]

class LatticeLSTMLayer_sup_back_V1(nn.Module):
# V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式,
# 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同
def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
bias=True,device=None,debug=False,skip_before_head=False):
super().__init__()

self.debug = debug

self.skip_before_head = skip_before_head

self.hidden_size = hidden_size

self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug)

self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)

self.word_input_size = word_input_size
self.left2right = left2right
self.bias = bias
self.device = device

def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
'''

:param inp: batch * seq_len * embedding, chars
:param seq_len: batch, length of chars
:param skip_sources: batch * seq_len * X, 跳边的起点
:param skip_words: batch * seq_len * X * embedding_size, 跳边的词
:param lexicon_count: batch * seq_len,
lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量
:param init_state: the hx of rnn
:return:
'''


if self.left2right:

max_seq_len = max(seq_len)
batch_size = inp.size(0)
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

for i in range(max_seq_len):
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
h_0, c_0 = h_[:, i, :], c_[:, i, :]

#为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量
#为了匹配pytorch的[]取址方式,需要将reshape成二维张量

skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
index_1 = skip_source_flat


if not self.skip_before_head:
c_x = c_[[index_0, index_1+1]]
h_x = h_[[index_0, index_1+1]]
else:
c_x = c_[[index_0,index_1]]
h_x = h_[[index_0,index_1]]

c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)



c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))

c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)

return h_[:,1:],c_[:,1:]
else:
mask_for_seq_len = seq_len_to_mask(seq_len)

max_seq_len = max(seq_len)
batch_size = inp.size(0)
c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

for i in reversed(range(max_seq_len)):
max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)


h_0, c_0 = h_[:, 0, :], c_[:, 0, :]

skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
index_1 = skip_source_flat-i

if not self.skip_before_head:
c_x = c_[[index_0, index_1-1]]
h_x = h_[[index_0, index_1-1]]
else:
c_x = c_[[index_0,index_1]]
h_x = h_[[index_0,index_1]]

c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)




c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))



c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0)
c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0)


h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)



return h_[:,:-1],c_[:,:-1]





+ 23
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/pathes.py View File

@@ -0,0 +1,23 @@


glove_100_path = 'en-glove-6b-100d'
glove_50_path = 'en-glove-6b-50d'
glove_200_path = ''
glove_300_path = 'en-glove-840b-300'
fasttext_path = 'en-fasttext' #300
tencent_chinese_word_path = 'cn' # tencent 200
fasttext_cn_path = 'cn-fasttext' # 300
yangjie_rich_pretrain_unigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.uni.ite50.vec'
yangjie_rich_pretrain_bigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.bi.ite50.vec'
yangjie_rich_pretrain_word_path = '/remote-home/xnli/data/pretrain/chinese/ctb.50d.vec'


conll_2003_path = '/remote-home/xnli/data/corpus/multi_task/conll_2013/data_mine.pkl'
conllized_ontonote_path = '/remote-home/txsun/data/OntoNotes-5.0-NER-master/v12/english'
conllized_ontonote_pkl_path = '/remote-home/txsun/data/ontonotes5.pkl'
sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/'
# weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo'
ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER'
msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER'
resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER'
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER'

+ 126
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/small.py View File

@@ -0,0 +1,126 @@
from utils_ import get_skip_path_trivial, Trie, get_skip_path
from load_data import load_yangjie_rich_pretrain_word_list, load_ontonotes4ner, equip_chinese_ner_with_skip
from pathes import *
from functools import partial
from fastNLP import cache_results
from fastNLP.embeddings.static_embedding import StaticEmbedding
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastNLP.core.metrics import _bmes_tag_to_spans,_bmeso_tag_to_spans
from load_data import load_resume_ner


# embed = StaticEmbedding(None,embedding_dim=2)
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
# _refresh=True,index_token=False)
#
# w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
# _refresh=False)
#
# datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
# _refresh=True)
#

def reverse_style(input_string):
target_position = input_string.index('[')
input_len = len(input_string)
output_string = input_string[target_position:input_len] + input_string[0:target_position]
# print('in:{}.out:{}'.format(input_string, output_string))
return output_string





def get_yangjie_bmeso(label_list):
def get_ner_BMESO_yj(label_list):
# list_len = len(word_list)
# assert(list_len == len(label_list)), "word list size unmatch with label list"
list_len = len(label_list)
begin_label = 'b-'
end_label = 'e-'
single_label = 's-'
whole_tag = ''
index_tag = ''
tag_list = []
stand_matrix = []
for i in range(0, list_len):
# wordlabel = word_list[i]
current_label = label_list[i].lower()
if begin_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i - 1))
whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i)
index_tag = current_label.replace(begin_label, "", 1)

elif single_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i - 1))
whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i)
tag_list.append(whole_tag)
whole_tag = ""
index_tag = ""
elif end_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i))
whole_tag = ''
index_tag = ''
else:
continue
if (whole_tag != '') & (index_tag != ''):
tag_list.append(whole_tag)
tag_list_len = len(tag_list)

for i in range(0, tag_list_len):
if len(tag_list[i]) > 0:
tag_list[i] = tag_list[i] + ']'
insert_list = reverse_style(tag_list[i])
stand_matrix.append(insert_list)
# print stand_matrix
return stand_matrix

def transform_YJ_to_fastNLP(span):
span = span[1:]
span_split = span.split(']')
# print('span_list:{}'.format(span_split))
span_type = span_split[1]
# print('span_split[0].split(','):{}'.format(span_split[0].split(',')))
if ',' in span_split[0]:
b, e = span_split[0].split(',')
else:
b = span_split[0]
e = b

b = int(b)
e = int(e)

e += 1

return (span_type, (b, e))
yj_form = get_ner_BMESO_yj(label_list)
# print('label_list:{}'.format(label_list))
# print('yj_from:{}'.format(yj_form))
fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form))
return fastNLP_form


# tag_list = ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']
# span_list = get_ner_BMES(tag_list)
# print(span_list)
# yangjie_label_list = ['B-NAME', 'E-NAME', 'O', 'B-CONT', 'M-CONT', 'E-CONT', 'B-RACE', 'E-RACE', 'B-TITLE', 'M-TITLE', 'E-TITLE', 'B-EDU', 'M-EDU', 'E-EDU', 'B-ORG', 'M-ORG', 'E-ORG', 'M-NAME', 'B-PRO', 'M-PRO', 'E-PRO', 'S-RACE', 'S-NAME', 'B-LOC', 'M-LOC', 'E-LOC', 'M-RACE', 'S-ORG']
# my_label_list = ['O', 'M-ORG', 'M-TITLE', 'B-TITLE', 'E-TITLE', 'B-ORG', 'E-ORG', 'M-EDU', 'B-NAME', 'E-NAME', 'B-EDU', 'E-EDU', 'M-NAME', 'M-PRO', 'M-CONT', 'B-PRO', 'E-PRO', 'B-CONT', 'E-CONT', 'M-LOC', 'B-RACE', 'E-RACE', 'S-NAME', 'B-LOC', 'E-LOC', 'M-RACE', 'S-RACE', 'S-ORG']
# yangjie_label = set(yangjie_label_list)
# my_label = set(my_label_list)

a = torch.tensor([0,2,0,3])
b = (a==0)
print(b)
print(b.float())
from fastNLP import RandomSampler

# f = open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')
# weight_dict = torch.load(f)
# print(weight_dict.keys())
# for k,v in weight_dict.items():
# print("{}:{}".format(k,v.size()))

+ 361
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils.py View File

@@ -0,0 +1,361 @@
import torch.nn.functional as F
import torch
import random
import numpy as np
from fastNLP import Const
from fastNLP import CrossEntropyLoss
from fastNLP import AccuracyMetric
from fastNLP import Tester
import os
from fastNLP import logger
def should_mask(name, t=''):
if 'bias' in name:
return False
if 'embedding' in name:
splited = name.split('.')
if splited[-1]!='weight':
return False
if 'embedding' in splited[-2]:
return False
if 'c0' in name:
return False
if 'h0' in name:
return False

if 'output' in name and t not in name:
return False

return True
def get_init_mask(model):
init_masks = {}
for name, param in model.named_parameters():
if should_mask(name):
init_masks[name+'.mask'] = torch.ones_like(param)
# logger.info(init_masks[name+'.mask'].requires_grad)

return init_masks

def set_seed(seed):
random.seed(seed)
np.random.seed(seed+100)
torch.manual_seed(seed+200)
torch.cuda.manual_seed_all(seed+300)

def get_parameters_size(model):
result = {}
for name,p in model.state_dict().items():
result[name] = p.size()

return result

def prune_by_proportion_model(model,proportion,task):
# print('this time prune to ',proportion*100,'%')
for name, p in model.named_parameters():
# print(name)
if not should_mask(name,task):
continue

tensor = p.data.cpu().numpy()
index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy())
# print(name,'alive count',len(index[0]))
alive = tensor[index]
# print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size()))
percentile_value = np.percentile(abs(alive), (1 - proportion) * 100)
# tensor = p
# index = torch.nonzero(model.mask[task][name+'.mask'])
# # print('nonzero len',index)
# alive = tensor[index]
# print('alive size:',alive.shape)
# prune_by_proportion_model()

# percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values
# print('the',(1-proportion)*len(index[0]),'th big')
# print('threshold:',percentile_value)

prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value)
# for

def prune_by_proportion_model_global(model,proportion,task):
# print('this time prune to ',proportion*100,'%')
alive = None
for name, p in model.named_parameters():
# print(name)
if not should_mask(name,task):
continue

tensor = p.data.cpu().numpy()
index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy())
# print(name,'alive count',len(index[0]))
if alive is None:
alive = tensor[index]
else:
alive = np.concatenate([alive,tensor[index]],axis=0)

percentile_value = np.percentile(abs(alive), (1 - proportion) * 100)

for name, p in model.named_parameters():
if should_mask(name,task):
prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value)


def prune_by_threshold_parameter(p, mask, threshold):
p_abs = torch.abs(p)

new_mask = (p_abs > threshold).float()
# print(mask)
mask[:]*=new_mask


def one_time_train_and_prune_single_task(trainer,PRUNE_PER,
optimizer_init_state_dict=None,
model_init_state_dict=None,
is_global=None,
):


from fastNLP import Trainer


trainer.optimizer.load_state_dict(optimizer_init_state_dict)
trainer.model.load_state_dict(model_init_state_dict)
# print('metrics:',metrics.__dict__)
# print('loss:',loss.__dict__)
# print('trainer input:',task.train_set.get_input_name())
# trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics,
# optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks)


trainer.train(load_best_model=True)
# tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False)
# print('FOR DEBUG: test train_set:',tester.test())
# print('**'*20)
# if task.test_set:
# tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1)
# tester.test()
if is_global:

prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task)

else:
prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task)



# def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None):
def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None):

'''

:param trainer:
:param ITER:
:param PRUNE:
:param is_global:
:param save_path: should be a dictionary which will be filled with mask and state dict
:return:
'''



from fastNLP import Trainer
import torch
import math
import copy
PRUNE = args.prune
ITER = args.iter
trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict())
model_init_state_dict = copy.deepcopy(trainer.model.state_dict())
if save_path is not None:
if not os.path.exists(save_path):
os.makedirs(save_path)
# if not os.path.exists(os.path.join(save_path, 'model_init.pkl')):
# f = open(os.path.join(save_path, 'model_init.pkl'), 'wb')
# torch.save(trainer.model.state_dict(),f)


mask_count = 0
model = trainer.model
task = trainer.model.now_task
for name, p in model.mask[task].items():
mask_count += torch.sum(p).item()
init_mask_count = mask_count
logger.info('init mask count:{}'.format(mask_count))
# logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count,
# mask_count / init_mask_count * 100))

prune_per_iter = math.pow(PRUNE, 1 / ITER)


for i in range(ITER):
trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict)
if save_path is not None:
f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb')
torch.save(model.mask[task],f)

mask_count = 0
for name, p in model.mask[task].items():
mask_count += torch.sum(p).item()
logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100))


def get_appropriate_cuda(task_scale='s'):
if task_scale not in {'s','m','l'}:
logger.info('task scale wrong!')
exit(2)
import pynvml
pynvml.nvmlInit()
total_cuda_num = pynvml.nvmlDeviceGetCount()
for i in range(total_cuda_num):
logger.info(i)
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # 这里的0是GPU id
memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu)
if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2:
logger.info(i,memInfo.used / memInfo.total)
return 'cuda:'+str(i)

if task_scale=='s':
max_memory=2000
elif task_scale=='m':
max_memory=6000
else:
max_memory = 9000

max_id = -1
for i in range(total_cuda_num):
handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 这里的0是GPU id
memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
if max_memory < memInfo.free:
max_memory = memInfo.free
max_id = i

if id == -1:
logger.info('no appropriate gpu, wait!')
exit(2)

return 'cuda:'+str(max_id)

# if memInfo.used / memInfo.total < 0.5:
# return

def print_mask(mask_dict):
def seq_mul(*X):
res = 1
for x in X:
res*=x
return res

for name,p in mask_dict.items():
total_size = seq_mul(*p.size())
unmasked_size = len(np.nonzero(p))

print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%')


print()


def check_words_same(dataset_1,dataset_2,field_1,field_2):
if len(dataset_1[field_1]) != len(dataset_2[field_2]):
logger.info('CHECK: example num not same!')
return False

for i, words in enumerate(dataset_1[field_1]):
if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]):
logger.info('CHECK {} th example length not same'.format(i))
logger.info('1:{}'.format(dataset_1[field_1][i]))
logger.info('2:'.format(dataset_2[field_2][i]))
return False

# for j,w in enumerate(words):
# if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]:
# print('CHECK', i, 'th example has words different!')
# print('1:',dataset_1[field_1][i])
# print('2:',dataset_2[field_2][i])
# return False

logger.info('CHECK: totally same!')

return True

def get_now_time():
import time
from datetime import datetime, timezone, timedelta
dt = datetime.utcnow()
# print(dt)
tzutc_8 = timezone(timedelta(hours=8))
local_dt = dt.astimezone(tzutc_8)
result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute,
local_dt.second))

return result


def get_bigrams(words):
result = []
for i,w in enumerate(words):
if i!=len(words)-1:
result.append(words[i]+words[i+1])
else:
result.append(words[i]+'<end>')

return result

def print_info(*inp,islog=False,sep=' '):
from fastNLP import logger
if islog:
print(*inp,sep=sep)
else:
inp = sep.join(map(str,inp))
logger.info(inp)

def better_init_rnn(rnn,coupled=False):
import torch.nn as nn
if coupled:
repeat_size = 3
else:
repeat_size = 4
# print(list(rnn.named_parameters()))
if hasattr(rnn,'num_layers'):
for i in range(rnn.num_layers):
nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data)
weight_hh_data = torch.eye(rnn.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, repeat_size)
with torch.no_grad():
getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data)
nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0)
nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0)

if rnn.bidirectional:
for i in range(rnn.num_layers):
nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data)
weight_hh_data = torch.eye(rnn.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, repeat_size)
with torch.no_grad():
getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data)
nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0)
nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0)


else:
nn.init.orthogonal(rnn.weight_ih.data)
weight_hh_data = torch.eye(rnn.hidden_size)
weight_hh_data = weight_hh_data.repeat(repeat_size,1)
with torch.no_grad():
rnn.weight_hh.set_(weight_hh_data)
# The bias is just set to zero vectors.
print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn)))
if rnn.bias:
nn.init.constant(rnn.bias_ih.data, val=0)
nn.init.constant(rnn.bias_hh.data, val=0)

# print(list(rnn.named_parameters()))







+ 405
- 0
reproduction/seqence_labelling/chinese_ner/LatticeLSTM/utils_.py View File

@@ -0,0 +1,405 @@
import collections
from fastNLP import cache_results
def get_skip_path(chars,w_trie):
sentence = ''.join(chars)
result = w_trie.get_lexicon(sentence)

return result

# @cache_results(_cache_fp='cache/get_skip_path_trivial',_refresh=True)
def get_skip_path_trivial(chars,w_list):
chars = ''.join(chars)
w_set = set(w_list)
result = []
# for i in range(len(chars)):
# result.append([])
for i in range(len(chars)-1):
for j in range(i+2,len(chars)+1):
if chars[i:j] in w_set:
result.append([i,j-1,chars[i:j]])

return result


class TrieNode:
def __init__(self):
self.children = collections.defaultdict(TrieNode)
self.is_w = False

class Trie:
def __init__(self):
self.root = TrieNode()

def insert(self,w):

current = self.root
for c in w:
current = current.children[c]

current.is_w = True

def search(self,w):
'''

:param w:
:return:
-1:not w route
0:subroute but not word
1:subroute and word
'''
current = self.root

for c in w:
current = current.children.get(c)

if current is None:
return -1

if current.is_w:
return 1
else:
return 0

def get_lexicon(self,sentence):
result = []
for i in range(len(sentence)):
current = self.root
for j in range(i, len(sentence)):
current = current.children.get(sentence[j])
if current is None:
break

if current.is_w:
result.append([i,j,sentence[i:j+1]])

return result

from fastNLP.core.field import Padder
import numpy as np
import torch
from collections import defaultdict
class LatticeLexiconPadder(Padder):

def __init__(self, pad_val=0, pad_val_dynamic=False,dynamic_offset=0, **kwargs):
'''

:param pad_val:
:param pad_val_dynamic: if True, pad_val is the seq_len
:param kwargs:
'''
self.pad_val = pad_val
self.pad_val_dynamic = pad_val_dynamic
self.dynamic_offset = dynamic_offset

def __call__(self, contents, field_name, field_ele_dtype, dim: int):
# 与autoPadder中 dim=2 的情况一样
max_len = max(map(len, contents))

max_len = max(max_len,1)#avoid 0 size dim which causes cuda wrong

max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])

max_word_len = max(max_word_len,1)
if self.pad_val_dynamic:
# print('pad_val_dynamic:{}'.format(max_len-1))

array = np.full((len(contents), max_len, max_word_len), max_len-1+self.dynamic_offset,
dtype=field_ele_dtype)

else:
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
for j, content_ii in enumerate(content_i):
array[i, j, :len(content_ii)] = content_ii
array = torch.tensor(array)

return array

from fastNLP.core.metrics import MetricBase

def get_yangjie_bmeso(label_list,ignore_labels=None):
def get_ner_BMESO_yj(label_list):
def reverse_style(input_string):
target_position = input_string.index('[')
input_len = len(input_string)
output_string = input_string[target_position:input_len] + input_string[0:target_position]
# print('in:{}.out:{}'.format(input_string, output_string))
return output_string

# list_len = len(word_list)
# assert(list_len == len(label_list)), "word list size unmatch with label list"
list_len = len(label_list)
begin_label = 'b-'
end_label = 'e-'
single_label = 's-'
whole_tag = ''
index_tag = ''
tag_list = []
stand_matrix = []
for i in range(0, list_len):
# wordlabel = word_list[i]
current_label = label_list[i].lower()
if begin_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i - 1))
whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i)
index_tag = current_label.replace(begin_label, "", 1)

elif single_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i - 1))
whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i)
tag_list.append(whole_tag)
whole_tag = ""
index_tag = ""
elif end_label in current_label:
if index_tag != '':
tag_list.append(whole_tag + ',' + str(i))
whole_tag = ''
index_tag = ''
else:
continue
if (whole_tag != '') & (index_tag != ''):
tag_list.append(whole_tag)
tag_list_len = len(tag_list)

for i in range(0, tag_list_len):
if len(tag_list[i]) > 0:
tag_list[i] = tag_list[i] + ']'
insert_list = reverse_style(tag_list[i])
stand_matrix.append(insert_list)
# print stand_matrix
return stand_matrix

def transform_YJ_to_fastNLP(span):
span = span[1:]
span_split = span.split(']')
# print('span_list:{}'.format(span_split))
span_type = span_split[1]
# print('span_split[0].split(','):{}'.format(span_split[0].split(',')))
if ',' in span_split[0]:
b, e = span_split[0].split(',')
else:
b = span_split[0]
e = b

b = int(b)
e = int(e)

e += 1

return (span_type, (b, e))
yj_form = get_ner_BMESO_yj(label_list)
# print('label_list:{}'.format(label_list))
# print('yj_from:{}'.format(yj_form))
fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form))
return fastNLP_form
class SpanFPreRecMetric_YJ(MetricBase):
r"""
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric`

在序列标注问题中,以span的方式计算F, pre, rec.
比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例)
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
最后得到的metric结果为::

{
'f': xxx, # 这里使用f考虑以后可以计算f_beta值
'pre': xxx,
'rec':xxx
}

若only_gross=False, 即还会返回各个label的metric统计值::

{
'f': xxx,
'pre': xxx,
'rec':xxx,
'f-label': xxx,
'pre-label': xxx,
'rec-label':xxx,
...
}

:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
label的f1, pre, rec
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` :
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` .
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
"""
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
only_gross=True, f_type='micro', beta=1):
from fastNLP.core import Vocabulary
from fastNLP.core.metrics import _bmes_tag_to_spans,_bio_tag_to_spans,\
_bioes_tag_to_spans,_bmeso_tag_to_spans
from collections import defaultdict

encoding_type = encoding_type.lower()

if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))

self.encoding_type = encoding_type
# print('encoding_type:{}'self.encoding_type)
if self.encoding_type == 'bmes':
self.tag_to_span_func = _bmes_tag_to_spans
elif self.encoding_type == 'bio':
self.tag_to_span_func = _bio_tag_to_spans
elif self.encoding_type == 'bmeso':
self.tag_to_span_func = _bmeso_tag_to_spans
elif self.encoding_type == 'bioes':
self.tag_to_span_func = _bioes_tag_to_spans
elif self.encoding_type == 'bmesoyj':
self.tag_to_span_func = get_yangjie_bmeso
# self.tag_to_span_func =
else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")

self.ignore_labels = ignore_labels
self.f_type = f_type
self.beta = beta
self.beta_square = self.beta ** 2
self.only_gross = only_gross

super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)

self.tag_vocab = tag_vocab

self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

def evaluate(self, pred, target, seq_len):
from fastNLP.core.utils import _get_func_signature
"""evaluate函数将针对一个批次的预测结果做评价指标的累计

:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
:param target: [batch, seq_len], 真实值
:param seq_len: [batch] 文本长度标记
:return:
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")

if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
num_classes = pred.size(-1)
pred = pred.argmax(dim=-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
else:
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

batch_size = pred.size(0)
pred = pred.tolist()
target = target.tolist()
for i in range(batch_size):
pred_tags = pred[i][:int(seq_len[i])]
gold_tags = target[i][:int(seq_len[i])]

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]

pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)

for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1

def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {}
if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(set(self._false_positives.keys()))
tags.update(set(self._true_positives.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag]
fn = self._false_negatives[tag]
fp = self._false_positives[tag]
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
f_sum += f
pre_sum += pre
rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f
evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec

if self.f_type == 'macro':
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)

if self.f_type == 'micro':
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec

if reset:
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)

return evaluate_result

def _compute_f_pre_rec(self, tp, fn, fp):
"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)

return f, pre, rec





Loading…
Cancel
Save